зеркало из https://github.com/microsoft/nni.git
[Compression] v2.5 pruning tutorial (#5476)
This commit is contained in:
Родитель
a68cb047d4
Коммит
cdf3bfb3e5
|
@ -0,0 +1,8 @@
|
|||
Best Practices
|
||||
==============
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 2
|
||||
|
||||
Pruning Transformer </tutorials/new_pruning_bert_glue>
|
|
@ -0,0 +1,8 @@
|
|||
Compression (Preview)
|
||||
=====================
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 2
|
||||
|
||||
Pruning <toctree_pruning>
|
|
@ -0,0 +1,8 @@
|
|||
Pruning
|
||||
=======
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 2
|
||||
|
||||
Best Practices <best_practices>
|
|
@ -16,7 +16,8 @@ NNI Documentation
|
|||
|
||||
hpo/toctree
|
||||
nas/toctree
|
||||
Model Compression <compression/toctree>
|
||||
compression/toctree
|
||||
compression_preview/toctree
|
||||
feature_engineering/toctree
|
||||
experiment/toctree
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ NNI 文档
|
|||
超参调优 <hpo/toctree>
|
||||
架构搜索 <nas/toctree>
|
||||
模型压缩 <compression/toctree>
|
||||
模型压缩(预览) <compression_preview/toctree>
|
||||
特征工程 <feature_engineering/toctree>
|
||||
实验管理 <experiment/toctree>
|
||||
|
||||
|
|
|
@ -8,36 +8,32 @@ msgid ""
|
|||
msgstr ""
|
||||
"Project-Id-Version: NNI \n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2022-04-20 05:50+0000\n"
|
||||
"POT-Creation-Date: 2023-03-27 02:44+0000\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.9.1\n"
|
||||
"Generated-By: Babel 2.11.0\n"
|
||||
|
||||
#: ../../source/index.rst:4 ../../source/index.rst:52
|
||||
#: ../../source/index.rst:4 ../../source/index.rst:53
|
||||
msgid "Get Started"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:12
|
||||
msgid "Model Compression"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:12
|
||||
msgid "User Guide"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:23
|
||||
#: ../../source/index.rst:24
|
||||
msgid "Python API"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:23
|
||||
#: ../../source/index.rst:24
|
||||
msgid "References"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:32
|
||||
#: ../../source/index.rst:33
|
||||
msgid "Misc"
|
||||
msgstr ""
|
||||
|
||||
|
@ -45,68 +41,68 @@ msgstr ""
|
|||
msgid "NNI Documentation"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:44
|
||||
#: ../../source/index.rst:45
|
||||
msgid ""
|
||||
"**NNI (Neural Network Intelligence)** is a lightweight but powerful "
|
||||
"toolkit to help users **automate**:"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:46
|
||||
#: ../../source/index.rst:47
|
||||
msgid ":doc:`Hyperparameter Optimization </hpo/overview>`"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:47
|
||||
#: ../../source/index.rst:48
|
||||
msgid ":doc:`Neural Architecture Search </nas/overview>`"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:48
|
||||
#: ../../source/index.rst:49
|
||||
msgid ":doc:`Model Compression </compression/overview>`"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:49
|
||||
#: ../../source/index.rst:50
|
||||
msgid ":doc:`Feature Engineering </feature_engineering/overview>`"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:54
|
||||
#: ../../source/index.rst:55
|
||||
msgid "To install the current release:"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:60
|
||||
#: ../../source/index.rst:61
|
||||
msgid ""
|
||||
"See the :doc:`installation guide </installation>` if you need additional "
|
||||
"help on installation."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:63
|
||||
#: ../../source/index.rst:64
|
||||
msgid "Try your first NNI experiment"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:69
|
||||
#: ../../source/index.rst:70
|
||||
msgid ""
|
||||
"You need to have `PyTorch <https://pytorch.org/>`_ (as well as "
|
||||
"`torchvision <https://pytorch.org/vision/stable/index.html>`_) installed "
|
||||
"to run this experiment."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:71
|
||||
#: ../../source/index.rst:72
|
||||
msgid ""
|
||||
"To start your journey now, please follow the :doc:`absolute quickstart of"
|
||||
" NNI <quickstart>`!"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:74
|
||||
#: ../../source/index.rst:75
|
||||
msgid "Why choose NNI?"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:77
|
||||
#: ../../source/index.rst:78
|
||||
msgid "NNI makes AutoML techniques plug-and-play"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:221
|
||||
#: ../../source/index.rst:222
|
||||
msgid "NNI eases the effort to scale and manage AutoML experiments"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:229
|
||||
#: ../../source/index.rst:230
|
||||
msgid ""
|
||||
"An AutoML experiment requires many trials to explore feasible and "
|
||||
"potentially good-performing models. **Training service** aims to make the"
|
||||
|
@ -116,13 +112,13 @@ msgid ""
|
|||
"kinds of training services."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:240
|
||||
#: ../../source/index.rst:241
|
||||
msgid ""
|
||||
"Web portal visualizes the tuning process, exposing the ability to "
|
||||
"inspect, monitor and control the experiment."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:251
|
||||
#: ../../source/index.rst:252
|
||||
msgid ""
|
||||
"The DNN model tuning often requires more than one experiment. Users might"
|
||||
" try different tuning algorithms, fine-tune their search space, or switch"
|
||||
|
@ -131,66 +127,66 @@ msgid ""
|
|||
"so that the tuning workflow becomes clean and organized."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:257
|
||||
#: ../../source/index.rst:258
|
||||
msgid "Get Support and Contribute Back"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:259
|
||||
#: ../../source/index.rst:260
|
||||
msgid ""
|
||||
"NNI is maintained on the `NNI GitHub repository "
|
||||
"<https://github.com/microsoft/nni>`_. We collect feedbacks and new "
|
||||
"proposals/ideas on GitHub. You can:"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:261
|
||||
#: ../../source/index.rst:262
|
||||
msgid ""
|
||||
"Open a `GitHub issue <https://github.com/microsoft/nni/issues>`_ for bugs"
|
||||
" and feature requests."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:262
|
||||
#: ../../source/index.rst:263
|
||||
msgid ""
|
||||
"Open a `pull request <https://github.com/microsoft/nni/pulls>`_ to "
|
||||
"contribute code (make sure to read the :doc:`contribution guide "
|
||||
"<notes/contributing>` before doing this)."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:263
|
||||
#: ../../source/index.rst:264
|
||||
msgid ""
|
||||
"Participate in `NNI Discussion "
|
||||
"<https://github.com/microsoft/nni/discussions>`_ for general questions "
|
||||
"and new ideas."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:264
|
||||
#: ../../source/index.rst:265
|
||||
msgid "Join the following IM groups."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:270
|
||||
#: ../../source/index.rst:271
|
||||
msgid "Gitter"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:271
|
||||
#: ../../source/index.rst:272
|
||||
msgid "WeChat"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:278
|
||||
#: ../../source/index.rst:279
|
||||
msgid "Citing NNI"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:280
|
||||
#: ../../source/index.rst:281
|
||||
msgid ""
|
||||
"If you use NNI in a scientific publication, please consider citing NNI in"
|
||||
" your references."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:282
|
||||
#: ../../source/index.rst:283
|
||||
msgid ""
|
||||
"Microsoft. Neural Network Intelligence (version |release|). "
|
||||
"https://github.com/microsoft/nni"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:284
|
||||
#: ../../source/index.rst:285
|
||||
msgid ""
|
||||
"Bibtex entry (please replace the version with the particular version you "
|
||||
"are using): ::"
|
||||
|
@ -217,3 +213,6 @@ msgstr ""
|
|||
#~ "doing this)."
|
||||
#~ msgstr ""
|
||||
|
||||
#~ msgid "Model Compression"
|
||||
#~ msgstr ""
|
||||
|
||||
|
|
Двоичные данные
docs/source/tutorials/images/thumb/sphx_glr_new_pruning_bert_glue_thumb.png
Normal file
Двоичные данные
docs/source/tutorials/images/thumb/sphx_glr_new_pruning_bert_glue_thumb.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 18 KiB |
|
@ -17,7 +17,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
|
||||
:alt:
|
||||
:alt: Speedup Model with Mask
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_speedup.py`
|
||||
|
||||
|
@ -34,7 +34,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
|
||||
:alt:
|
||||
:alt: Quantization Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
|
||||
|
||||
|
@ -51,7 +51,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
|
||||
:alt:
|
||||
:alt: Pruning Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
|
||||
|
||||
|
@ -68,7 +68,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
|
||||
:alt:
|
||||
:alt: Customize a new quantization algorithm
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_customize.py`
|
||||
|
||||
|
@ -85,7 +85,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
|
||||
:alt:
|
||||
:alt: Use NAS Benchmarks as Datasets
|
||||
|
||||
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
|
||||
|
||||
|
@ -102,7 +102,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
|
||||
:alt:
|
||||
:alt: Speed Up Quantized Model with TensorRT
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_speedup.py`
|
||||
|
||||
|
@ -119,7 +119,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
|
||||
:alt:
|
||||
:alt: Hello, NAS!
|
||||
|
||||
:ref:`sphx_glr_tutorials_hello_nas.py`
|
||||
|
||||
|
@ -136,7 +136,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_darts_thumb.png
|
||||
:alt:
|
||||
:alt: Searching in DARTS search space
|
||||
|
||||
:ref:`sphx_glr_tutorials_darts.py`
|
||||
|
||||
|
@ -146,6 +146,23 @@ Tutorials
|
|||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="This is a new tutorial on pruning transformer in nni v3.0 (`old tutorial <https://nni.readthedo...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_new_pruning_bert_glue_thumb.png
|
||||
:alt: Pruning Bert on Task MNLI
|
||||
|
||||
:ref:`sphx_glr_tutorials_new_pruning_bert_glue.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Pruning Bert on Task MNLI</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------">
|
||||
|
@ -153,7 +170,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
|
||||
:alt:
|
||||
:alt: Pruning Bert on Task MNLI
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
|
||||
|
||||
|
@ -179,6 +196,7 @@ Tutorials
|
|||
/tutorials/quantization_speedup
|
||||
/tutorials/hello_nas
|
||||
/tutorials/darts
|
||||
/tutorials/new_pruning_bert_glue
|
||||
/tutorials/pruning_bert_glue
|
||||
|
||||
|
||||
|
@ -196,7 +214,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt:
|
||||
:alt: HPO Quickstart with PyTorch
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
|
||||
|
||||
|
@ -213,7 +231,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt:
|
||||
:alt: Port PyTorch Quickstart to NNI
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
|
||||
|
||||
|
@ -242,7 +260,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt:
|
||||
:alt: HPO Quickstart with TensorFlow
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
|
||||
|
||||
|
@ -259,7 +277,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt:
|
||||
:alt: Port TensorFlow Quickstart to NNI
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
|
||||
|
||||
|
@ -278,7 +296,6 @@ Tutorials
|
|||
:hidden:
|
||||
:includehidden:
|
||||
|
||||
|
||||
/tutorials/hpo_quickstart_pytorch/index.rst
|
||||
/tutorials/hpo_quickstart_tensorflow/index.rst
|
||||
|
||||
|
|
|
@ -0,0 +1,349 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Pruning Bert on Task MNLI\n\nThis is a new tutorial on pruning transformer in nni v3.0 ([old tutorial](https://nni.readthedocs.io/en/v2.9/tutorials/pruning_bert_glue.html)_).\nThe main difference between this tutorial and the previous is that it integrates the feature of fusion compression (pruning + distillation) in nni,\nuses a new more powerful and stable pruning speedup tool,\nand additionally prunes the whole model hidden dimensions which greatly reduces the model size (pruning embedding layers).\n\nAt the same time, the huggingface [transformers.Trainer](https://huggingface.co/docs/transformers/main_classes/trainer)_ is used in this tutorial\nto reduce the burden of user writing training and evaluation logic.\n\n## Workable Pruning Process\n\nThe whole pruning process is divided into three steps:\n\n1. pruning attention layers,\n2. pruning feed forward layers,\n3. pruning embedding layers.\n\nIn each step, the pruner is first used for simulated pruning to generate masks corresponding to the module pruning targets (weight, input, output).\nAfter that comes the speedup stage, sparsity propagation is used to explore the global redundancy due to the local masks,\nthen modify the original model into a smaller one by replacing the sub module in the model.\n\nThe compression of the model naturally applies the distillation method,\nso in this tutorial, distillers will also be used to help restore the model accuracy.\n\n## Experiment\n\n### Preparations\n\nThe preparations mainly includes preparing the transformers trainer and model.\n\nThis is generally consistent with the preparations required to normally train a Bert model.\nThe only difference is that the ``transformers.Trainer`` is needed to wrap by ``nni.trace`` to trace the init arguments,\nthis is because nni need re-create trainer during training aware pruning and distilling.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Please set ``skip_exec`` to ``False`` to run this tutorial. Here ``skip_exec`` is ``True`` by default is for generating documents.</p></div>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import annotations\n\nfrom pathlib import Path\n\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import ConcatDataset\n\nimport nni\n\nfrom datasets import load_dataset, load_metric\nfrom transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction\nfrom transformers.trainer import Trainer\nfrom transformers.training_args import TrainingArguments\n\nskip_exec = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set the downstream task name here, you could replace the task with the task in GLUE.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"task_name = 'mnli'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here using BertForSequenceClassification as the base model for show case.\nIf you want to prune other kind of transformer model, you could replace the base model here.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_model(pretrained_model_name_or_path: str, task_name: str):\n is_regression = task_name == 'stsb'\n num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)\n model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)\n return model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Prepare the GLUE train & validation datasets, if the task has multi validation datasets, concat the datasets by ``ConcatDataset``.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):\n task_to_keys = {\n 'cola': ('sentence', None),\n 'mnli': ('premise', 'hypothesis'),\n 'mrpc': ('sentence1', 'sentence2'),\n 'qnli': ('question', 'sentence'),\n 'qqp': ('question1', 'question2'),\n 'rte': ('sentence1', 'sentence2'),\n 'sst2': ('sentence', None),\n 'stsb': ('sentence1', 'sentence2'),\n 'wnli': ('sentence1', 'sentence2'),\n }\n sentence1_key, sentence2_key = task_to_keys[task_name]\n\n # used to preprocess the raw data\n def preprocess_function(examples):\n # Tokenize the texts\n args = (\n (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n )\n result = tokenizer(*args, padding=False, max_length=128, truncation=True)\n\n if 'label' in examples:\n # In all cases, rename the column to labels because the model will expect that.\n result['labels'] = examples['label']\n return result\n\n raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)\n for key in list(raw_datasets.keys()):\n if 'test' in key:\n raw_datasets.pop(key)\n\n processed_datasets = raw_datasets.map(preprocess_function, batched=True,\n remove_columns=raw_datasets['train'].column_names)\n\n train_dataset = processed_datasets['train']\n if task_name == 'mnli':\n validation_datasets = {\n 'validation_matched': processed_datasets['validation_matched'],\n 'validation_mismatched': processed_datasets['validation_mismatched']\n }\n else:\n validation_datasets = {\n 'validation': processed_datasets['validation']\n }\n\n return train_dataset, validation_datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Prepare the trainer, note that the ``Trainer`` class is wrapped by ``nni.trace``.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n def compute_metrics(p: EvalPrediction):\n preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions\n preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)\n result = metric.compute(predictions=preds, references=p.label_ids)\n result['default'] = result.get('f1', result.get('accuracy', 0.))\n return result\n\n tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')\n train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)\n merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])\n data_collator = DataCollatorWithPadding(tokenizer)\n training_args = TrainingArguments(output_dir='./output/trainer',\n do_train=True,\n do_eval=True,\n evaluation_strategy='steps',\n per_device_train_batch_size=32,\n per_device_eval_batch_size=32,\n num_train_epochs=3,\n dataloader_num_workers=12,\n learning_rate=3e-5,\n save_strategy='steps',\n save_total_limit=1,\n metric_for_best_model='default',\n load_best_model_at_end=load_best_model_at_end,\n disable_tqdm=True,\n optim='adamw_torch',\n seed=1024)\n trainer = nni.trace(Trainer)(model=model,\n args=training_args,\n data_collator=data_collator,\n train_dataset=train_dataset,\n eval_dataset=merged_validation_dataset,\n tokenizer=tokenizer,\n compute_metrics=compute_metrics,)\n return trainer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If the finetuned model is existed, directly load it.\nIf the finetuned model is not existed, train the pretrained model with the trainer.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_finetuning_model(task_name: str, state_dict_path: str):\n model = build_model('bert-base-uncased', task_name)\n if Path(state_dict_path).exists():\n model.load_state_dict(torch.load(state_dict_path))\n else:\n trainer = prepare_traced_trainer(model, task_name, True)\n trainer.train()\n torch.save(model.state_dict(), state_dict_path)\n return model\n\n\nif not skip_exec:\n Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)\n build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The following code creates distillers for distillation.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from nni.contrib.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller\nfrom nni.contrib.compression.utils import TransformersEvaluator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Dynamic distillation is suitable for the situation where the distillation states dimension of the student and the teacher match.\nA student state can try to distill on multiple teacher states, and finally select the teacher state with the smallest distillation loss as the target for distillation.\n\nIn this tutorial, dynamic distillation is applied before speedup the embedding pruning.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,\n student_trainer: Trainer):\n layer_num = len(student_model.bert.encoder.layer)\n config_list = [{\n 'op_names': [f'bert.encoder.layer.{i}'],\n 'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],\n 'lambda': 0.9,\n 'apply_method': 'mse',\n } for i in range(layer_num)]\n config_list.append({\n 'op_names': ['classifier'],\n 'link': ['classifier'],\n 'lambda': 0.9,\n 'apply_method': 'kl',\n })\n\n evaluator = TransformersEvaluator(student_trainer)\n\n def teacher_predict(batch, teacher_model):\n return teacher_model(**batch)\n\n return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)\n\n\ndef dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,\n max_steps: int | None, max_epochs: int | None):\n student_trainer = prepare_traced_trainer(student_model, task_name, True)\n\n ori_teacher_device = teacher_model.device\n training = teacher_model.training\n teacher_model.to(student_trainer.args.device).eval()\n\n distiller = dynamic_distiller(student_model, teacher_model, student_trainer)\n distiller.compress(max_steps, max_epochs)\n distiller.unwrap_model()\n\n teacher_model.to(ori_teacher_device).train(training)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Adapt distillation is applied after pruning embedding layers.\nThe hidden states dimension will mismatch between student model and teacher model after pruning embedding layers,\nthen adapt distiller will add a linear layer for each distillation module pair to align dimension.\nFor example, pruning hidden dimension from 768 to 384, then for each student transformer block,\nwill add a ``Linear(in_features=384, out_features=768)`` for shifting dimention 384 to 768,\naligned with the teacher model transformer block output.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,\n student_trainer: Trainer):\n layer_num = len(student_model.bert.encoder.layer)\n config_list = [{\n 'op_names': [f'bert.encoder.layer.{i}'],\n 'lambda': 0.9,\n 'apply_method': 'mse',\n } for i in range(layer_num)]\n config_list.append({\n 'op_names': ['classifier'],\n 'link': ['classifier'],\n 'lambda': 0.9,\n 'apply_method': 'kl',\n })\n\n evaluator = TransformersEvaluator(student_trainer)\n\n def teacher_predict(batch, teacher_model):\n return teacher_model(**batch)\n\n return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)\n\n\ndef adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,\n max_steps: int | None, max_epochs: int | None):\n student_trainer = prepare_traced_trainer(student_model, task_name, True)\n\n ori_teacher_device = teacher_model.device\n training = teacher_model.training\n teacher_model.to(student_trainer.args.device).eval()\n\n distiller = adapt_distiller(student_model, teacher_model, student_trainer)\n dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))\n dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]\n distiller.track_forward(*dummy_input)\n\n distiller.compress(max_steps, max_epochs)\n distiller.unwrap_model()\n\n teacher_model.to(ori_teacher_device).train(training)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Pruning Attention Layers\n\nHere using ``MovementPruner`` to generate block sparse masks. Choosing ``64 x 64`` block is because the head width is 64,\nthis is a kind of coarse grained between head pruning and finegrained pruning, also you can have a try with ``64 x 32``,\n``32 x 32`` or any other granularity here.\n\nWe use ``sparse_threshold`` instead of ``sparse_ratio`` here to apply adaptive sparse allocation.\n``sparse_threshold`` here is a float number between 0. and 1., but its value has little effect on the final sparse ratio.\nIf you want a more sparse model, you could set a larger ``regular_scale`` in ``MovementPruner``.\nYou could refer to the experiment results to choose a appropriate ``regular_scale`` you like.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from nni.contrib.compression.pruning import MovementPruner\nfrom nni.compression.pytorch.speedup.v2 import ModelSpeedup\nfrom nni.compression.pytorch.speedup.v2.external_replacer import TransformersAttentionReplacer\n\n\ndef pruning_attn():\n Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)\n model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')\n trainer = prepare_traced_trainer(model, task_name)\n evaluator = TransformersEvaluator(trainer)\n\n config_list = [{\n 'op_types': ['Linear'],\n 'op_names_re': ['bert\\.encoder\\.layer\\.[0-9]*\\.attention\\.*'],\n 'sparse_threshold': 0.1,\n 'granularity': [64, 64]\n }]\n\n pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)\n pruner.compress(None, 4)\n pruner.unwrap_model()\n\n masks = pruner.get_masks()\n Path('./output/pruning/').mkdir(parents=True, exist_ok=True)\n torch.save(masks, './output/pruning/attn_masks.pth')\n torch.save(model, './output/pruning/attn_masked_model.pth')\n\n\nif not skip_exec:\n pruning_attn()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We apply head pruning during the speedup stage, if the head is fully masked it will be pruned,\nif the header is partially masked, it will be restored.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def speedup_attn():\n model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')\n masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')\n dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))\n replacer = TransformersAttentionReplacer(model)\n ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()\n\n # finetuning\n teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n dynamic_distillation(model, teacher_model, None, 3)\n torch.save(model, './output/pruning/attn_pruned_model.pth')\n\n\nif not skip_exec:\n speedup_attn()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Pruning Feed Forward Layers\n\nHere using ``TaylorPruner`` for pruning feed forward layers,\nand the sparse ratio related to the pruned head number in the same transformer block.\nThe more heads are pruned, the higher the sparse ratio is set for feed forward layers.\n\nNote that ``TaylorPruner`` has no schedule sparse ratio function,\nso we use ``AGPPruner`` to schedule the sparse ratio to achieve better pruning performance.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from nni.contrib.compression.pruning import TaylorPruner, AGPPruner\nfrom transformers.models.bert.modeling_bert import BertLayer\n\n\ndef pruning_ffn():\n model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')\n teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n # create ffn config list, here simply use a linear function related to the number of retained heads to determine the sparse ratio\n config_list = []\n for name, module in model.named_modules():\n if isinstance(module, BertLayer):\n retained_head_num = module.attention.self.num_attention_heads\n ori_head_num = len(module.attention.pruned_heads) + retained_head_num\n ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2\n config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})\n\n trainer = prepare_traced_trainer(model, task_name)\n teacher_model.eval().to(trainer.args.device)\n # create a distiller for restoring the accuracy\n distiller = dynamic_distiller(model, teacher_model, trainer)\n # fusion compress: TaylorPruner + DynamicLayerwiseDistiller\n taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)\n # fusion compress: AGPPruner(TaylorPruner) + DynamicLayerwiseDistiller\n agp_pruner = AGPPruner(taylor_pruner, 1000, 36)\n agp_pruner.compress(None, 3)\n agp_pruner.unwrap_model()\n distiller.unwrap_teacher_model()\n\n masks = agp_pruner.get_masks()\n Path('./output/pruning/').mkdir(parents=True, exist_ok=True)\n torch.save(masks, './output/pruning/ffn_masks.pth')\n torch.save(model, './output/pruning/ffn_masked_model.pth')\n\n\nif not skip_exec:\n pruning_ffn()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Speedup the feed forward layers.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def speedup_ffn():\n model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')\n masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')\n dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))\n ModelSpeedup(model, dummy_input, masks).speedup_model()\n\n # finetuning\n teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n dynamic_distillation(model, teacher_model, None, 3)\n torch.save(model, './output/pruning/ffn_pruned_model.pth')\n\n\nif not skip_exec:\n speedup_ffn()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Pruning Embedding Layers\n\nWe want to simulate the pruning effect better, so we register the output mask setting for ``BertAttention`` and ``BertOutput``.\nThe output masks can be generated and applied after register the setting template for them.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from nni.contrib.compression.base.setting import PruningSetting\n\noutput_align_setting = {\n '_output_': {\n 'align': {\n 'module_name': None,\n 'target_name': 'weight',\n 'dims': [0],\n },\n 'apply_method': 'mul',\n }\n}\nPruningSetting.register('BertAttention', output_align_setting)\nPruningSetting.register('BertOutput', output_align_setting)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Similar to prune feed forward layers, we also use ``AGPPruner + TaylorPruner + DynamicLayerwiseDistiller`` here.\nFor the better pruning effect simulation, set output ``align`` mask generation in ``config_list``,\nthen the relevant layers will generate its own output masks according to the embedding masks.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def pruning_embedding():\n model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')\n teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n\n sparse_ratio = 0.5\n config_list = [{\n 'op_types': ['Embedding'],\n 'op_names_re': ['bert\\.embeddings.*'],\n 'sparse_ratio': sparse_ratio,\n 'dependency_group_id': 1,\n 'granularity': [-1, 1],\n }, {\n 'op_names_re': ['bert\\.encoder\\.layer\\.[0-9]*\\.attention$',\n 'bert\\.encoder\\.layer\\.[0-9]*\\.output$'],\n 'target_names': ['_output_'],\n 'target_settings': {\n '_output_': {\n 'align': {\n 'module_name': 'bert.embeddings.word_embeddings',\n 'target_name': 'weight',\n 'dims': [1],\n }\n }\n }\n }, {\n 'op_names_re': ['bert\\.encoder\\.layer\\.[0-9]*\\.attention.output.dense',\n 'bert\\.encoder\\.layer\\.[0-9]*\\.output.dense'],\n 'target_names': ['weight'],\n 'target_settings': {\n 'weight': {\n 'granularity': 'out_channel',\n 'align': {\n 'module_name': 'bert.embeddings.word_embeddings',\n 'target_name': 'weight',\n 'dims': [1],\n }\n }\n }\n }]\n\n trainer = prepare_traced_trainer(model, task_name)\n teacher_model.eval().to(trainer.args.device)\n distiller = dynamic_distiller(model, teacher_model, trainer)\n taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)\n agp_pruner = AGPPruner(taylor_pruner, 1000, 36)\n agp_pruner.compress(None, 3)\n agp_pruner.unwrap_model()\n distiller.unwrap_teacher_model()\n\n masks = agp_pruner.get_masks()\n Path('./output/pruning/').mkdir(parents=True, exist_ok=True)\n torch.save(masks, './output/pruning/embedding_masks.pth')\n torch.save(model, './output/pruning/embedding_masked_model.pth')\n\n\nif not skip_exec:\n pruning_embedding()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Speedup the embedding layers.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def speedup_embedding():\n model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')\n masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')\n dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))\n ModelSpeedup(model, dummy_input, masks).speedup_model()\n\n # finetuning\n teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')\n adapt_distillation(model, teacher_model, None, 4)\n torch.save(model, './output/pruning/embedding_pruned_model.pth')\n\n\nif not skip_exec:\n speedup_embedding()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Evaluation\n\nEvaluate the pruned model size and accuracy.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluate_pruned_model():\n model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')\n trainer = prepare_traced_trainer(model, task_name)\n metric = trainer.evaluate()\n pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())\n\n model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')\n ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())\n print(f'Metric: {metric}\\nSparsity: {1 - pruned_num_params / ori_num_params}')\n\n\nif not skip_exec:\n evaluate_pruned_model()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Results\n\n.. list-table:: Prune Bert-base-uncased on MNLI\n :header-rows: 1\n :widths: auto\n\n * - Total Sparsity\n - Embedding Sparsity\n - Encoder Sparsity\n - Pooler Sparsity\n - Acc. (m/mm avg.)\n * - 0.%\n - 0.%\n - 0.%\n - 0.%\n - 84.95%\n * - 57.76%\n - 33.33% (15.89M)\n - 64.78% (29.96M)\n - 33.33% (0.39M)\n - 84.42%\n * - 68.31% (34.70M)\n - 50.00% (11.92M)\n - 73.57% (22.48M)\n - 50.00% (0.30M)\n - 83.33%\n * - 70.95% (31.81M)\n - 33.33% (15.89M)\n - 81.75% (15.52M)\n - 33.33% (0.39M)\n - 83.79%\n * - 78.20% (23.86M)\n - 50.00% (11.92M)\n - 86.31% (11.65M)\n - 50.00% (0.30M)\n - 82.53%\n * - 81.65% (20.12M)\n - 50.00% (11.92M)\n - 90.71% (7.90M)\n - 50.00% (0.30M)\n - 82.08%\n * - 84.32% (17.17M)\n - 50.00% (11.92M)\n - 94.18% (4.95M)\n - 50.00% (0.30M)\n - 81.35%\n\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
|
@ -0,0 +1,617 @@
|
|||
"""
|
||||
Pruning Bert on Task MNLI
|
||||
=========================
|
||||
|
||||
This is a new tutorial on pruning transformer in nni v3.0 (`old tutorial <https://nni.readthedocs.io/en/v2.9/tutorials/pruning_bert_glue.html>`__).
|
||||
The main difference between this tutorial and the previous is that it integrates the feature of fusion compression (pruning + distillation) in nni,
|
||||
uses a new more powerful and stable pruning speedup tool,
|
||||
and additionally prunes the whole model hidden dimensions which greatly reduces the model size (pruning embedding layers).
|
||||
|
||||
At the same time, the huggingface `transformers.Trainer <https://huggingface.co/docs/transformers/main_classes/trainer>`__ is used in this tutorial
|
||||
to reduce the burden of user writing training and evaluation logic.
|
||||
|
||||
Workable Pruning Process
|
||||
------------------------
|
||||
|
||||
The whole pruning process is divided into three steps:
|
||||
|
||||
1. pruning attention layers,
|
||||
2. pruning feed forward layers,
|
||||
3. pruning embedding layers.
|
||||
|
||||
In each step, the pruner is first used for simulated pruning to generate masks corresponding to the module pruning targets (weight, input, output).
|
||||
After that comes the speedup stage, sparsity propagation is used to explore the global redundancy due to the local masks,
|
||||
then modify the original model into a smaller one by replacing the sub module in the model.
|
||||
|
||||
The compression of the model naturally applies the distillation method,
|
||||
so in this tutorial, distillers will also be used to help restore the model accuracy.
|
||||
|
||||
Experiment
|
||||
----------
|
||||
|
||||
Preparations
|
||||
^^^^^^^^^^^^
|
||||
|
||||
The preparations mainly includes preparing the transformers trainer and model.
|
||||
|
||||
This is generally consistent with the preparations required to normally train a Bert model.
|
||||
The only difference is that the ``transformers.Trainer`` is needed to wrap by ``nni.trace`` to trace the init arguments,
|
||||
this is because nni need re-create trainer during training aware pruning and distilling.
|
||||
|
||||
.. note::
|
||||
|
||||
Please set ``skip_exec`` to ``False`` to run this tutorial. Here ``skip_exec`` is ``True`` by default is for generating documents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import nni
|
||||
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
skip_exec = True
|
||||
|
||||
# %%
|
||||
# Set the downstream task name here, you could replace the task with the task in GLUE.
|
||||
|
||||
task_name = 'mnli'
|
||||
|
||||
# %%
|
||||
# Here using BertForSequenceClassification as the base model for show case.
|
||||
# If you want to prune other kind of transformer model, you could replace the base model here.
|
||||
|
||||
def build_model(pretrained_model_name_or_path: str, task_name: str):
|
||||
is_regression = task_name == 'stsb'
|
||||
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
|
||||
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
|
||||
return model
|
||||
|
||||
|
||||
# %%
|
||||
# Prepare the GLUE train & validation datasets, if the task has multi validation datasets, concat the datasets by ``ConcatDataset``.
|
||||
|
||||
def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):
|
||||
task_to_keys = {
|
||||
'cola': ('sentence', None),
|
||||
'mnli': ('premise', 'hypothesis'),
|
||||
'mrpc': ('sentence1', 'sentence2'),
|
||||
'qnli': ('question', 'sentence'),
|
||||
'qqp': ('question1', 'question2'),
|
||||
'rte': ('sentence1', 'sentence2'),
|
||||
'sst2': ('sentence', None),
|
||||
'stsb': ('sentence1', 'sentence2'),
|
||||
'wnli': ('sentence1', 'sentence2'),
|
||||
}
|
||||
sentence1_key, sentence2_key = task_to_keys[task_name]
|
||||
|
||||
# used to preprocess the raw data
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
|
||||
|
||||
if 'label' in examples:
|
||||
# In all cases, rename the column to labels because the model will expect that.
|
||||
result['labels'] = examples['label']
|
||||
return result
|
||||
|
||||
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
|
||||
for key in list(raw_datasets.keys()):
|
||||
if 'test' in key:
|
||||
raw_datasets.pop(key)
|
||||
|
||||
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
|
||||
remove_columns=raw_datasets['train'].column_names)
|
||||
|
||||
train_dataset = processed_datasets['train']
|
||||
if task_name == 'mnli':
|
||||
validation_datasets = {
|
||||
'validation_matched': processed_datasets['validation_matched'],
|
||||
'validation_mismatched': processed_datasets['validation_mismatched']
|
||||
}
|
||||
else:
|
||||
validation_datasets = {
|
||||
'validation': processed_datasets['validation']
|
||||
}
|
||||
|
||||
return train_dataset, validation_datasets
|
||||
|
||||
|
||||
# %%
|
||||
# Prepare the trainer, note that the ``Trainer`` class is wrapped by ``nni.trace``.
|
||||
|
||||
|
||||
def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):
|
||||
is_regression = task_name == 'stsb'
|
||||
metric = load_metric('glue', task_name)
|
||||
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
|
||||
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
|
||||
result = metric.compute(predictions=preds, references=p.label_ids)
|
||||
result['default'] = result.get('f1', result.get('accuracy', 0.))
|
||||
return result
|
||||
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)
|
||||
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
training_args = TrainingArguments(output_dir='./output/trainer',
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
evaluation_strategy='steps',
|
||||
per_device_train_batch_size=32,
|
||||
per_device_eval_batch_size=32,
|
||||
num_train_epochs=3,
|
||||
dataloader_num_workers=12,
|
||||
learning_rate=3e-5,
|
||||
save_strategy='steps',
|
||||
save_total_limit=1,
|
||||
metric_for_best_model='default',
|
||||
load_best_model_at_end=load_best_model_at_end,
|
||||
disable_tqdm=True,
|
||||
optim='adamw_torch',
|
||||
seed=1024)
|
||||
trainer = nni.trace(Trainer)(model=model,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=merged_validation_dataset,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,)
|
||||
return trainer
|
||||
|
||||
|
||||
# %%
|
||||
# If the finetuned model is existed, directly load it.
|
||||
# If the finetuned model is not existed, train the pretrained model with the trainer.
|
||||
|
||||
|
||||
def build_finetuning_model(task_name: str, state_dict_path: str):
|
||||
model = build_model('bert-base-uncased', task_name)
|
||||
if Path(state_dict_path).exists():
|
||||
model.load_state_dict(torch.load(state_dict_path))
|
||||
else:
|
||||
trainer = prepare_traced_trainer(model, task_name, True)
|
||||
trainer.train()
|
||||
torch.save(model.state_dict(), state_dict_path)
|
||||
return model
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
|
||||
build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
|
||||
|
||||
|
||||
# %%
|
||||
# The following code creates distillers for distillation.
|
||||
|
||||
|
||||
from nni.contrib.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
|
||||
from nni.contrib.compression.utils import TransformersEvaluator
|
||||
|
||||
# %%
|
||||
# Dynamic distillation is suitable for the situation where the distillation states dimension of the student and the teacher match.
|
||||
# A student state can try to distill on multiple teacher states, and finally select the teacher state with the smallest distillation loss as the target for distillation.
|
||||
#
|
||||
# In this tutorial, dynamic distillation is applied before speedup the embedding pruning.
|
||||
|
||||
def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
student_trainer: Trainer):
|
||||
layer_num = len(student_model.bert.encoder.layer)
|
||||
config_list = [{
|
||||
'op_names': [f'bert.encoder.layer.{i}'],
|
||||
'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'mse',
|
||||
} for i in range(layer_num)]
|
||||
config_list.append({
|
||||
'op_names': ['classifier'],
|
||||
'link': ['classifier'],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'kl',
|
||||
})
|
||||
|
||||
evaluator = TransformersEvaluator(student_trainer)
|
||||
|
||||
def teacher_predict(batch, teacher_model):
|
||||
return teacher_model(**batch)
|
||||
|
||||
return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
|
||||
|
||||
|
||||
def dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
max_steps: int | None, max_epochs: int | None):
|
||||
student_trainer = prepare_traced_trainer(student_model, task_name, True)
|
||||
|
||||
ori_teacher_device = teacher_model.device
|
||||
training = teacher_model.training
|
||||
teacher_model.to(student_trainer.args.device).eval()
|
||||
|
||||
distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
|
||||
distiller.compress(max_steps, max_epochs)
|
||||
distiller.unwrap_model()
|
||||
|
||||
teacher_model.to(ori_teacher_device).train(training)
|
||||
|
||||
|
||||
# %%
|
||||
# Adapt distillation is applied after pruning embedding layers.
|
||||
# The hidden states dimension will mismatch between student model and teacher model after pruning embedding layers,
|
||||
# then adapt distiller will add a linear layer for each distillation module pair to align dimension.
|
||||
# For example, pruning hidden dimension from 768 to 384, then for each student transformer block,
|
||||
# will add a ``Linear(in_features=384, out_features=768)`` for shifting dimention 384 to 768,
|
||||
# aligned with the teacher model transformer block output.
|
||||
|
||||
|
||||
def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
student_trainer: Trainer):
|
||||
layer_num = len(student_model.bert.encoder.layer)
|
||||
config_list = [{
|
||||
'op_names': [f'bert.encoder.layer.{i}'],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'mse',
|
||||
} for i in range(layer_num)]
|
||||
config_list.append({
|
||||
'op_names': ['classifier'],
|
||||
'link': ['classifier'],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'kl',
|
||||
})
|
||||
|
||||
evaluator = TransformersEvaluator(student_trainer)
|
||||
|
||||
def teacher_predict(batch, teacher_model):
|
||||
return teacher_model(**batch)
|
||||
|
||||
return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
|
||||
|
||||
|
||||
def adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
max_steps: int | None, max_epochs: int | None):
|
||||
student_trainer = prepare_traced_trainer(student_model, task_name, True)
|
||||
|
||||
ori_teacher_device = teacher_model.device
|
||||
training = teacher_model.training
|
||||
teacher_model.to(student_trainer.args.device).eval()
|
||||
|
||||
distiller = adapt_distiller(student_model, teacher_model, student_trainer)
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]
|
||||
distiller.track_forward(*dummy_input)
|
||||
|
||||
distiller.compress(max_steps, max_epochs)
|
||||
distiller.unwrap_model()
|
||||
|
||||
teacher_model.to(ori_teacher_device).train(training)
|
||||
|
||||
|
||||
# %%
|
||||
# Pruning Attention Layers
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
#
|
||||
# Here using ``MovementPruner`` to generate block sparse masks. Choosing ``64 x 64`` block is because the head width is 64,
|
||||
# this is a kind of coarse grained between head pruning and finegrained pruning, also you can have a try with ``64 x 32``,
|
||||
# ``32 x 32`` or any other granularity here.
|
||||
#
|
||||
# We use ``sparse_threshold`` instead of ``sparse_ratio`` here to apply adaptive sparse allocation.
|
||||
# ``sparse_threshold`` here is a float number between 0. and 1., but its value has little effect on the final sparse ratio.
|
||||
# If you want a more sparse model, you could set a larger ``regular_scale`` in ``MovementPruner``.
|
||||
# You could refer to the experiment results to choose a appropriate ``regular_scale`` you like.
|
||||
|
||||
|
||||
from nni.contrib.compression.pruning import MovementPruner
|
||||
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
|
||||
from nni.compression.pytorch.speedup.v2.external_replacer import TransformersAttentionReplacer
|
||||
|
||||
|
||||
def pruning_attn():
|
||||
Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
|
||||
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
evaluator = TransformersEvaluator(trainer)
|
||||
|
||||
config_list = [{
|
||||
'op_types': ['Linear'],
|
||||
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention\.*'],
|
||||
'sparse_threshold': 0.1,
|
||||
'granularity': [64, 64]
|
||||
}]
|
||||
|
||||
pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)
|
||||
pruner.compress(None, 4)
|
||||
pruner.unwrap_model()
|
||||
|
||||
masks = pruner.get_masks()
|
||||
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
|
||||
torch.save(masks, './output/pruning/attn_masks.pth')
|
||||
torch.save(model, './output/pruning/attn_masked_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
pruning_attn()
|
||||
|
||||
|
||||
# %%
|
||||
# We apply head pruning during the speedup stage, if the head is fully masked it will be pruned,
|
||||
# if the header is partially masked, it will be restored.
|
||||
|
||||
|
||||
def speedup_attn():
|
||||
model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')
|
||||
masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
replacer = TransformersAttentionReplacer(model)
|
||||
ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()
|
||||
|
||||
# finetuning
|
||||
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
dynamic_distillation(model, teacher_model, None, 3)
|
||||
torch.save(model, './output/pruning/attn_pruned_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
speedup_attn()
|
||||
|
||||
|
||||
# %%
|
||||
# Pruning Feed Forward Layers
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
#
|
||||
# Here using ``TaylorPruner`` for pruning feed forward layers,
|
||||
# and the sparse ratio related to the pruned head number in the same transformer block.
|
||||
# The more heads are pruned, the higher the sparse ratio is set for feed forward layers.
|
||||
#
|
||||
# Note that ``TaylorPruner`` has no schedule sparse ratio function,
|
||||
# so we use ``AGPPruner`` to schedule the sparse ratio to achieve better pruning performance.
|
||||
|
||||
|
||||
from nni.contrib.compression.pruning import TaylorPruner, AGPPruner
|
||||
from transformers.models.bert.modeling_bert import BertLayer
|
||||
|
||||
|
||||
def pruning_ffn():
|
||||
model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')
|
||||
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
# create ffn config list, here simply use a linear function related to the number of retained heads to determine the sparse ratio
|
||||
config_list = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, BertLayer):
|
||||
retained_head_num = module.attention.self.num_attention_heads
|
||||
ori_head_num = len(module.attention.pruned_heads) + retained_head_num
|
||||
ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2
|
||||
config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})
|
||||
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
teacher_model.eval().to(trainer.args.device)
|
||||
# create a distiller for restoring the accuracy
|
||||
distiller = dynamic_distiller(model, teacher_model, trainer)
|
||||
# fusion compress: TaylorPruner + DynamicLayerwiseDistiller
|
||||
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
|
||||
# fusion compress: AGPPruner(TaylorPruner) + DynamicLayerwiseDistiller
|
||||
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
|
||||
agp_pruner.compress(None, 3)
|
||||
agp_pruner.unwrap_model()
|
||||
distiller.unwrap_teacher_model()
|
||||
|
||||
masks = agp_pruner.get_masks()
|
||||
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
|
||||
torch.save(masks, './output/pruning/ffn_masks.pth')
|
||||
torch.save(model, './output/pruning/ffn_masked_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
pruning_ffn()
|
||||
|
||||
|
||||
# %%
|
||||
# Speedup the feed forward layers.
|
||||
|
||||
|
||||
def speedup_ffn():
|
||||
model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')
|
||||
masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
ModelSpeedup(model, dummy_input, masks).speedup_model()
|
||||
|
||||
# finetuning
|
||||
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
dynamic_distillation(model, teacher_model, None, 3)
|
||||
torch.save(model, './output/pruning/ffn_pruned_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
speedup_ffn()
|
||||
|
||||
|
||||
# %%
|
||||
# Pruning Embedding Layers
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
#
|
||||
# We want to simulate the pruning effect better, so we register the output mask setting for ``BertAttention`` and ``BertOutput``.
|
||||
# The output masks can be generated and applied after register the setting template for them.
|
||||
|
||||
|
||||
from nni.contrib.compression.base.setting import PruningSetting
|
||||
|
||||
output_align_setting = {
|
||||
'_output_': {
|
||||
'align': {
|
||||
'module_name': None,
|
||||
'target_name': 'weight',
|
||||
'dims': [0],
|
||||
},
|
||||
'apply_method': 'mul',
|
||||
}
|
||||
}
|
||||
PruningSetting.register('BertAttention', output_align_setting)
|
||||
PruningSetting.register('BertOutput', output_align_setting)
|
||||
|
||||
|
||||
# %%
|
||||
# Similar to prune feed forward layers, we also use ``AGPPruner + TaylorPruner + DynamicLayerwiseDistiller`` here.
|
||||
# For the better pruning effect simulation, set output ``align`` mask generation in ``config_list``,
|
||||
# then the relevant layers will generate its own output masks according to the embedding masks.
|
||||
|
||||
|
||||
def pruning_embedding():
|
||||
model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')
|
||||
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
|
||||
sparse_ratio = 0.5
|
||||
config_list = [{
|
||||
'op_types': ['Embedding'],
|
||||
'op_names_re': ['bert\.embeddings.*'],
|
||||
'sparse_ratio': sparse_ratio,
|
||||
'dependency_group_id': 1,
|
||||
'granularity': [-1, 1],
|
||||
}, {
|
||||
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention$',
|
||||
'bert\.encoder\.layer\.[0-9]*\.output$'],
|
||||
'target_names': ['_output_'],
|
||||
'target_settings': {
|
||||
'_output_': {
|
||||
'align': {
|
||||
'module_name': 'bert.embeddings.word_embeddings',
|
||||
'target_name': 'weight',
|
||||
'dims': [1],
|
||||
}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention.output.dense',
|
||||
'bert\.encoder\.layer\.[0-9]*\.output.dense'],
|
||||
'target_names': ['weight'],
|
||||
'target_settings': {
|
||||
'weight': {
|
||||
'granularity': 'out_channel',
|
||||
'align': {
|
||||
'module_name': 'bert.embeddings.word_embeddings',
|
||||
'target_name': 'weight',
|
||||
'dims': [1],
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
teacher_model.eval().to(trainer.args.device)
|
||||
distiller = dynamic_distiller(model, teacher_model, trainer)
|
||||
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
|
||||
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
|
||||
agp_pruner.compress(None, 3)
|
||||
agp_pruner.unwrap_model()
|
||||
distiller.unwrap_teacher_model()
|
||||
|
||||
masks = agp_pruner.get_masks()
|
||||
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
|
||||
torch.save(masks, './output/pruning/embedding_masks.pth')
|
||||
torch.save(model, './output/pruning/embedding_masked_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
pruning_embedding()
|
||||
|
||||
|
||||
# %%
|
||||
# Speedup the embedding layers.
|
||||
|
||||
|
||||
def speedup_embedding():
|
||||
model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')
|
||||
masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
ModelSpeedup(model, dummy_input, masks).speedup_model()
|
||||
|
||||
# finetuning
|
||||
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
adapt_distillation(model, teacher_model, None, 4)
|
||||
torch.save(model, './output/pruning/embedding_pruned_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
speedup_embedding()
|
||||
|
||||
|
||||
# %%
|
||||
# Evaluation
|
||||
# ^^^^^^^^^^
|
||||
#
|
||||
# Evaluate the pruned model size and accuracy.
|
||||
|
||||
|
||||
def evaluate_pruned_model():
|
||||
model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
metric = trainer.evaluate()
|
||||
pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
|
||||
|
||||
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
|
||||
ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
|
||||
print(f'Metric: {metric}\nSparsity: {1 - pruned_num_params / ori_num_params}')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
evaluate_pruned_model()
|
||||
|
||||
|
||||
# %%
|
||||
# Results
|
||||
# -------
|
||||
#
|
||||
# .. list-table:: Prune Bert-base-uncased on MNLI
|
||||
# :header-rows: 1
|
||||
# :widths: auto
|
||||
#
|
||||
# * - Total Sparsity
|
||||
# - Embedding Sparsity
|
||||
# - Encoder Sparsity
|
||||
# - Pooler Sparsity
|
||||
# - Acc. (m/mm avg.)
|
||||
# * - 0.%
|
||||
# - 0.%
|
||||
# - 0.%
|
||||
# - 0.%
|
||||
# - 84.95%
|
||||
# * - 57.76%
|
||||
# - 33.33% (15.89M)
|
||||
# - 64.78% (29.96M)
|
||||
# - 33.33% (0.39M)
|
||||
# - 84.42%
|
||||
# * - 68.31% (34.70M)
|
||||
# - 50.00% (11.92M)
|
||||
# - 73.57% (22.48M)
|
||||
# - 50.00% (0.30M)
|
||||
# - 83.33%
|
||||
# * - 70.95% (31.81M)
|
||||
# - 33.33% (15.89M)
|
||||
# - 81.75% (15.52M)
|
||||
# - 33.33% (0.39M)
|
||||
# - 83.79%
|
||||
# * - 78.20% (23.86M)
|
||||
# - 50.00% (11.92M)
|
||||
# - 86.31% (11.65M)
|
||||
# - 50.00% (0.30M)
|
||||
# - 82.53%
|
||||
# * - 81.65% (20.12M)
|
||||
# - 50.00% (11.92M)
|
||||
# - 90.71% (7.90M)
|
||||
# - 50.00% (0.30M)
|
||||
# - 82.08%
|
||||
# * - 84.32% (17.17M)
|
||||
# - 50.00% (11.92M)
|
||||
# - 94.18% (4.95M)
|
||||
# - 50.00% (0.30M)
|
||||
# - 81.35%
|
|
@ -0,0 +1 @@
|
|||
3e81f00f13fab8cfc204a0baef7d075e
|
|
@ -0,0 +1,884 @@
|
|||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "tutorials/new_pruning_bert_glue.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_tutorials_new_pruning_bert_glue.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_tutorials_new_pruning_bert_glue.py:
|
||||
|
||||
|
||||
Pruning Bert on Task MNLI
|
||||
=========================
|
||||
|
||||
This is a new tutorial on pruning transformer in nni v3.0 (`old tutorial <https://nni.readthedocs.io/en/v2.9/tutorials/pruning_bert_glue.html>`__).
|
||||
The main difference between this tutorial and the previous is that it integrates the feature of fusion compression (pruning + distillation) in nni,
|
||||
uses a new more powerful and stable pruning speedup tool,
|
||||
and additionally prunes the whole model hidden dimensions which greatly reduces the model size (pruning embedding layers).
|
||||
|
||||
At the same time, the huggingface `transformers.Trainer <https://huggingface.co/docs/transformers/main_classes/trainer>`__ is used in this tutorial
|
||||
to reduce the burden of user writing training and evaluation logic.
|
||||
|
||||
Workable Pruning Process
|
||||
------------------------
|
||||
|
||||
The whole pruning process is divided into three steps:
|
||||
|
||||
1. pruning attention layers,
|
||||
2. pruning feed forward layers,
|
||||
3. pruning embedding layers.
|
||||
|
||||
In each step, the pruner is first used for simulated pruning to generate masks corresponding to the module pruning targets (weight, input, output).
|
||||
After that comes the speedup stage, sparsity propagation is used to explore the global redundancy due to the local masks,
|
||||
then modify the original model into a smaller one by replacing the sub module in the model.
|
||||
|
||||
The compression of the model naturally applies the distillation method,
|
||||
so in this tutorial, distillers will also be used to help restore the model accuracy.
|
||||
|
||||
Experiment
|
||||
----------
|
||||
|
||||
Preparations
|
||||
^^^^^^^^^^^^
|
||||
|
||||
The preparations mainly includes preparing the transformers trainer and model.
|
||||
|
||||
This is generally consistent with the preparations required to normally train a Bert model.
|
||||
The only difference is that the ``transformers.Trainer`` is needed to wrap by ``nni.trace`` to trace the init arguments,
|
||||
this is because nni need re-create trainer during training aware pruning and distilling.
|
||||
|
||||
.. note::
|
||||
|
||||
Please set ``skip_exec`` to ``False`` to run this tutorial. Here ``skip_exec`` is ``True`` by default is for generating documents.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 45-64
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import nni
|
||||
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
skip_exec = True
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 65-66
|
||||
|
||||
Set the downstream task name here, you could replace the task with the task in GLUE.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 66-69
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
task_name = 'mnli'
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 70-72
|
||||
|
||||
Here using BertForSequenceClassification as the base model for show case.
|
||||
If you want to prune other kind of transformer model, you could replace the base model here.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 72-80
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
def build_model(pretrained_model_name_or_path: str, task_name: str):
|
||||
is_regression = task_name == 'stsb'
|
||||
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
|
||||
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 81-82
|
||||
|
||||
Prepare the GLUE train & validation datasets, if the task has multi validation datasets, concat the datasets by ``ConcatDataset``.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 82-132
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):
|
||||
task_to_keys = {
|
||||
'cola': ('sentence', None),
|
||||
'mnli': ('premise', 'hypothesis'),
|
||||
'mrpc': ('sentence1', 'sentence2'),
|
||||
'qnli': ('question', 'sentence'),
|
||||
'qqp': ('question1', 'question2'),
|
||||
'rte': ('sentence1', 'sentence2'),
|
||||
'sst2': ('sentence', None),
|
||||
'stsb': ('sentence1', 'sentence2'),
|
||||
'wnli': ('sentence1', 'sentence2'),
|
||||
}
|
||||
sentence1_key, sentence2_key = task_to_keys[task_name]
|
||||
|
||||
# used to preprocess the raw data
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
|
||||
|
||||
if 'label' in examples:
|
||||
# In all cases, rename the column to labels because the model will expect that.
|
||||
result['labels'] = examples['label']
|
||||
return result
|
||||
|
||||
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
|
||||
for key in list(raw_datasets.keys()):
|
||||
if 'test' in key:
|
||||
raw_datasets.pop(key)
|
||||
|
||||
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
|
||||
remove_columns=raw_datasets['train'].column_names)
|
||||
|
||||
train_dataset = processed_datasets['train']
|
||||
if task_name == 'mnli':
|
||||
validation_datasets = {
|
||||
'validation_matched': processed_datasets['validation_matched'],
|
||||
'validation_mismatched': processed_datasets['validation_mismatched']
|
||||
}
|
||||
else:
|
||||
validation_datasets = {
|
||||
'validation': processed_datasets['validation']
|
||||
}
|
||||
|
||||
return train_dataset, validation_datasets
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 133-134
|
||||
|
||||
Prepare the trainer, note that the ``Trainer`` class is wrapped by ``nni.trace``.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 134-177
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):
|
||||
is_regression = task_name == 'stsb'
|
||||
metric = load_metric('glue', task_name)
|
||||
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
|
||||
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
|
||||
result = metric.compute(predictions=preds, references=p.label_ids)
|
||||
result['default'] = result.get('f1', result.get('accuracy', 0.))
|
||||
return result
|
||||
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)
|
||||
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
training_args = TrainingArguments(output_dir='./output/trainer',
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
evaluation_strategy='steps',
|
||||
per_device_train_batch_size=32,
|
||||
per_device_eval_batch_size=32,
|
||||
num_train_epochs=3,
|
||||
dataloader_num_workers=12,
|
||||
learning_rate=3e-5,
|
||||
save_strategy='steps',
|
||||
save_total_limit=1,
|
||||
metric_for_best_model='default',
|
||||
load_best_model_at_end=load_best_model_at_end,
|
||||
disable_tqdm=True,
|
||||
optim='adamw_torch',
|
||||
seed=1024)
|
||||
trainer = nni.trace(Trainer)(model=model,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=merged_validation_dataset,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,)
|
||||
return trainer
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 178-180
|
||||
|
||||
If the finetuned model is existed, directly load it.
|
||||
If the finetuned model is not existed, train the pretrained model with the trainer.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 180-198
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def build_finetuning_model(task_name: str, state_dict_path: str):
|
||||
model = build_model('bert-base-uncased', task_name)
|
||||
if Path(state_dict_path).exists():
|
||||
model.load_state_dict(torch.load(state_dict_path))
|
||||
else:
|
||||
trainer = prepare_traced_trainer(model, task_name, True)
|
||||
trainer.train()
|
||||
torch.save(model.state_dict(), state_dict_path)
|
||||
return model
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
|
||||
build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 199-200
|
||||
|
||||
The following code creates distillers for distillation.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 200-205
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
from nni.contrib.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
|
||||
from nni.contrib.compression.utils import TransformersEvaluator
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 206-210
|
||||
|
||||
Dynamic distillation is suitable for the situation where the distillation states dimension of the student and the teacher match.
|
||||
A student state can try to distill on multiple teacher states, and finally select the teacher state with the smallest distillation loss as the target for distillation.
|
||||
|
||||
In this tutorial, dynamic distillation is applied before speedup the embedding pruning.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 210-250
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
student_trainer: Trainer):
|
||||
layer_num = len(student_model.bert.encoder.layer)
|
||||
config_list = [{
|
||||
'op_names': [f'bert.encoder.layer.{i}'],
|
||||
'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'mse',
|
||||
} for i in range(layer_num)]
|
||||
config_list.append({
|
||||
'op_names': ['classifier'],
|
||||
'link': ['classifier'],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'kl',
|
||||
})
|
||||
|
||||
evaluator = TransformersEvaluator(student_trainer)
|
||||
|
||||
def teacher_predict(batch, teacher_model):
|
||||
return teacher_model(**batch)
|
||||
|
||||
return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
|
||||
|
||||
|
||||
def dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
max_steps: int | None, max_epochs: int | None):
|
||||
student_trainer = prepare_traced_trainer(student_model, task_name, True)
|
||||
|
||||
ori_teacher_device = teacher_model.device
|
||||
training = teacher_model.training
|
||||
teacher_model.to(student_trainer.args.device).eval()
|
||||
|
||||
distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
|
||||
distiller.compress(max_steps, max_epochs)
|
||||
distiller.unwrap_model()
|
||||
|
||||
teacher_model.to(ori_teacher_device).train(training)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 251-257
|
||||
|
||||
Adapt distillation is applied after pruning embedding layers.
|
||||
The hidden states dimension will mismatch between student model and teacher model after pruning embedding layers,
|
||||
then adapt distiller will add a linear layer for each distillation module pair to align dimension.
|
||||
For example, pruning hidden dimension from 768 to 384, then for each student transformer block,
|
||||
will add a ``Linear(in_features=384, out_features=768)`` for shifting dimention 384 to 768,
|
||||
aligned with the teacher model transformer block output.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 257-301
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
student_trainer: Trainer):
|
||||
layer_num = len(student_model.bert.encoder.layer)
|
||||
config_list = [{
|
||||
'op_names': [f'bert.encoder.layer.{i}'],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'mse',
|
||||
} for i in range(layer_num)]
|
||||
config_list.append({
|
||||
'op_names': ['classifier'],
|
||||
'link': ['classifier'],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'kl',
|
||||
})
|
||||
|
||||
evaluator = TransformersEvaluator(student_trainer)
|
||||
|
||||
def teacher_predict(batch, teacher_model):
|
||||
return teacher_model(**batch)
|
||||
|
||||
return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
|
||||
|
||||
|
||||
def adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
max_steps: int | None, max_epochs: int | None):
|
||||
student_trainer = prepare_traced_trainer(student_model, task_name, True)
|
||||
|
||||
ori_teacher_device = teacher_model.device
|
||||
training = teacher_model.training
|
||||
teacher_model.to(student_trainer.args.device).eval()
|
||||
|
||||
distiller = adapt_distiller(student_model, teacher_model, student_trainer)
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]
|
||||
distiller.track_forward(*dummy_input)
|
||||
|
||||
distiller.compress(max_steps, max_epochs)
|
||||
distiller.unwrap_model()
|
||||
|
||||
teacher_model.to(ori_teacher_device).train(training)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 302-313
|
||||
|
||||
Pruning Attention Layers
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Here using ``MovementPruner`` to generate block sparse masks. Choosing ``64 x 64`` block is because the head width is 64,
|
||||
this is a kind of coarse grained between head pruning and finegrained pruning, also you can have a try with ``64 x 32``,
|
||||
``32 x 32`` or any other granularity here.
|
||||
|
||||
We use ``sparse_threshold`` instead of ``sparse_ratio`` here to apply adaptive sparse allocation.
|
||||
``sparse_threshold`` here is a float number between 0. and 1., but its value has little effect on the final sparse ratio.
|
||||
If you want a more sparse model, you could set a larger ``regular_scale`` in ``MovementPruner``.
|
||||
You could refer to the experiment results to choose a appropriate ``regular_scale`` you like.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 313-347
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
from nni.contrib.compression.pruning import MovementPruner
|
||||
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
|
||||
from nni.compression.pytorch.speedup.v2.external_replacer import TransformersAttentionReplacer
|
||||
|
||||
|
||||
def pruning_attn():
|
||||
Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
|
||||
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
evaluator = TransformersEvaluator(trainer)
|
||||
|
||||
config_list = [{
|
||||
'op_types': ['Linear'],
|
||||
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention\.*'],
|
||||
'sparse_threshold': 0.1,
|
||||
'granularity': [64, 64]
|
||||
}]
|
||||
|
||||
pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)
|
||||
pruner.compress(None, 4)
|
||||
pruner.unwrap_model()
|
||||
|
||||
masks = pruner.get_masks()
|
||||
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
|
||||
torch.save(masks, './output/pruning/attn_masks.pth')
|
||||
torch.save(model, './output/pruning/attn_masked_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
pruning_attn()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 348-350
|
||||
|
||||
We apply head pruning during the speedup stage, if the head is fully masked it will be pruned,
|
||||
if the header is partially masked, it will be restored.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 350-369
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def speedup_attn():
|
||||
model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')
|
||||
masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
replacer = TransformersAttentionReplacer(model)
|
||||
ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()
|
||||
|
||||
# finetuning
|
||||
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
dynamic_distillation(model, teacher_model, None, 3)
|
||||
torch.save(model, './output/pruning/attn_pruned_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
speedup_attn()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 370-379
|
||||
|
||||
Pruning Feed Forward Layers
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Here using ``TaylorPruner`` for pruning feed forward layers,
|
||||
and the sparse ratio related to the pruned head number in the same transformer block.
|
||||
The more heads are pruned, the higher the sparse ratio is set for feed forward layers.
|
||||
|
||||
Note that ``TaylorPruner`` has no schedule sparse ratio function,
|
||||
so we use ``AGPPruner`` to schedule the sparse ratio to achieve better pruning performance.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 379-419
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
from nni.contrib.compression.pruning import TaylorPruner, AGPPruner
|
||||
from transformers.models.bert.modeling_bert import BertLayer
|
||||
|
||||
|
||||
def pruning_ffn():
|
||||
model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')
|
||||
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
# create ffn config list, here simply use a linear function related to the number of retained heads to determine the sparse ratio
|
||||
config_list = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, BertLayer):
|
||||
retained_head_num = module.attention.self.num_attention_heads
|
||||
ori_head_num = len(module.attention.pruned_heads) + retained_head_num
|
||||
ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2
|
||||
config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})
|
||||
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
teacher_model.eval().to(trainer.args.device)
|
||||
# create a distiller for restoring the accuracy
|
||||
distiller = dynamic_distiller(model, teacher_model, trainer)
|
||||
# fusion compress: TaylorPruner + DynamicLayerwiseDistiller
|
||||
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
|
||||
# fusion compress: AGPPruner(TaylorPruner) + DynamicLayerwiseDistiller
|
||||
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
|
||||
agp_pruner.compress(None, 3)
|
||||
agp_pruner.unwrap_model()
|
||||
distiller.unwrap_teacher_model()
|
||||
|
||||
masks = agp_pruner.get_masks()
|
||||
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
|
||||
torch.save(masks, './output/pruning/ffn_masks.pth')
|
||||
torch.save(model, './output/pruning/ffn_masked_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
pruning_ffn()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 420-421
|
||||
|
||||
Speedup the feed forward layers.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 421-439
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def speedup_ffn():
|
||||
model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')
|
||||
masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
ModelSpeedup(model, dummy_input, masks).speedup_model()
|
||||
|
||||
# finetuning
|
||||
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
dynamic_distillation(model, teacher_model, None, 3)
|
||||
torch.save(model, './output/pruning/ffn_pruned_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
speedup_ffn()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 440-445
|
||||
|
||||
Pruning Embedding Layers
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We want to simulate the pruning effect better, so we register the output mask setting for ``BertAttention`` and ``BertOutput``.
|
||||
The output masks can be generated and applied after register the setting template for them.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 445-463
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
from nni.contrib.compression.base.setting import PruningSetting
|
||||
|
||||
output_align_setting = {
|
||||
'_output_': {
|
||||
'align': {
|
||||
'module_name': None,
|
||||
'target_name': 'weight',
|
||||
'dims': [0],
|
||||
},
|
||||
'apply_method': 'mul',
|
||||
}
|
||||
}
|
||||
PruningSetting.register('BertAttention', output_align_setting)
|
||||
PruningSetting.register('BertOutput', output_align_setting)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 464-467
|
||||
|
||||
Similar to prune feed forward layers, we also use ``AGPPruner + TaylorPruner + DynamicLayerwiseDistiller`` here.
|
||||
For the better pruning effect simulation, set output ``align`` mask generation in ``config_list``,
|
||||
then the relevant layers will generate its own output masks according to the embedding masks.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 467-528
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def pruning_embedding():
|
||||
model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')
|
||||
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
|
||||
sparse_ratio = 0.5
|
||||
config_list = [{
|
||||
'op_types': ['Embedding'],
|
||||
'op_names_re': ['bert\.embeddings.*'],
|
||||
'sparse_ratio': sparse_ratio,
|
||||
'dependency_group_id': 1,
|
||||
'granularity': [-1, 1],
|
||||
}, {
|
||||
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention$',
|
||||
'bert\.encoder\.layer\.[0-9]*\.output$'],
|
||||
'target_names': ['_output_'],
|
||||
'target_settings': {
|
||||
'_output_': {
|
||||
'align': {
|
||||
'module_name': 'bert.embeddings.word_embeddings',
|
||||
'target_name': 'weight',
|
||||
'dims': [1],
|
||||
}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention.output.dense',
|
||||
'bert\.encoder\.layer\.[0-9]*\.output.dense'],
|
||||
'target_names': ['weight'],
|
||||
'target_settings': {
|
||||
'weight': {
|
||||
'granularity': 'out_channel',
|
||||
'align': {
|
||||
'module_name': 'bert.embeddings.word_embeddings',
|
||||
'target_name': 'weight',
|
||||
'dims': [1],
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
teacher_model.eval().to(trainer.args.device)
|
||||
distiller = dynamic_distiller(model, teacher_model, trainer)
|
||||
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
|
||||
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
|
||||
agp_pruner.compress(None, 3)
|
||||
agp_pruner.unwrap_model()
|
||||
distiller.unwrap_teacher_model()
|
||||
|
||||
masks = agp_pruner.get_masks()
|
||||
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
|
||||
torch.save(masks, './output/pruning/embedding_masks.pth')
|
||||
torch.save(model, './output/pruning/embedding_masked_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
pruning_embedding()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 529-530
|
||||
|
||||
Speedup the embedding layers.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 530-548
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def speedup_embedding():
|
||||
model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')
|
||||
masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
ModelSpeedup(model, dummy_input, masks).speedup_model()
|
||||
|
||||
# finetuning
|
||||
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
adapt_distillation(model, teacher_model, None, 4)
|
||||
torch.save(model, './output/pruning/embedding_pruned_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
speedup_embedding()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 549-553
|
||||
|
||||
Evaluation
|
||||
^^^^^^^^^^
|
||||
|
||||
Evaluate the pruned model size and accuracy.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 553-570
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
def evaluate_pruned_model():
|
||||
model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
metric = trainer.evaluate()
|
||||
pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
|
||||
|
||||
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
|
||||
ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
|
||||
print(f'Metric: {metric}\nSparsity: {1 - pruned_num_params / ori_num_params}')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
evaluate_pruned_model()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 571-618
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
.. list-table:: Prune Bert-base-uncased on MNLI
|
||||
:header-rows: 1
|
||||
:widths: auto
|
||||
|
||||
* - Total Sparsity
|
||||
- Embedding Sparsity
|
||||
- Encoder Sparsity
|
||||
- Pooler Sparsity
|
||||
- Acc. (m/mm avg.)
|
||||
* - 0.%
|
||||
- 0.%
|
||||
- 0.%
|
||||
- 0.%
|
||||
- 84.95%
|
||||
* - 57.76%
|
||||
- 33.33% (15.89M)
|
||||
- 64.78% (29.96M)
|
||||
- 33.33% (0.39M)
|
||||
- 84.42%
|
||||
* - 68.31% (34.70M)
|
||||
- 50.00% (11.92M)
|
||||
- 73.57% (22.48M)
|
||||
- 50.00% (0.30M)
|
||||
- 83.33%
|
||||
* - 70.95% (31.81M)
|
||||
- 33.33% (15.89M)
|
||||
- 81.75% (15.52M)
|
||||
- 33.33% (0.39M)
|
||||
- 83.79%
|
||||
* - 78.20% (23.86M)
|
||||
- 50.00% (11.92M)
|
||||
- 86.31% (11.65M)
|
||||
- 50.00% (0.30M)
|
||||
- 82.53%
|
||||
* - 81.65% (20.12M)
|
||||
- 50.00% (11.92M)
|
||||
- 90.71% (7.90M)
|
||||
- 50.00% (0.30M)
|
||||
- 82.08%
|
||||
* - 84.32% (17.17M)
|
||||
- 50.00% (11.92M)
|
||||
- 94.18% (4.95M)
|
||||
- 50.00% (0.30M)
|
||||
- 81.35%
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 1.990 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_tutorials_new_pruning_bert_glue.py:
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. container:: sphx-glr-footer sphx-glr-footer-example
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: new_pruning_bert_glue.py <new_pruning_bert_glue.py>`
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: new_pruning_bert_glue.ipynb <new_pruning_bert_glue.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
Двоичные данные
docs/source/tutorials/new_pruning_bert_glue_codeobj.pickle
сгенерированный
Normal file
Двоичные данные
docs/source/tutorials/new_pruning_bert_glue_codeobj.pickle
сгенерированный
Normal file
Двоичный файл не отображается.
|
@ -33,11 +33,11 @@ As TensorRT has supported post-training quantization, directly leveraging this f
|
|||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
def prepare_data_loaders(data_path, batch_size, datatype='train'):
|
||||
def prepare_data_loaders(data_path, batch_size):
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
dataset = torchvision.datasets.ImageNet(
|
||||
data_path, split=datatype,
|
||||
data_path, split="train",
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
|
@ -90,20 +90,26 @@ def test_accelerated_model(engine, data_loader, neval_batches):
|
|||
cnt = 0
|
||||
total_time = 0
|
||||
for image, target in data_loader:
|
||||
start_time = time.time()
|
||||
output, time_span = engine.inference(image)
|
||||
print('time: ', time_span)
|
||||
infer_time = time.time() - start_time
|
||||
print('time: ', time_span, infer_time)
|
||||
total_time += time_span
|
||||
|
||||
start_time = time.time()
|
||||
output = output.view(-1, 1000)
|
||||
cnt += 1
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
top1.update(acc1[0], image.size(0))
|
||||
top5.update(acc5[0], image.size(0))
|
||||
rest_time = time.time() - start_time
|
||||
print('rest time: ', rest_time)
|
||||
if cnt >= neval_batches:
|
||||
break
|
||||
print('inference time: ', total_time / neval_batches)
|
||||
return top1, top5
|
||||
|
||||
data_loader = prepare_data_loaders(data_path, batch_size=64, datatype='test')
|
||||
data_loader = prepare_data_loaders(data_path, batch_size=64)
|
||||
top1, top5 = test_accelerated_model(engine, data_loader, neval_batches=32)
|
||||
print('Accuracy of mode #1: ', top1, top5)
|
||||
|
||||
|
@ -167,6 +173,6 @@ model.eval()
|
|||
|
||||
engine = ModelSpeedupTensorRT(model, input_shape=(64, 3, 224, 224), config=calibration_config)
|
||||
engine.compress()
|
||||
data_loader = prepare_data_loaders(data_path, batch_size=64, datatype='test')
|
||||
data_loader = prepare_data_loaders(data_path, batch_size=64)
|
||||
top1, top5 = test_accelerated_model(engine, data_loader, neval_batches=32)
|
||||
print('Accuracy of mode #2: ', top1, top5)
|
||||
|
|
|
@ -6,10 +6,10 @@
|
|||
|
||||
Computation times
|
||||
=================
|
||||
**00:29.846** total execution time for **tutorials** files:
|
||||
**00:01.990** total execution time for **tutorials** files:
|
||||
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:29.846 | 0.0 MB |
|
||||
| :ref:`sphx_glr_tutorials_new_pruning_bert_glue.py` (``new_pruning_bert_glue.py``) | 00:01.990 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
|
|
@ -4,3 +4,4 @@ log/
|
|||
lightning_logs
|
||||
models/
|
||||
pruning_log/
|
||||
output*/
|
|
@ -0,0 +1,617 @@
|
|||
"""
|
||||
Pruning Bert on Task MNLI
|
||||
=========================
|
||||
|
||||
This is a new tutorial on pruning transformer in nni v3.0 (`old tutorial <https://nni.readthedocs.io/en/v2.9/tutorials/pruning_bert_glue.html>`__).
|
||||
The main difference between this tutorial and the previous is that it integrates the feature of fusion compression (pruning + distillation) in nni,
|
||||
uses a new more powerful and stable pruning speedup tool,
|
||||
and additionally prunes the whole model hidden dimensions which greatly reduces the model size (pruning embedding layers).
|
||||
|
||||
At the same time, the huggingface `transformers.Trainer <https://huggingface.co/docs/transformers/main_classes/trainer>`__ is used in this tutorial
|
||||
to reduce the burden of user writing training and evaluation logic.
|
||||
|
||||
Workable Pruning Process
|
||||
------------------------
|
||||
|
||||
The whole pruning process is divided into three steps:
|
||||
|
||||
1. pruning attention layers,
|
||||
2. pruning feed forward layers,
|
||||
3. pruning embedding layers.
|
||||
|
||||
In each step, the pruner is first used for simulated pruning to generate masks corresponding to the module pruning targets (weight, input, output).
|
||||
After that comes the speedup stage, sparsity propagation is used to explore the global redundancy due to the local masks,
|
||||
then modify the original model into a smaller one by replacing the sub module in the model.
|
||||
|
||||
The compression of the model naturally applies the distillation method,
|
||||
so in this tutorial, distillers will also be used to help restore the model accuracy.
|
||||
|
||||
Experiment
|
||||
----------
|
||||
|
||||
Preparations
|
||||
^^^^^^^^^^^^
|
||||
|
||||
The preparations mainly includes preparing the transformers trainer and model.
|
||||
|
||||
This is generally consistent with the preparations required to normally train a Bert model.
|
||||
The only difference is that the ``transformers.Trainer`` is needed to wrap by ``nni.trace`` to trace the init arguments,
|
||||
this is because nni need re-create trainer during training aware pruning and distilling.
|
||||
|
||||
.. note::
|
||||
|
||||
Please set ``skip_exec`` to ``False`` to run this tutorial. Here ``skip_exec`` is ``True`` by default is for generating documents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import nni
|
||||
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
skip_exec = True
|
||||
|
||||
# %%
|
||||
# Set the downstream task name here, you could replace the task with the task in GLUE.
|
||||
|
||||
task_name = 'mnli'
|
||||
|
||||
# %%
|
||||
# Here using BertForSequenceClassification as the base model for show case.
|
||||
# If you want to prune other kind of transformer model, you could replace the base model here.
|
||||
|
||||
def build_model(pretrained_model_name_or_path: str, task_name: str):
|
||||
is_regression = task_name == 'stsb'
|
||||
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
|
||||
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
|
||||
return model
|
||||
|
||||
|
||||
# %%
|
||||
# Prepare the GLUE train & validation datasets, if the task has multi validation datasets, concat the datasets by ``ConcatDataset``.
|
||||
|
||||
def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):
|
||||
task_to_keys = {
|
||||
'cola': ('sentence', None),
|
||||
'mnli': ('premise', 'hypothesis'),
|
||||
'mrpc': ('sentence1', 'sentence2'),
|
||||
'qnli': ('question', 'sentence'),
|
||||
'qqp': ('question1', 'question2'),
|
||||
'rte': ('sentence1', 'sentence2'),
|
||||
'sst2': ('sentence', None),
|
||||
'stsb': ('sentence1', 'sentence2'),
|
||||
'wnli': ('sentence1', 'sentence2'),
|
||||
}
|
||||
sentence1_key, sentence2_key = task_to_keys[task_name]
|
||||
|
||||
# used to preprocess the raw data
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
|
||||
|
||||
if 'label' in examples:
|
||||
# In all cases, rename the column to labels because the model will expect that.
|
||||
result['labels'] = examples['label']
|
||||
return result
|
||||
|
||||
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
|
||||
for key in list(raw_datasets.keys()):
|
||||
if 'test' in key:
|
||||
raw_datasets.pop(key)
|
||||
|
||||
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
|
||||
remove_columns=raw_datasets['train'].column_names)
|
||||
|
||||
train_dataset = processed_datasets['train']
|
||||
if task_name == 'mnli':
|
||||
validation_datasets = {
|
||||
'validation_matched': processed_datasets['validation_matched'],
|
||||
'validation_mismatched': processed_datasets['validation_mismatched']
|
||||
}
|
||||
else:
|
||||
validation_datasets = {
|
||||
'validation': processed_datasets['validation']
|
||||
}
|
||||
|
||||
return train_dataset, validation_datasets
|
||||
|
||||
|
||||
# %%
|
||||
# Prepare the trainer, note that the ``Trainer`` class is wrapped by ``nni.trace``.
|
||||
|
||||
|
||||
def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):
|
||||
is_regression = task_name == 'stsb'
|
||||
metric = load_metric('glue', task_name)
|
||||
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
|
||||
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
|
||||
result = metric.compute(predictions=preds, references=p.label_ids)
|
||||
result['default'] = result.get('f1', result.get('accuracy', 0.))
|
||||
return result
|
||||
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)
|
||||
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
training_args = TrainingArguments(output_dir='./output/trainer',
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
evaluation_strategy='steps',
|
||||
per_device_train_batch_size=32,
|
||||
per_device_eval_batch_size=32,
|
||||
num_train_epochs=3,
|
||||
dataloader_num_workers=12,
|
||||
learning_rate=3e-5,
|
||||
save_strategy='steps',
|
||||
save_total_limit=1,
|
||||
metric_for_best_model='default',
|
||||
load_best_model_at_end=load_best_model_at_end,
|
||||
disable_tqdm=True,
|
||||
optim='adamw_torch',
|
||||
seed=1024)
|
||||
trainer = nni.trace(Trainer)(model=model,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=merged_validation_dataset,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,)
|
||||
return trainer
|
||||
|
||||
|
||||
# %%
|
||||
# If the finetuned model is existed, directly load it.
|
||||
# If the finetuned model is not existed, train the pretrained model with the trainer.
|
||||
|
||||
|
||||
def build_finetuning_model(task_name: str, state_dict_path: str):
|
||||
model = build_model('bert-base-uncased', task_name)
|
||||
if Path(state_dict_path).exists():
|
||||
model.load_state_dict(torch.load(state_dict_path))
|
||||
else:
|
||||
trainer = prepare_traced_trainer(model, task_name, True)
|
||||
trainer.train()
|
||||
torch.save(model.state_dict(), state_dict_path)
|
||||
return model
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
|
||||
build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
|
||||
|
||||
|
||||
# %%
|
||||
# The following code creates distillers for distillation.
|
||||
|
||||
|
||||
from nni.contrib.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
|
||||
from nni.contrib.compression.utils import TransformersEvaluator
|
||||
|
||||
# %%
|
||||
# Dynamic distillation is suitable for the situation where the distillation states dimension of the student and the teacher match.
|
||||
# A student state can try to distill on multiple teacher states, and finally select the teacher state with the smallest distillation loss as the target for distillation.
|
||||
#
|
||||
# In this tutorial, dynamic distillation is applied before speedup the embedding pruning.
|
||||
|
||||
def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
student_trainer: Trainer):
|
||||
layer_num = len(student_model.bert.encoder.layer)
|
||||
config_list = [{
|
||||
'op_names': [f'bert.encoder.layer.{i}'],
|
||||
'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'mse',
|
||||
} for i in range(layer_num)]
|
||||
config_list.append({
|
||||
'op_names': ['classifier'],
|
||||
'link': ['classifier'],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'kl',
|
||||
})
|
||||
|
||||
evaluator = TransformersEvaluator(student_trainer)
|
||||
|
||||
def teacher_predict(batch, teacher_model):
|
||||
return teacher_model(**batch)
|
||||
|
||||
return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
|
||||
|
||||
|
||||
def dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
max_steps: int | None, max_epochs: int | None):
|
||||
student_trainer = prepare_traced_trainer(student_model, task_name, True)
|
||||
|
||||
ori_teacher_device = teacher_model.device
|
||||
training = teacher_model.training
|
||||
teacher_model.to(student_trainer.args.device).eval()
|
||||
|
||||
distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
|
||||
distiller.compress(max_steps, max_epochs)
|
||||
distiller.unwrap_model()
|
||||
|
||||
teacher_model.to(ori_teacher_device).train(training)
|
||||
|
||||
|
||||
# %%
|
||||
# Adapt distillation is applied after pruning embedding layers.
|
||||
# The hidden states dimension will mismatch between student model and teacher model after pruning embedding layers,
|
||||
# then adapt distiller will add a linear layer for each distillation module pair to align dimension.
|
||||
# For example, pruning hidden dimension from 768 to 384, then for each student transformer block,
|
||||
# will add a ``Linear(in_features=384, out_features=768)`` for shifting dimention 384 to 768,
|
||||
# aligned with the teacher model transformer block output.
|
||||
|
||||
|
||||
def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
student_trainer: Trainer):
|
||||
layer_num = len(student_model.bert.encoder.layer)
|
||||
config_list = [{
|
||||
'op_names': [f'bert.encoder.layer.{i}'],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'mse',
|
||||
} for i in range(layer_num)]
|
||||
config_list.append({
|
||||
'op_names': ['classifier'],
|
||||
'link': ['classifier'],
|
||||
'lambda': 0.9,
|
||||
'apply_method': 'kl',
|
||||
})
|
||||
|
||||
evaluator = TransformersEvaluator(student_trainer)
|
||||
|
||||
def teacher_predict(batch, teacher_model):
|
||||
return teacher_model(**batch)
|
||||
|
||||
return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
|
||||
|
||||
|
||||
def adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
|
||||
max_steps: int | None, max_epochs: int | None):
|
||||
student_trainer = prepare_traced_trainer(student_model, task_name, True)
|
||||
|
||||
ori_teacher_device = teacher_model.device
|
||||
training = teacher_model.training
|
||||
teacher_model.to(student_trainer.args.device).eval()
|
||||
|
||||
distiller = adapt_distiller(student_model, teacher_model, student_trainer)
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]
|
||||
distiller.track_forward(*dummy_input)
|
||||
|
||||
distiller.compress(max_steps, max_epochs)
|
||||
distiller.unwrap_model()
|
||||
|
||||
teacher_model.to(ori_teacher_device).train(training)
|
||||
|
||||
|
||||
# %%
|
||||
# Pruning Attention Layers
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
#
|
||||
# Here using ``MovementPruner`` to generate block sparse masks. Choosing ``64 x 64`` block is because the head width is 64,
|
||||
# this is a kind of coarse grained between head pruning and finegrained pruning, also you can have a try with ``64 x 32``,
|
||||
# ``32 x 32`` or any other granularity here.
|
||||
#
|
||||
# We use ``sparse_threshold`` instead of ``sparse_ratio`` here to apply adaptive sparse allocation.
|
||||
# ``sparse_threshold`` here is a float number between 0. and 1., but its value has little effect on the final sparse ratio.
|
||||
# If you want a more sparse model, you could set a larger ``regular_scale`` in ``MovementPruner``.
|
||||
# You could refer to the experiment results to choose a appropriate ``regular_scale`` you like.
|
||||
|
||||
|
||||
from nni.contrib.compression.pruning import MovementPruner
|
||||
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
|
||||
from nni.compression.pytorch.speedup.v2.external_replacer import TransformersAttentionReplacer
|
||||
|
||||
|
||||
def pruning_attn():
|
||||
Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
|
||||
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
evaluator = TransformersEvaluator(trainer)
|
||||
|
||||
config_list = [{
|
||||
'op_types': ['Linear'],
|
||||
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention\.*'],
|
||||
'sparse_threshold': 0.1,
|
||||
'granularity': [64, 64]
|
||||
}]
|
||||
|
||||
pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)
|
||||
pruner.compress(None, 4)
|
||||
pruner.unwrap_model()
|
||||
|
||||
masks = pruner.get_masks()
|
||||
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
|
||||
torch.save(masks, './output/pruning/attn_masks.pth')
|
||||
torch.save(model, './output/pruning/attn_masked_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
pruning_attn()
|
||||
|
||||
|
||||
# %%
|
||||
# We apply head pruning during the speedup stage, if the head is fully masked it will be pruned,
|
||||
# if the header is partially masked, it will be restored.
|
||||
|
||||
|
||||
def speedup_attn():
|
||||
model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')
|
||||
masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
replacer = TransformersAttentionReplacer(model)
|
||||
ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()
|
||||
|
||||
# finetuning
|
||||
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
dynamic_distillation(model, teacher_model, None, 3)
|
||||
torch.save(model, './output/pruning/attn_pruned_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
speedup_attn()
|
||||
|
||||
|
||||
# %%
|
||||
# Pruning Feed Forward Layers
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
#
|
||||
# Here using ``TaylorPruner`` for pruning feed forward layers,
|
||||
# and the sparse ratio related to the pruned head number in the same transformer block.
|
||||
# The more heads are pruned, the higher the sparse ratio is set for feed forward layers.
|
||||
#
|
||||
# Note that ``TaylorPruner`` has no schedule sparse ratio function,
|
||||
# so we use ``AGPPruner`` to schedule the sparse ratio to achieve better pruning performance.
|
||||
|
||||
|
||||
from nni.contrib.compression.pruning import TaylorPruner, AGPPruner
|
||||
from transformers.models.bert.modeling_bert import BertLayer
|
||||
|
||||
|
||||
def pruning_ffn():
|
||||
model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')
|
||||
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
# create ffn config list, here simply use a linear function related to the number of retained heads to determine the sparse ratio
|
||||
config_list = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, BertLayer):
|
||||
retained_head_num = module.attention.self.num_attention_heads
|
||||
ori_head_num = len(module.attention.pruned_heads) + retained_head_num
|
||||
ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2
|
||||
config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})
|
||||
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
teacher_model.eval().to(trainer.args.device)
|
||||
# create a distiller for restoring the accuracy
|
||||
distiller = dynamic_distiller(model, teacher_model, trainer)
|
||||
# fusion compress: TaylorPruner + DynamicLayerwiseDistiller
|
||||
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
|
||||
# fusion compress: AGPPruner(TaylorPruner) + DynamicLayerwiseDistiller
|
||||
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
|
||||
agp_pruner.compress(None, 3)
|
||||
agp_pruner.unwrap_model()
|
||||
distiller.unwrap_teacher_model()
|
||||
|
||||
masks = agp_pruner.get_masks()
|
||||
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
|
||||
torch.save(masks, './output/pruning/ffn_masks.pth')
|
||||
torch.save(model, './output/pruning/ffn_masked_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
pruning_ffn()
|
||||
|
||||
|
||||
# %%
|
||||
# Speedup the feed forward layers.
|
||||
|
||||
|
||||
def speedup_ffn():
|
||||
model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')
|
||||
masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
ModelSpeedup(model, dummy_input, masks).speedup_model()
|
||||
|
||||
# finetuning
|
||||
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
dynamic_distillation(model, teacher_model, None, 3)
|
||||
torch.save(model, './output/pruning/ffn_pruned_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
speedup_ffn()
|
||||
|
||||
|
||||
# %%
|
||||
# Pruning Embedding Layers
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
#
|
||||
# We want to simulate the pruning effect better, so we register the output mask setting for ``BertAttention`` and ``BertOutput``.
|
||||
# The output masks can be generated and applied after register the setting template for them.
|
||||
|
||||
|
||||
from nni.contrib.compression.base.setting import PruningSetting
|
||||
|
||||
output_align_setting = {
|
||||
'_output_': {
|
||||
'align': {
|
||||
'module_name': None,
|
||||
'target_name': 'weight',
|
||||
'dims': [0],
|
||||
},
|
||||
'apply_method': 'mul',
|
||||
}
|
||||
}
|
||||
PruningSetting.register('BertAttention', output_align_setting)
|
||||
PruningSetting.register('BertOutput', output_align_setting)
|
||||
|
||||
|
||||
# %%
|
||||
# Similar to prune feed forward layers, we also use ``AGPPruner + TaylorPruner + DynamicLayerwiseDistiller`` here.
|
||||
# For the better pruning effect simulation, set output ``align`` mask generation in ``config_list``,
|
||||
# then the relevant layers will generate its own output masks according to the embedding masks.
|
||||
|
||||
|
||||
def pruning_embedding():
|
||||
model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')
|
||||
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
|
||||
sparse_ratio = 0.5
|
||||
config_list = [{
|
||||
'op_types': ['Embedding'],
|
||||
'op_names_re': ['bert\.embeddings.*'],
|
||||
'sparse_ratio': sparse_ratio,
|
||||
'dependency_group_id': 1,
|
||||
'granularity': [-1, 1],
|
||||
}, {
|
||||
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention$',
|
||||
'bert\.encoder\.layer\.[0-9]*\.output$'],
|
||||
'target_names': ['_output_'],
|
||||
'target_settings': {
|
||||
'_output_': {
|
||||
'align': {
|
||||
'module_name': 'bert.embeddings.word_embeddings',
|
||||
'target_name': 'weight',
|
||||
'dims': [1],
|
||||
}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention.output.dense',
|
||||
'bert\.encoder\.layer\.[0-9]*\.output.dense'],
|
||||
'target_names': ['weight'],
|
||||
'target_settings': {
|
||||
'weight': {
|
||||
'granularity': 'out_channel',
|
||||
'align': {
|
||||
'module_name': 'bert.embeddings.word_embeddings',
|
||||
'target_name': 'weight',
|
||||
'dims': [1],
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
teacher_model.eval().to(trainer.args.device)
|
||||
distiller = dynamic_distiller(model, teacher_model, trainer)
|
||||
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
|
||||
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
|
||||
agp_pruner.compress(None, 3)
|
||||
agp_pruner.unwrap_model()
|
||||
distiller.unwrap_teacher_model()
|
||||
|
||||
masks = agp_pruner.get_masks()
|
||||
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
|
||||
torch.save(masks, './output/pruning/embedding_masks.pth')
|
||||
torch.save(model, './output/pruning/embedding_masked_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
pruning_embedding()
|
||||
|
||||
|
||||
# %%
|
||||
# Speedup the embedding layers.
|
||||
|
||||
|
||||
def speedup_embedding():
|
||||
model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')
|
||||
masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')
|
||||
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
|
||||
ModelSpeedup(model, dummy_input, masks).speedup_model()
|
||||
|
||||
# finetuning
|
||||
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
|
||||
adapt_distillation(model, teacher_model, None, 4)
|
||||
torch.save(model, './output/pruning/embedding_pruned_model.pth')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
speedup_embedding()
|
||||
|
||||
|
||||
# %%
|
||||
# Evaluation
|
||||
# ^^^^^^^^^^
|
||||
#
|
||||
# Evaluate the pruned model size and accuracy.
|
||||
|
||||
|
||||
def evaluate_pruned_model():
|
||||
model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')
|
||||
trainer = prepare_traced_trainer(model, task_name)
|
||||
metric = trainer.evaluate()
|
||||
pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
|
||||
|
||||
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
|
||||
ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
|
||||
print(f'Metric: {metric}\nSparsity: {1 - pruned_num_params / ori_num_params}')
|
||||
|
||||
|
||||
if not skip_exec:
|
||||
evaluate_pruned_model()
|
||||
|
||||
|
||||
# %%
|
||||
# Results
|
||||
# -------
|
||||
#
|
||||
# .. list-table:: Prune Bert-base-uncased on MNLI
|
||||
# :header-rows: 1
|
||||
# :widths: auto
|
||||
#
|
||||
# * - Total Sparsity
|
||||
# - Embedding Sparsity
|
||||
# - Encoder Sparsity
|
||||
# - Pooler Sparsity
|
||||
# - Acc. (m/mm avg.)
|
||||
# * - 0.%
|
||||
# - 0.%
|
||||
# - 0.%
|
||||
# - 0.%
|
||||
# - 84.95%
|
||||
# * - 57.76%
|
||||
# - 33.33% (15.89M)
|
||||
# - 64.78% (29.96M)
|
||||
# - 33.33% (0.39M)
|
||||
# - 84.42%
|
||||
# * - 68.31% (34.70M)
|
||||
# - 50.00% (11.92M)
|
||||
# - 73.57% (22.48M)
|
||||
# - 50.00% (0.30M)
|
||||
# - 83.33%
|
||||
# * - 70.95% (31.81M)
|
||||
# - 33.33% (15.89M)
|
||||
# - 81.75% (15.52M)
|
||||
# - 33.33% (0.39M)
|
||||
# - 83.79%
|
||||
# * - 78.20% (23.86M)
|
||||
# - 50.00% (11.92M)
|
||||
# - 86.31% (11.65M)
|
||||
# - 50.00% (0.30M)
|
||||
# - 82.53%
|
||||
# * - 81.65% (20.12M)
|
||||
# - 50.00% (11.92M)
|
||||
# - 90.71% (7.90M)
|
||||
# - 50.00% (0.30M)
|
||||
# - 82.08%
|
||||
# * - 84.32% (17.17M)
|
||||
# - 50.00% (11.92M)
|
||||
# - 94.18% (4.95M)
|
||||
# - 50.00% (0.30M)
|
||||
# - 81.35%
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from nni.compression.pytorch.speedup.v2.replacer import Replacer
|
||||
from nni.compression.pytorch.utils.attr import get_nested_attr
|
||||
from nni.compression.pytorch.utils.external.huggingface import parser_factory, HuggingfaceModelParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model_speedup import ModelSpeedup
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _endwith(s: str, suffixes: List[str]):
|
||||
return any(s.endswith(suffix) for suffix in suffixes)
|
||||
|
||||
|
||||
def _prune_head_idxs(mask: torch.Tensor, num_heads: int) -> List[int]:
|
||||
head_mask = (mask.reshape([num_heads, -1]).sum(-1) == 0.)
|
||||
return torch.arange(len(head_mask))[head_mask].long().tolist()
|
||||
|
||||
|
||||
def _remained_idxs(mask: torch.Tensor, num_heads: int) -> List[int]:
|
||||
repeats = mask.shape[0] // num_heads
|
||||
remained = (mask.reshape([num_heads, -1]).sum(-1) != 0.).repeat_interleave(repeats)
|
||||
return torch.arange(len(mask))[remained].long().tolist()
|
||||
|
||||
|
||||
def _fill_one_on_dims(mask: torch.Tensor, dims: int | List[int]) -> torch.Tensor:
|
||||
dims = dims if isinstance(dims, list) else [dims]
|
||||
dims = [d if d >= 0 else d + len(mask.shape) for d in dims]
|
||||
new_mask = torch.ones_like(mask)
|
||||
for i in range(len(mask.shape)):
|
||||
if i in dims:
|
||||
continue
|
||||
dim_mask = (mask.sum([_ for _ in range(len(mask.shape)) if _ != i]) == 0.)
|
||||
new_mask = new_mask.transpose(0, i)
|
||||
new_mask[torch.arange(len(dim_mask), device=new_mask.device)[dim_mask].long().tolist()] = 0.
|
||||
new_mask = new_mask.transpose(0, i)
|
||||
return new_mask
|
||||
|
||||
|
||||
class TransformersAttentionReplacer(Replacer):
|
||||
"""
|
||||
This replacer is used to prune huggingface transformers attention heads,
|
||||
it base on ``HuggingfaceModelParser`` to find the attention module,
|
||||
and prune heads with attention module built-in ``prune_heads`` interface.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model
|
||||
The transformer model, now nni officially support bert, bart, t5, vit.
|
||||
parser
|
||||
The model parser used to find the attention module.
|
||||
If the model passed in is not bert, bart, t5 or vit,
|
||||
please inherit ``nni.compression.pytorch.utils.external.huggingface.HuggingfaceModelParser``
|
||||
to customize a new model parser and pass in.
|
||||
"""
|
||||
def __init__(self, model: torch.nn.Module, parser: HuggingfaceModelParser | None = None):
|
||||
self.parser = parser_factory(model) if parser is None else parser
|
||||
if self.parser is None:
|
||||
err_msg = f'Can not get the model parser of {type(model)}'
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
def replace_modules(self, speedup: 'ModelSpeedup'):
|
||||
# Note: This replace function base on prune_heads interface in Huggingface transformers.
|
||||
attention_name_dict = defaultdict(list)
|
||||
attention_patterns = [self.parser.TRANSFORMER_PREFIX + att_p for att_p in self.parser.ATTENTION]
|
||||
# find layers which has attention layer name prefix
|
||||
target2node = {}
|
||||
for node, node_info in speedup.node_infos.items():
|
||||
if node.op == 'call_module' and self.parser.is_attention(node.target):
|
||||
target2node[node.target] = node
|
||||
for attention_pattern in attention_patterns:
|
||||
attention_layer_name = re.findall(attention_pattern, node.target)[0]
|
||||
attention_name_dict[attention_layer_name].append(node.target)
|
||||
# prune heads
|
||||
for attention_layer_name, qkvo_names in attention_name_dict.items():
|
||||
# qkvo_flatten_head_mask is the sum of qkv output mask and o input mask
|
||||
qkvo_flatten_head_mask: torch.Tensor | None = None
|
||||
for name in qkvo_names:
|
||||
if _endwith(name, self.parser.QKVO):
|
||||
info_msg = f'Find QKVO layer `{name}`, try to prune head.'
|
||||
_logger.info(info_msg)
|
||||
node = target2node[name]
|
||||
node_info = speedup.node_infos[node]
|
||||
if _endwith(name, self.parser.QKV):
|
||||
out_masks = node_info.output_masks
|
||||
flatten_head_mask = \
|
||||
(torch.sum(out_masks, dim=[_ for _ in range(len(out_masks.shape) - 1)]).detach() > 0.).float()
|
||||
else:
|
||||
in_masks = tree_map(lambda n: speedup.node_infos[n].output_masks, node.args)
|
||||
flatten_head_mask = \
|
||||
(torch.sum(in_masks[0], dim=[_ for _ in range(len(in_masks[0].shape) - 1)]).detach() > 0.).float()
|
||||
if qkvo_flatten_head_mask is not None:
|
||||
qkvo_flatten_head_mask *= flatten_head_mask
|
||||
else:
|
||||
qkvo_flatten_head_mask = flatten_head_mask
|
||||
if qkvo_flatten_head_mask is not None:
|
||||
original_num_heads = self.parser.get_num_heads(attention_layer_name, speedup.bound_model)
|
||||
head_idxs = _prune_head_idxs(qkvo_flatten_head_mask, original_num_heads)
|
||||
info_msg = f'Prune {attention_layer_name} head {head_idxs}'
|
||||
_logger.info(info_msg)
|
||||
attention_layer = get_nested_attr(speedup.bound_model, attention_layer_name)
|
||||
attention_layer.prune_heads(head_idxs) # type: ignore
|
||||
# replace autoinfer masks with ones, assume QKVO are all Linear
|
||||
remained_idxs = _remained_idxs(qkvo_flatten_head_mask, original_num_heads)
|
||||
for name in qkvo_names:
|
||||
if _endwith(name, self.parser.QKVO):
|
||||
node = target2node[name]
|
||||
node_info = speedup.node_infos[node]
|
||||
if _endwith(name, self.parser.QKV):
|
||||
mask = node_info.param_masks['weight'][remained_idxs]
|
||||
node_info.param_masks['weight'] = _fill_one_on_dims(mask, 0)
|
||||
mask = node_info.output_masks.transpose(0, -1)[remained_idxs].transpose(0, -1)
|
||||
node_info.output_masks = _fill_one_on_dims(mask, -1)
|
||||
else:
|
||||
mask = node_info.param_masks['weight'][:, remained_idxs]
|
||||
node_info.param_masks['weight'] = _fill_one_on_dims(mask, 1)
|
||||
masks = tree_map(lambda n: speedup.node_infos[n].output_masks, node.args)
|
||||
mask = masks[0].transpose(0, -1)[remained_idxs].transpose(0, -1)
|
||||
for n in node.args:
|
||||
speedup.node_infos[n].output_masks = _fill_one_on_dims(mask, -1)
|
|
@ -14,6 +14,7 @@ try:
|
|||
PreTrainedModel,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
DistilBertConfig,
|
||||
T5Config,
|
||||
ViTConfig
|
||||
)
|
||||
|
@ -106,6 +107,15 @@ class HuggingfaceBertParser(HuggingfaceModelParser):
|
|||
ATTENTION = ('attention',)
|
||||
|
||||
|
||||
class HuggingfaceDistilBertParser(HuggingfaceModelParser):
|
||||
TRANSFORMER_PREFIX = r'distilbert\.transformer\.layer\.[0-9]+\.'
|
||||
QKV = ('attention.q_lin', 'attention.k_lin', 'attention.v_lin')
|
||||
QKVO = QKV + ('attention.out_lin',)
|
||||
FFN1 = ('ffn.lin1',)
|
||||
FFN2 = ('ffn.lin2',)
|
||||
ATTENTION = ('attention',)
|
||||
|
||||
|
||||
class HuggingfaceBartParser(HuggingfaceModelParser):
|
||||
TRANSFORMER_PREFIX = r'(en|de)coder\.layer\.[0-9]+\.'
|
||||
QKV = ('self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'encoder_attn.q_proj', 'encoder_attn.k_proj', 'encoder_attn.v_proj')
|
||||
|
@ -139,12 +149,14 @@ def parser_factory(model: Module) -> HuggingfaceModelParser | None:
|
|||
cls2parser = {
|
||||
BartConfig: HuggingfaceBartParser,
|
||||
BertConfig: HuggingfaceBertParser,
|
||||
DistilBertConfig: HuggingfaceDistilBertParser,
|
||||
T5Config: HuggingfaceT5Parser,
|
||||
ViTConfig: HuggingfaceViTParser
|
||||
}
|
||||
type2parser = {
|
||||
'bart': HuggingfaceBartParser,
|
||||
'bert': HuggingfaceBertParser,
|
||||
'distilbert': HuggingfaceDistilBertParser,
|
||||
't5': HuggingfaceT5Parser,
|
||||
'vit': HuggingfaceViTParser
|
||||
}
|
||||
|
|
|
@ -478,8 +478,8 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
|
|||
return module_wrappers, configured_target_spaces
|
||||
|
||||
|
||||
def create_module_wrapper(model:nn.Module, module: nn.Module, module_name: str, mode: Literal['pruning', 'quantization', 'distillation'], \
|
||||
config: Dict[str, Any], wrapper: ModuleWrapper | None = None, fused_modules_pair: List[str] | None = None):
|
||||
def create_module_wrapper(model: nn.Module, module: nn.Module, module_name: str, mode: Literal['pruning', 'quantization', 'distillation'],
|
||||
config: Dict[str, Any], wrapper: ModuleWrapper | None = None, fused_modules_pair: List[str] | None = None):
|
||||
fused_modules_pair = fused_modules_pair if fused_modules_pair is not None else []
|
||||
if mode != 'quantization' and len(fused_modules_pair) > 0:
|
||||
raise ValueError(f"Only quantization supports model fusion, but got {mode} and {fused_modules_pair}")
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import Any, Callable, Dict, List, overload
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Adam
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from ..base.compressor import Compressor, Distiller, _DISTILLATION_TARGET_SPACES
|
||||
|
@ -214,7 +213,7 @@ class DynamicLayerwiseDistiller(TeacherModelBasedDistiller):
|
|||
loss_list.append(target_space.lambda_ * \
|
||||
F.kl_div((stu_hs / 2).log_softmax(dim=-1), (tea_hs / 2).softmax(dim=-1), reduction='batchmean') * (2 ** 2))
|
||||
if loss_list:
|
||||
distill_loss += min(loss_list)
|
||||
distill_loss = distill_loss + min(loss_list)
|
||||
for _, ts in self._target_spaces.items():
|
||||
for _, target_space in ts.items():
|
||||
target_space.clean()
|
||||
|
@ -296,23 +295,16 @@ class Adaptive1dLayerwiseDistiller(TeacherModelBasedDistiller):
|
|||
self.trans_linears[module_name][target_name] = torch.nn.Linear(stu_hs.shape[-1], tea_hs.shape[-1]).to(stu_hs.device)
|
||||
|
||||
def _register_linears_optimization(self, evaluator: Evaluator):
|
||||
linear_params = []
|
||||
for _, linears in self.trans_linears.items():
|
||||
linear_params = {}
|
||||
for module_name, linears in self.trans_linears.items():
|
||||
for _, linear in linears.items():
|
||||
if linear is not None:
|
||||
linear_params.extend(linear.parameters())
|
||||
linear_params[module_name] = list(linear.parameters())
|
||||
|
||||
if not linear_params:
|
||||
return
|
||||
|
||||
params = [{"params": linear_params}]
|
||||
optimizer = Adam(params, 1e-2)
|
||||
|
||||
def optimizer_task():
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
evaluator.patch_optimizer_step(before_step_tasks=[optimizer_task], after_step_tasks=[])
|
||||
evaluator.patch_optim_param_group(linear_params)
|
||||
|
||||
def compute_distill_loss(self):
|
||||
distill_loss = 0
|
||||
|
|
|
@ -187,7 +187,7 @@ class Evaluator:
|
|||
raise NotImplementedError
|
||||
|
||||
def _optimizer_add_param_group(self, model: Union[torch.nn.Module, pl.LightningModule],
|
||||
module_name_param_dict: Dict[str, List[Tensor]], optimizers: Optimizer | List[Optimizer]):
|
||||
module_name_param_dict: Dict[str, List[Tensor]], optimizers: Optimizer | List[Optimizer]):
|
||||
# used in the bind_model process
|
||||
def find_param_group(param_groups: List[Dict], module_name: str):
|
||||
for i, param_group in enumerate(param_groups):
|
||||
|
@ -217,13 +217,13 @@ class Evaluator:
|
|||
# copyed from torch.optim to check the validation of param
|
||||
if not isinstance(param, torch.Tensor):
|
||||
raise TypeError("optimizer can only optimize Tensors, "
|
||||
"but one of the params is " + torch.typename(param))
|
||||
"but one of the params is " + torch.typename(param))
|
||||
if not optimizer.defaults.get('differentiable', None) \
|
||||
and not (param.is_leaf or param.retains_grad): # type: ignore
|
||||
raise ValueError("can't optimize a non-leaf Tensor")
|
||||
target_param_group['params'].append(param)
|
||||
|
||||
assert isinstance(model, (Module, pl.LightningModule))
|
||||
assert isinstance(model, Module)
|
||||
param2name_dict = {id(p): name for name, p in model.named_parameters()}
|
||||
assert optimizers is not None, "Please provide optimizers for adding param_groups in optimizers"
|
||||
optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
|
||||
|
@ -1024,7 +1024,6 @@ class TransformersEvaluator(Evaluator):
|
|||
def patch_optim_param_group(self, module_name_param_dict: Dict[str, List[Tensor]]):
|
||||
assert isinstance(self.model, Module)
|
||||
assert module_name_param_dict is not None
|
||||
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, self.trainer.optimizer)
|
||||
|
||||
def unbind_model(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче