update `mvmt-pruning/saving_prunebert` (updating torch to 1.5)
This commit is contained in:
Родитель
caf3746678
Коммит
473808da0d
|
@ -18,7 +18,9 @@
|
|||
"\n",
|
||||
"We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n",
|
||||
"\n",
|
||||
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Floptical_disk_21MB.jpg/440px-Floptical_disk_21MB.jpg\" width=\"200\">"
|
||||
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Floptical_disk_21MB.jpg/440px-Floptical_disk_21MB.jpg\" width=\"200\">\n",
|
||||
"\n",
|
||||
"*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](https://github.com/huggingface/transformers/commit/b11386e158e86e62d4041eabd86d044cd1695737).*"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -67,10 +69,7 @@
|
|||
"source": [
|
||||
"# Load fine-pruned model and quantize the model\n",
|
||||
"\n",
|
||||
"model_path = \"serialization_dir/bert-base-uncased/92/squad/l1\"\n",
|
||||
"model_name = \"bertarized_l1_with_distil_0._0.1_1_2_l1_1100._3e-5_1e-2_sigmoied_threshold_constant_0._10_epochs\"\n",
|
||||
"\n",
|
||||
"model = BertForQuestionAnswering.from_pretrained(os.path.join(model_path, model_name))\n",
|
||||
"model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n",
|
||||
"model.to('cpu')\n",
|
||||
"\n",
|
||||
"quantized_model = torch.quantization.quantize_dynamic(\n",
|
||||
|
@ -196,7 +195,7 @@
|
|||
"\n",
|
||||
"elementary_qtz_st = {}\n",
|
||||
"for name, param in qtz_st.items():\n",
|
||||
" if param.is_quantized:\n",
|
||||
" if \"dtype\" not in name and param.is_quantized:\n",
|
||||
" print(\"Decompose quantization for\", name)\n",
|
||||
" # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
|
||||
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
|
||||
|
@ -221,6 +220,17 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n",
|
||||
"str_2_dtype = {\"qint8\": torch.qint8}\n",
|
||||
"dtype_2_str = {torch.qint8: \"qint8\"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -245,7 +255,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -266,9 +276,10 @@
|
|||
"Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n",
|
||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n",
|
||||
"Skip bert.pooler.dense._packed_params.bias\n",
|
||||
"Skip bert.pooler.dense._packed_params.dtype\n",
|
||||
"\n",
|
||||
"Encoder Size (MB) - Dense: 340.25\n",
|
||||
"Encoder Size (MB) - Sparse & Quantized: 11.27\n"
|
||||
"Encoder Size (MB) - Dense: 340.26\n",
|
||||
"Encoder Size (MB) - Sparse & Quantized: 11.28\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -300,10 +311,14 @@
|
|||
"\n",
|
||||
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
|
||||
" # float - tensor _packed_params.weight.scale\n",
|
||||
" # int - tensor_packed_params.weight.zero_point\n",
|
||||
" # int - tensor _packed_params.weight.zero_point\n",
|
||||
" # tuple - tensor _packed_params.weight.shape\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
"\n",
|
||||
" elif type(param) == torch.dtype:\n",
|
||||
" # dtype - tensor _packed_params.dtype\n",
|
||||
" hf.attrs[name] = dtype_2_str[param]\n",
|
||||
" \n",
|
||||
" else:\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
|
@ -319,7 +334,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -327,7 +342,7 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Size (MB): 99.39\n"
|
||||
"Size (MB): 99.41\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -363,10 +378,15 @@
|
|||
" # tuple - tensor _packed_params.weight.shape\n",
|
||||
" hf.attrs[name] = param\n",
|
||||
"\n",
|
||||
" elif type(param) == torch.dtype:\n",
|
||||
" # dtype - tensor _packed_params.dtype\n",
|
||||
" hf.attrs[name] = dtype_2_str[param]\n",
|
||||
" \n",
|
||||
" else:\n",
|
||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"with open('dbg/metadata.json', 'w') as f:\n",
|
||||
" f.write(json.dumps(qtz_st._metadata)) \n",
|
||||
"\n",
|
||||
|
@ -383,7 +403,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -406,6 +426,8 @@
|
|||
" attr_param = int(attr_param)\n",
|
||||
" else:\n",
|
||||
" attr_param = torch.tensor(attr_param)\n",
|
||||
" elif \".dtype\" in attr_name:\n",
|
||||
" attr_param = str_2_dtype[attr_param]\n",
|
||||
" reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
|
||||
" # print(f\"Unpack {attr_name}\")\n",
|
||||
" \n",
|
||||
|
@ -428,7 +450,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -451,7 +473,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -487,7 +509,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -517,7 +539,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -526,7 +548,7 @@
|
|||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -553,7 +575,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче