[Compression] v2.5 pruning tutorial (#5476)

This commit is contained in:
J-shang 2023-04-13 12:55:45 +08:00 коммит произвёл GitHub
Родитель a68cb047d4
Коммит cdf3bfb3e5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
22 изменённых файлов: 2732 добавлений и 79 удалений

Просмотреть файл

@ -0,0 +1,8 @@
Best Practices
==============
.. toctree::
:hidden:
:maxdepth: 2
Pruning Transformer </tutorials/new_pruning_bert_glue>

Просмотреть файл

@ -0,0 +1,8 @@
Compression (Preview)
=====================
.. toctree::
:hidden:
:maxdepth: 2
Pruning <toctree_pruning>

Просмотреть файл

@ -0,0 +1,8 @@
Pruning
=======
.. toctree::
:hidden:
:maxdepth: 2
Best Practices <best_practices>

Просмотреть файл

@ -16,7 +16,8 @@ NNI Documentation
hpo/toctree
nas/toctree
Model Compression <compression/toctree>
compression/toctree
compression_preview/toctree
feature_engineering/toctree
experiment/toctree

Просмотреть файл

@ -20,6 +20,7 @@ NNI 文档
超参调优 <hpo/toctree>
架构搜索 <nas/toctree>
模型压缩 <compression/toctree>
模型压缩(预览) <compression_preview/toctree>
特征工程 <feature_engineering/toctree>
实验管理 <experiment/toctree>

Просмотреть файл

@ -8,36 +8,32 @@ msgid ""
msgstr ""
"Project-Id-Version: NNI \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2022-04-20 05:50+0000\n"
"POT-Creation-Date: 2023-03-27 02:44+0000\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.9.1\n"
"Generated-By: Babel 2.11.0\n"
#: ../../source/index.rst:4 ../../source/index.rst:52
#: ../../source/index.rst:4 ../../source/index.rst:53
msgid "Get Started"
msgstr ""
#: ../../source/index.rst:12
msgid "Model Compression"
msgstr ""
#: ../../source/index.rst:12
msgid "User Guide"
msgstr ""
#: ../../source/index.rst:23
#: ../../source/index.rst:24
msgid "Python API"
msgstr ""
#: ../../source/index.rst:23
#: ../../source/index.rst:24
msgid "References"
msgstr ""
#: ../../source/index.rst:32
#: ../../source/index.rst:33
msgid "Misc"
msgstr ""
@ -45,68 +41,68 @@ msgstr ""
msgid "NNI Documentation"
msgstr ""
#: ../../source/index.rst:44
#: ../../source/index.rst:45
msgid ""
"**NNI (Neural Network Intelligence)** is a lightweight but powerful "
"toolkit to help users **automate**:"
msgstr ""
#: ../../source/index.rst:46
#: ../../source/index.rst:47
msgid ":doc:`Hyperparameter Optimization </hpo/overview>`"
msgstr ""
#: ../../source/index.rst:47
#: ../../source/index.rst:48
msgid ":doc:`Neural Architecture Search </nas/overview>`"
msgstr ""
#: ../../source/index.rst:48
#: ../../source/index.rst:49
msgid ":doc:`Model Compression </compression/overview>`"
msgstr ""
#: ../../source/index.rst:49
#: ../../source/index.rst:50
msgid ":doc:`Feature Engineering </feature_engineering/overview>`"
msgstr ""
#: ../../source/index.rst:54
#: ../../source/index.rst:55
msgid "To install the current release:"
msgstr ""
#: ../../source/index.rst:60
#: ../../source/index.rst:61
msgid ""
"See the :doc:`installation guide </installation>` if you need additional "
"help on installation."
msgstr ""
#: ../../source/index.rst:63
#: ../../source/index.rst:64
msgid "Try your first NNI experiment"
msgstr ""
#: ../../source/index.rst:69
#: ../../source/index.rst:70
msgid ""
"You need to have `PyTorch <https://pytorch.org/>`_ (as well as "
"`torchvision <https://pytorch.org/vision/stable/index.html>`_) installed "
"to run this experiment."
msgstr ""
#: ../../source/index.rst:71
#: ../../source/index.rst:72
msgid ""
"To start your journey now, please follow the :doc:`absolute quickstart of"
" NNI <quickstart>`!"
msgstr ""
#: ../../source/index.rst:74
#: ../../source/index.rst:75
msgid "Why choose NNI?"
msgstr ""
#: ../../source/index.rst:77
#: ../../source/index.rst:78
msgid "NNI makes AutoML techniques plug-and-play"
msgstr ""
#: ../../source/index.rst:221
#: ../../source/index.rst:222
msgid "NNI eases the effort to scale and manage AutoML experiments"
msgstr ""
#: ../../source/index.rst:229
#: ../../source/index.rst:230
msgid ""
"An AutoML experiment requires many trials to explore feasible and "
"potentially good-performing models. **Training service** aims to make the"
@ -116,13 +112,13 @@ msgid ""
"kinds of training services."
msgstr ""
#: ../../source/index.rst:240
#: ../../source/index.rst:241
msgid ""
"Web portal visualizes the tuning process, exposing the ability to "
"inspect, monitor and control the experiment."
msgstr ""
#: ../../source/index.rst:251
#: ../../source/index.rst:252
msgid ""
"The DNN model tuning often requires more than one experiment. Users might"
" try different tuning algorithms, fine-tune their search space, or switch"
@ -131,66 +127,66 @@ msgid ""
"so that the tuning workflow becomes clean and organized."
msgstr ""
#: ../../source/index.rst:257
#: ../../source/index.rst:258
msgid "Get Support and Contribute Back"
msgstr ""
#: ../../source/index.rst:259
#: ../../source/index.rst:260
msgid ""
"NNI is maintained on the `NNI GitHub repository "
"<https://github.com/microsoft/nni>`_. We collect feedbacks and new "
"proposals/ideas on GitHub. You can:"
msgstr ""
#: ../../source/index.rst:261
#: ../../source/index.rst:262
msgid ""
"Open a `GitHub issue <https://github.com/microsoft/nni/issues>`_ for bugs"
" and feature requests."
msgstr ""
#: ../../source/index.rst:262
#: ../../source/index.rst:263
msgid ""
"Open a `pull request <https://github.com/microsoft/nni/pulls>`_ to "
"contribute code (make sure to read the :doc:`contribution guide "
"<notes/contributing>` before doing this)."
msgstr ""
#: ../../source/index.rst:263
#: ../../source/index.rst:264
msgid ""
"Participate in `NNI Discussion "
"<https://github.com/microsoft/nni/discussions>`_ for general questions "
"and new ideas."
msgstr ""
#: ../../source/index.rst:264
#: ../../source/index.rst:265
msgid "Join the following IM groups."
msgstr ""
#: ../../source/index.rst:270
#: ../../source/index.rst:271
msgid "Gitter"
msgstr ""
#: ../../source/index.rst:271
#: ../../source/index.rst:272
msgid "WeChat"
msgstr ""
#: ../../source/index.rst:278
#: ../../source/index.rst:279
msgid "Citing NNI"
msgstr ""
#: ../../source/index.rst:280
#: ../../source/index.rst:281
msgid ""
"If you use NNI in a scientific publication, please consider citing NNI in"
" your references."
msgstr ""
#: ../../source/index.rst:282
#: ../../source/index.rst:283
msgid ""
"Microsoft. Neural Network Intelligence (version |release|). "
"https://github.com/microsoft/nni"
msgstr ""
#: ../../source/index.rst:284
#: ../../source/index.rst:285
msgid ""
"Bibtex entry (please replace the version with the particular version you "
"are using): ::"
@ -217,3 +213,6 @@ msgstr ""
#~ "doing this)."
#~ msgstr ""
#~ msgid "Model Compression"
#~ msgstr ""

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 18 KiB

45
docs/source/tutorials/index.rst сгенерированный
Просмотреть файл

@ -17,7 +17,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
:alt:
:alt: Speedup Model with Mask
:ref:`sphx_glr_tutorials_pruning_speedup.py`
@ -34,7 +34,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
:alt:
:alt: Quantization Quickstart
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
@ -51,7 +51,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
:alt:
:alt: Pruning Quickstart
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
@ -68,7 +68,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
:alt:
:alt: Customize a new quantization algorithm
:ref:`sphx_glr_tutorials_quantization_customize.py`
@ -85,7 +85,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
:alt:
:alt: Use NAS Benchmarks as Datasets
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
@ -102,7 +102,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
:alt:
:alt: Speed Up Quantized Model with TensorRT
:ref:`sphx_glr_tutorials_quantization_speedup.py`
@ -119,7 +119,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
:alt:
:alt: Hello, NAS!
:ref:`sphx_glr_tutorials_hello_nas.py`
@ -136,7 +136,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_darts_thumb.png
:alt:
:alt: Searching in DARTS search space
:ref:`sphx_glr_tutorials_darts.py`
@ -146,6 +146,23 @@ Tutorials
</div>
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="This is a new tutorial on pruning transformer in nni v3.0 (`old tutorial &lt;https://nni.readthedo...">
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_new_pruning_bert_glue_thumb.png
:alt: Pruning Bert on Task MNLI
:ref:`sphx_glr_tutorials_new_pruning_bert_glue.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Pruning Bert on Task MNLI</div>
</div>
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------">
@ -153,7 +170,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
:alt:
:alt: Pruning Bert on Task MNLI
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
@ -179,6 +196,7 @@ Tutorials
/tutorials/quantization_speedup
/tutorials/hello_nas
/tutorials/darts
/tutorials/new_pruning_bert_glue
/tutorials/pruning_bert_glue
@ -196,7 +214,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
:alt:
:alt: HPO Quickstart with PyTorch
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
@ -213,7 +231,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
:alt:
:alt: Port PyTorch Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
@ -242,7 +260,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
:alt:
:alt: HPO Quickstart with TensorFlow
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
@ -259,7 +277,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
:alt:
:alt: Port TensorFlow Quickstart to NNI
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
@ -278,7 +296,6 @@ Tutorials
:hidden:
:includehidden:
/tutorials/hpo_quickstart_pytorch/index.rst
/tutorials/hpo_quickstart_tensorflow/index.rst

349
docs/source/tutorials/new_pruning_bert_glue.ipynb сгенерированный Normal file
Просмотреть файл

@ -0,0 +1,349 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Pruning Bert on Task MNLI\n\nThis is a new tutorial on pruning transformer in nni v3.0 ([old tutorial](https://nni.readthedocs.io/en/v2.9/tutorials/pruning_bert_glue.html)_).\nThe main difference between this tutorial and the previous is that it integrates the feature of fusion compression (pruning + distillation) in nni,\nuses a new more powerful and stable pruning speedup tool,\nand additionally prunes the whole model hidden dimensions which greatly reduces the model size (pruning embedding layers).\n\nAt the same time, the huggingface [transformers.Trainer](https://huggingface.co/docs/transformers/main_classes/trainer)_ is used in this tutorial\nto reduce the burden of user writing training and evaluation logic.\n\n## Workable Pruning Process\n\nThe whole pruning process is divided into three steps:\n\n1. pruning attention layers,\n2. pruning feed forward layers,\n3. pruning embedding layers.\n\nIn each step, the pruner is first used for simulated pruning to generate masks corresponding to the module pruning targets (weight, input, output).\nAfter that comes the speedup stage, sparsity propagation is used to explore the global redundancy due to the local masks,\nthen modify the original model into a smaller one by replacing the sub module in the model.\n\nThe compression of the model naturally applies the distillation method,\nso in this tutorial, distillers will also be used to help restore the model accuracy.\n\n## Experiment\n\n### Preparations\n\nThe preparations mainly includes preparing the transformers trainer and model.\n\nThis is generally consistent with the preparations required to normally train a Bert model.\nThe only difference is that the ``transformers.Trainer`` is needed to wrap by ``nni.trace`` to trace the init arguments,\nthis is because nni need re-create trainer during training aware pruning and distilling.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Please set ``skip_exec`` to ``False`` to run this tutorial. Here ``skip_exec`` is ``True`` by default is for generating documents.</p></div>\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from __future__ import annotations\n\nfrom pathlib import Path\n\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import ConcatDataset\n\nimport nni\n\nfrom datasets import load_dataset, load_metric\nfrom transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction\nfrom transformers.trainer import Trainer\nfrom transformers.training_args import TrainingArguments\n\nskip_exec = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the downstream task name here, you could replace the task with the task in GLUE.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"task_name = 'mnli'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here using BertForSequenceClassification as the base model for show case.\nIf you want to prune other kind of transformer model, you could replace the base model here.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def build_model(pretrained_model_name_or_path: str, task_name: str):\n is_regression = task_name == 'stsb'\n num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)\n model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)\n return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prepare the GLUE train & validation datasets, if the task has multi validation datasets, concat the datasets by ``ConcatDataset``.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):\n task_to_keys = {\n 'cola': ('sentence', None),\n 'mnli': ('premise', 'hypothesis'),\n 'mrpc': ('sentence1', 'sentence2'),\n 'qnli': ('question', 'sentence'),\n 'qqp': ('question1', 'question2'),\n 'rte': ('sentence1', 'sentence2'),\n 'sst2': ('sentence', None),\n 'stsb': ('sentence1', 'sentence2'),\n 'wnli': ('sentence1', 'sentence2'),\n }\n sentence1_key, sentence2_key = task_to_keys[task_name]\n\n # used to preprocess the raw data\n def preprocess_function(examples):\n # Tokenize the texts\n args = (\n (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n )\n result = tokenizer(*args, padding=False, max_length=128, truncation=True)\n\n if 'label' in examples:\n # In all cases, rename the column to labels because the model will expect that.\n result['labels'] = examples['label']\n return result\n\n raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)\n for key in list(raw_datasets.keys()):\n if 'test' in key:\n raw_datasets.pop(key)\n\n processed_datasets = raw_datasets.map(preprocess_function, batched=True,\n remove_columns=raw_datasets['train'].column_names)\n\n train_dataset = processed_datasets['train']\n if task_name == 'mnli':\n validation_datasets = {\n 'validation_matched': processed_datasets['validation_matched'],\n 'validation_mismatched': processed_datasets['validation_mismatched']\n }\n else:\n validation_datasets = {\n 'validation': processed_datasets['validation']\n }\n\n return train_dataset, validation_datasets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prepare the trainer, note that the ``Trainer`` class is wrapped by ``nni.trace``.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n def compute_metrics(p: EvalPrediction):\n preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions\n preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)\n result = metric.compute(predictions=preds, references=p.label_ids)\n result['default'] = result.get('f1', result.get('accuracy', 0.))\n return result\n\n tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')\n train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)\n merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])\n data_collator = DataCollatorWithPadding(tokenizer)\n training_args = TrainingArguments(output_dir='./output/trainer',\n do_train=True,\n do_eval=True,\n evaluation_strategy='steps',\n per_device_train_batch_size=32,\n per_device_eval_batch_size=32,\n num_train_epochs=3,\n dataloader_num_workers=12,\n learning_rate=3e-5,\n save_strategy='steps',\n save_total_limit=1,\n metric_for_best_model='default',\n load_best_model_at_end=load_best_model_at_end,\n disable_tqdm=True,\n optim='adamw_torch',\n seed=1024)\n trainer = nni.trace(Trainer)(model=model,\n args=training_args,\n data_collator=data_collator,\n train_dataset=train_dataset,\n eval_dataset=merged_validation_dataset,\n tokenizer=tokenizer,\n compute_metrics=compute_metrics,)\n return trainer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the finetuned model is existed, directly load it.\nIf the finetuned model is not existed, train the pretrained model with the trainer.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def build_finetuning_model(task_name: str, state_dict_path: str):\n model = build_model('bert-base-uncased', task_name)\n if Path(state_dict_path).exists():\n model.load_state_dict(torch.load(state_dict_path))\n else:\n trainer = prepare_traced_trainer(model, task_name, True)\n trainer.train()\n torch.save(model.state_dict(), state_dict_path)\n return model\n\n\nif not skip_exec:\n Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)\n build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following code creates distillers for distillation.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from nni.contrib.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller\nfrom nni.contrib.compression.utils import TransformersEvaluator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dynamic distillation is suitable for the situation where the distillation states dimension of the student and the teacher match.\nA student state can try to distill on multiple teacher states, and finally select the teacher state with the smallest distillation loss as the target for distillation.\n\nIn this tutorial, dynamic distillation is applied before speedup the embedding pruning.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,\n student_trainer: Trainer):\n layer_num = len(student_model.bert.encoder.layer)\n config_list = [{\n 'op_names': [f'bert.encoder.layer.{i}'],\n 'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],\n 'lambda': 0.9,\n 'apply_method': 'mse',\n } for i in range(layer_num)]\n config_list.append({\n 'op_names': ['classifier'],\n 'link': ['classifier'],\n 'lambda': 0.9,\n 'apply_method': 'kl',\n })\n\n evaluator = TransformersEvaluator(student_trainer)\n\n def teacher_predict(batch, teacher_model):\n return teacher_model(**batch)\n\n return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)\n\n\ndef dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,\n max_steps: int | None, max_epochs: int | None):\n student_trainer = prepare_traced_trainer(student_model, task_name, True)\n\n ori_teacher_device = teacher_model.device\n training = teacher_model.training\n teacher_model.to(student_trainer.args.device).eval()\n\n distiller = dynamic_distiller(student_model, teacher_model, student_trainer)\n distiller.compress(max_steps, max_epochs)\n distiller.unwrap_model()\n\n teacher_model.to(ori_teacher_device).train(training)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adapt distillation is applied after pruning embedding layers.\nThe hidden states dimension will mismatch between student model and teacher model after pruning embedding layers,\nthen adapt distiller will add a linear layer for each distillation module pair to align dimension.\nFor example, pruning hidden dimension from 768 to 384, then for each student transformer block,\nwill add a ``Linear(in_features=384, out_features=768)`` for shifting dimention 384 to 768,\naligned with the teacher model transformer block output.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,\n student_trainer: Trainer):\n layer_num = len(student_model.bert.encoder.layer)\n config_list = [{\n 'op_names': [f'bert.encoder.layer.{i}'],\n 'lambda': 0.9,\n 'apply_method': 'mse',\n } for i in range(layer_num)]\n config_list.append({\n 'op_names': ['classifier'],\n 'link': ['classifier'],\n 'lambda': 0.9,\n 'apply_method': 'kl',\n })\n\n evaluator = TransformersEvaluator(student_trainer)\n\n def teacher_predict(batch, teacher_model):\n return teacher_model(**batch)\n\n return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)\n\n\ndef adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,\n max_steps: int | None, max_epochs: int | None):\n student_trainer = prepare_traced_trainer(student_model, task_name, True)\n\n ori_teacher_device = teacher_model.device\n training = teacher_model.training\n teacher_model.to(student_trainer.args.device).eval()\n\n distiller = adapt_distiller(student_model, teacher_model, student_trainer)\n dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))\n dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]\n distiller.track_forward(*dummy_input)\n\n distiller.compress(max_steps, max_epochs)\n distiller.unwrap_model()\n\n teacher_model.to(ori_teacher_device).train(training)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pruning Attention Layers\n\nHere using ``MovementPruner`` to generate block sparse masks. Choosing ``64 x 64`` block is because the head width is 64,\nthis is a kind of coarse grained between head pruning and finegrained pruning, also you can have a try with ``64 x 32``,\n``32 x 32`` or any other granularity here.\n\nWe use ``sparse_threshold`` instead of ``sparse_ratio`` here to apply adaptive sparse allocation.\n``sparse_threshold`` here is a float number between 0. and 1., but its value has little effect on the final sparse ratio.\nIf you want a more sparse model, you could set a larger ``regular_scale`` in ``MovementPruner``.\nYou could refer to the experiment results to choose a appropriate ``regular_scale`` you like.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from nni.contrib.compression.pruning import MovementPruner\nfrom nni.compression.pytorch.speedup.v2 import ModelSpeedup\nfrom nni.compression.pytorch.speedup.v2.external_replacer import TransformersAttentionReplacer\n\n\ndef pruning_attn():\n Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)\n model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')\n trainer = prepare_traced_trainer(model, task_name)\n evaluator = TransformersEvaluator(trainer)\n\n config_list = [{\n 'op_types': ['Linear'],\n 'op_names_re': ['bert\\.encoder\\.layer\\.[0-9]*\\.attention\\.*'],\n 'sparse_threshold': 0.1,\n 'granularity': [64, 64]\n }]\n\n pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)\n pruner.compress(None, 4)\n pruner.unwrap_model()\n\n masks = pruner.get_masks()\n Path('./output/pruning/').mkdir(parents=True, exist_ok=True)\n torch.save(masks, './output/pruning/attn_masks.pth')\n torch.save(model, './output/pruning/attn_masked_model.pth')\n\n\nif not skip_exec:\n pruning_attn()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We apply head pruning during the speedup stage, if the head is fully masked it will be pruned,\nif the header is partially masked, it will be restored.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def speedup_attn():\n model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')\n masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')\n dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))\n replacer = TransformersAttentionReplacer(model)\n ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()\n\n # finetuning\n teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n dynamic_distillation(model, teacher_model, None, 3)\n torch.save(model, './output/pruning/attn_pruned_model.pth')\n\n\nif not skip_exec:\n speedup_attn()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pruning Feed Forward Layers\n\nHere using ``TaylorPruner`` for pruning feed forward layers,\nand the sparse ratio related to the pruned head number in the same transformer block.\nThe more heads are pruned, the higher the sparse ratio is set for feed forward layers.\n\nNote that ``TaylorPruner`` has no schedule sparse ratio function,\nso we use ``AGPPruner`` to schedule the sparse ratio to achieve better pruning performance.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from nni.contrib.compression.pruning import TaylorPruner, AGPPruner\nfrom transformers.models.bert.modeling_bert import BertLayer\n\n\ndef pruning_ffn():\n model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')\n teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n # create ffn config list, here simply use a linear function related to the number of retained heads to determine the sparse ratio\n config_list = []\n for name, module in model.named_modules():\n if isinstance(module, BertLayer):\n retained_head_num = module.attention.self.num_attention_heads\n ori_head_num = len(module.attention.pruned_heads) + retained_head_num\n ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2\n config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})\n\n trainer = prepare_traced_trainer(model, task_name)\n teacher_model.eval().to(trainer.args.device)\n # create a distiller for restoring the accuracy\n distiller = dynamic_distiller(model, teacher_model, trainer)\n # fusion compress: TaylorPruner + DynamicLayerwiseDistiller\n taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)\n # fusion compress: AGPPruner(TaylorPruner) + DynamicLayerwiseDistiller\n agp_pruner = AGPPruner(taylor_pruner, 1000, 36)\n agp_pruner.compress(None, 3)\n agp_pruner.unwrap_model()\n distiller.unwrap_teacher_model()\n\n masks = agp_pruner.get_masks()\n Path('./output/pruning/').mkdir(parents=True, exist_ok=True)\n torch.save(masks, './output/pruning/ffn_masks.pth')\n torch.save(model, './output/pruning/ffn_masked_model.pth')\n\n\nif not skip_exec:\n pruning_ffn()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Speedup the feed forward layers.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def speedup_ffn():\n model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')\n masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')\n dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))\n ModelSpeedup(model, dummy_input, masks).speedup_model()\n\n # finetuning\n teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n dynamic_distillation(model, teacher_model, None, 3)\n torch.save(model, './output/pruning/ffn_pruned_model.pth')\n\n\nif not skip_exec:\n speedup_ffn()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pruning Embedding Layers\n\nWe want to simulate the pruning effect better, so we register the output mask setting for ``BertAttention`` and ``BertOutput``.\nThe output masks can be generated and applied after register the setting template for them.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from nni.contrib.compression.base.setting import PruningSetting\n\noutput_align_setting = {\n '_output_': {\n 'align': {\n 'module_name': None,\n 'target_name': 'weight',\n 'dims': [0],\n },\n 'apply_method': 'mul',\n }\n}\nPruningSetting.register('BertAttention', output_align_setting)\nPruningSetting.register('BertOutput', output_align_setting)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similar to prune feed forward layers, we also use ``AGPPruner + TaylorPruner + DynamicLayerwiseDistiller`` here.\nFor the better pruning effect simulation, set output ``align`` mask generation in ``config_list``,\nthen the relevant layers will generate its own output masks according to the embedding masks.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def pruning_embedding():\n model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')\n teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n\n sparse_ratio = 0.5\n config_list = [{\n 'op_types': ['Embedding'],\n 'op_names_re': ['bert\\.embeddings.*'],\n 'sparse_ratio': sparse_ratio,\n 'dependency_group_id': 1,\n 'granularity': [-1, 1],\n }, {\n 'op_names_re': ['bert\\.encoder\\.layer\\.[0-9]*\\.attention$',\n 'bert\\.encoder\\.layer\\.[0-9]*\\.output$'],\n 'target_names': ['_output_'],\n 'target_settings': {\n '_output_': {\n 'align': {\n 'module_name': 'bert.embeddings.word_embeddings',\n 'target_name': 'weight',\n 'dims': [1],\n }\n }\n }\n }, {\n 'op_names_re': ['bert\\.encoder\\.layer\\.[0-9]*\\.attention.output.dense',\n 'bert\\.encoder\\.layer\\.[0-9]*\\.output.dense'],\n 'target_names': ['weight'],\n 'target_settings': {\n 'weight': {\n 'granularity': 'out_channel',\n 'align': {\n 'module_name': 'bert.embeddings.word_embeddings',\n 'target_name': 'weight',\n 'dims': [1],\n }\n }\n }\n }]\n\n trainer = prepare_traced_trainer(model, task_name)\n teacher_model.eval().to(trainer.args.device)\n distiller = dynamic_distiller(model, teacher_model, trainer)\n taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)\n agp_pruner = AGPPruner(taylor_pruner, 1000, 36)\n agp_pruner.compress(None, 3)\n agp_pruner.unwrap_model()\n distiller.unwrap_teacher_model()\n\n masks = agp_pruner.get_masks()\n Path('./output/pruning/').mkdir(parents=True, exist_ok=True)\n torch.save(masks, './output/pruning/embedding_masks.pth')\n torch.save(model, './output/pruning/embedding_masked_model.pth')\n\n\nif not skip_exec:\n pruning_embedding()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Speedup the embedding layers.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def speedup_embedding():\n model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')\n masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')\n dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))\n ModelSpeedup(model, dummy_input, masks).speedup_model()\n\n # finetuning\n teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n adapt_distillation(model, teacher_model, None, 4)\n torch.save(model, './output/pruning/embedding_pruned_model.pth')\n\n\nif not skip_exec:\n speedup_embedding()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluation\n\nEvaluate the pruned model size and accuracy.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def evaluate_pruned_model():\n model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')\n trainer = prepare_traced_trainer(model, task_name)\n metric = trainer.evaluate()\n pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())\n\n model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')\n ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())\n print(f'Metric: {metric}\\nSparsity: {1 - pruned_num_params / ori_num_params}')\n\n\nif not skip_exec:\n evaluate_pruned_model()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Results\n\n.. list-table:: Prune Bert-base-uncased on MNLI\n :header-rows: 1\n :widths: auto\n\n * - Total Sparsity\n - Embedding Sparsity\n - Encoder Sparsity\n - Pooler Sparsity\n - Acc. (m/mm avg.)\n * - 0.%\n - 0.%\n - 0.%\n - 0.%\n - 84.95%\n * - 57.76%\n - 33.33% (15.89M)\n - 64.78% (29.96M)\n - 33.33% (0.39M)\n - 84.42%\n * - 68.31% (34.70M)\n - 50.00% (11.92M)\n - 73.57% (22.48M)\n - 50.00% (0.30M)\n - 83.33%\n * - 70.95% (31.81M)\n - 33.33% (15.89M)\n - 81.75% (15.52M)\n - 33.33% (0.39M)\n - 83.79%\n * - 78.20% (23.86M)\n - 50.00% (11.92M)\n - 86.31% (11.65M)\n - 50.00% (0.30M)\n - 82.53%\n * - 81.65% (20.12M)\n - 50.00% (11.92M)\n - 90.71% (7.90M)\n - 50.00% (0.30M)\n - 82.08%\n * - 84.32% (17.17M)\n - 50.00% (11.92M)\n - 94.18% (4.95M)\n - 50.00% (0.30M)\n - 81.35%\n\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

617
docs/source/tutorials/new_pruning_bert_glue.py сгенерированный Normal file
Просмотреть файл

@ -0,0 +1,617 @@
"""
Pruning Bert on Task MNLI
=========================
This is a new tutorial on pruning transformer in nni v3.0 (`old tutorial <https://nni.readthedocs.io/en/v2.9/tutorials/pruning_bert_glue.html>`__).
The main difference between this tutorial and the previous is that it integrates the feature of fusion compression (pruning + distillation) in nni,
uses a new more powerful and stable pruning speedup tool,
and additionally prunes the whole model hidden dimensions which greatly reduces the model size (pruning embedding layers).
At the same time, the huggingface `transformers.Trainer <https://huggingface.co/docs/transformers/main_classes/trainer>`__ is used in this tutorial
to reduce the burden of user writing training and evaluation logic.
Workable Pruning Process
------------------------
The whole pruning process is divided into three steps:
1. pruning attention layers,
2. pruning feed forward layers,
3. pruning embedding layers.
In each step, the pruner is first used for simulated pruning to generate masks corresponding to the module pruning targets (weight, input, output).
After that comes the speedup stage, sparsity propagation is used to explore the global redundancy due to the local masks,
then modify the original model into a smaller one by replacing the sub module in the model.
The compression of the model naturally applies the distillation method,
so in this tutorial, distillers will also be used to help restore the model accuracy.
Experiment
----------
Preparations
^^^^^^^^^^^^
The preparations mainly includes preparing the transformers trainer and model.
This is generally consistent with the preparations required to normally train a Bert model.
The only difference is that the ``transformers.Trainer`` is needed to wrap by ``nni.trace`` to trace the init arguments,
this is because nni need re-create trainer during training aware pruning and distilling.
.. note::
Please set ``skip_exec`` to ``False`` to run this tutorial. Here ``skip_exec`` is ``True`` by default is for generating documents.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import ConcatDataset
import nni
from datasets import load_dataset, load_metric
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
skip_exec = True
# %%
# Set the downstream task name here, you could replace the task with the task in GLUE.
task_name = 'mnli'
# %%
# Here using BertForSequenceClassification as the base model for show case.
# If you want to prune other kind of transformer model, you could replace the base model here.
def build_model(pretrained_model_name_or_path: str, task_name: str):
is_regression = task_name == 'stsb'
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
return model
# %%
# Prepare the GLUE train & validation datasets, if the task has multi validation datasets, concat the datasets by ``ConcatDataset``.
def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):
task_to_keys = {
'cola': ('sentence', None),
'mnli': ('premise', 'hypothesis'),
'mrpc': ('sentence1', 'sentence2'),
'qnli': ('question', 'sentence'),
'qqp': ('question1', 'question2'),
'rte': ('sentence1', 'sentence2'),
'sst2': ('sentence', None),
'stsb': ('sentence1', 'sentence2'),
'wnli': ('sentence1', 'sentence2'),
}
sentence1_key, sentence2_key = task_to_keys[task_name]
# used to preprocess the raw data
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
if 'label' in examples:
# In all cases, rename the column to labels because the model will expect that.
result['labels'] = examples['label']
return result
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
for key in list(raw_datasets.keys()):
if 'test' in key:
raw_datasets.pop(key)
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
remove_columns=raw_datasets['train'].column_names)
train_dataset = processed_datasets['train']
if task_name == 'mnli':
validation_datasets = {
'validation_matched': processed_datasets['validation_matched'],
'validation_mismatched': processed_datasets['validation_mismatched']
}
else:
validation_datasets = {
'validation': processed_datasets['validation']
}
return train_dataset, validation_datasets
# %%
# Prepare the trainer, note that the ``Trainer`` class is wrapped by ``nni.trace``.
def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):
is_regression = task_name == 'stsb'
metric = load_metric('glue', task_name)
def compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
result = metric.compute(predictions=preds, references=p.label_ids)
result['default'] = result.get('f1', result.get('accuracy', 0.))
return result
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])
data_collator = DataCollatorWithPadding(tokenizer)
training_args = TrainingArguments(output_dir='./output/trainer',
do_train=True,
do_eval=True,
evaluation_strategy='steps',
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
num_train_epochs=3,
dataloader_num_workers=12,
learning_rate=3e-5,
save_strategy='steps',
save_total_limit=1,
metric_for_best_model='default',
load_best_model_at_end=load_best_model_at_end,
disable_tqdm=True,
optim='adamw_torch',
seed=1024)
trainer = nni.trace(Trainer)(model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=merged_validation_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,)
return trainer
# %%
# If the finetuned model is existed, directly load it.
# If the finetuned model is not existed, train the pretrained model with the trainer.
def build_finetuning_model(task_name: str, state_dict_path: str):
model = build_model('bert-base-uncased', task_name)
if Path(state_dict_path).exists():
model.load_state_dict(torch.load(state_dict_path))
else:
trainer = prepare_traced_trainer(model, task_name, True)
trainer.train()
torch.save(model.state_dict(), state_dict_path)
return model
if not skip_exec:
Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
# %%
# The following code creates distillers for distillation.
from nni.contrib.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
from nni.contrib.compression.utils import TransformersEvaluator
# %%
# Dynamic distillation is suitable for the situation where the distillation states dimension of the student and the teacher match.
# A student state can try to distill on multiple teacher states, and finally select the teacher state with the smallest distillation loss as the target for distillation.
#
# In this tutorial, dynamic distillation is applied before speedup the embedding pruning.
def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
student_trainer: Trainer):
layer_num = len(student_model.bert.encoder.layer)
config_list = [{
'op_names': [f'bert.encoder.layer.{i}'],
'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],
'lambda': 0.9,
'apply_method': 'mse',
} for i in range(layer_num)]
config_list.append({
'op_names': ['classifier'],
'link': ['classifier'],
'lambda': 0.9,
'apply_method': 'kl',
})
evaluator = TransformersEvaluator(student_trainer)
def teacher_predict(batch, teacher_model):
return teacher_model(**batch)
return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
def dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
max_steps: int | None, max_epochs: int | None):
student_trainer = prepare_traced_trainer(student_model, task_name, True)
ori_teacher_device = teacher_model.device
training = teacher_model.training
teacher_model.to(student_trainer.args.device).eval()
distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
distiller.compress(max_steps, max_epochs)
distiller.unwrap_model()
teacher_model.to(ori_teacher_device).train(training)
# %%
# Adapt distillation is applied after pruning embedding layers.
# The hidden states dimension will mismatch between student model and teacher model after pruning embedding layers,
# then adapt distiller will add a linear layer for each distillation module pair to align dimension.
# For example, pruning hidden dimension from 768 to 384, then for each student transformer block,
# will add a ``Linear(in_features=384, out_features=768)`` for shifting dimention 384 to 768,
# aligned with the teacher model transformer block output.
def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
student_trainer: Trainer):
layer_num = len(student_model.bert.encoder.layer)
config_list = [{
'op_names': [f'bert.encoder.layer.{i}'],
'lambda': 0.9,
'apply_method': 'mse',
} for i in range(layer_num)]
config_list.append({
'op_names': ['classifier'],
'link': ['classifier'],
'lambda': 0.9,
'apply_method': 'kl',
})
evaluator = TransformersEvaluator(student_trainer)
def teacher_predict(batch, teacher_model):
return teacher_model(**batch)
return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
def adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
max_steps: int | None, max_epochs: int | None):
student_trainer = prepare_traced_trainer(student_model, task_name, True)
ori_teacher_device = teacher_model.device
training = teacher_model.training
teacher_model.to(student_trainer.args.device).eval()
distiller = adapt_distiller(student_model, teacher_model, student_trainer)
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]
distiller.track_forward(*dummy_input)
distiller.compress(max_steps, max_epochs)
distiller.unwrap_model()
teacher_model.to(ori_teacher_device).train(training)
# %%
# Pruning Attention Layers
# ^^^^^^^^^^^^^^^^^^^^^^^^
#
# Here using ``MovementPruner`` to generate block sparse masks. Choosing ``64 x 64`` block is because the head width is 64,
# this is a kind of coarse grained between head pruning and finegrained pruning, also you can have a try with ``64 x 32``,
# ``32 x 32`` or any other granularity here.
#
# We use ``sparse_threshold`` instead of ``sparse_ratio`` here to apply adaptive sparse allocation.
# ``sparse_threshold`` here is a float number between 0. and 1., but its value has little effect on the final sparse ratio.
# If you want a more sparse model, you could set a larger ``regular_scale`` in ``MovementPruner``.
# You could refer to the experiment results to choose a appropriate ``regular_scale`` you like.
from nni.contrib.compression.pruning import MovementPruner
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
from nni.compression.pytorch.speedup.v2.external_replacer import TransformersAttentionReplacer
def pruning_attn():
Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
trainer = prepare_traced_trainer(model, task_name)
evaluator = TransformersEvaluator(trainer)
config_list = [{
'op_types': ['Linear'],
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention\.*'],
'sparse_threshold': 0.1,
'granularity': [64, 64]
}]
pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)
pruner.compress(None, 4)
pruner.unwrap_model()
masks = pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/attn_masks.pth')
torch.save(model, './output/pruning/attn_masked_model.pth')
if not skip_exec:
pruning_attn()
# %%
# We apply head pruning during the speedup stage, if the head is fully masked it will be pruned,
# if the header is partially masked, it will be restored.
def speedup_attn():
model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
replacer = TransformersAttentionReplacer(model)
ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
dynamic_distillation(model, teacher_model, None, 3)
torch.save(model, './output/pruning/attn_pruned_model.pth')
if not skip_exec:
speedup_attn()
# %%
# Pruning Feed Forward Layers
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Here using ``TaylorPruner`` for pruning feed forward layers,
# and the sparse ratio related to the pruned head number in the same transformer block.
# The more heads are pruned, the higher the sparse ratio is set for feed forward layers.
#
# Note that ``TaylorPruner`` has no schedule sparse ratio function,
# so we use ``AGPPruner`` to schedule the sparse ratio to achieve better pruning performance.
from nni.contrib.compression.pruning import TaylorPruner, AGPPruner
from transformers.models.bert.modeling_bert import BertLayer
def pruning_ffn():
model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
# create ffn config list, here simply use a linear function related to the number of retained heads to determine the sparse ratio
config_list = []
for name, module in model.named_modules():
if isinstance(module, BertLayer):
retained_head_num = module.attention.self.num_attention_heads
ori_head_num = len(module.attention.pruned_heads) + retained_head_num
ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2
config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})
trainer = prepare_traced_trainer(model, task_name)
teacher_model.eval().to(trainer.args.device)
# create a distiller for restoring the accuracy
distiller = dynamic_distiller(model, teacher_model, trainer)
# fusion compress: TaylorPruner + DynamicLayerwiseDistiller
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
# fusion compress: AGPPruner(TaylorPruner) + DynamicLayerwiseDistiller
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
agp_pruner.compress(None, 3)
agp_pruner.unwrap_model()
distiller.unwrap_teacher_model()
masks = agp_pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/ffn_masks.pth')
torch.save(model, './output/pruning/ffn_masked_model.pth')
if not skip_exec:
pruning_ffn()
# %%
# Speedup the feed forward layers.
def speedup_ffn():
model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
ModelSpeedup(model, dummy_input, masks).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
dynamic_distillation(model, teacher_model, None, 3)
torch.save(model, './output/pruning/ffn_pruned_model.pth')
if not skip_exec:
speedup_ffn()
# %%
# Pruning Embedding Layers
# ^^^^^^^^^^^^^^^^^^^^^^^^
#
# We want to simulate the pruning effect better, so we register the output mask setting for ``BertAttention`` and ``BertOutput``.
# The output masks can be generated and applied after register the setting template for them.
from nni.contrib.compression.base.setting import PruningSetting
output_align_setting = {
'_output_': {
'align': {
'module_name': None,
'target_name': 'weight',
'dims': [0],
},
'apply_method': 'mul',
}
}
PruningSetting.register('BertAttention', output_align_setting)
PruningSetting.register('BertOutput', output_align_setting)
# %%
# Similar to prune feed forward layers, we also use ``AGPPruner + TaylorPruner + DynamicLayerwiseDistiller`` here.
# For the better pruning effect simulation, set output ``align`` mask generation in ``config_list``,
# then the relevant layers will generate its own output masks according to the embedding masks.
def pruning_embedding():
model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
sparse_ratio = 0.5
config_list = [{
'op_types': ['Embedding'],
'op_names_re': ['bert\.embeddings.*'],
'sparse_ratio': sparse_ratio,
'dependency_group_id': 1,
'granularity': [-1, 1],
}, {
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention$',
'bert\.encoder\.layer\.[0-9]*\.output$'],
'target_names': ['_output_'],
'target_settings': {
'_output_': {
'align': {
'module_name': 'bert.embeddings.word_embeddings',
'target_name': 'weight',
'dims': [1],
}
}
}
}, {
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention.output.dense',
'bert\.encoder\.layer\.[0-9]*\.output.dense'],
'target_names': ['weight'],
'target_settings': {
'weight': {
'granularity': 'out_channel',
'align': {
'module_name': 'bert.embeddings.word_embeddings',
'target_name': 'weight',
'dims': [1],
}
}
}
}]
trainer = prepare_traced_trainer(model, task_name)
teacher_model.eval().to(trainer.args.device)
distiller = dynamic_distiller(model, teacher_model, trainer)
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
agp_pruner.compress(None, 3)
agp_pruner.unwrap_model()
distiller.unwrap_teacher_model()
masks = agp_pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/embedding_masks.pth')
torch.save(model, './output/pruning/embedding_masked_model.pth')
if not skip_exec:
pruning_embedding()
# %%
# Speedup the embedding layers.
def speedup_embedding():
model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
ModelSpeedup(model, dummy_input, masks).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
adapt_distillation(model, teacher_model, None, 4)
torch.save(model, './output/pruning/embedding_pruned_model.pth')
if not skip_exec:
speedup_embedding()
# %%
# Evaluation
# ^^^^^^^^^^
#
# Evaluate the pruned model size and accuracy.
def evaluate_pruned_model():
model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')
trainer = prepare_traced_trainer(model, task_name)
metric = trainer.evaluate()
pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
print(f'Metric: {metric}\nSparsity: {1 - pruned_num_params / ori_num_params}')
if not skip_exec:
evaluate_pruned_model()
# %%
# Results
# -------
#
# .. list-table:: Prune Bert-base-uncased on MNLI
# :header-rows: 1
# :widths: auto
#
# * - Total Sparsity
# - Embedding Sparsity
# - Encoder Sparsity
# - Pooler Sparsity
# - Acc. (m/mm avg.)
# * - 0.%
# - 0.%
# - 0.%
# - 0.%
# - 84.95%
# * - 57.76%
# - 33.33% (15.89M)
# - 64.78% (29.96M)
# - 33.33% (0.39M)
# - 84.42%
# * - 68.31% (34.70M)
# - 50.00% (11.92M)
# - 73.57% (22.48M)
# - 50.00% (0.30M)
# - 83.33%
# * - 70.95% (31.81M)
# - 33.33% (15.89M)
# - 81.75% (15.52M)
# - 33.33% (0.39M)
# - 83.79%
# * - 78.20% (23.86M)
# - 50.00% (11.92M)
# - 86.31% (11.65M)
# - 50.00% (0.30M)
# - 82.53%
# * - 81.65% (20.12M)
# - 50.00% (11.92M)
# - 90.71% (7.90M)
# - 50.00% (0.30M)
# - 82.08%
# * - 84.32% (17.17M)
# - 50.00% (11.92M)
# - 94.18% (4.95M)
# - 50.00% (0.30M)
# - 81.35%

1
docs/source/tutorials/new_pruning_bert_glue.py.md5 сгенерированный Normal file
Просмотреть файл

@ -0,0 +1 @@
3e81f00f13fab8cfc204a0baef7d075e

884
docs/source/tutorials/new_pruning_bert_glue.rst сгенерированный Normal file
Просмотреть файл

@ -0,0 +1,884 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/new_pruning_bert_glue.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_tutorials_new_pruning_bert_glue.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_tutorials_new_pruning_bert_glue.py:
Pruning Bert on Task MNLI
=========================
This is a new tutorial on pruning transformer in nni v3.0 (`old tutorial <https://nni.readthedocs.io/en/v2.9/tutorials/pruning_bert_glue.html>`__).
The main difference between this tutorial and the previous is that it integrates the feature of fusion compression (pruning + distillation) in nni,
uses a new more powerful and stable pruning speedup tool,
and additionally prunes the whole model hidden dimensions which greatly reduces the model size (pruning embedding layers).
At the same time, the huggingface `transformers.Trainer <https://huggingface.co/docs/transformers/main_classes/trainer>`__ is used in this tutorial
to reduce the burden of user writing training and evaluation logic.
Workable Pruning Process
------------------------
The whole pruning process is divided into three steps:
1. pruning attention layers,
2. pruning feed forward layers,
3. pruning embedding layers.
In each step, the pruner is first used for simulated pruning to generate masks corresponding to the module pruning targets (weight, input, output).
After that comes the speedup stage, sparsity propagation is used to explore the global redundancy due to the local masks,
then modify the original model into a smaller one by replacing the sub module in the model.
The compression of the model naturally applies the distillation method,
so in this tutorial, distillers will also be used to help restore the model accuracy.
Experiment
----------
Preparations
^^^^^^^^^^^^
The preparations mainly includes preparing the transformers trainer and model.
This is generally consistent with the preparations required to normally train a Bert model.
The only difference is that the ``transformers.Trainer`` is needed to wrap by ``nni.trace`` to trace the init arguments,
this is because nni need re-create trainer during training aware pruning and distilling.
.. note::
Please set ``skip_exec`` to ``False`` to run this tutorial. Here ``skip_exec`` is ``True`` by default is for generating documents.
.. GENERATED FROM PYTHON SOURCE LINES 45-64
.. code-block:: default
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import ConcatDataset
import nni
from datasets import load_dataset, load_metric
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
skip_exec = True
.. GENERATED FROM PYTHON SOURCE LINES 65-66
Set the downstream task name here, you could replace the task with the task in GLUE.
.. GENERATED FROM PYTHON SOURCE LINES 66-69
.. code-block:: default
task_name = 'mnli'
.. GENERATED FROM PYTHON SOURCE LINES 70-72
Here using BertForSequenceClassification as the base model for show case.
If you want to prune other kind of transformer model, you could replace the base model here.
.. GENERATED FROM PYTHON SOURCE LINES 72-80
.. code-block:: default
def build_model(pretrained_model_name_or_path: str, task_name: str):
is_regression = task_name == 'stsb'
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
return model
.. GENERATED FROM PYTHON SOURCE LINES 81-82
Prepare the GLUE train & validation datasets, if the task has multi validation datasets, concat the datasets by ``ConcatDataset``.
.. GENERATED FROM PYTHON SOURCE LINES 82-132
.. code-block:: default
def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):
task_to_keys = {
'cola': ('sentence', None),
'mnli': ('premise', 'hypothesis'),
'mrpc': ('sentence1', 'sentence2'),
'qnli': ('question', 'sentence'),
'qqp': ('question1', 'question2'),
'rte': ('sentence1', 'sentence2'),
'sst2': ('sentence', None),
'stsb': ('sentence1', 'sentence2'),
'wnli': ('sentence1', 'sentence2'),
}
sentence1_key, sentence2_key = task_to_keys[task_name]
# used to preprocess the raw data
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
if 'label' in examples:
# In all cases, rename the column to labels because the model will expect that.
result['labels'] = examples['label']
return result
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
for key in list(raw_datasets.keys()):
if 'test' in key:
raw_datasets.pop(key)
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
remove_columns=raw_datasets['train'].column_names)
train_dataset = processed_datasets['train']
if task_name == 'mnli':
validation_datasets = {
'validation_matched': processed_datasets['validation_matched'],
'validation_mismatched': processed_datasets['validation_mismatched']
}
else:
validation_datasets = {
'validation': processed_datasets['validation']
}
return train_dataset, validation_datasets
.. GENERATED FROM PYTHON SOURCE LINES 133-134
Prepare the trainer, note that the ``Trainer`` class is wrapped by ``nni.trace``.
.. GENERATED FROM PYTHON SOURCE LINES 134-177
.. code-block:: default
def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):
is_regression = task_name == 'stsb'
metric = load_metric('glue', task_name)
def compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
result = metric.compute(predictions=preds, references=p.label_ids)
result['default'] = result.get('f1', result.get('accuracy', 0.))
return result
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])
data_collator = DataCollatorWithPadding(tokenizer)
training_args = TrainingArguments(output_dir='./output/trainer',
do_train=True,
do_eval=True,
evaluation_strategy='steps',
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
num_train_epochs=3,
dataloader_num_workers=12,
learning_rate=3e-5,
save_strategy='steps',
save_total_limit=1,
metric_for_best_model='default',
load_best_model_at_end=load_best_model_at_end,
disable_tqdm=True,
optim='adamw_torch',
seed=1024)
trainer = nni.trace(Trainer)(model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=merged_validation_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,)
return trainer
.. GENERATED FROM PYTHON SOURCE LINES 178-180
If the finetuned model is existed, directly load it.
If the finetuned model is not existed, train the pretrained model with the trainer.
.. GENERATED FROM PYTHON SOURCE LINES 180-198
.. code-block:: default
def build_finetuning_model(task_name: str, state_dict_path: str):
model = build_model('bert-base-uncased', task_name)
if Path(state_dict_path).exists():
model.load_state_dict(torch.load(state_dict_path))
else:
trainer = prepare_traced_trainer(model, task_name, True)
trainer.train()
torch.save(model.state_dict(), state_dict_path)
return model
if not skip_exec:
Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
.. GENERATED FROM PYTHON SOURCE LINES 199-200
The following code creates distillers for distillation.
.. GENERATED FROM PYTHON SOURCE LINES 200-205
.. code-block:: default
from nni.contrib.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
from nni.contrib.compression.utils import TransformersEvaluator
.. GENERATED FROM PYTHON SOURCE LINES 206-210
Dynamic distillation is suitable for the situation where the distillation states dimension of the student and the teacher match.
A student state can try to distill on multiple teacher states, and finally select the teacher state with the smallest distillation loss as the target for distillation.
In this tutorial, dynamic distillation is applied before speedup the embedding pruning.
.. GENERATED FROM PYTHON SOURCE LINES 210-250
.. code-block:: default
def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
student_trainer: Trainer):
layer_num = len(student_model.bert.encoder.layer)
config_list = [{
'op_names': [f'bert.encoder.layer.{i}'],
'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],
'lambda': 0.9,
'apply_method': 'mse',
} for i in range(layer_num)]
config_list.append({
'op_names': ['classifier'],
'link': ['classifier'],
'lambda': 0.9,
'apply_method': 'kl',
})
evaluator = TransformersEvaluator(student_trainer)
def teacher_predict(batch, teacher_model):
return teacher_model(**batch)
return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
def dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
max_steps: int | None, max_epochs: int | None):
student_trainer = prepare_traced_trainer(student_model, task_name, True)
ori_teacher_device = teacher_model.device
training = teacher_model.training
teacher_model.to(student_trainer.args.device).eval()
distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
distiller.compress(max_steps, max_epochs)
distiller.unwrap_model()
teacher_model.to(ori_teacher_device).train(training)
.. GENERATED FROM PYTHON SOURCE LINES 251-257
Adapt distillation is applied after pruning embedding layers.
The hidden states dimension will mismatch between student model and teacher model after pruning embedding layers,
then adapt distiller will add a linear layer for each distillation module pair to align dimension.
For example, pruning hidden dimension from 768 to 384, then for each student transformer block,
will add a ``Linear(in_features=384, out_features=768)`` for shifting dimention 384 to 768,
aligned with the teacher model transformer block output.
.. GENERATED FROM PYTHON SOURCE LINES 257-301
.. code-block:: default
def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
student_trainer: Trainer):
layer_num = len(student_model.bert.encoder.layer)
config_list = [{
'op_names': [f'bert.encoder.layer.{i}'],
'lambda': 0.9,
'apply_method': 'mse',
} for i in range(layer_num)]
config_list.append({
'op_names': ['classifier'],
'link': ['classifier'],
'lambda': 0.9,
'apply_method': 'kl',
})
evaluator = TransformersEvaluator(student_trainer)
def teacher_predict(batch, teacher_model):
return teacher_model(**batch)
return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
def adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
max_steps: int | None, max_epochs: int | None):
student_trainer = prepare_traced_trainer(student_model, task_name, True)
ori_teacher_device = teacher_model.device
training = teacher_model.training
teacher_model.to(student_trainer.args.device).eval()
distiller = adapt_distiller(student_model, teacher_model, student_trainer)
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]
distiller.track_forward(*dummy_input)
distiller.compress(max_steps, max_epochs)
distiller.unwrap_model()
teacher_model.to(ori_teacher_device).train(training)
.. GENERATED FROM PYTHON SOURCE LINES 302-313
Pruning Attention Layers
^^^^^^^^^^^^^^^^^^^^^^^^
Here using ``MovementPruner`` to generate block sparse masks. Choosing ``64 x 64`` block is because the head width is 64,
this is a kind of coarse grained between head pruning and finegrained pruning, also you can have a try with ``64 x 32``,
``32 x 32`` or any other granularity here.
We use ``sparse_threshold`` instead of ``sparse_ratio`` here to apply adaptive sparse allocation.
``sparse_threshold`` here is a float number between 0. and 1., but its value has little effect on the final sparse ratio.
If you want a more sparse model, you could set a larger ``regular_scale`` in ``MovementPruner``.
You could refer to the experiment results to choose a appropriate ``regular_scale`` you like.
.. GENERATED FROM PYTHON SOURCE LINES 313-347
.. code-block:: default
from nni.contrib.compression.pruning import MovementPruner
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
from nni.compression.pytorch.speedup.v2.external_replacer import TransformersAttentionReplacer
def pruning_attn():
Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
trainer = prepare_traced_trainer(model, task_name)
evaluator = TransformersEvaluator(trainer)
config_list = [{
'op_types': ['Linear'],
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention\.*'],
'sparse_threshold': 0.1,
'granularity': [64, 64]
}]
pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)
pruner.compress(None, 4)
pruner.unwrap_model()
masks = pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/attn_masks.pth')
torch.save(model, './output/pruning/attn_masked_model.pth')
if not skip_exec:
pruning_attn()
.. GENERATED FROM PYTHON SOURCE LINES 348-350
We apply head pruning during the speedup stage, if the head is fully masked it will be pruned,
if the header is partially masked, it will be restored.
.. GENERATED FROM PYTHON SOURCE LINES 350-369
.. code-block:: default
def speedup_attn():
model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
replacer = TransformersAttentionReplacer(model)
ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
dynamic_distillation(model, teacher_model, None, 3)
torch.save(model, './output/pruning/attn_pruned_model.pth')
if not skip_exec:
speedup_attn()
.. GENERATED FROM PYTHON SOURCE LINES 370-379
Pruning Feed Forward Layers
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Here using ``TaylorPruner`` for pruning feed forward layers,
and the sparse ratio related to the pruned head number in the same transformer block.
The more heads are pruned, the higher the sparse ratio is set for feed forward layers.
Note that ``TaylorPruner`` has no schedule sparse ratio function,
so we use ``AGPPruner`` to schedule the sparse ratio to achieve better pruning performance.
.. GENERATED FROM PYTHON SOURCE LINES 379-419
.. code-block:: default
from nni.contrib.compression.pruning import TaylorPruner, AGPPruner
from transformers.models.bert.modeling_bert import BertLayer
def pruning_ffn():
model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
# create ffn config list, here simply use a linear function related to the number of retained heads to determine the sparse ratio
config_list = []
for name, module in model.named_modules():
if isinstance(module, BertLayer):
retained_head_num = module.attention.self.num_attention_heads
ori_head_num = len(module.attention.pruned_heads) + retained_head_num
ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2
config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})
trainer = prepare_traced_trainer(model, task_name)
teacher_model.eval().to(trainer.args.device)
# create a distiller for restoring the accuracy
distiller = dynamic_distiller(model, teacher_model, trainer)
# fusion compress: TaylorPruner + DynamicLayerwiseDistiller
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
# fusion compress: AGPPruner(TaylorPruner) + DynamicLayerwiseDistiller
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
agp_pruner.compress(None, 3)
agp_pruner.unwrap_model()
distiller.unwrap_teacher_model()
masks = agp_pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/ffn_masks.pth')
torch.save(model, './output/pruning/ffn_masked_model.pth')
if not skip_exec:
pruning_ffn()
.. GENERATED FROM PYTHON SOURCE LINES 420-421
Speedup the feed forward layers.
.. GENERATED FROM PYTHON SOURCE LINES 421-439
.. code-block:: default
def speedup_ffn():
model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
ModelSpeedup(model, dummy_input, masks).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
dynamic_distillation(model, teacher_model, None, 3)
torch.save(model, './output/pruning/ffn_pruned_model.pth')
if not skip_exec:
speedup_ffn()
.. GENERATED FROM PYTHON SOURCE LINES 440-445
Pruning Embedding Layers
^^^^^^^^^^^^^^^^^^^^^^^^
We want to simulate the pruning effect better, so we register the output mask setting for ``BertAttention`` and ``BertOutput``.
The output masks can be generated and applied after register the setting template for them.
.. GENERATED FROM PYTHON SOURCE LINES 445-463
.. code-block:: default
from nni.contrib.compression.base.setting import PruningSetting
output_align_setting = {
'_output_': {
'align': {
'module_name': None,
'target_name': 'weight',
'dims': [0],
},
'apply_method': 'mul',
}
}
PruningSetting.register('BertAttention', output_align_setting)
PruningSetting.register('BertOutput', output_align_setting)
.. GENERATED FROM PYTHON SOURCE LINES 464-467
Similar to prune feed forward layers, we also use ``AGPPruner + TaylorPruner + DynamicLayerwiseDistiller`` here.
For the better pruning effect simulation, set output ``align`` mask generation in ``config_list``,
then the relevant layers will generate its own output masks according to the embedding masks.
.. GENERATED FROM PYTHON SOURCE LINES 467-528
.. code-block:: default
def pruning_embedding():
model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
sparse_ratio = 0.5
config_list = [{
'op_types': ['Embedding'],
'op_names_re': ['bert\.embeddings.*'],
'sparse_ratio': sparse_ratio,
'dependency_group_id': 1,
'granularity': [-1, 1],
}, {
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention$',
'bert\.encoder\.layer\.[0-9]*\.output$'],
'target_names': ['_output_'],
'target_settings': {
'_output_': {
'align': {
'module_name': 'bert.embeddings.word_embeddings',
'target_name': 'weight',
'dims': [1],
}
}
}
}, {
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention.output.dense',
'bert\.encoder\.layer\.[0-9]*\.output.dense'],
'target_names': ['weight'],
'target_settings': {
'weight': {
'granularity': 'out_channel',
'align': {
'module_name': 'bert.embeddings.word_embeddings',
'target_name': 'weight',
'dims': [1],
}
}
}
}]
trainer = prepare_traced_trainer(model, task_name)
teacher_model.eval().to(trainer.args.device)
distiller = dynamic_distiller(model, teacher_model, trainer)
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
agp_pruner.compress(None, 3)
agp_pruner.unwrap_model()
distiller.unwrap_teacher_model()
masks = agp_pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/embedding_masks.pth')
torch.save(model, './output/pruning/embedding_masked_model.pth')
if not skip_exec:
pruning_embedding()
.. GENERATED FROM PYTHON SOURCE LINES 529-530
Speedup the embedding layers.
.. GENERATED FROM PYTHON SOURCE LINES 530-548
.. code-block:: default
def speedup_embedding():
model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
ModelSpeedup(model, dummy_input, masks).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
adapt_distillation(model, teacher_model, None, 4)
torch.save(model, './output/pruning/embedding_pruned_model.pth')
if not skip_exec:
speedup_embedding()
.. GENERATED FROM PYTHON SOURCE LINES 549-553
Evaluation
^^^^^^^^^^
Evaluate the pruned model size and accuracy.
.. GENERATED FROM PYTHON SOURCE LINES 553-570
.. code-block:: default
def evaluate_pruned_model():
model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')
trainer = prepare_traced_trainer(model, task_name)
metric = trainer.evaluate()
pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
print(f'Metric: {metric}\nSparsity: {1 - pruned_num_params / ori_num_params}')
if not skip_exec:
evaluate_pruned_model()
.. GENERATED FROM PYTHON SOURCE LINES 571-618
Results
-------
.. list-table:: Prune Bert-base-uncased on MNLI
:header-rows: 1
:widths: auto
* - Total Sparsity
- Embedding Sparsity
- Encoder Sparsity
- Pooler Sparsity
- Acc. (m/mm avg.)
* - 0.%
- 0.%
- 0.%
- 0.%
- 84.95%
* - 57.76%
- 33.33% (15.89M)
- 64.78% (29.96M)
- 33.33% (0.39M)
- 84.42%
* - 68.31% (34.70M)
- 50.00% (11.92M)
- 73.57% (22.48M)
- 50.00% (0.30M)
- 83.33%
* - 70.95% (31.81M)
- 33.33% (15.89M)
- 81.75% (15.52M)
- 33.33% (0.39M)
- 83.79%
* - 78.20% (23.86M)
- 50.00% (11.92M)
- 86.31% (11.65M)
- 50.00% (0.30M)
- 82.53%
* - 81.65% (20.12M)
- 50.00% (11.92M)
- 90.71% (7.90M)
- 50.00% (0.30M)
- 82.08%
* - 84.32% (17.17M)
- 50.00% (11.92M)
- 94.18% (4.95M)
- 50.00% (0.30M)
- 81.35%
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 1.990 seconds)
.. _sphx_glr_download_tutorials_new_pruning_bert_glue.py:
.. only:: html
.. container:: sphx-glr-footer sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: new_pruning_bert_glue.py <new_pruning_bert_glue.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: new_pruning_bert_glue.ipynb <new_pruning_bert_glue.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

Двоичные данные
docs/source/tutorials/new_pruning_bert_glue_codeobj.pickle сгенерированный Normal file

Двоичный файл не отображается.

16
docs/source/tutorials/quantization_speedup.py сгенерированный
Просмотреть файл

@ -33,11 +33,11 @@ As TensorRT has supported post-training quantization, directly leveraging this f
import torch
import torchvision
import torchvision.transforms as transforms
def prepare_data_loaders(data_path, batch_size, datatype='train'):
def prepare_data_loaders(data_path, batch_size):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(
data_path, split=datatype,
data_path, split="train",
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
@ -90,20 +90,26 @@ def test_accelerated_model(engine, data_loader, neval_batches):
cnt = 0
total_time = 0
for image, target in data_loader:
start_time = time.time()
output, time_span = engine.inference(image)
print('time: ', time_span)
infer_time = time.time() - start_time
print('time: ', time_span, infer_time)
total_time += time_span
start_time = time.time()
output = output.view(-1, 1000)
cnt += 1
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
rest_time = time.time() - start_time
print('rest time: ', rest_time)
if cnt >= neval_batches:
break
print('inference time: ', total_time / neval_batches)
return top1, top5
data_loader = prepare_data_loaders(data_path, batch_size=64, datatype='test')
data_loader = prepare_data_loaders(data_path, batch_size=64)
top1, top5 = test_accelerated_model(engine, data_loader, neval_batches=32)
print('Accuracy of mode #1: ', top1, top5)
@ -167,6 +173,6 @@ model.eval()
engine = ModelSpeedupTensorRT(model, input_shape=(64, 3, 224, 224), config=calibration_config)
engine.compress()
data_loader = prepare_data_loaders(data_path, batch_size=64, datatype='test')
data_loader = prepare_data_loaders(data_path, batch_size=64)
top1, top5 = test_accelerated_model(engine, data_loader, neval_batches=32)
print('Accuracy of mode #2: ', top1, top5)

4
docs/source/tutorials/sg_execution_times.rst сгенерированный
Просмотреть файл

@ -6,10 +6,10 @@
Computation times
=================
**00:29.846** total execution time for **tutorials** files:
**00:01.990** total execution time for **tutorials** files:
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:29.846 | 0.0 MB |
| :ref:`sphx_glr_tutorials_new_pruning_bert_glue.py` (``new_pruning_bert_glue.py``) | 00:01.990 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+

1
examples/tutorials/.gitignore поставляемый
Просмотреть файл

@ -4,3 +4,4 @@ log/
lightning_logs
models/
pruning_log/
output*/

Просмотреть файл

@ -0,0 +1,617 @@
"""
Pruning Bert on Task MNLI
=========================
This is a new tutorial on pruning transformer in nni v3.0 (`old tutorial <https://nni.readthedocs.io/en/v2.9/tutorials/pruning_bert_glue.html>`__).
The main difference between this tutorial and the previous is that it integrates the feature of fusion compression (pruning + distillation) in nni,
uses a new more powerful and stable pruning speedup tool,
and additionally prunes the whole model hidden dimensions which greatly reduces the model size (pruning embedding layers).
At the same time, the huggingface `transformers.Trainer <https://huggingface.co/docs/transformers/main_classes/trainer>`__ is used in this tutorial
to reduce the burden of user writing training and evaluation logic.
Workable Pruning Process
------------------------
The whole pruning process is divided into three steps:
1. pruning attention layers,
2. pruning feed forward layers,
3. pruning embedding layers.
In each step, the pruner is first used for simulated pruning to generate masks corresponding to the module pruning targets (weight, input, output).
After that comes the speedup stage, sparsity propagation is used to explore the global redundancy due to the local masks,
then modify the original model into a smaller one by replacing the sub module in the model.
The compression of the model naturally applies the distillation method,
so in this tutorial, distillers will also be used to help restore the model accuracy.
Experiment
----------
Preparations
^^^^^^^^^^^^
The preparations mainly includes preparing the transformers trainer and model.
This is generally consistent with the preparations required to normally train a Bert model.
The only difference is that the ``transformers.Trainer`` is needed to wrap by ``nni.trace`` to trace the init arguments,
this is because nni need re-create trainer during training aware pruning and distilling.
.. note::
Please set ``skip_exec`` to ``False`` to run this tutorial. Here ``skip_exec`` is ``True`` by default is for generating documents.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import ConcatDataset
import nni
from datasets import load_dataset, load_metric
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
skip_exec = True
# %%
# Set the downstream task name here, you could replace the task with the task in GLUE.
task_name = 'mnli'
# %%
# Here using BertForSequenceClassification as the base model for show case.
# If you want to prune other kind of transformer model, you could replace the base model here.
def build_model(pretrained_model_name_or_path: str, task_name: str):
is_regression = task_name == 'stsb'
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
return model
# %%
# Prepare the GLUE train & validation datasets, if the task has multi validation datasets, concat the datasets by ``ConcatDataset``.
def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):
task_to_keys = {
'cola': ('sentence', None),
'mnli': ('premise', 'hypothesis'),
'mrpc': ('sentence1', 'sentence2'),
'qnli': ('question', 'sentence'),
'qqp': ('question1', 'question2'),
'rte': ('sentence1', 'sentence2'),
'sst2': ('sentence', None),
'stsb': ('sentence1', 'sentence2'),
'wnli': ('sentence1', 'sentence2'),
}
sentence1_key, sentence2_key = task_to_keys[task_name]
# used to preprocess the raw data
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
if 'label' in examples:
# In all cases, rename the column to labels because the model will expect that.
result['labels'] = examples['label']
return result
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
for key in list(raw_datasets.keys()):
if 'test' in key:
raw_datasets.pop(key)
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
remove_columns=raw_datasets['train'].column_names)
train_dataset = processed_datasets['train']
if task_name == 'mnli':
validation_datasets = {
'validation_matched': processed_datasets['validation_matched'],
'validation_mismatched': processed_datasets['validation_mismatched']
}
else:
validation_datasets = {
'validation': processed_datasets['validation']
}
return train_dataset, validation_datasets
# %%
# Prepare the trainer, note that the ``Trainer`` class is wrapped by ``nni.trace``.
def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):
is_regression = task_name == 'stsb'
metric = load_metric('glue', task_name)
def compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
result = metric.compute(predictions=preds, references=p.label_ids)
result['default'] = result.get('f1', result.get('accuracy', 0.))
return result
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])
data_collator = DataCollatorWithPadding(tokenizer)
training_args = TrainingArguments(output_dir='./output/trainer',
do_train=True,
do_eval=True,
evaluation_strategy='steps',
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
num_train_epochs=3,
dataloader_num_workers=12,
learning_rate=3e-5,
save_strategy='steps',
save_total_limit=1,
metric_for_best_model='default',
load_best_model_at_end=load_best_model_at_end,
disable_tqdm=True,
optim='adamw_torch',
seed=1024)
trainer = nni.trace(Trainer)(model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=merged_validation_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,)
return trainer
# %%
# If the finetuned model is existed, directly load it.
# If the finetuned model is not existed, train the pretrained model with the trainer.
def build_finetuning_model(task_name: str, state_dict_path: str):
model = build_model('bert-base-uncased', task_name)
if Path(state_dict_path).exists():
model.load_state_dict(torch.load(state_dict_path))
else:
trainer = prepare_traced_trainer(model, task_name, True)
trainer.train()
torch.save(model.state_dict(), state_dict_path)
return model
if not skip_exec:
Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
# %%
# The following code creates distillers for distillation.
from nni.contrib.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
from nni.contrib.compression.utils import TransformersEvaluator
# %%
# Dynamic distillation is suitable for the situation where the distillation states dimension of the student and the teacher match.
# A student state can try to distill on multiple teacher states, and finally select the teacher state with the smallest distillation loss as the target for distillation.
#
# In this tutorial, dynamic distillation is applied before speedup the embedding pruning.
def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
student_trainer: Trainer):
layer_num = len(student_model.bert.encoder.layer)
config_list = [{
'op_names': [f'bert.encoder.layer.{i}'],
'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],
'lambda': 0.9,
'apply_method': 'mse',
} for i in range(layer_num)]
config_list.append({
'op_names': ['classifier'],
'link': ['classifier'],
'lambda': 0.9,
'apply_method': 'kl',
})
evaluator = TransformersEvaluator(student_trainer)
def teacher_predict(batch, teacher_model):
return teacher_model(**batch)
return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
def dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
max_steps: int | None, max_epochs: int | None):
student_trainer = prepare_traced_trainer(student_model, task_name, True)
ori_teacher_device = teacher_model.device
training = teacher_model.training
teacher_model.to(student_trainer.args.device).eval()
distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
distiller.compress(max_steps, max_epochs)
distiller.unwrap_model()
teacher_model.to(ori_teacher_device).train(training)
# %%
# Adapt distillation is applied after pruning embedding layers.
# The hidden states dimension will mismatch between student model and teacher model after pruning embedding layers,
# then adapt distiller will add a linear layer for each distillation module pair to align dimension.
# For example, pruning hidden dimension from 768 to 384, then for each student transformer block,
# will add a ``Linear(in_features=384, out_features=768)`` for shifting dimention 384 to 768,
# aligned with the teacher model transformer block output.
def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
student_trainer: Trainer):
layer_num = len(student_model.bert.encoder.layer)
config_list = [{
'op_names': [f'bert.encoder.layer.{i}'],
'lambda': 0.9,
'apply_method': 'mse',
} for i in range(layer_num)]
config_list.append({
'op_names': ['classifier'],
'link': ['classifier'],
'lambda': 0.9,
'apply_method': 'kl',
})
evaluator = TransformersEvaluator(student_trainer)
def teacher_predict(batch, teacher_model):
return teacher_model(**batch)
return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
def adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
max_steps: int | None, max_epochs: int | None):
student_trainer = prepare_traced_trainer(student_model, task_name, True)
ori_teacher_device = teacher_model.device
training = teacher_model.training
teacher_model.to(student_trainer.args.device).eval()
distiller = adapt_distiller(student_model, teacher_model, student_trainer)
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]
distiller.track_forward(*dummy_input)
distiller.compress(max_steps, max_epochs)
distiller.unwrap_model()
teacher_model.to(ori_teacher_device).train(training)
# %%
# Pruning Attention Layers
# ^^^^^^^^^^^^^^^^^^^^^^^^
#
# Here using ``MovementPruner`` to generate block sparse masks. Choosing ``64 x 64`` block is because the head width is 64,
# this is a kind of coarse grained between head pruning and finegrained pruning, also you can have a try with ``64 x 32``,
# ``32 x 32`` or any other granularity here.
#
# We use ``sparse_threshold`` instead of ``sparse_ratio`` here to apply adaptive sparse allocation.
# ``sparse_threshold`` here is a float number between 0. and 1., but its value has little effect on the final sparse ratio.
# If you want a more sparse model, you could set a larger ``regular_scale`` in ``MovementPruner``.
# You could refer to the experiment results to choose a appropriate ``regular_scale`` you like.
from nni.contrib.compression.pruning import MovementPruner
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
from nni.compression.pytorch.speedup.v2.external_replacer import TransformersAttentionReplacer
def pruning_attn():
Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
trainer = prepare_traced_trainer(model, task_name)
evaluator = TransformersEvaluator(trainer)
config_list = [{
'op_types': ['Linear'],
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention\.*'],
'sparse_threshold': 0.1,
'granularity': [64, 64]
}]
pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)
pruner.compress(None, 4)
pruner.unwrap_model()
masks = pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/attn_masks.pth')
torch.save(model, './output/pruning/attn_masked_model.pth')
if not skip_exec:
pruning_attn()
# %%
# We apply head pruning during the speedup stage, if the head is fully masked it will be pruned,
# if the header is partially masked, it will be restored.
def speedup_attn():
model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
replacer = TransformersAttentionReplacer(model)
ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
dynamic_distillation(model, teacher_model, None, 3)
torch.save(model, './output/pruning/attn_pruned_model.pth')
if not skip_exec:
speedup_attn()
# %%
# Pruning Feed Forward Layers
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Here using ``TaylorPruner`` for pruning feed forward layers,
# and the sparse ratio related to the pruned head number in the same transformer block.
# The more heads are pruned, the higher the sparse ratio is set for feed forward layers.
#
# Note that ``TaylorPruner`` has no schedule sparse ratio function,
# so we use ``AGPPruner`` to schedule the sparse ratio to achieve better pruning performance.
from nni.contrib.compression.pruning import TaylorPruner, AGPPruner
from transformers.models.bert.modeling_bert import BertLayer
def pruning_ffn():
model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
# create ffn config list, here simply use a linear function related to the number of retained heads to determine the sparse ratio
config_list = []
for name, module in model.named_modules():
if isinstance(module, BertLayer):
retained_head_num = module.attention.self.num_attention_heads
ori_head_num = len(module.attention.pruned_heads) + retained_head_num
ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2
config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})
trainer = prepare_traced_trainer(model, task_name)
teacher_model.eval().to(trainer.args.device)
# create a distiller for restoring the accuracy
distiller = dynamic_distiller(model, teacher_model, trainer)
# fusion compress: TaylorPruner + DynamicLayerwiseDistiller
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
# fusion compress: AGPPruner(TaylorPruner) + DynamicLayerwiseDistiller
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
agp_pruner.compress(None, 3)
agp_pruner.unwrap_model()
distiller.unwrap_teacher_model()
masks = agp_pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/ffn_masks.pth')
torch.save(model, './output/pruning/ffn_masked_model.pth')
if not skip_exec:
pruning_ffn()
# %%
# Speedup the feed forward layers.
def speedup_ffn():
model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
ModelSpeedup(model, dummy_input, masks).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
dynamic_distillation(model, teacher_model, None, 3)
torch.save(model, './output/pruning/ffn_pruned_model.pth')
if not skip_exec:
speedup_ffn()
# %%
# Pruning Embedding Layers
# ^^^^^^^^^^^^^^^^^^^^^^^^
#
# We want to simulate the pruning effect better, so we register the output mask setting for ``BertAttention`` and ``BertOutput``.
# The output masks can be generated and applied after register the setting template for them.
from nni.contrib.compression.base.setting import PruningSetting
output_align_setting = {
'_output_': {
'align': {
'module_name': None,
'target_name': 'weight',
'dims': [0],
},
'apply_method': 'mul',
}
}
PruningSetting.register('BertAttention', output_align_setting)
PruningSetting.register('BertOutput', output_align_setting)
# %%
# Similar to prune feed forward layers, we also use ``AGPPruner + TaylorPruner + DynamicLayerwiseDistiller`` here.
# For the better pruning effect simulation, set output ``align`` mask generation in ``config_list``,
# then the relevant layers will generate its own output masks according to the embedding masks.
def pruning_embedding():
model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
sparse_ratio = 0.5
config_list = [{
'op_types': ['Embedding'],
'op_names_re': ['bert\.embeddings.*'],
'sparse_ratio': sparse_ratio,
'dependency_group_id': 1,
'granularity': [-1, 1],
}, {
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention$',
'bert\.encoder\.layer\.[0-9]*\.output$'],
'target_names': ['_output_'],
'target_settings': {
'_output_': {
'align': {
'module_name': 'bert.embeddings.word_embeddings',
'target_name': 'weight',
'dims': [1],
}
}
}
}, {
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention.output.dense',
'bert\.encoder\.layer\.[0-9]*\.output.dense'],
'target_names': ['weight'],
'target_settings': {
'weight': {
'granularity': 'out_channel',
'align': {
'module_name': 'bert.embeddings.word_embeddings',
'target_name': 'weight',
'dims': [1],
}
}
}
}]
trainer = prepare_traced_trainer(model, task_name)
teacher_model.eval().to(trainer.args.device)
distiller = dynamic_distiller(model, teacher_model, trainer)
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
agp_pruner.compress(None, 3)
agp_pruner.unwrap_model()
distiller.unwrap_teacher_model()
masks = agp_pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/embedding_masks.pth')
torch.save(model, './output/pruning/embedding_masked_model.pth')
if not skip_exec:
pruning_embedding()
# %%
# Speedup the embedding layers.
def speedup_embedding():
model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
ModelSpeedup(model, dummy_input, masks).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
adapt_distillation(model, teacher_model, None, 4)
torch.save(model, './output/pruning/embedding_pruned_model.pth')
if not skip_exec:
speedup_embedding()
# %%
# Evaluation
# ^^^^^^^^^^
#
# Evaluate the pruned model size and accuracy.
def evaluate_pruned_model():
model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')
trainer = prepare_traced_trainer(model, task_name)
metric = trainer.evaluate()
pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
print(f'Metric: {metric}\nSparsity: {1 - pruned_num_params / ori_num_params}')
if not skip_exec:
evaluate_pruned_model()
# %%
# Results
# -------
#
# .. list-table:: Prune Bert-base-uncased on MNLI
# :header-rows: 1
# :widths: auto
#
# * - Total Sparsity
# - Embedding Sparsity
# - Encoder Sparsity
# - Pooler Sparsity
# - Acc. (m/mm avg.)
# * - 0.%
# - 0.%
# - 0.%
# - 0.%
# - 84.95%
# * - 57.76%
# - 33.33% (15.89M)
# - 64.78% (29.96M)
# - 33.33% (0.39M)
# - 84.42%
# * - 68.31% (34.70M)
# - 50.00% (11.92M)
# - 73.57% (22.48M)
# - 50.00% (0.30M)
# - 83.33%
# * - 70.95% (31.81M)
# - 33.33% (15.89M)
# - 81.75% (15.52M)
# - 33.33% (0.39M)
# - 83.79%
# * - 78.20% (23.86M)
# - 50.00% (11.92M)
# - 86.31% (11.65M)
# - 50.00% (0.30M)
# - 82.53%
# * - 81.65% (20.12M)
# - 50.00% (11.92M)
# - 90.71% (7.90M)
# - 50.00% (0.30M)
# - 82.08%
# * - 84.32% (17.17M)
# - 50.00% (11.92M)
# - 94.18% (4.95M)
# - 50.00% (0.30M)
# - 81.35%

Просмотреть файл

@ -0,0 +1,133 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from collections import defaultdict
import logging
import re
from typing import TYPE_CHECKING, List
import torch
from torch.utils._pytree import tree_map
from nni.compression.pytorch.speedup.v2.replacer import Replacer
from nni.compression.pytorch.utils.attr import get_nested_attr
from nni.compression.pytorch.utils.external.huggingface import parser_factory, HuggingfaceModelParser
if TYPE_CHECKING:
from .model_speedup import ModelSpeedup
_logger = logging.getLogger(__name__)
def _endwith(s: str, suffixes: List[str]):
return any(s.endswith(suffix) for suffix in suffixes)
def _prune_head_idxs(mask: torch.Tensor, num_heads: int) -> List[int]:
head_mask = (mask.reshape([num_heads, -1]).sum(-1) == 0.)
return torch.arange(len(head_mask))[head_mask].long().tolist()
def _remained_idxs(mask: torch.Tensor, num_heads: int) -> List[int]:
repeats = mask.shape[0] // num_heads
remained = (mask.reshape([num_heads, -1]).sum(-1) != 0.).repeat_interleave(repeats)
return torch.arange(len(mask))[remained].long().tolist()
def _fill_one_on_dims(mask: torch.Tensor, dims: int | List[int]) -> torch.Tensor:
dims = dims if isinstance(dims, list) else [dims]
dims = [d if d >= 0 else d + len(mask.shape) for d in dims]
new_mask = torch.ones_like(mask)
for i in range(len(mask.shape)):
if i in dims:
continue
dim_mask = (mask.sum([_ for _ in range(len(mask.shape)) if _ != i]) == 0.)
new_mask = new_mask.transpose(0, i)
new_mask[torch.arange(len(dim_mask), device=new_mask.device)[dim_mask].long().tolist()] = 0.
new_mask = new_mask.transpose(0, i)
return new_mask
class TransformersAttentionReplacer(Replacer):
"""
This replacer is used to prune huggingface transformers attention heads,
it base on ``HuggingfaceModelParser`` to find the attention module,
and prune heads with attention module built-in ``prune_heads`` interface.
Parameters
----------
model
The transformer model, now nni officially support bert, bart, t5, vit.
parser
The model parser used to find the attention module.
If the model passed in is not bert, bart, t5 or vit,
please inherit ``nni.compression.pytorch.utils.external.huggingface.HuggingfaceModelParser``
to customize a new model parser and pass in.
"""
def __init__(self, model: torch.nn.Module, parser: HuggingfaceModelParser | None = None):
self.parser = parser_factory(model) if parser is None else parser
if self.parser is None:
err_msg = f'Can not get the model parser of {type(model)}'
raise RuntimeError(err_msg)
def replace_modules(self, speedup: 'ModelSpeedup'):
# Note: This replace function base on prune_heads interface in Huggingface transformers.
attention_name_dict = defaultdict(list)
attention_patterns = [self.parser.TRANSFORMER_PREFIX + att_p for att_p in self.parser.ATTENTION]
# find layers which has attention layer name prefix
target2node = {}
for node, node_info in speedup.node_infos.items():
if node.op == 'call_module' and self.parser.is_attention(node.target):
target2node[node.target] = node
for attention_pattern in attention_patterns:
attention_layer_name = re.findall(attention_pattern, node.target)[0]
attention_name_dict[attention_layer_name].append(node.target)
# prune heads
for attention_layer_name, qkvo_names in attention_name_dict.items():
# qkvo_flatten_head_mask is the sum of qkv output mask and o input mask
qkvo_flatten_head_mask: torch.Tensor | None = None
for name in qkvo_names:
if _endwith(name, self.parser.QKVO):
info_msg = f'Find QKVO layer `{name}`, try to prune head.'
_logger.info(info_msg)
node = target2node[name]
node_info = speedup.node_infos[node]
if _endwith(name, self.parser.QKV):
out_masks = node_info.output_masks
flatten_head_mask = \
(torch.sum(out_masks, dim=[_ for _ in range(len(out_masks.shape) - 1)]).detach() > 0.).float()
else:
in_masks = tree_map(lambda n: speedup.node_infos[n].output_masks, node.args)
flatten_head_mask = \
(torch.sum(in_masks[0], dim=[_ for _ in range(len(in_masks[0].shape) - 1)]).detach() > 0.).float()
if qkvo_flatten_head_mask is not None:
qkvo_flatten_head_mask *= flatten_head_mask
else:
qkvo_flatten_head_mask = flatten_head_mask
if qkvo_flatten_head_mask is not None:
original_num_heads = self.parser.get_num_heads(attention_layer_name, speedup.bound_model)
head_idxs = _prune_head_idxs(qkvo_flatten_head_mask, original_num_heads)
info_msg = f'Prune {attention_layer_name} head {head_idxs}'
_logger.info(info_msg)
attention_layer = get_nested_attr(speedup.bound_model, attention_layer_name)
attention_layer.prune_heads(head_idxs) # type: ignore
# replace autoinfer masks with ones, assume QKVO are all Linear
remained_idxs = _remained_idxs(qkvo_flatten_head_mask, original_num_heads)
for name in qkvo_names:
if _endwith(name, self.parser.QKVO):
node = target2node[name]
node_info = speedup.node_infos[node]
if _endwith(name, self.parser.QKV):
mask = node_info.param_masks['weight'][remained_idxs]
node_info.param_masks['weight'] = _fill_one_on_dims(mask, 0)
mask = node_info.output_masks.transpose(0, -1)[remained_idxs].transpose(0, -1)
node_info.output_masks = _fill_one_on_dims(mask, -1)
else:
mask = node_info.param_masks['weight'][:, remained_idxs]
node_info.param_masks['weight'] = _fill_one_on_dims(mask, 1)
masks = tree_map(lambda n: speedup.node_infos[n].output_masks, node.args)
mask = masks[0].transpose(0, -1)[remained_idxs].transpose(0, -1)
for n in node.args:
speedup.node_infos[n].output_masks = _fill_one_on_dims(mask, -1)

Просмотреть файл

@ -14,6 +14,7 @@ try:
PreTrainedModel,
BartConfig,
BertConfig,
DistilBertConfig,
T5Config,
ViTConfig
)
@ -106,6 +107,15 @@ class HuggingfaceBertParser(HuggingfaceModelParser):
ATTENTION = ('attention',)
class HuggingfaceDistilBertParser(HuggingfaceModelParser):
TRANSFORMER_PREFIX = r'distilbert\.transformer\.layer\.[0-9]+\.'
QKV = ('attention.q_lin', 'attention.k_lin', 'attention.v_lin')
QKVO = QKV + ('attention.out_lin',)
FFN1 = ('ffn.lin1',)
FFN2 = ('ffn.lin2',)
ATTENTION = ('attention',)
class HuggingfaceBartParser(HuggingfaceModelParser):
TRANSFORMER_PREFIX = r'(en|de)coder\.layer\.[0-9]+\.'
QKV = ('self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'encoder_attn.q_proj', 'encoder_attn.k_proj', 'encoder_attn.v_proj')
@ -139,12 +149,14 @@ def parser_factory(model: Module) -> HuggingfaceModelParser | None:
cls2parser = {
BartConfig: HuggingfaceBartParser,
BertConfig: HuggingfaceBertParser,
DistilBertConfig: HuggingfaceDistilBertParser,
T5Config: HuggingfaceT5Parser,
ViTConfig: HuggingfaceViTParser
}
type2parser = {
'bart': HuggingfaceBartParser,
'bert': HuggingfaceBertParser,
'distilbert': HuggingfaceDistilBertParser,
't5': HuggingfaceT5Parser,
'vit': HuggingfaceViTParser
}

Просмотреть файл

@ -478,8 +478,8 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
return module_wrappers, configured_target_spaces
def create_module_wrapper(model:nn.Module, module: nn.Module, module_name: str, mode: Literal['pruning', 'quantization', 'distillation'], \
config: Dict[str, Any], wrapper: ModuleWrapper | None = None, fused_modules_pair: List[str] | None = None):
def create_module_wrapper(model: nn.Module, module: nn.Module, module_name: str, mode: Literal['pruning', 'quantization', 'distillation'],
config: Dict[str, Any], wrapper: ModuleWrapper | None = None, fused_modules_pair: List[str] | None = None):
fused_modules_pair = fused_modules_pair if fused_modules_pair is not None else []
if mode != 'quantization' and len(fused_modules_pair) > 0:
raise ValueError(f"Only quantization supports model fusion, but got {mode} and {fused_modules_pair}")

Просмотреть файл

@ -9,7 +9,6 @@ from typing import Any, Callable, Dict, List, overload
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils._pytree import tree_map
from ..base.compressor import Compressor, Distiller, _DISTILLATION_TARGET_SPACES
@ -214,7 +213,7 @@ class DynamicLayerwiseDistiller(TeacherModelBasedDistiller):
loss_list.append(target_space.lambda_ * \
F.kl_div((stu_hs / 2).log_softmax(dim=-1), (tea_hs / 2).softmax(dim=-1), reduction='batchmean') * (2 ** 2))
if loss_list:
distill_loss += min(loss_list)
distill_loss = distill_loss + min(loss_list)
for _, ts in self._target_spaces.items():
for _, target_space in ts.items():
target_space.clean()
@ -296,23 +295,16 @@ class Adaptive1dLayerwiseDistiller(TeacherModelBasedDistiller):
self.trans_linears[module_name][target_name] = torch.nn.Linear(stu_hs.shape[-1], tea_hs.shape[-1]).to(stu_hs.device)
def _register_linears_optimization(self, evaluator: Evaluator):
linear_params = []
for _, linears in self.trans_linears.items():
linear_params = {}
for module_name, linears in self.trans_linears.items():
for _, linear in linears.items():
if linear is not None:
linear_params.extend(linear.parameters())
linear_params[module_name] = list(linear.parameters())
if not linear_params:
return
params = [{"params": linear_params}]
optimizer = Adam(params, 1e-2)
def optimizer_task():
optimizer.step()
optimizer.zero_grad()
evaluator.patch_optimizer_step(before_step_tasks=[optimizer_task], after_step_tasks=[])
evaluator.patch_optim_param_group(linear_params)
def compute_distill_loss(self):
distill_loss = 0

Просмотреть файл

@ -187,7 +187,7 @@ class Evaluator:
raise NotImplementedError
def _optimizer_add_param_group(self, model: Union[torch.nn.Module, pl.LightningModule],
module_name_param_dict: Dict[str, List[Tensor]], optimizers: Optimizer | List[Optimizer]):
module_name_param_dict: Dict[str, List[Tensor]], optimizers: Optimizer | List[Optimizer]):
# used in the bind_model process
def find_param_group(param_groups: List[Dict], module_name: str):
for i, param_group in enumerate(param_groups):
@ -217,13 +217,13 @@ class Evaluator:
# copyed from torch.optim to check the validation of param
if not isinstance(param, torch.Tensor):
raise TypeError("optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param))
"but one of the params is " + torch.typename(param))
if not optimizer.defaults.get('differentiable', None) \
and not (param.is_leaf or param.retains_grad): # type: ignore
raise ValueError("can't optimize a non-leaf Tensor")
target_param_group['params'].append(param)
assert isinstance(model, (Module, pl.LightningModule))
assert isinstance(model, Module)
param2name_dict = {id(p): name for name, p in model.named_parameters()}
assert optimizers is not None, "Please provide optimizers for adding param_groups in optimizers"
optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
@ -1024,7 +1024,6 @@ class TransformersEvaluator(Evaluator):
def patch_optim_param_group(self, module_name_param_dict: Dict[str, List[Tensor]]):
assert isinstance(self.model, Module)
assert module_name_param_dict is not None
self._optimizer_add_param_group(self.model, module_name_param_dict, self.trainer.optimizer)
def unbind_model(self):