update `mvmt-pruning/saving_prunebert` (updating torch to 1.5)
This commit is contained in:
Родитель
caf3746678
Коммит
473808da0d
|
@ -18,7 +18,9 @@
|
||||||
"\n",
|
"\n",
|
||||||
"We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n",
|
"We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n",
|
||||||
"\n",
|
"\n",
|
||||||
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Floptical_disk_21MB.jpg/440px-Floptical_disk_21MB.jpg\" width=\"200\">"
|
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Floptical_disk_21MB.jpg/440px-Floptical_disk_21MB.jpg\" width=\"200\">\n",
|
||||||
|
"\n",
|
||||||
|
"*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](https://github.com/huggingface/transformers/commit/b11386e158e86e62d4041eabd86d044cd1695737).*"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -67,10 +69,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"# Load fine-pruned model and quantize the model\n",
|
"# Load fine-pruned model and quantize the model\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model_path = \"serialization_dir/bert-base-uncased/92/squad/l1\"\n",
|
"model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n",
|
||||||
"model_name = \"bertarized_l1_with_distil_0._0.1_1_2_l1_1100._3e-5_1e-2_sigmoied_threshold_constant_0._10_epochs\"\n",
|
|
||||||
"\n",
|
|
||||||
"model = BertForQuestionAnswering.from_pretrained(os.path.join(model_path, model_name))\n",
|
|
||||||
"model.to('cpu')\n",
|
"model.to('cpu')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"quantized_model = torch.quantization.quantize_dynamic(\n",
|
"quantized_model = torch.quantization.quantize_dynamic(\n",
|
||||||
|
@ -196,7 +195,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"elementary_qtz_st = {}\n",
|
"elementary_qtz_st = {}\n",
|
||||||
"for name, param in qtz_st.items():\n",
|
"for name, param in qtz_st.items():\n",
|
||||||
" if param.is_quantized:\n",
|
" if \"dtype\" not in name and param.is_quantized:\n",
|
||||||
" print(\"Decompose quantization for\", name)\n",
|
" print(\"Decompose quantization for\", name)\n",
|
||||||
" # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
|
" # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
|
||||||
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
|
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
|
||||||
|
@ -221,6 +220,17 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n",
|
||||||
|
"str_2_dtype = {\"qint8\": torch.qint8}\n",
|
||||||
|
"dtype_2_str = {torch.qint8: \"qint8\"}\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"scrolled": true
|
"scrolled": true
|
||||||
},
|
},
|
||||||
|
@ -245,7 +255,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -266,9 +276,10 @@
|
||||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n",
|
"Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n",
|
||||||
"Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n",
|
"Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n",
|
||||||
"Skip bert.pooler.dense._packed_params.bias\n",
|
"Skip bert.pooler.dense._packed_params.bias\n",
|
||||||
|
"Skip bert.pooler.dense._packed_params.dtype\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Encoder Size (MB) - Dense: 340.25\n",
|
"Encoder Size (MB) - Dense: 340.26\n",
|
||||||
"Encoder Size (MB) - Sparse & Quantized: 11.27\n"
|
"Encoder Size (MB) - Sparse & Quantized: 11.28\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -300,10 +311,14 @@
|
||||||
"\n",
|
"\n",
|
||||||
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
|
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
|
||||||
" # float - tensor _packed_params.weight.scale\n",
|
" # float - tensor _packed_params.weight.scale\n",
|
||||||
" # int - tensor_packed_params.weight.zero_point\n",
|
" # int - tensor _packed_params.weight.zero_point\n",
|
||||||
" # tuple - tensor _packed_params.weight.shape\n",
|
" # tuple - tensor _packed_params.weight.shape\n",
|
||||||
" hf.attrs[name] = param\n",
|
" hf.attrs[name] = param\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" elif type(param) == torch.dtype:\n",
|
||||||
|
" # dtype - tensor _packed_params.dtype\n",
|
||||||
|
" hf.attrs[name] = dtype_2_str[param]\n",
|
||||||
|
" \n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -319,7 +334,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -327,7 +342,7 @@
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\n",
|
"\n",
|
||||||
"Size (MB): 99.39\n"
|
"Size (MB): 99.41\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -363,10 +378,15 @@
|
||||||
" # tuple - tensor _packed_params.weight.shape\n",
|
" # tuple - tensor _packed_params.weight.shape\n",
|
||||||
" hf.attrs[name] = param\n",
|
" hf.attrs[name] = param\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" elif type(param) == torch.dtype:\n",
|
||||||
|
" # dtype - tensor _packed_params.dtype\n",
|
||||||
|
" hf.attrs[name] = dtype_2_str[param]\n",
|
||||||
|
" \n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"with open('dbg/metadata.json', 'w') as f:\n",
|
"with open('dbg/metadata.json', 'w') as f:\n",
|
||||||
" f.write(json.dumps(qtz_st._metadata)) \n",
|
" f.write(json.dumps(qtz_st._metadata)) \n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -383,7 +403,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -406,6 +426,8 @@
|
||||||
" attr_param = int(attr_param)\n",
|
" attr_param = int(attr_param)\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" attr_param = torch.tensor(attr_param)\n",
|
" attr_param = torch.tensor(attr_param)\n",
|
||||||
|
" elif \".dtype\" in attr_name:\n",
|
||||||
|
" attr_param = str_2_dtype[attr_param]\n",
|
||||||
" reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
|
" reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
|
||||||
" # print(f\"Unpack {attr_name}\")\n",
|
" # print(f\"Unpack {attr_name}\")\n",
|
||||||
" \n",
|
" \n",
|
||||||
|
@ -428,7 +450,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -451,7 +473,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 11,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -487,7 +509,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -517,7 +539,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -526,7 +548,7 @@
|
||||||
"<All keys matched successfully>"
|
"<All keys matched successfully>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 12,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
|
@ -553,7 +575,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 14,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
|
Загрузка…
Ссылка в новой задаче