зеркало из https://github.com/microsoft/nni.git
[Compression] Transformer pruning example (#5017)
This commit is contained in:
Родитель
3eca23d519
Коммит
b2c31ca27b
|
@ -0,0 +1,8 @@
|
|||
Best Practices
|
||||
==============
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 2
|
||||
|
||||
Pruning Transformer </tutorials/pruning_bert_glue>
|
|
@ -9,3 +9,4 @@ Pruning
|
|||
Quickstart </tutorials/pruning_quick_start_mnist>
|
||||
Pruner <pruner>
|
||||
Speedup </tutorials/pruning_speedup>
|
||||
Best Practices <best_practices>
|
||||
|
|
|
@ -74,3 +74,11 @@ More examples can be found in our :githublink:`GitHub repository <examples>`.
|
|||
:image: ../img/thumbnails/quantization-speed-up.svg
|
||||
:background: indigo
|
||||
:tags: Compression
|
||||
|
||||
.. cardlinkitem::
|
||||
:header: Pruning Bert on Task MNLI
|
||||
:description: An end to end example for how to using NNI pruning transformer and show the real speedup number
|
||||
:link: tutorials/pruning_bert_glue
|
||||
:image: ../img/thumbnails/pruning-tutorial.svg
|
||||
:background: indigo
|
||||
:tags: Compression
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
|
||||
|
||||
.. _sphx_glr_tutorials_hpo_quickstart_pytorch:
|
||||
|
||||
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnails">
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="The tutorial consists of 4 steps: ">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with PyTorch
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">HPO Quickstart with PyTorch</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port PyTorch Quickstart to NNI
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Port PyTorch Quickstart to NNI</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hpo_quickstart_pytorch/main
|
||||
/tutorials/hpo_quickstart_pytorch/model
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
|
||||
|
||||
.. _sphx_glr_tutorials_hpo_quickstart_tensorflow:
|
||||
|
||||
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnails">
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="The tutorial consists of 4 steps: ">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with TensorFlow
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">HPO Quickstart with TensorFlow</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port TensorFlow Quickstart to NNI
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Port TensorFlow Quickstart to NNI</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hpo_quickstart_tensorflow/main
|
||||
/tutorials/hpo_quickstart_tensorflow/model
|
||||
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 35 KiB |
|
@ -1,24 +1,167 @@
|
|||
:orphan:
|
||||
|
||||
|
||||
|
||||
.. _sphx_glr_tutorials:
|
||||
|
||||
Tutorials
|
||||
=========
|
||||
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnails">
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Introduction ------------">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
|
||||
:alt: Speedup Model with Mask
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
|
||||
:alt: Speedup Model with Mask
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_speedup.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Speedup Model with Mask</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip=" Introduction ------------">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
|
||||
:alt: SpeedUp Model with Calibration Config
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_speedup.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">SpeedUp Model with Calibration Config</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Here is a four-minute video to get you started with model quantization.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
|
||||
:alt: Quantization Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Quantization Quickstart</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Here is a three-minute video to get you started with model pruning.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
|
||||
:alt: Pruning Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Pruning Quickstart</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="To write a new quantization algorithm, you can write a class that inherits nni.compression.pyto...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
|
||||
:alt: Customize a new quantization algorithm
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_customize.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Customize a new quantization algorithm</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we show how to use NAS Benchmarks as datasets. For research purposes we somet...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
|
||||
:alt: Use NAS Benchmarks as Datasets
|
||||
|
||||
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Use NAS Benchmarks as Datasets</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Users can easily customize a basic pruner in NNI. A large number of basic modules have been pro...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_customize_thumb.png
|
||||
:alt: Customize Basic Pruner
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_customize.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Customize Basic Pruner</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="This is the 101 tutorial of Neural Architecture Search (NAS) on NNI. In this tutorial, we will ...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
|
||||
:alt: Hello, NAS!
|
||||
|
||||
:ref:`sphx_glr_tutorials_hello_nas.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Hello, NAS!</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
|
||||
:alt: Pruning Transformer with NNI
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Pruning Transformer with NNI</div>
|
||||
</div>
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_speedup.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
|
@ -29,162 +172,21 @@ Tutorials
|
|||
:hidden:
|
||||
|
||||
/tutorials/pruning_speedup
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip=" Introduction ------------">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
|
||||
:alt: SpeedUp Model with Calibration Config
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_speedup.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/quantization_speedup
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Here is a four-minute video to get you started with model quantization.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
|
||||
:alt: Quantization Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/quantization_quick_start_mnist
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Here is a three-minute video to get you started with model pruning.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
|
||||
:alt: Pruning Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/pruning_quick_start_mnist
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="To write a new quantization algorithm, you can write a class that inherits nni.compression.pyto...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
|
||||
:alt: Customize a new quantization algorithm
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_customize.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/quantization_customize
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we show how to use NAS Benchmarks as datasets. For research purposes we somet...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
|
||||
:alt: Use NAS Benchmarks as Datasets
|
||||
|
||||
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/nasbench_as_dataset
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Users can easily customize a basic pruner in NNI. A large number of basic modules have been pro...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_customize_thumb.png
|
||||
:alt: Customize Basic Pruner
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_customize.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/pruning_customize
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="This is the 101 tutorial of Neural Architecture Search (NAS) on NNI. In this tutorial, we will ...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
|
||||
:alt: Hello, NAS!
|
||||
|
||||
:ref:`sphx_glr_tutorials_hello_nas.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hello_nas
|
||||
/tutorials/pruning_bert_glue
|
||||
|
||||
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
|
||||
|
||||
|
||||
.. _sphx_glr_tutorials_hpo_quickstart_pytorch:
|
||||
|
||||
|
||||
<div class="sphx-glr-thumbnails">
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
@ -193,50 +195,44 @@ Tutorials
|
|||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with PyTorch
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with PyTorch
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">HPO Quickstart with PyTorch</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hpo_quickstart_pytorch/main
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port PyTorch Quickstart to NNI
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port PyTorch Quickstart to NNI
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Port PyTorch Quickstart to NNI</div>
|
||||
</div>
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hpo_quickstart_pytorch/model
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
|
||||
|
||||
|
||||
.. _sphx_glr_tutorials_hpo_quickstart_tensorflow:
|
||||
|
||||
|
||||
<div class="sphx-glr-thumbnails">
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
@ -245,31 +241,33 @@ Tutorials
|
|||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with TensorFlow
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with TensorFlow
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">HPO Quickstart with TensorFlow</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hpo_quickstart_tensorflow/main
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port TensorFlow Quickstart to NNI
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port TensorFlow Quickstart to NNI
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Port TensorFlow Quickstart to NNI</div>
|
||||
</div>
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
|
@ -278,11 +276,10 @@ Tutorials
|
|||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:includehidden:
|
||||
|
||||
/tutorials/hpo_quickstart_tensorflow/model
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
/tutorials/hpo_quickstart_pytorch/index.rst
|
||||
/tutorials/hpo_quickstart_tensorflow/index.rst
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Pruning Transformer with NNI\n\n## Workable Pruning Process\n\nHere we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.\n\nThe entire pruning process can be divided into the following steps:\n\n1. Finetune the pre-trained model on the downstream task. From our experience,\n the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.\n At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following\n distillation training.\n2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,\n and directly prune the head (condense the weight) if the head was fully masked.\n If the head was partially masked, we will not prune it and recover its weight.\n3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.\n4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,\n and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.\n5. Retrain the final pruned model with distillation.\n\nDuring the process of pruning transformer, we gained some of the following experiences:\n\n* We using `movement-pruner` in step 2 and `taylor-fo-weight-pruner` in step 4. `movement-pruner` has good performance on attention layers,\n and `taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,\n we also try weight-based pruning algorithms like `l1-norm-pruner`, but it doesn't seem to work well in this scenario.\n* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.\n* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.\n\n## Experiment\n\n### Preparation\nPlease set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.\n\nThe complete pruning process takes about 8 hours on one A100.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dev_mode = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Some basic setting.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\nfrom typing import Callable\n\npretrained_model_name_or_path = 'bert-base-uncased'\ntask_name = 'mnli'\nexperiment_id = 'pruning_bert'\n\n# heads_num and layers_num should align with pretrained_model_name_or_path\nheads_num = 12\nlayers_num = 12\n\n# used to save the experiment log\nlog_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')\nlog_dir.mkdir(parents=True, exist_ok=True)\n\n# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name\nmodel_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')\nmodel_dir.mkdir(parents=True, exist_ok=True)\n\nfrom transformers import set_seed\nset_seed(1024)\n\nimport torch\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The function used to create dataloaders, note that 'mnli' has two evaluation dataset.\nIf teacher_model is set, will run all dataset on teacher model to get the 'teacher_logits' for distillation.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader\n\nfrom datasets import load_dataset\nfrom transformers import BertTokenizerFast, DataCollatorWithPadding\n\ntask_to_keys = {\n 'cola': ('sentence', None),\n 'mnli': ('premise', 'hypothesis'),\n 'mrpc': ('sentence1', 'sentence2'),\n 'qnli': ('question', 'sentence'),\n 'qqp': ('question1', 'question2'),\n 'rte': ('sentence1', 'sentence2'),\n 'sst2': ('sentence', None),\n 'stsb': ('sentence1', 'sentence2'),\n 'wnli': ('sentence1', 'sentence2'),\n}\n\ndef prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32,\n teacher_model: torch.nn.Module = None):\n tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)\n sentence1_key, sentence2_key = task_to_keys[task_name]\n data_collator = DataCollatorWithPadding(tokenizer)\n\n # used to preprocess the raw data\n def preprocess_function(examples):\n # Tokenize the texts\n args = (\n (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n )\n result = tokenizer(*args, padding=False, max_length=128, truncation=True)\n\n if 'label' in examples:\n # In all cases, rename the column to labels because the model will expect that.\n result['labels'] = examples['label']\n return result\n\n raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)\n for key in list(raw_datasets.keys()):\n if 'test' in key:\n raw_datasets.pop(key)\n\n processed_datasets = raw_datasets.map(preprocess_function, batched=True,\n remove_columns=raw_datasets['train'].column_names)\n\n # if has teacher model, add 'teacher_logits' to datasets who has 'labels'.\n # 'teacher_logits' is used for distillation and avoid the double counting.\n if teacher_model:\n teacher_model_training = teacher_model.training\n teacher_model.eval()\n model_device = next(teacher_model.parameters()).device\n\n def add_teacher_logits(examples):\n result = {k: v for k, v in examples.items()}\n samples = data_collator(result).to(model_device)\n if 'labels' in samples:\n with torch.no_grad():\n logits = teacher_model(**samples).logits.tolist()\n result['teacher_logits'] = logits\n return result\n\n processed_datasets = processed_datasets.map(add_teacher_logits, batched=True,\n batch_size=train_batch_size)\n teacher_model.train(teacher_model_training)\n\n train_dataset = processed_datasets['train']\n validation_dataset = processed_datasets['validation_matched' if task_name == 'mnli' else 'validation']\n validation_dataset2 = processed_datasets['validation_mismatched'] if task_name == 'mnli' else None\n\n train_dataloader = DataLoader(train_dataset,\n shuffle=True,\n collate_fn=data_collator,\n batch_size=train_batch_size)\n validation_dataloader = DataLoader(validation_dataset,\n collate_fn=data_collator,\n batch_size=eval_batch_size)\n validation_dataloader2 = DataLoader(validation_dataset2,\n collate_fn=data_collator,\n batch_size=eval_batch_size) if task_name == 'mnli' else None\n\n return train_dataloader, validation_dataloader, validation_dataloader2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Training function & evaluation function.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\nimport torch.nn.functional as F\nfrom datasets import load_metric\n\ndef training(train_dataloader: DataLoader,\n model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,\n max_steps: int = None, max_epochs: int = None,\n save_best_model: bool = False, save_path: str = None,\n log_path: str = Path(log_dir) / 'training.log',\n distillation: bool = False,\n evaluation_func=None):\n model.train()\n current_step = 0\n best_result = 0\n\n for current_epoch in range(max_epochs if max_epochs else 1):\n for batch in train_dataloader:\n batch.to(device)\n teacher_logits = batch.pop('teacher_logits', None)\n optimizer.zero_grad()\n outputs = model(**batch)\n loss = outputs.loss\n\n if distillation:\n assert teacher_logits is not None\n distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),\n F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n loss = criterion(loss, None)\n loss.backward()\n optimizer.step()\n\n if lr_scheduler:\n lr_scheduler.step()\n\n current_step += 1\n\n # evaluation for every 1000 steps\n if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(model) if evaluation_func else None\n with (log_path).open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)\n f.write(msg)\n # if it's the best model, save it.\n if save_best_model and best_result < result['default']:\n assert save_path is not None\n torch.save(model.state_dict(), save_path)\n best_result = result['default']\n\n if max_steps and current_step >= max_steps:\n return\n\ndef evaluation(validation_dataloader: DataLoader,\n validation_dataloader2: DataLoader,\n model: torch.nn.Module):\n training = model.training\n model.eval()\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n for batch in validation_dataloader:\n batch.pop('teacher_logits', None)\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result = metric.compute()\n\n if validation_dataloader2:\n for batch in validation_dataloader2:\n batch.pop('teacher_logits', None)\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result = {'matched': result, 'mismatched': metric.compute()}\n result['default'] = (result['matched']['accuracy'] + result['mismatched']['accuracy']) / 2\n else:\n result['default'] = result.get('f1', result.get('accuracy', None))\n\n model.train(training)\n return result\n\n# using huggingface native loss\ndef fake_criterion(outputs, targets):\n return outputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Prepare pre-trained model and finetuning on downstream task.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import functools\n\nfrom torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom transformers import BertForSequenceClassification\n\ndef create_pretrained_model():\n is_regression = task_name == 'stsb'\n num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)\n return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)\n\ndef create_finetuned_model():\n pretrained_model = create_pretrained_model().to(device)\n\n train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()\n evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)\n steps_per_epoch = len(train_dataloader)\n training_epochs = 3\n\n finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'\n\n if finetuned_model_state_path.exists():\n pretrained_model.load_state_dict(torch.load(finetuned_model_state_path))\n elif dev_mode:\n pass\n else:\n optimizer = Adam(pretrained_model.parameters(), lr=3e-5, eps=1e-8)\n\n def lr_lambda(current_step: int):\n return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))\n\n lr_scheduler = LambdaLR(optimizer, lr_lambda)\n training(train_dataloader, pretrained_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=training_epochs,\n save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func)\n return pretrained_model\n\nfinetuned_model = create_finetuned_model()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Using finetuned model as teacher model to create dataloader.\nAdd 'teacher_logits' to dataset, it is used to do the distillation, it can be seen as a kind of data label.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not dev_mode:\n train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data(teacher_model=finetuned_model)\nelse:\n train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()\n\nevaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Pruning\nFirst, using MovementPruner to prune attention head.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"steps_per_epoch = len(train_dataloader)\n\n# Set training steps/epochs for pruning.\n\nif not dev_mode:\n total_epochs = 4\n total_steps = total_epochs * steps_per_epoch\n warmup_steps = 1 * steps_per_epoch\n cooldown_steps = 1 * steps_per_epoch\nelse:\n total_epochs = 1\n total_steps = 3\n warmup_steps = 1\n cooldown_steps = 1\n\n# Initialize evaluator used by MovementPruner.\n\nimport nni\nfrom nni.algorithms.compression.v2.pytorch import TorchEvaluator\n\nmovement_training = functools.partial(training, train_dataloader, log_path=log_dir / 'movement_pruning.log',\n evaluation_func=evaluation_func)\ntraced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n if current_step < warmup_steps:\n return float(current_step) / warmup_steps\n return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))\n\ntraced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)\nevaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)\n\n# Apply block-soft-movement pruning on attention layers.\n\nfrom nni.compression.pytorch.pruning import MovementPruner\n\nconfig_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder.layer.{}.'.format(i) for i in range(layers_num)], 'sparsity': 0.1}]\npruner = MovementPruner(model=finetuned_model,\n config_list=config_list,\n evaluator=evaluator,\n training_epochs=total_epochs,\n training_steps=total_steps,\n warm_up_step=warmup_steps,\n cool_down_beginning_step=total_steps - cooldown_steps,\n regular_scale=10,\n movement_mode='soft',\n sparse_granularity='auto')\n_, attention_masks = pruner.compress()\npruner.show_pruned_weights()\n\ntorch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Load a new finetuned model to do the speedup.\nNote that nni speedup don't support replace attention module, so here we manully replace the attention module.\n\nIf the head is entire masked, physically prune it and create config_list for FFN pruning.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"attention_pruned_model = create_finetuned_model().to(device)\nattention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')\n\nffn_config_list = []\nlayer_count = 0\nmodule_list = []\nfor i in range(0, layers_num):\n prefix = f'bert.encoder.layer.{i}.'\n value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']\n head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)\n head_idx = torch.arange(len(head_mask))[head_mask].long().tolist()\n print(f'layer {i} pruner {len(head_idx)} head: {head_idx}')\n if len(head_idx) != heads_num:\n attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idx)\n module_list.append(attention_pruned_model.bert.encoder.layer[i])\n # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.\n # This is just an empirical configuration, you can use any other method to determine this sparsity.\n sparsity = 1 - (1 - len(head_idx) / heads_num) * 0.5\n # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.\n sparsity_per_iter = 1 - (1 - sparsity) ** (1 / heads_num)\n ffn_config_list.append({'op_names': [f'bert.encoder.layer.{layer_count}.intermediate.dense'], 'sparsity': sparsity_per_iter})\n layer_count += 1\n\nattention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Retrain the attention pruned model with distillation.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not dev_mode:\n total_epochs = 5\n total_steps = None\n distillation = True\nelse:\n total_epochs = 1\n total_steps = 1\n distillation = False\n\noptimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))\n\nlr_scheduler = LambdaLR(optimizer, lr_lambda)\nat_model_save_path = log_dir / 'attention_pruned_model_state.pth'\ntraining(train_dataloader, attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,\n max_epochs=total_epochs, max_steps=total_steps, save_best_model=True, save_path=at_model_save_path,\n distillation=distillation, evaluation_func=evaluation_func)\n\nif not dev_mode:\n attention_pruned_model.load_state_dict(torch.load(at_model_save_path))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.\nFinetuning 2000 steps after each iteration, then finetuning 2 epochs after pruning finished.\n\nNNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not dev_mode:\n total_epochs = 4\n total_steps = None\n taylor_pruner_steps = 1000\n steps_per_iteration = 2000\n total_pruning_steps = 24000\n distillation = True\nelse:\n total_epochs = 1\n total_steps = 6\n taylor_pruner_steps = 2\n steps_per_iteration = 2\n total_pruning_steps = 4\n distillation = False\n\nfrom nni.compression.pytorch.pruning import TaylorFOWeightPruner\nfrom nni.compression.pytorch.speedup import ModelSpeedup\n\ndistil_training = functools.partial(training, train_dataloader, log_path=log_dir / 'taylor_pruning.log',\n distillation=distillation, evaluation_func=evaluation_func)\ntraced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\nevaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)\n\ncurrent_step = 0\nbest_result = 0\ninit_lr = 3e-5\n\ndummy_input = torch.rand(8, 128, 768).to(device)\n\nattention_pruned_model.train()\nfor current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if total_steps and current_step >= total_steps:\n break\n # pruning 12 times\n if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:\n check_point = attention_pruned_model.state_dict()\n pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)\n _, ffn_masks = pruner.compress()\n renamed_ffn_masks = {}\n # rename the masks keys, because we only speedup the bert.encoder\n for model_name, targets_mask in ffn_masks.items():\n renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask\n pruner._unwrap_model()\n attention_pruned_model.load_state_dict(check_point)\n ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()\n optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)\n\n batch.to(device)\n teacher_logits = batch.pop('teacher_logits', None)\n optimizer.zero_grad()\n\n # manually schedule lr\n for params_group in optimizer.param_groups:\n params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr\n\n outputs = attention_pruned_model(**batch)\n loss = outputs.loss\n\n # distillation\n if teacher_logits is not None:\n distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),\n F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n loss = 0.1 * loss + 0.9 * distil_loss\n loss.backward()\n optimizer.step()\n\n current_step += 1\n if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(attention_pruned_model)\n with (log_dir / 'ffn_pruning.log').open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())),\n current_epoch, current_step, result)\n f.write(msg)\n if current_step >= total_pruning_steps and best_result < result['default']:\n torch.save(attention_pruned_model, log_dir / 'best_model.pth')\n best_result = result['default']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Result\nThe speedup is test on the entire validation dataset with batch size 32 on A100.\nWe test under two pytorch version and found the latency varying widely.\n\nSetting 1: pytorch 1.12.1\n\nSetting 2: pytorch 1.10.0\n\n.. list-table:: Prune Bert-base-uncased on MNLI\n :header-rows: 1\n :widths: auto\n\n * - Attention Pruning Method\n - FFN Pruning Method\n - Total Sparsity\n - Accuracy\n - Acc. Drop\n - Speedup (S1)\n - Speedup (S2)\n * -\n -\n - 0%\n - 84.73 / 84.63\n - +0.0 / +0.0\n - 12.56s (x1.00)\n - 4.05s (x1.00)\n * - `movement-pruner` (soft, th=0.1, lambda=5)\n - `taylor-fo-weight-pruner`\n - 51.39%\n - 84.25 / 84.96\n - -0.48 / +0.33\n - 6.85s (x1.83)\n - 2.7s (x1.50)\n * - `movement-pruner` (soft, th=0.1, lambda=10)\n - `taylor-fo-weight-pruner`\n - 66.67%\n - 83.98 / 83.75\n - -0.75 / -0.88\n - 4.73s (x2.66)\n - 2.16s (x1.86)\n * - `movement-pruner` (soft, th=0.1, lambda=20)\n - `taylor-fo-weight-pruner`\n - 77.78%\n - 83.02 / 83.06\n - -1.71 / -1.57\n - 3.35s (x3.75)\n - 1.72s (x2.35)\n * - `movement-pruner` (soft, th=0.1, lambda=30)\n - `taylor-fo-weight-pruner`\n - 87.04%\n - 81.24 / 80.99\n - -3.49 / -3.64\n - 2.19s (x5.74)\n - 1.31s (x3.09)\n\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
|
@ -0,0 +1,563 @@
|
|||
"""
|
||||
Pruning Transformer with NNI
|
||||
============================
|
||||
|
||||
Workable Pruning Process
|
||||
------------------------
|
||||
|
||||
Here we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.
|
||||
|
||||
The entire pruning process can be divided into the following steps:
|
||||
|
||||
1. Finetune the pre-trained model on the downstream task. From our experience,
|
||||
the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.
|
||||
At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following
|
||||
distillation training.
|
||||
2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,
|
||||
and directly prune the head (condense the weight) if the head was fully masked.
|
||||
If the head was partially masked, we will not prune it and recover its weight.
|
||||
3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.
|
||||
4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,
|
||||
and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.
|
||||
5. Retrain the final pruned model with distillation.
|
||||
|
||||
During the process of pruning transformer, we gained some of the following experiences:
|
||||
|
||||
* We using :ref:`movement-pruner` in step 2 and :ref:`taylor-fo-weight-pruner` in step 4. :ref:`movement-pruner` has good performance on attention layers,
|
||||
and :ref:`taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,
|
||||
we also try weight-based pruning algorithms like :ref:`l1-norm-pruner`, but it doesn't seem to work well in this scenario.
|
||||
* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.
|
||||
* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.
|
||||
|
||||
Experiment
|
||||
----------
|
||||
|
||||
Preparation
|
||||
^^^^^^^^^^^
|
||||
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
|
||||
|
||||
The complete pruning process takes about 8 hours on one A100.
|
||||
"""
|
||||
|
||||
dev_mode = True
|
||||
|
||||
# %%
|
||||
# Some basic setting.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
pretrained_model_name_or_path = 'bert-base-uncased'
|
||||
task_name = 'mnli'
|
||||
experiment_id = 'pruning_bert'
|
||||
|
||||
# heads_num and layers_num should align with pretrained_model_name_or_path
|
||||
heads_num = 12
|
||||
layers_num = 12
|
||||
|
||||
# used to save the experiment log
|
||||
log_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name
|
||||
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from transformers import set_seed
|
||||
set_seed(1024)
|
||||
|
||||
import torch
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# %%
|
||||
# The function used to create dataloaders, note that 'mnli' has two evaluation dataset.
|
||||
# If teacher_model is set, will run all dataset on teacher model to get the 'teacher_logits' for distillation.
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import BertTokenizerFast, DataCollatorWithPadding
|
||||
|
||||
task_to_keys = {
|
||||
'cola': ('sentence', None),
|
||||
'mnli': ('premise', 'hypothesis'),
|
||||
'mrpc': ('sentence1', 'sentence2'),
|
||||
'qnli': ('question', 'sentence'),
|
||||
'qqp': ('question1', 'question2'),
|
||||
'rte': ('sentence1', 'sentence2'),
|
||||
'sst2': ('sentence', None),
|
||||
'stsb': ('sentence1', 'sentence2'),
|
||||
'wnli': ('sentence1', 'sentence2'),
|
||||
}
|
||||
|
||||
def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32,
|
||||
teacher_model: torch.nn.Module = None):
|
||||
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
|
||||
sentence1_key, sentence2_key = task_to_keys[task_name]
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
|
||||
# used to preprocess the raw data
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
|
||||
|
||||
if 'label' in examples:
|
||||
# In all cases, rename the column to labels because the model will expect that.
|
||||
result['labels'] = examples['label']
|
||||
return result
|
||||
|
||||
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
|
||||
for key in list(raw_datasets.keys()):
|
||||
if 'test' in key:
|
||||
raw_datasets.pop(key)
|
||||
|
||||
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
|
||||
remove_columns=raw_datasets['train'].column_names)
|
||||
|
||||
# if has teacher model, add 'teacher_logits' to datasets who has 'labels'.
|
||||
# 'teacher_logits' is used for distillation and avoid the double counting.
|
||||
if teacher_model:
|
||||
teacher_model_training = teacher_model.training
|
||||
teacher_model.eval()
|
||||
model_device = next(teacher_model.parameters()).device
|
||||
|
||||
def add_teacher_logits(examples):
|
||||
result = {k: v for k, v in examples.items()}
|
||||
samples = data_collator(result).to(model_device)
|
||||
if 'labels' in samples:
|
||||
with torch.no_grad():
|
||||
logits = teacher_model(**samples).logits.tolist()
|
||||
result['teacher_logits'] = logits
|
||||
return result
|
||||
|
||||
processed_datasets = processed_datasets.map(add_teacher_logits, batched=True,
|
||||
batch_size=train_batch_size)
|
||||
teacher_model.train(teacher_model_training)
|
||||
|
||||
train_dataset = processed_datasets['train']
|
||||
validation_dataset = processed_datasets['validation_matched' if task_name == 'mnli' else 'validation']
|
||||
validation_dataset2 = processed_datasets['validation_mismatched'] if task_name == 'mnli' else None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
batch_size=train_batch_size)
|
||||
validation_dataloader = DataLoader(validation_dataset,
|
||||
collate_fn=data_collator,
|
||||
batch_size=eval_batch_size)
|
||||
validation_dataloader2 = DataLoader(validation_dataset2,
|
||||
collate_fn=data_collator,
|
||||
batch_size=eval_batch_size) if task_name == 'mnli' else None
|
||||
|
||||
return train_dataloader, validation_dataloader, validation_dataloader2
|
||||
|
||||
# %%
|
||||
# Training function & evaluation function.
|
||||
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
from datasets import load_metric
|
||||
|
||||
def training(train_dataloader: DataLoader,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
max_steps: int = None, max_epochs: int = None,
|
||||
save_best_model: bool = False, save_path: str = None,
|
||||
log_path: str = Path(log_dir) / 'training.log',
|
||||
distillation: bool = False,
|
||||
evaluation_func=None):
|
||||
model.train()
|
||||
current_step = 0
|
||||
best_result = 0
|
||||
|
||||
for current_epoch in range(max_epochs if max_epochs else 1):
|
||||
for batch in train_dataloader:
|
||||
batch.to(device)
|
||||
teacher_logits = batch.pop('teacher_logits', None)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
if distillation:
|
||||
assert teacher_logits is not None
|
||||
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),
|
||||
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
|
||||
loss = 0.1 * loss + 0.9 * distil_loss
|
||||
|
||||
loss = criterion(loss, None)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if lr_scheduler:
|
||||
lr_scheduler.step()
|
||||
|
||||
current_step += 1
|
||||
|
||||
# evaluation for every 1000 steps
|
||||
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
|
||||
result = evaluation_func(model) if evaluation_func else None
|
||||
with (log_path).open('a+') as f:
|
||||
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
|
||||
f.write(msg)
|
||||
# if it's the best model, save it.
|
||||
if save_best_model and best_result < result['default']:
|
||||
assert save_path is not None
|
||||
torch.save(model.state_dict(), save_path)
|
||||
best_result = result['default']
|
||||
|
||||
if max_steps and current_step >= max_steps:
|
||||
return
|
||||
|
||||
def evaluation(validation_dataloader: DataLoader,
|
||||
validation_dataloader2: DataLoader,
|
||||
model: torch.nn.Module):
|
||||
training = model.training
|
||||
model.eval()
|
||||
is_regression = task_name == 'stsb'
|
||||
metric = load_metric('glue', task_name)
|
||||
|
||||
for batch in validation_dataloader:
|
||||
batch.pop('teacher_logits', None)
|
||||
batch.to(device)
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=batch['labels'],
|
||||
)
|
||||
result = metric.compute()
|
||||
|
||||
if validation_dataloader2:
|
||||
for batch in validation_dataloader2:
|
||||
batch.pop('teacher_logits', None)
|
||||
batch.to(device)
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=batch['labels'],
|
||||
)
|
||||
result = {'matched': result, 'mismatched': metric.compute()}
|
||||
result['default'] = (result['matched']['accuracy'] + result['mismatched']['accuracy']) / 2
|
||||
else:
|
||||
result['default'] = result.get('f1', result.get('accuracy', None))
|
||||
|
||||
model.train(training)
|
||||
return result
|
||||
|
||||
# using huggingface native loss
|
||||
def fake_criterion(outputs, targets):
|
||||
return outputs
|
||||
|
||||
|
||||
# %%
|
||||
# Prepare pre-trained model and finetuning on downstream task.
|
||||
|
||||
import functools
|
||||
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
def create_pretrained_model():
|
||||
is_regression = task_name == 'stsb'
|
||||
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
|
||||
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
|
||||
|
||||
def create_finetuned_model():
|
||||
pretrained_model = create_pretrained_model().to(device)
|
||||
|
||||
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
|
||||
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
training_epochs = 3
|
||||
|
||||
finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'
|
||||
|
||||
if finetuned_model_state_path.exists():
|
||||
pretrained_model.load_state_dict(torch.load(finetuned_model_state_path))
|
||||
elif dev_mode:
|
||||
pass
|
||||
else:
|
||||
optimizer = Adam(pretrained_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))
|
||||
|
||||
lr_scheduler = LambdaLR(optimizer, lr_lambda)
|
||||
training(train_dataloader, pretrained_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=training_epochs,
|
||||
save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func)
|
||||
return pretrained_model
|
||||
|
||||
finetuned_model = create_finetuned_model()
|
||||
|
||||
# %%
|
||||
# Using finetuned model as teacher model to create dataloader.
|
||||
# Add 'teacher_logits' to dataset, it is used to do the distillation, it can be seen as a kind of data label.
|
||||
|
||||
if not dev_mode:
|
||||
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data(teacher_model=finetuned_model)
|
||||
else:
|
||||
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
|
||||
|
||||
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
|
||||
|
||||
# %%
|
||||
# Pruning
|
||||
# ^^^^^^^
|
||||
# First, using MovementPruner to prune attention head.
|
||||
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
|
||||
# Set training steps/epochs for pruning.
|
||||
|
||||
if not dev_mode:
|
||||
total_epochs = 4
|
||||
total_steps = total_epochs * steps_per_epoch
|
||||
warmup_steps = 1 * steps_per_epoch
|
||||
cooldown_steps = 1 * steps_per_epoch
|
||||
else:
|
||||
total_epochs = 1
|
||||
total_steps = 3
|
||||
warmup_steps = 1
|
||||
cooldown_steps = 1
|
||||
|
||||
# Initialize evaluator used by MovementPruner.
|
||||
|
||||
import nni
|
||||
from nni.algorithms.compression.v2.pytorch import TorchEvaluator
|
||||
|
||||
movement_training = functools.partial(training, train_dataloader, log_path=log_dir / 'movement_pruning.log',
|
||||
evaluation_func=evaluation_func)
|
||||
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / warmup_steps
|
||||
return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))
|
||||
|
||||
traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)
|
||||
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)
|
||||
|
||||
# Apply block-soft-movement pruning on attention layers.
|
||||
|
||||
from nni.compression.pytorch.pruning import MovementPruner
|
||||
|
||||
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder.layer.{}.'.format(i) for i in range(layers_num)], 'sparsity': 0.1}]
|
||||
pruner = MovementPruner(model=finetuned_model,
|
||||
config_list=config_list,
|
||||
evaluator=evaluator,
|
||||
training_epochs=total_epochs,
|
||||
training_steps=total_steps,
|
||||
warm_up_step=warmup_steps,
|
||||
cool_down_beginning_step=total_steps - cooldown_steps,
|
||||
regular_scale=10,
|
||||
movement_mode='soft',
|
||||
sparse_granularity='auto')
|
||||
_, attention_masks = pruner.compress()
|
||||
pruner.show_pruned_weights()
|
||||
|
||||
torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')
|
||||
|
||||
# %%
|
||||
# Load a new finetuned model to do the speedup.
|
||||
# Note that nni speedup don't support replace attention module, so here we manully replace the attention module.
|
||||
#
|
||||
# If the head is entire masked, physically prune it and create config_list for FFN pruning.
|
||||
|
||||
attention_pruned_model = create_finetuned_model().to(device)
|
||||
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')
|
||||
|
||||
ffn_config_list = []
|
||||
layer_count = 0
|
||||
module_list = []
|
||||
for i in range(0, layers_num):
|
||||
prefix = f'bert.encoder.layer.{i}.'
|
||||
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
|
||||
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
|
||||
head_idx = torch.arange(len(head_mask))[head_mask].long().tolist()
|
||||
print(f'layer {i} pruner {len(head_idx)} head: {head_idx}')
|
||||
if len(head_idx) != heads_num:
|
||||
attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idx)
|
||||
module_list.append(attention_pruned_model.bert.encoder.layer[i])
|
||||
# The final ffn weight remaining ratio is the half of the attention weight remaining ratio.
|
||||
# This is just an empirical configuration, you can use any other method to determine this sparsity.
|
||||
sparsity = 1 - (1 - len(head_idx) / heads_num) * 0.5
|
||||
# here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.
|
||||
sparsity_per_iter = 1 - (1 - sparsity) ** (1 / heads_num)
|
||||
ffn_config_list.append({'op_names': [f'bert.encoder.layer.{layer_count}.intermediate.dense'], 'sparsity': sparsity_per_iter})
|
||||
layer_count += 1
|
||||
|
||||
attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
|
||||
|
||||
# %%
|
||||
# Retrain the attention pruned model with distillation.
|
||||
|
||||
if not dev_mode:
|
||||
total_epochs = 5
|
||||
total_steps = None
|
||||
distillation = True
|
||||
else:
|
||||
total_epochs = 1
|
||||
total_steps = 1
|
||||
distillation = False
|
||||
|
||||
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))
|
||||
|
||||
lr_scheduler = LambdaLR(optimizer, lr_lambda)
|
||||
at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
|
||||
training(train_dataloader, attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
|
||||
max_epochs=total_epochs, max_steps=total_steps, save_best_model=True, save_path=at_model_save_path,
|
||||
distillation=distillation, evaluation_func=evaluation_func)
|
||||
|
||||
if not dev_mode:
|
||||
attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
|
||||
|
||||
# %%
|
||||
# Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.
|
||||
# Finetuning 2000 steps after each iteration, then finetuning 2 epochs after pruning finished.
|
||||
#
|
||||
# NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.
|
||||
|
||||
if not dev_mode:
|
||||
total_epochs = 4
|
||||
total_steps = None
|
||||
taylor_pruner_steps = 1000
|
||||
steps_per_iteration = 2000
|
||||
total_pruning_steps = 24000
|
||||
distillation = True
|
||||
else:
|
||||
total_epochs = 1
|
||||
total_steps = 6
|
||||
taylor_pruner_steps = 2
|
||||
steps_per_iteration = 2
|
||||
total_pruning_steps = 4
|
||||
distillation = False
|
||||
|
||||
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
|
||||
from nni.compression.pytorch.speedup import ModelSpeedup
|
||||
|
||||
distil_training = functools.partial(training, train_dataloader, log_path=log_dir / 'taylor_pruning.log',
|
||||
distillation=distillation, evaluation_func=evaluation_func)
|
||||
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)
|
||||
|
||||
current_step = 0
|
||||
best_result = 0
|
||||
init_lr = 3e-5
|
||||
|
||||
dummy_input = torch.rand(8, 128, 768).to(device)
|
||||
|
||||
attention_pruned_model.train()
|
||||
for current_epoch in range(total_epochs):
|
||||
for batch in train_dataloader:
|
||||
if total_steps and current_step >= total_steps:
|
||||
break
|
||||
# pruning 12 times
|
||||
if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
|
||||
check_point = attention_pruned_model.state_dict()
|
||||
pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
|
||||
_, ffn_masks = pruner.compress()
|
||||
renamed_ffn_masks = {}
|
||||
# rename the masks keys, because we only speedup the bert.encoder
|
||||
for model_name, targets_mask in ffn_masks.items():
|
||||
renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask
|
||||
pruner._unwrap_model()
|
||||
attention_pruned_model.load_state_dict(check_point)
|
||||
ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()
|
||||
optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)
|
||||
|
||||
batch.to(device)
|
||||
teacher_logits = batch.pop('teacher_logits', None)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# manually schedule lr
|
||||
for params_group in optimizer.param_groups:
|
||||
params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr
|
||||
|
||||
outputs = attention_pruned_model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# distillation
|
||||
if teacher_logits is not None:
|
||||
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),
|
||||
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
|
||||
loss = 0.1 * loss + 0.9 * distil_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
current_step += 1
|
||||
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
|
||||
result = evaluation_func(attention_pruned_model)
|
||||
with (log_dir / 'ffn_pruning.log').open('a+') as f:
|
||||
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())),
|
||||
current_epoch, current_step, result)
|
||||
f.write(msg)
|
||||
if current_step >= total_pruning_steps and best_result < result['default']:
|
||||
torch.save(attention_pruned_model, log_dir / 'best_model.pth')
|
||||
best_result = result['default']
|
||||
|
||||
# %%
|
||||
# Result
|
||||
# ------
|
||||
# The speedup is test on the entire validation dataset with batch size 32 on A100.
|
||||
# We test under two pytorch version and found the latency varying widely.
|
||||
#
|
||||
# Setting 1: pytorch 1.12.1
|
||||
#
|
||||
# Setting 2: pytorch 1.10.0
|
||||
#
|
||||
# .. list-table:: Prune Bert-base-uncased on MNLI
|
||||
# :header-rows: 1
|
||||
# :widths: auto
|
||||
#
|
||||
# * - Attention Pruning Method
|
||||
# - FFN Pruning Method
|
||||
# - Total Sparsity
|
||||
# - Accuracy
|
||||
# - Acc. Drop
|
||||
# - Speedup (S1)
|
||||
# - Speedup (S2)
|
||||
# * -
|
||||
# -
|
||||
# - 0%
|
||||
# - 84.73 / 84.63
|
||||
# - +0.0 / +0.0
|
||||
# - 12.56s (x1.00)
|
||||
# - 4.05s (x1.00)
|
||||
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=5)
|
||||
# - :ref:`taylor-fo-weight-pruner`
|
||||
# - 51.39%
|
||||
# - 84.25 / 84.96
|
||||
# - -0.48 / +0.33
|
||||
# - 6.85s (x1.83)
|
||||
# - 2.7s (x1.50)
|
||||
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=10)
|
||||
# - :ref:`taylor-fo-weight-pruner`
|
||||
# - 66.67%
|
||||
# - 83.98 / 83.75
|
||||
# - -0.75 / -0.88
|
||||
# - 4.73s (x2.66)
|
||||
# - 2.16s (x1.86)
|
||||
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=20)
|
||||
# - :ref:`taylor-fo-weight-pruner`
|
||||
# - 77.78%
|
||||
# - 83.02 / 83.06
|
||||
# - -1.71 / -1.57
|
||||
# - 3.35s (x3.75)
|
||||
# - 1.72s (x2.35)
|
||||
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=30)
|
||||
# - :ref:`taylor-fo-weight-pruner`
|
||||
# - 87.04%
|
||||
# - 81.24 / 80.99
|
||||
# - -3.49 / -3.64
|
||||
# - 2.19s (x5.74)
|
||||
# - 1.31s (x3.09)
|
|
@ -0,0 +1 @@
|
|||
7d8ff24fe5a88d208ad2ad051f060df4
|
|
@ -0,0 +1,809 @@
|
|||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "tutorials/pruning_bert_glue.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_tutorials_pruning_bert_glue.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_tutorials_pruning_bert_glue.py:
|
||||
|
||||
|
||||
Pruning Transformer with NNI
|
||||
============================
|
||||
|
||||
Workable Pruning Process
|
||||
------------------------
|
||||
|
||||
Here we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.
|
||||
|
||||
The entire pruning process can be divided into the following steps:
|
||||
|
||||
1. Finetune the pre-trained model on the downstream task. From our experience,
|
||||
the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.
|
||||
At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following
|
||||
distillation training.
|
||||
2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,
|
||||
and directly prune the head (condense the weight) if the head was fully masked.
|
||||
If the head was partially masked, we will not prune it and recover its weight.
|
||||
3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.
|
||||
4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,
|
||||
and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.
|
||||
5. Retrain the final pruned model with distillation.
|
||||
|
||||
During the process of pruning transformer, we gained some of the following experiences:
|
||||
|
||||
* We using :ref:`movement-pruner` in step 2 and :ref:`taylor-fo-weight-pruner` in step 4. :ref:`movement-pruner` has good performance on attention layers,
|
||||
and :ref:`taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,
|
||||
we also try weight-based pruning algorithms like :ref:`l1-norm-pruner`, but it doesn't seem to work well in this scenario.
|
||||
* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.
|
||||
* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.
|
||||
|
||||
Experiment
|
||||
----------
|
||||
|
||||
Preparation
|
||||
^^^^^^^^^^^
|
||||
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
|
||||
|
||||
The complete pruning process takes about 8 hours on one A100.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 41-44
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
dev_mode = True
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 45-46
|
||||
|
||||
Some basic setting.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 46-72
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
pretrained_model_name_or_path = 'bert-base-uncased'
|
||||
task_name = 'mnli'
|
||||
experiment_id = 'pruning_bert'
|
||||
|
||||
# heads_num and layers_num should align with pretrained_model_name_or_path
|
||||
heads_num = 12
|
||||
layers_num = 12
|
||||
|
||||
# used to save the experiment log
|
||||
log_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name
|
||||
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from transformers import set_seed
|
||||
set_seed(1024)
|
||||
|
||||
import torch
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 73-75
|
||||
|
||||
The function used to create dataloaders, note that 'mnli' has two evaluation dataset.
|
||||
If teacher_model is set, will run all dataset on teacher model to get the 'teacher_logits' for distillation.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 75-157
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import BertTokenizerFast, DataCollatorWithPadding
|
||||
|
||||
task_to_keys = {
|
||||
'cola': ('sentence', None),
|
||||
'mnli': ('premise', 'hypothesis'),
|
||||
'mrpc': ('sentence1', 'sentence2'),
|
||||
'qnli': ('question', 'sentence'),
|
||||
'qqp': ('question1', 'question2'),
|
||||
'rte': ('sentence1', 'sentence2'),
|
||||
'sst2': ('sentence', None),
|
||||
'stsb': ('sentence1', 'sentence2'),
|
||||
'wnli': ('sentence1', 'sentence2'),
|
||||
}
|
||||
|
||||
def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32,
|
||||
teacher_model: torch.nn.Module = None):
|
||||
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
|
||||
sentence1_key, sentence2_key = task_to_keys[task_name]
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
|
||||
# used to preprocess the raw data
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
|
||||
|
||||
if 'label' in examples:
|
||||
# In all cases, rename the column to labels because the model will expect that.
|
||||
result['labels'] = examples['label']
|
||||
return result
|
||||
|
||||
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
|
||||
for key in list(raw_datasets.keys()):
|
||||
if 'test' in key:
|
||||
raw_datasets.pop(key)
|
||||
|
||||
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
|
||||
remove_columns=raw_datasets['train'].column_names)
|
||||
|
||||
# if has teacher model, add 'teacher_logits' to datasets who has 'labels'.
|
||||
# 'teacher_logits' is used for distillation and avoid the double counting.
|
||||
if teacher_model:
|
||||
teacher_model_training = teacher_model.training
|
||||
teacher_model.eval()
|
||||
model_device = next(teacher_model.parameters()).device
|
||||
|
||||
def add_teacher_logits(examples):
|
||||
result = {k: v for k, v in examples.items()}
|
||||
samples = data_collator(result).to(model_device)
|
||||
if 'labels' in samples:
|
||||
with torch.no_grad():
|
||||
logits = teacher_model(**samples).logits.tolist()
|
||||
result['teacher_logits'] = logits
|
||||
return result
|
||||
|
||||
processed_datasets = processed_datasets.map(add_teacher_logits, batched=True,
|
||||
batch_size=train_batch_size)
|
||||
teacher_model.train(teacher_model_training)
|
||||
|
||||
train_dataset = processed_datasets['train']
|
||||
validation_dataset = processed_datasets['validation_matched' if task_name == 'mnli' else 'validation']
|
||||
validation_dataset2 = processed_datasets['validation_mismatched'] if task_name == 'mnli' else None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
batch_size=train_batch_size)
|
||||
validation_dataloader = DataLoader(validation_dataset,
|
||||
collate_fn=data_collator,
|
||||
batch_size=eval_batch_size)
|
||||
validation_dataloader2 = DataLoader(validation_dataset2,
|
||||
collate_fn=data_collator,
|
||||
batch_size=eval_batch_size) if task_name == 'mnli' else None
|
||||
|
||||
return train_dataloader, validation_dataloader, validation_dataloader2
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 158-159
|
||||
|
||||
Training function & evaluation function.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 159-258
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
from datasets import load_metric
|
||||
|
||||
def training(train_dataloader: DataLoader,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
max_steps: int = None, max_epochs: int = None,
|
||||
save_best_model: bool = False, save_path: str = None,
|
||||
log_path: str = Path(log_dir) / 'training.log',
|
||||
distillation: bool = False,
|
||||
evaluation_func=None):
|
||||
model.train()
|
||||
current_step = 0
|
||||
best_result = 0
|
||||
|
||||
for current_epoch in range(max_epochs if max_epochs else 1):
|
||||
for batch in train_dataloader:
|
||||
batch.to(device)
|
||||
teacher_logits = batch.pop('teacher_logits', None)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
if distillation:
|
||||
assert teacher_logits is not None
|
||||
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),
|
||||
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
|
||||
loss = 0.1 * loss + 0.9 * distil_loss
|
||||
|
||||
loss = criterion(loss, None)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if lr_scheduler:
|
||||
lr_scheduler.step()
|
||||
|
||||
current_step += 1
|
||||
|
||||
# evaluation for every 1000 steps
|
||||
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
|
||||
result = evaluation_func(model) if evaluation_func else None
|
||||
with (log_path).open('a+') as f:
|
||||
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
|
||||
f.write(msg)
|
||||
# if it's the best model, save it.
|
||||
if save_best_model and best_result < result['default']:
|
||||
assert save_path is not None
|
||||
torch.save(model.state_dict(), save_path)
|
||||
best_result = result['default']
|
||||
|
||||
if max_steps and current_step >= max_steps:
|
||||
return
|
||||
|
||||
def evaluation(validation_dataloader: DataLoader,
|
||||
validation_dataloader2: DataLoader,
|
||||
model: torch.nn.Module):
|
||||
training = model.training
|
||||
model.eval()
|
||||
is_regression = task_name == 'stsb'
|
||||
metric = load_metric('glue', task_name)
|
||||
|
||||
for batch in validation_dataloader:
|
||||
batch.pop('teacher_logits', None)
|
||||
batch.to(device)
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=batch['labels'],
|
||||
)
|
||||
result = metric.compute()
|
||||
|
||||
if validation_dataloader2:
|
||||
for batch in validation_dataloader2:
|
||||
batch.pop('teacher_logits', None)
|
||||
batch.to(device)
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=batch['labels'],
|
||||
)
|
||||
result = {'matched': result, 'mismatched': metric.compute()}
|
||||
result['default'] = (result['matched']['accuracy'] + result['mismatched']['accuracy']) / 2
|
||||
else:
|
||||
result['default'] = result.get('f1', result.get('accuracy', None))
|
||||
|
||||
model.train(training)
|
||||
return result
|
||||
|
||||
# using huggingface native loss
|
||||
def fake_criterion(outputs, targets):
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 259-260
|
||||
|
||||
Prepare pre-trained model and finetuning on downstream task.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 260-299
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import functools
|
||||
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
def create_pretrained_model():
|
||||
is_regression = task_name == 'stsb'
|
||||
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
|
||||
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
|
||||
|
||||
def create_finetuned_model():
|
||||
pretrained_model = create_pretrained_model().to(device)
|
||||
|
||||
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
|
||||
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
training_epochs = 3
|
||||
|
||||
finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'
|
||||
|
||||
if finetuned_model_state_path.exists():
|
||||
pretrained_model.load_state_dict(torch.load(finetuned_model_state_path))
|
||||
elif dev_mode:
|
||||
pass
|
||||
else:
|
||||
optimizer = Adam(pretrained_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))
|
||||
|
||||
lr_scheduler = LambdaLR(optimizer, lr_lambda)
|
||||
training(train_dataloader, pretrained_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=training_epochs,
|
||||
save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func)
|
||||
return pretrained_model
|
||||
|
||||
finetuned_model = create_finetuned_model()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
|
||||
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
|
||||
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
|
||||
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
Reusing dataset glue (./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
|
||||
0%| | 0/5 [00:00<?, ?it/s]
100%|##########| 5/5 [00:00<00:00, 1213.84it/s]
|
||||
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9c32a3d5eca55607.arrow
|
||||
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-6f0849c5f6325016.arrow
|
||||
0%| | 0/10 [00:00<?, ?ba/s]
40%|#### | 4/10 [00:00<00:00, 34.52ba/s]
90%|######### | 9/10 [00:00<00:00, 38.77ba/s]
100%|##########| 10/10 [00:00<00:00, 38.78ba/s]
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 300-302
|
||||
|
||||
Using finetuned model as teacher model to create dataloader.
|
||||
Add 'teacher_logits' to dataset, it is used to do the distillation, it can be seen as a kind of data label.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 302-310
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
if not dev_mode:
|
||||
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data(teacher_model=finetuned_model)
|
||||
else:
|
||||
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
|
||||
|
||||
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Reusing dataset glue (./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
|
||||
0%| | 0/5 [00:00<?, ?it/s]
100%|##########| 5/5 [00:00<00:00, 1249.79it/s]
|
||||
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9c32a3d5eca55607.arrow
|
||||
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-6f0849c5f6325016.arrow
|
||||
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-5db72911f5dfb448.arrow
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 311-314
|
||||
|
||||
Pruning
|
||||
^^^^^^^
|
||||
First, using MovementPruner to prune attention head.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 314-367
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
|
||||
# Set training steps/epochs for pruning.
|
||||
|
||||
if not dev_mode:
|
||||
total_epochs = 4
|
||||
total_steps = total_epochs * steps_per_epoch
|
||||
warmup_steps = 1 * steps_per_epoch
|
||||
cooldown_steps = 1 * steps_per_epoch
|
||||
else:
|
||||
total_epochs = 1
|
||||
total_steps = 3
|
||||
warmup_steps = 1
|
||||
cooldown_steps = 1
|
||||
|
||||
# Initialize evaluator used by MovementPruner.
|
||||
|
||||
import nni
|
||||
from nni.algorithms.compression.v2.pytorch import TorchEvaluator
|
||||
|
||||
movement_training = functools.partial(training, train_dataloader, log_path=log_dir / 'movement_pruning.log',
|
||||
evaluation_func=evaluation_func)
|
||||
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / warmup_steps
|
||||
return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))
|
||||
|
||||
traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)
|
||||
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)
|
||||
|
||||
# Apply block-soft-movement pruning on attention layers.
|
||||
|
||||
from nni.compression.pytorch.pruning import MovementPruner
|
||||
|
||||
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder.layer.{}.'.format(i) for i in range(layers_num)], 'sparsity': 0.1}]
|
||||
pruner = MovementPruner(model=finetuned_model,
|
||||
config_list=config_list,
|
||||
evaluator=evaluator,
|
||||
training_epochs=total_epochs,
|
||||
training_steps=total_steps,
|
||||
warm_up_step=warmup_steps,
|
||||
cool_down_beginning_step=total_steps - cooldown_steps,
|
||||
regular_scale=10,
|
||||
movement_mode='soft',
|
||||
sparse_granularity='auto')
|
||||
_, attention_masks = pruner.compress()
|
||||
pruner.show_pruned_weights()
|
||||
|
||||
torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Did not bind any model, no need to unbind model.
|
||||
Did not bind any model, no need to unbind model.
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 368-372
|
||||
|
||||
Load a new finetuned model to do the speedup.
|
||||
Note that nni speedup don't support replace attention module, so here we manully replace the attention module.
|
||||
|
||||
If the head is entire masked, physically prune it and create config_list for FFN pruning.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 372-398
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
attention_pruned_model = create_finetuned_model().to(device)
|
||||
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')
|
||||
|
||||
ffn_config_list = []
|
||||
layer_count = 0
|
||||
module_list = []
|
||||
for i in range(0, layers_num):
|
||||
prefix = f'bert.encoder.layer.{i}.'
|
||||
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
|
||||
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
|
||||
head_idx = torch.arange(len(head_mask))[head_mask].long().tolist()
|
||||
print(f'layer {i} pruner {len(head_idx)} head: {head_idx}')
|
||||
if len(head_idx) != heads_num:
|
||||
attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idx)
|
||||
module_list.append(attention_pruned_model.bert.encoder.layer[i])
|
||||
# The final ffn weight remaining ratio is the half of the attention weight remaining ratio.
|
||||
# This is just an empirical configuration, you can use any other method to determine this sparsity.
|
||||
sparsity = 1 - (1 - len(head_idx) / heads_num) * 0.5
|
||||
# here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.
|
||||
sparsity_per_iter = 1 - (1 - sparsity) ** (1 / heads_num)
|
||||
ffn_config_list.append({'op_names': [f'bert.encoder.layer.{layer_count}.intermediate.dense'], 'sparsity': sparsity_per_iter})
|
||||
layer_count += 1
|
||||
|
||||
attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
|
||||
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
|
||||
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
|
||||
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
Reusing dataset glue (./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
|
||||
0%| | 0/5 [00:00<?, ?it/s]
100%|##########| 5/5 [00:00<00:00, 1141.12it/s]
|
||||
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9c32a3d5eca55607.arrow
|
||||
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-6f0849c5f6325016.arrow
|
||||
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-5db72911f5dfb448.arrow
|
||||
layer 0 pruner 0 head: []
|
||||
layer 1 pruner 0 head: []
|
||||
layer 2 pruner 0 head: []
|
||||
layer 3 pruner 0 head: []
|
||||
layer 4 pruner 0 head: []
|
||||
layer 5 pruner 0 head: []
|
||||
layer 6 pruner 0 head: []
|
||||
layer 7 pruner 0 head: []
|
||||
layer 8 pruner 0 head: []
|
||||
layer 9 pruner 0 head: []
|
||||
layer 10 pruner 0 head: []
|
||||
layer 11 pruner 0 head: []
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 399-400
|
||||
|
||||
Retrain the attention pruned model with distillation.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 400-424
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
if not dev_mode:
|
||||
total_epochs = 5
|
||||
total_steps = None
|
||||
distillation = True
|
||||
else:
|
||||
total_epochs = 1
|
||||
total_steps = 1
|
||||
distillation = False
|
||||
|
||||
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))
|
||||
|
||||
lr_scheduler = LambdaLR(optimizer, lr_lambda)
|
||||
at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
|
||||
training(train_dataloader, attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
|
||||
max_epochs=total_epochs, max_steps=total_steps, save_best_model=True, save_path=at_model_save_path,
|
||||
distillation=distillation, evaluation_func=evaluation_func)
|
||||
|
||||
if not dev_mode:
|
||||
attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 425-429
|
||||
|
||||
Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.
|
||||
Finetuning 2000 steps after each iteration, then finetuning 2 epochs after pruning finished.
|
||||
|
||||
NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 429-508
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
if not dev_mode:
|
||||
total_epochs = 4
|
||||
total_steps = None
|
||||
taylor_pruner_steps = 1000
|
||||
steps_per_iteration = 2000
|
||||
total_pruning_steps = 24000
|
||||
distillation = True
|
||||
else:
|
||||
total_epochs = 1
|
||||
total_steps = 6
|
||||
taylor_pruner_steps = 2
|
||||
steps_per_iteration = 2
|
||||
total_pruning_steps = 4
|
||||
distillation = False
|
||||
|
||||
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
|
||||
from nni.compression.pytorch.speedup import ModelSpeedup
|
||||
|
||||
distil_training = functools.partial(training, train_dataloader, log_path=log_dir / 'taylor_pruning.log',
|
||||
distillation=distillation, evaluation_func=evaluation_func)
|
||||
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)
|
||||
|
||||
current_step = 0
|
||||
best_result = 0
|
||||
init_lr = 3e-5
|
||||
|
||||
dummy_input = torch.rand(8, 128, 768).to(device)
|
||||
|
||||
attention_pruned_model.train()
|
||||
for current_epoch in range(total_epochs):
|
||||
for batch in train_dataloader:
|
||||
if total_steps and current_step >= total_steps:
|
||||
break
|
||||
# pruning 12 times
|
||||
if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
|
||||
check_point = attention_pruned_model.state_dict()
|
||||
pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
|
||||
_, ffn_masks = pruner.compress()
|
||||
renamed_ffn_masks = {}
|
||||
# rename the masks keys, because we only speedup the bert.encoder
|
||||
for model_name, targets_mask in ffn_masks.items():
|
||||
renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask
|
||||
pruner._unwrap_model()
|
||||
attention_pruned_model.load_state_dict(check_point)
|
||||
ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()
|
||||
optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)
|
||||
|
||||
batch.to(device)
|
||||
teacher_logits = batch.pop('teacher_logits', None)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# manually schedule lr
|
||||
for params_group in optimizer.param_groups:
|
||||
params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr
|
||||
|
||||
outputs = attention_pruned_model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# distillation
|
||||
if teacher_logits is not None:
|
||||
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),
|
||||
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
|
||||
loss = 0.1 * loss + 0.9 * distil_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
current_step += 1
|
||||
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
|
||||
result = evaluation_func(attention_pruned_model)
|
||||
with (log_dir / 'ffn_pruning.log').open('a+') as f:
|
||||
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())),
|
||||
current_epoch, current_step, result)
|
||||
f.write(msg)
|
||||
if current_step >= total_pruning_steps and best_result < result['default']:
|
||||
torch.save(attention_pruned_model, log_dir / 'best_model.pth')
|
||||
best_result = result['default']
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Did not bind any model, no need to unbind model.
|
||||
no multi-dimension masks found.
|
||||
/home/nishang/anaconda3/envs/nni-dev/lib/python3.7/site-packages/torch/_tensor.py:1083: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:477.)
|
||||
return self._grad
|
||||
Did not bind any model, no need to unbind model.
|
||||
no multi-dimension masks found.
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 509-564
|
||||
|
||||
Result
|
||||
------
|
||||
The speedup is test on the entire validation dataset with batch size 32 on A100.
|
||||
We test under two pytorch version and found the latency varying widely.
|
||||
|
||||
Setting 1: pytorch 1.12.1
|
||||
|
||||
Setting 2: pytorch 1.10.0
|
||||
|
||||
.. list-table:: Prune Bert-base-uncased on MNLI
|
||||
:header-rows: 1
|
||||
:widths: auto
|
||||
|
||||
* - Attention Pruning Method
|
||||
- FFN Pruning Method
|
||||
- Total Sparsity
|
||||
- Accuracy
|
||||
- Acc. Drop
|
||||
- Speedup (S1)
|
||||
- Speedup (S2)
|
||||
* -
|
||||
-
|
||||
- 0%
|
||||
- 84.73 / 84.63
|
||||
- +0.0 / +0.0
|
||||
- 12.56s (x1.00)
|
||||
- 4.05s (x1.00)
|
||||
* - :ref:`movement-pruner` (soft, th=0.1, lambda=5)
|
||||
- :ref:`taylor-fo-weight-pruner`
|
||||
- 51.39%
|
||||
- 84.25 / 84.96
|
||||
- -0.48 / +0.33
|
||||
- 6.85s (x1.83)
|
||||
- 2.7s (x1.50)
|
||||
* - :ref:`movement-pruner` (soft, th=0.1, lambda=10)
|
||||
- :ref:`taylor-fo-weight-pruner`
|
||||
- 66.67%
|
||||
- 83.98 / 83.75
|
||||
- -0.75 / -0.88
|
||||
- 4.73s (x2.66)
|
||||
- 2.16s (x1.86)
|
||||
* - :ref:`movement-pruner` (soft, th=0.1, lambda=20)
|
||||
- :ref:`taylor-fo-weight-pruner`
|
||||
- 77.78%
|
||||
- 83.02 / 83.06
|
||||
- -1.71 / -1.57
|
||||
- 3.35s (x3.75)
|
||||
- 1.72s (x2.35)
|
||||
* - :ref:`movement-pruner` (soft, th=0.1, lambda=30)
|
||||
- :ref:`taylor-fo-weight-pruner`
|
||||
- 87.04%
|
||||
- 81.24 / 80.99
|
||||
- -3.49 / -3.64
|
||||
- 2.19s (x5.74)
|
||||
- 1.31s (x3.09)
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 27.206 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_tutorials_pruning_bert_glue.py:
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. container:: sphx-glr-footer sphx-glr-footer-example
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: pruning_bert_glue.py <pruning_bert_glue.py>`
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: pruning_bert_glue.ipynb <pruning_bert_glue.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
Двоичный файл не отображается.
|
@ -5,10 +5,10 @@
|
|||
|
||||
Computation times
|
||||
=================
|
||||
**01:45.743** total execution time for **tutorials** files:
|
||||
**00:27.206** total execution time for **tutorials** files:
|
||||
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 01:45.743 | 0.0 MB |
|
||||
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:27.206 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
@ -22,5 +22,7 @@ Computation times
|
|||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_customize.py` (``quantization_customize.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
|
|
@ -3,4 +3,6 @@
|
|||
data/
|
||||
MNIST/
|
||||
cifar-10-batches-py/
|
||||
experiment_data/
|
||||
experiment_data/
|
||||
pruning/models
|
||||
pruning/pruning_log
|
|
@ -1,3 +1,5 @@
|
|||
data/
|
||||
log/
|
||||
*.onnx
|
||||
*.onnx
|
||||
models/
|
||||
pruning_log/
|
|
@ -0,0 +1,563 @@
|
|||
"""
|
||||
Pruning Transformer with NNI
|
||||
============================
|
||||
|
||||
Workable Pruning Process
|
||||
------------------------
|
||||
|
||||
Here we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.
|
||||
|
||||
The entire pruning process can be divided into the following steps:
|
||||
|
||||
1. Finetune the pre-trained model on the downstream task. From our experience,
|
||||
the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.
|
||||
At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following
|
||||
distillation training.
|
||||
2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,
|
||||
and directly prune the head (condense the weight) if the head was fully masked.
|
||||
If the head was partially masked, we will not prune it and recover its weight.
|
||||
3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.
|
||||
4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,
|
||||
and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.
|
||||
5. Retrain the final pruned model with distillation.
|
||||
|
||||
During the process of pruning transformer, we gained some of the following experiences:
|
||||
|
||||
* We using :ref:`movement-pruner` in step 2 and :ref:`taylor-fo-weight-pruner` in step 4. :ref:`movement-pruner` has good performance on attention layers,
|
||||
and :ref:`taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,
|
||||
we also try weight-based pruning algorithms like :ref:`l1-norm-pruner`, but it doesn't seem to work well in this scenario.
|
||||
* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.
|
||||
* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.
|
||||
|
||||
Experiment
|
||||
----------
|
||||
|
||||
Preparation
|
||||
^^^^^^^^^^^
|
||||
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
|
||||
|
||||
The complete pruning process takes about 8 hours on one A100.
|
||||
"""
|
||||
|
||||
dev_mode = True
|
||||
|
||||
# %%
|
||||
# Some basic setting.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
pretrained_model_name_or_path = 'bert-base-uncased'
|
||||
task_name = 'mnli'
|
||||
experiment_id = 'pruning_bert'
|
||||
|
||||
# heads_num and layers_num should align with pretrained_model_name_or_path
|
||||
heads_num = 12
|
||||
layers_num = 12
|
||||
|
||||
# used to save the experiment log
|
||||
log_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name
|
||||
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from transformers import set_seed
|
||||
set_seed(1024)
|
||||
|
||||
import torch
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# %%
|
||||
# The function used to create dataloaders, note that 'mnli' has two evaluation dataset.
|
||||
# If teacher_model is set, will run all dataset on teacher model to get the 'teacher_logits' for distillation.
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import BertTokenizerFast, DataCollatorWithPadding
|
||||
|
||||
task_to_keys = {
|
||||
'cola': ('sentence', None),
|
||||
'mnli': ('premise', 'hypothesis'),
|
||||
'mrpc': ('sentence1', 'sentence2'),
|
||||
'qnli': ('question', 'sentence'),
|
||||
'qqp': ('question1', 'question2'),
|
||||
'rte': ('sentence1', 'sentence2'),
|
||||
'sst2': ('sentence', None),
|
||||
'stsb': ('sentence1', 'sentence2'),
|
||||
'wnli': ('sentence1', 'sentence2'),
|
||||
}
|
||||
|
||||
def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32,
|
||||
teacher_model: torch.nn.Module = None):
|
||||
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
|
||||
sentence1_key, sentence2_key = task_to_keys[task_name]
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
|
||||
# used to preprocess the raw data
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
|
||||
|
||||
if 'label' in examples:
|
||||
# In all cases, rename the column to labels because the model will expect that.
|
||||
result['labels'] = examples['label']
|
||||
return result
|
||||
|
||||
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
|
||||
for key in list(raw_datasets.keys()):
|
||||
if 'test' in key:
|
||||
raw_datasets.pop(key)
|
||||
|
||||
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
|
||||
remove_columns=raw_datasets['train'].column_names)
|
||||
|
||||
# if has teacher model, add 'teacher_logits' to datasets who has 'labels'.
|
||||
# 'teacher_logits' is used for distillation and avoid the double counting.
|
||||
if teacher_model:
|
||||
teacher_model_training = teacher_model.training
|
||||
teacher_model.eval()
|
||||
model_device = next(teacher_model.parameters()).device
|
||||
|
||||
def add_teacher_logits(examples):
|
||||
result = {k: v for k, v in examples.items()}
|
||||
samples = data_collator(result).to(model_device)
|
||||
if 'labels' in samples:
|
||||
with torch.no_grad():
|
||||
logits = teacher_model(**samples).logits.tolist()
|
||||
result['teacher_logits'] = logits
|
||||
return result
|
||||
|
||||
processed_datasets = processed_datasets.map(add_teacher_logits, batched=True,
|
||||
batch_size=train_batch_size)
|
||||
teacher_model.train(teacher_model_training)
|
||||
|
||||
train_dataset = processed_datasets['train']
|
||||
validation_dataset = processed_datasets['validation_matched' if task_name == 'mnli' else 'validation']
|
||||
validation_dataset2 = processed_datasets['validation_mismatched'] if task_name == 'mnli' else None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
batch_size=train_batch_size)
|
||||
validation_dataloader = DataLoader(validation_dataset,
|
||||
collate_fn=data_collator,
|
||||
batch_size=eval_batch_size)
|
||||
validation_dataloader2 = DataLoader(validation_dataset2,
|
||||
collate_fn=data_collator,
|
||||
batch_size=eval_batch_size) if task_name == 'mnli' else None
|
||||
|
||||
return train_dataloader, validation_dataloader, validation_dataloader2
|
||||
|
||||
# %%
|
||||
# Training function & evaluation function.
|
||||
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
from datasets import load_metric
|
||||
|
||||
def training(train_dataloader: DataLoader,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
max_steps: int = None, max_epochs: int = None,
|
||||
save_best_model: bool = False, save_path: str = None,
|
||||
log_path: str = Path(log_dir) / 'training.log',
|
||||
distillation: bool = False,
|
||||
evaluation_func=None):
|
||||
model.train()
|
||||
current_step = 0
|
||||
best_result = 0
|
||||
|
||||
for current_epoch in range(max_epochs if max_epochs else 1):
|
||||
for batch in train_dataloader:
|
||||
batch.to(device)
|
||||
teacher_logits = batch.pop('teacher_logits', None)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
if distillation:
|
||||
assert teacher_logits is not None
|
||||
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),
|
||||
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
|
||||
loss = 0.1 * loss + 0.9 * distil_loss
|
||||
|
||||
loss = criterion(loss, None)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if lr_scheduler:
|
||||
lr_scheduler.step()
|
||||
|
||||
current_step += 1
|
||||
|
||||
# evaluation for every 1000 steps
|
||||
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
|
||||
result = evaluation_func(model) if evaluation_func else None
|
||||
with (log_path).open('a+') as f:
|
||||
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
|
||||
f.write(msg)
|
||||
# if it's the best model, save it.
|
||||
if save_best_model and best_result < result['default']:
|
||||
assert save_path is not None
|
||||
torch.save(model.state_dict(), save_path)
|
||||
best_result = result['default']
|
||||
|
||||
if max_steps and current_step >= max_steps:
|
||||
return
|
||||
|
||||
def evaluation(validation_dataloader: DataLoader,
|
||||
validation_dataloader2: DataLoader,
|
||||
model: torch.nn.Module):
|
||||
training = model.training
|
||||
model.eval()
|
||||
is_regression = task_name == 'stsb'
|
||||
metric = load_metric('glue', task_name)
|
||||
|
||||
for batch in validation_dataloader:
|
||||
batch.pop('teacher_logits', None)
|
||||
batch.to(device)
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=batch['labels'],
|
||||
)
|
||||
result = metric.compute()
|
||||
|
||||
if validation_dataloader2:
|
||||
for batch in validation_dataloader2:
|
||||
batch.pop('teacher_logits', None)
|
||||
batch.to(device)
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=batch['labels'],
|
||||
)
|
||||
result = {'matched': result, 'mismatched': metric.compute()}
|
||||
result['default'] = (result['matched']['accuracy'] + result['mismatched']['accuracy']) / 2
|
||||
else:
|
||||
result['default'] = result.get('f1', result.get('accuracy', None))
|
||||
|
||||
model.train(training)
|
||||
return result
|
||||
|
||||
# using huggingface native loss
|
||||
def fake_criterion(outputs, targets):
|
||||
return outputs
|
||||
|
||||
|
||||
# %%
|
||||
# Prepare pre-trained model and finetuning on downstream task.
|
||||
|
||||
import functools
|
||||
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
def create_pretrained_model():
|
||||
is_regression = task_name == 'stsb'
|
||||
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
|
||||
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
|
||||
|
||||
def create_finetuned_model():
|
||||
pretrained_model = create_pretrained_model().to(device)
|
||||
|
||||
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
|
||||
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
training_epochs = 3
|
||||
|
||||
finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'
|
||||
|
||||
if finetuned_model_state_path.exists():
|
||||
pretrained_model.load_state_dict(torch.load(finetuned_model_state_path))
|
||||
elif dev_mode:
|
||||
pass
|
||||
else:
|
||||
optimizer = Adam(pretrained_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))
|
||||
|
||||
lr_scheduler = LambdaLR(optimizer, lr_lambda)
|
||||
training(train_dataloader, pretrained_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=training_epochs,
|
||||
save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func)
|
||||
return pretrained_model
|
||||
|
||||
finetuned_model = create_finetuned_model()
|
||||
|
||||
# %%
|
||||
# Using finetuned model as teacher model to create dataloader.
|
||||
# Add 'teacher_logits' to dataset, it is used to do the distillation, it can be seen as a kind of data label.
|
||||
|
||||
if not dev_mode:
|
||||
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data(teacher_model=finetuned_model)
|
||||
else:
|
||||
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
|
||||
|
||||
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
|
||||
|
||||
# %%
|
||||
# Pruning
|
||||
# ^^^^^^^
|
||||
# First, using MovementPruner to prune attention head.
|
||||
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
|
||||
# Set training steps/epochs for pruning.
|
||||
|
||||
if not dev_mode:
|
||||
total_epochs = 4
|
||||
total_steps = total_epochs * steps_per_epoch
|
||||
warmup_steps = 1 * steps_per_epoch
|
||||
cooldown_steps = 1 * steps_per_epoch
|
||||
else:
|
||||
total_epochs = 1
|
||||
total_steps = 3
|
||||
warmup_steps = 1
|
||||
cooldown_steps = 1
|
||||
|
||||
# Initialize evaluator used by MovementPruner.
|
||||
|
||||
import nni
|
||||
from nni.algorithms.compression.v2.pytorch import TorchEvaluator
|
||||
|
||||
movement_training = functools.partial(training, train_dataloader, log_path=log_dir / 'movement_pruning.log',
|
||||
evaluation_func=evaluation_func)
|
||||
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / warmup_steps
|
||||
return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))
|
||||
|
||||
traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)
|
||||
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)
|
||||
|
||||
# Apply block-soft-movement pruning on attention layers.
|
||||
|
||||
from nni.compression.pytorch.pruning import MovementPruner
|
||||
|
||||
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder.layer.{}.'.format(i) for i in range(layers_num)], 'sparsity': 0.1}]
|
||||
pruner = MovementPruner(model=finetuned_model,
|
||||
config_list=config_list,
|
||||
evaluator=evaluator,
|
||||
training_epochs=total_epochs,
|
||||
training_steps=total_steps,
|
||||
warm_up_step=warmup_steps,
|
||||
cool_down_beginning_step=total_steps - cooldown_steps,
|
||||
regular_scale=10,
|
||||
movement_mode='soft',
|
||||
sparse_granularity='auto')
|
||||
_, attention_masks = pruner.compress()
|
||||
pruner.show_pruned_weights()
|
||||
|
||||
torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')
|
||||
|
||||
# %%
|
||||
# Load a new finetuned model to do the speedup.
|
||||
# Note that nni speedup don't support replace attention module, so here we manully replace the attention module.
|
||||
#
|
||||
# If the head is entire masked, physically prune it and create config_list for FFN pruning.
|
||||
|
||||
attention_pruned_model = create_finetuned_model().to(device)
|
||||
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')
|
||||
|
||||
ffn_config_list = []
|
||||
layer_count = 0
|
||||
module_list = []
|
||||
for i in range(0, layers_num):
|
||||
prefix = f'bert.encoder.layer.{i}.'
|
||||
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
|
||||
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
|
||||
head_idx = torch.arange(len(head_mask))[head_mask].long().tolist()
|
||||
print(f'layer {i} pruner {len(head_idx)} head: {head_idx}')
|
||||
if len(head_idx) != heads_num:
|
||||
attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idx)
|
||||
module_list.append(attention_pruned_model.bert.encoder.layer[i])
|
||||
# The final ffn weight remaining ratio is the half of the attention weight remaining ratio.
|
||||
# This is just an empirical configuration, you can use any other method to determine this sparsity.
|
||||
sparsity = 1 - (1 - len(head_idx) / heads_num) * 0.5
|
||||
# here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.
|
||||
sparsity_per_iter = 1 - (1 - sparsity) ** (1 / heads_num)
|
||||
ffn_config_list.append({'op_names': [f'bert.encoder.layer.{layer_count}.intermediate.dense'], 'sparsity': sparsity_per_iter})
|
||||
layer_count += 1
|
||||
|
||||
attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
|
||||
|
||||
# %%
|
||||
# Retrain the attention pruned model with distillation.
|
||||
|
||||
if not dev_mode:
|
||||
total_epochs = 5
|
||||
total_steps = None
|
||||
distillation = True
|
||||
else:
|
||||
total_epochs = 1
|
||||
total_steps = 1
|
||||
distillation = False
|
||||
|
||||
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))
|
||||
|
||||
lr_scheduler = LambdaLR(optimizer, lr_lambda)
|
||||
at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
|
||||
training(train_dataloader, attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
|
||||
max_epochs=total_epochs, max_steps=total_steps, save_best_model=True, save_path=at_model_save_path,
|
||||
distillation=distillation, evaluation_func=evaluation_func)
|
||||
|
||||
if not dev_mode:
|
||||
attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
|
||||
|
||||
# %%
|
||||
# Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.
|
||||
# Finetuning 2000 steps after each iteration, then finetuning 2 epochs after pruning finished.
|
||||
#
|
||||
# NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.
|
||||
|
||||
if not dev_mode:
|
||||
total_epochs = 4
|
||||
total_steps = None
|
||||
taylor_pruner_steps = 1000
|
||||
steps_per_iteration = 2000
|
||||
total_pruning_steps = 24000
|
||||
distillation = True
|
||||
else:
|
||||
total_epochs = 1
|
||||
total_steps = 6
|
||||
taylor_pruner_steps = 2
|
||||
steps_per_iteration = 2
|
||||
total_pruning_steps = 4
|
||||
distillation = False
|
||||
|
||||
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
|
||||
from nni.compression.pytorch.speedup import ModelSpeedup
|
||||
|
||||
distil_training = functools.partial(training, train_dataloader, log_path=log_dir / 'taylor_pruning.log',
|
||||
distillation=distillation, evaluation_func=evaluation_func)
|
||||
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
|
||||
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)
|
||||
|
||||
current_step = 0
|
||||
best_result = 0
|
||||
init_lr = 3e-5
|
||||
|
||||
dummy_input = torch.rand(8, 128, 768).to(device)
|
||||
|
||||
attention_pruned_model.train()
|
||||
for current_epoch in range(total_epochs):
|
||||
for batch in train_dataloader:
|
||||
if total_steps and current_step >= total_steps:
|
||||
break
|
||||
# pruning 12 times
|
||||
if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
|
||||
check_point = attention_pruned_model.state_dict()
|
||||
pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
|
||||
_, ffn_masks = pruner.compress()
|
||||
renamed_ffn_masks = {}
|
||||
# rename the masks keys, because we only speedup the bert.encoder
|
||||
for model_name, targets_mask in ffn_masks.items():
|
||||
renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask
|
||||
pruner._unwrap_model()
|
||||
attention_pruned_model.load_state_dict(check_point)
|
||||
ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()
|
||||
optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)
|
||||
|
||||
batch.to(device)
|
||||
teacher_logits = batch.pop('teacher_logits', None)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# manually schedule lr
|
||||
for params_group in optimizer.param_groups:
|
||||
params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr
|
||||
|
||||
outputs = attention_pruned_model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# distillation
|
||||
if teacher_logits is not None:
|
||||
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),
|
||||
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
|
||||
loss = 0.1 * loss + 0.9 * distil_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
current_step += 1
|
||||
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
|
||||
result = evaluation_func(attention_pruned_model)
|
||||
with (log_dir / 'ffn_pruning.log').open('a+') as f:
|
||||
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())),
|
||||
current_epoch, current_step, result)
|
||||
f.write(msg)
|
||||
if current_step >= total_pruning_steps and best_result < result['default']:
|
||||
torch.save(attention_pruned_model, log_dir / 'best_model.pth')
|
||||
best_result = result['default']
|
||||
|
||||
# %%
|
||||
# Result
|
||||
# ------
|
||||
# The speedup is test on the entire validation dataset with batch size 32 on A100.
|
||||
# We test under two pytorch version and found the latency varying widely.
|
||||
#
|
||||
# Setting 1: pytorch 1.12.1
|
||||
#
|
||||
# Setting 2: pytorch 1.10.0
|
||||
#
|
||||
# .. list-table:: Prune Bert-base-uncased on MNLI
|
||||
# :header-rows: 1
|
||||
# :widths: auto
|
||||
#
|
||||
# * - Attention Pruning Method
|
||||
# - FFN Pruning Method
|
||||
# - Total Sparsity
|
||||
# - Accuracy
|
||||
# - Acc. Drop
|
||||
# - Speedup (S1)
|
||||
# - Speedup (S2)
|
||||
# * -
|
||||
# -
|
||||
# - 0%
|
||||
# - 84.73 / 84.63
|
||||
# - +0.0 / +0.0
|
||||
# - 12.56s (x1.00)
|
||||
# - 4.05s (x1.00)
|
||||
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=5)
|
||||
# - :ref:`taylor-fo-weight-pruner`
|
||||
# - 51.39%
|
||||
# - 84.25 / 84.96
|
||||
# - -0.48 / +0.33
|
||||
# - 6.85s (x1.83)
|
||||
# - 2.7s (x1.50)
|
||||
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=10)
|
||||
# - :ref:`taylor-fo-weight-pruner`
|
||||
# - 66.67%
|
||||
# - 83.98 / 83.75
|
||||
# - -0.75 / -0.88
|
||||
# - 4.73s (x2.66)
|
||||
# - 2.16s (x1.86)
|
||||
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=20)
|
||||
# - :ref:`taylor-fo-weight-pruner`
|
||||
# - 77.78%
|
||||
# - 83.02 / 83.06
|
||||
# - -1.71 / -1.57
|
||||
# - 3.35s (x3.75)
|
||||
# - 1.72s (x2.35)
|
||||
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=30)
|
||||
# - :ref:`taylor-fo-weight-pruner`
|
||||
# - 87.04%
|
||||
# - 81.24 / 80.99
|
||||
# - -3.49 / -3.64
|
||||
# - 2.19s (x5.74)
|
||||
# - 1.31s (x3.09)
|
|
@ -189,7 +189,7 @@ class EvaluatorBasedPruner(BasicPruner):
|
|||
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
|
||||
merged_kwargs[key] = value
|
||||
for key, value in def_kwargs.items():
|
||||
if key not in merged_kwargs:
|
||||
if key not in merged_kwargs and key in arg_names:
|
||||
merged_kwargs[key] = value
|
||||
diff = set(arg_names).difference(merged_kwargs.keys())
|
||||
if diff:
|
||||
|
@ -734,6 +734,8 @@ class ActivationPruner(EvaluatorBasedPruner):
|
|||
def _choose_activation(self, activation: str = 'relu') -> Callable:
|
||||
if activation == 'relu':
|
||||
return F.relu
|
||||
elif activation == 'gelu':
|
||||
return F.gelu
|
||||
elif activation == 'relu6':
|
||||
return F.relu6
|
||||
else:
|
||||
|
|
|
@ -60,7 +60,7 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler):
|
|||
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
|
||||
merged_kwargs[key] = value
|
||||
for key, value in def_kwargs.items():
|
||||
if key not in merged_kwargs:
|
||||
if key not in merged_kwargs and key in arg_names:
|
||||
merged_kwargs[key] = value
|
||||
diff = set(arg_names).difference(merged_kwargs.keys())
|
||||
if diff:
|
||||
|
|
|
@ -6,6 +6,7 @@ from __future__ import annotations
|
|||
from copy import deepcopy
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Callable, overload
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torch
|
||||
from torch import autograd, Tensor
|
||||
|
@ -21,15 +22,18 @@ from .tools.base import EvaluatorBasedDataCollector, TrainerBasedDataCollector
|
|||
|
||||
from .tools import (
|
||||
NormalSparsityAllocator,
|
||||
ThresholdSparsityAllocator,
|
||||
StraightMetricsCalculator
|
||||
)
|
||||
|
||||
from ..utils import (
|
||||
LightningEvaluator,
|
||||
TorchEvaluator
|
||||
TorchEvaluator,
|
||||
Scaling
|
||||
)
|
||||
|
||||
from ..utils.docstring import _EVALUATOR_DOCSTRING
|
||||
from ..utils.external.huggingface import parser_factory
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -48,14 +52,18 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
|
|||
module_name
|
||||
The name of the module to compress, wrapper module shares same name.
|
||||
"""
|
||||
def __init__(self, module: Module, module_name: str, config: Dict):
|
||||
def __init__(self, module: Module, module_name: str, config: Dict, score_size: List[int] | None = None):
|
||||
super().__init__(module, module_name, config)
|
||||
self.weight_score = Parameter(torch.empty(self.weight.size())) # type: ignore
|
||||
self.weight_score = Parameter(torch.empty(score_size)) \
|
||||
if score_size is not None else Parameter(torch.empty_like(module.weight)) # type: ignore
|
||||
torch.nn.init.constant_(self.weight_score, val=0.0)
|
||||
|
||||
def forward(self, *inputs):
|
||||
# apply mask to weight, bias
|
||||
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask)) # type: ignore
|
||||
repeat = [a // b for a, b in zip(self.weight.shape, self.weight_score.shape)] # type: ignore
|
||||
weight_score = self.weight_score
|
||||
for dim, num in enumerate(repeat):
|
||||
weight_score = weight_score.repeat_interleave(num, dim=dim)
|
||||
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(weight_score, self.weight_mask)) # type: ignore
|
||||
if hasattr(self.module, 'bias') and self.module.bias is not None:
|
||||
self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
|
||||
return self.module(*inputs)
|
||||
|
@ -124,9 +132,9 @@ class MovementPruner(EvaluatorBasedPruner):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
model : torch.nn.Module
|
||||
model
|
||||
Model to be pruned.
|
||||
config_list : List[Dict]
|
||||
config_list
|
||||
Supported keys:
|
||||
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
|
||||
- sparsity_per_layer : Equals to sparsity.
|
||||
|
@ -140,16 +148,39 @@ class MovementPruner(EvaluatorBasedPruner):
|
|||
{evaluator_docstring}
|
||||
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
|
||||
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
|
||||
training_epochs : int
|
||||
The total epoch number for training the model.
|
||||
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
|
||||
warm_up_step : int
|
||||
warm_up_step
|
||||
The total `optimizer.step()` number before start pruning for warm up.
|
||||
Make sure `warm_up_step` is smaller than `cool_down_beginning_step`.
|
||||
cool_down_beginning_step: int
|
||||
Make sure ``warm_up_step`` is smaller than ``cool_down_beginning_step``.
|
||||
cool_down_beginning_step
|
||||
The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn't mean masks not changed.
|
||||
The sparsity after each `optimizer.step()` is:
|
||||
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
|
||||
training_epochs
|
||||
The total epoch number for training the model.
|
||||
Make sure the total `optimizer.step()` in ``training_epochs`` is bigger than `cool_down_beginning_step`.
|
||||
If both ``training_epochs`` and ``training_steps`` are set, pruning will stop when either is reached.
|
||||
training_steps
|
||||
The total step number for training the model.
|
||||
Make sure ``training_epochs`` is bigger than ``cool_down_beginning_step``.
|
||||
If both ``training_epochs`` and ``training_steps`` are set, pruning will stop when either is reached.
|
||||
regular_scale
|
||||
Use to scale the movement score regular loss. In 'soft' mode, higher regular scale means higher final sparsity.
|
||||
The recommended range is 1 ~ 30.
|
||||
movement_mode
|
||||
'hard' or 'soft'. Note that in 'soft' mode, ``sparsity`` set in the ``config_list`` means the sparsify threshold,
|
||||
'soft' mode cannot precisely control the sparsity rate, but usually has higher performance compared with 'hard' mode.
|
||||
``sparsity`` in 'soft' mode usually set to ``0.1``, and using ``regular_scale`` to control the final relative sparsity.
|
||||
|
||||
For detailed differences between 'hard' and 'soft', please refer to the paper.
|
||||
In short, 'hard' means that the corresponding layer is pruned to a fixed ratio by the topk method according to the movement score,
|
||||
which is the sparsity ratio set in config_list.
|
||||
'soft' means that the final sparsity size will not be fixed, but the generation of the mask will be controlled by a threshold,
|
||||
and the positions corresponding to scores below the threshold will be masked during the movement training process.
|
||||
sparse_granularity
|
||||
This is an experimental interface, by default, apply 'finegrained' pruning. If 'auto' is set, will try to apply structure pruning.
|
||||
For the attention layer, will apply block sparse with size [head_width, head_width]. For the following two linear layers (FFN),
|
||||
will apply output channel pruning for the first linear, and the input channel pruning for the second one.
|
||||
'auto' only support partial hugingface transformers right now (bart, bert, t5).
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
@ -157,8 +188,10 @@ class MovementPruner(EvaluatorBasedPruner):
|
|||
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_epochs: int,
|
||||
warm_up_step: int, cool_down_beginning_step: int):
|
||||
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, warm_up_step: int,
|
||||
cool_down_beginning_step: int, training_epochs: int | None = None, training_steps: int | None = None,
|
||||
regular_scale: float | None = None, movement_mode: Literal['hard', 'soft'] = 'hard',
|
||||
sparse_granularity: Literal['auto', 'finegrained'] = 'finegrained'):
|
||||
...
|
||||
|
||||
@overload
|
||||
|
@ -169,14 +202,23 @@ class MovementPruner(EvaluatorBasedPruner):
|
|||
|
||||
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
|
||||
# TODO: remove in nni v3.0. Fake overload.
|
||||
new_api = ['evaluator', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
|
||||
new_api = ['evaluator', 'warm_up_step', 'cool_down_beginning_step', 'training_epochs', 'training_steps', 'regular_scale',
|
||||
'movement_mode', 'sparse_granularity']
|
||||
old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
|
||||
init_kwargs = self._init_evaluator(model, new_api, old_api, {}, args, kwargs)
|
||||
init_kwargs = {'training_epochs': None, 'training_steps': None, 'regular_scale': None, 'movement_mode': 'hard',
|
||||
'sparse_granularity': 'finegrained'}
|
||||
init_kwargs = self._init_evaluator(model, new_api, old_api, init_kwargs, args, kwargs)
|
||||
|
||||
self.training_epochs: int = init_kwargs['training_epochs']
|
||||
self.training_steps: int | None = init_kwargs['training_steps'] if self.using_evaluator else None
|
||||
self.warm_up_step: int = init_kwargs['warm_up_step']
|
||||
self.cool_down_beginning_step: int = init_kwargs['cool_down_beginning_step']
|
||||
self.regular_scale: int | None = init_kwargs['regular_scale'] if self.using_evaluator else None
|
||||
self.movement_mode: Literal['hard', 'soft'] | None = init_kwargs['movement_mode'] if self.using_evaluator else None
|
||||
self.sparse_granularity = init_kwargs['sparse_granularity'] if self.using_evaluator else None
|
||||
assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`'
|
||||
|
||||
self._model_parser = parser_factory(model)
|
||||
super().__init__(model, config_list)
|
||||
|
||||
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
|
||||
|
@ -185,20 +227,61 @@ class MovementPruner(EvaluatorBasedPruner):
|
|||
schema.validate(config_list)
|
||||
|
||||
def cubic_schedule(self, current_step: int):
|
||||
if self.warm_up_step < current_step <= self.cool_down_beginning_step:
|
||||
wrapper_dict = self.get_modules_wrapper()
|
||||
for config in self.config_list:
|
||||
scale = 1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3
|
||||
current_sparsity = config['total_sparsity'] * scale
|
||||
for op_name in config['op_names']:
|
||||
wrapper = wrapper_dict[op_name]
|
||||
wrapper.config['total_sparsity'] = current_sparsity
|
||||
wrapper_dict = self.get_modules_wrapper()
|
||||
for config in self.config_list:
|
||||
current_sparsity = config['total_sparsity'] * self._cubic_scale(current_step)
|
||||
for op_name in config['op_names']:
|
||||
# There is an unreachable pyright error if `wrapper_dict[op_name].config['total_sparsity'] = current_sparsity`,
|
||||
# seems a pyright bug...
|
||||
wrapper_config = wrapper_dict[op_name].config
|
||||
wrapper_config['total_sparsity'] = current_sparsity
|
||||
|
||||
def _cubic_scale(self, current_step: int):
|
||||
if self.warm_up_step > current_step:
|
||||
return 0
|
||||
elif current_step > self.cool_down_beginning_step:
|
||||
return 1
|
||||
else:
|
||||
return 1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3
|
||||
|
||||
def _create_scalers(self) -> Scaling | Dict[str, Dict[str, Scaling]]:
|
||||
assert self.bound_model is not None
|
||||
if self.sparse_granularity and self.sparse_granularity == 'auto' and self._model_parser:
|
||||
scalers = {}
|
||||
for module_name, wrapper in self.get_modules_wrapper().items():
|
||||
if self._model_parser.is_attention(module_name):
|
||||
num_heads = self._model_parser.get_num_heads(module_name, self.bound_model)
|
||||
if num_heads <= 0:
|
||||
scalers[module_name] = {'_default': Scaling([1])}
|
||||
else:
|
||||
# assume attention layer weights are 2D
|
||||
weight_h: int = wrapper.module.weight.shape[0] # type: ignore
|
||||
weight_w: int = wrapper.module.weight.shape[1] # type: ignore
|
||||
if weight_h % num_heads != 0 or weight_w % num_heads != 0:
|
||||
scalers[module_name] = {'_default': Scaling([1])}
|
||||
else:
|
||||
block_h = weight_h // num_heads
|
||||
block_w = weight_w // num_heads
|
||||
scalers[module_name] = {'_default': Scaling([block_h, block_w])}
|
||||
elif self._model_parser.is_ffn(module_name, ffn_num=1):
|
||||
scalers[module_name] = {'_default': Scaling([1, wrapper.module.weight.shape[1]])} # type: ignore
|
||||
elif self._model_parser.is_ffn(module_name, ffn_num=2):
|
||||
scalers[module_name] = {'_default': Scaling([wrapper.module.weight.shape[0], 1])} # type: ignore
|
||||
else:
|
||||
scalers[module_name] = {'_default': Scaling([1])}
|
||||
else:
|
||||
scalers = Scaling([1])
|
||||
return scalers
|
||||
|
||||
def reset_tools(self):
|
||||
scalers = self._create_scalers()
|
||||
if not hasattr(self, 'metrics_calculator'):
|
||||
self.metrics_calculator = StraightMetricsCalculator()
|
||||
if not hasattr(self, 'sparsity_allocator'):
|
||||
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
|
||||
if self.movement_mode == 'soft':
|
||||
self.sparsity_allocator = ThresholdSparsityAllocator(self, scalers=scalers, continuous_mask=False)
|
||||
else:
|
||||
self.sparsity_allocator = NormalSparsityAllocator(self, scalers=scalers, continuous_mask=False)
|
||||
|
||||
# use Adam to update the weight_score
|
||||
assert self.bound_model is not None
|
||||
|
@ -206,6 +289,14 @@ class MovementPruner(EvaluatorBasedPruner):
|
|||
optimizer = Adam(params, 1e-2)
|
||||
self.step_counter = 0
|
||||
|
||||
# TODO: waiting for api stable and experiemnts to prove this scheduler is needed.
|
||||
# def lr_lambda(current_step: int):
|
||||
# if current_step < self.warm_up_step:
|
||||
# return float(current_step) / self.warm_up_step
|
||||
# return max(0.0, float(147264 - current_step) / float(147264 - self.warm_up_step))
|
||||
|
||||
# lr_scheduler = LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
# update the masks after each optimzier step
|
||||
def _optimizer_patch():
|
||||
optimizer.step()
|
||||
|
@ -221,6 +312,17 @@ class MovementPruner(EvaluatorBasedPruner):
|
|||
masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
|
||||
self.load_masks(masks)
|
||||
|
||||
def _loss_patch(origin_loss: Tensor):
|
||||
if self.regular_scale is not None:
|
||||
l1_reg = 0
|
||||
count = 0
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
l1_reg += torch.norm(torch.sigmoid(wrapper.weight_score), p=1) / wrapper.weight_score.numel() # type: ignore
|
||||
count += 1
|
||||
return origin_loss + self.regular_scale * self._cubic_scale(self.step_counter) * l1_reg / count
|
||||
else:
|
||||
return origin_loss
|
||||
|
||||
if self.using_evaluator:
|
||||
# TODO: move to other place in nni v3.0
|
||||
self.evaluator.unbind_model()
|
||||
|
@ -228,7 +330,9 @@ class MovementPruner(EvaluatorBasedPruner):
|
|||
if not hasattr(self, 'data_collector'):
|
||||
self.data_collector = EvaluatorBasedScoreDataCollector(self, self.evaluator,
|
||||
after_opt_step_tasks=[_optimizer_patch],
|
||||
max_epochs=self.training_epochs)
|
||||
max_epochs=self.training_epochs,
|
||||
max_steps=self.training_steps,
|
||||
loss_patch=_loss_patch)
|
||||
else:
|
||||
self.data_collector.reset(after_opt_step_tasks=[_optimizer_patch])
|
||||
else:
|
||||
|
@ -252,7 +356,27 @@ class MovementPruner(EvaluatorBasedPruner):
|
|||
The configuration for generating the mask.
|
||||
"""
|
||||
_logger.debug("Module detected to compress : %s.", layer.name)
|
||||
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config)
|
||||
assert self.bound_model is not None
|
||||
# TODO: merge with _create_scalers after nni v3.0
|
||||
if self.sparse_granularity and self.sparse_granularity == 'auto' and self._model_parser:
|
||||
if self._model_parser.is_attention(layer.name):
|
||||
num_heads = self._model_parser.get_num_heads(layer.name, self.bound_model)
|
||||
if num_heads <= 0:
|
||||
score_size = None
|
||||
else:
|
||||
if layer.module.weight.shape[0] % num_heads != 0 or layer.module.weight.shape[1] % num_heads != 0: # type: ignore
|
||||
score_size = None
|
||||
else:
|
||||
score_size = [num_heads, num_heads]
|
||||
elif self._model_parser.is_ffn(layer.name, ffn_num=1):
|
||||
score_size = [layer.module.weight.shape[0], 1] # type: ignore
|
||||
elif self._model_parser.is_ffn(layer.name, ffn_num=2):
|
||||
score_size = [1, layer.module.weight.shape[1]] # type: ignore
|
||||
else:
|
||||
score_size = None
|
||||
else:
|
||||
score_size = None
|
||||
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config, score_size)
|
||||
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
|
||||
# move newly registered buffers to the same device of weight
|
||||
wrapper.to(layer.module.weight.device) # type: ignore
|
||||
|
|
|
@ -29,6 +29,7 @@ from .metrics_calculator import (
|
|||
)
|
||||
from .sparsity_allocator import (
|
||||
NormalSparsityAllocator,
|
||||
ThresholdSparsityAllocator,
|
||||
BankSparsityAllocator,
|
||||
GlobalSparsityAllocator,
|
||||
DependencyAwareAllocator
|
||||
|
|
|
@ -6,7 +6,8 @@ from datetime import datetime
|
|||
import logging
|
||||
from pathlib import Path
|
||||
import types
|
||||
from typing import List, Dict, Literal, Tuple, Optional, Callable, Union
|
||||
from typing import List, Dict, Tuple, Optional, Callable, Union
|
||||
from typing_extensions import Literal
|
||||
|
||||
import json_tricks
|
||||
import torch
|
||||
|
|
|
@ -24,7 +24,7 @@ class StraightMetricsCalculator(MetricsCalculator):
|
|||
for module_name, targets_data in data.items():
|
||||
metrics[module_name] = {}
|
||||
for target_name, target_data in targets_data.items():
|
||||
metrics[module_name][target_name] = target_data.clone().detach()
|
||||
metrics[module_name][target_name] = self._get_scaler(module_name, target_name).shrink(target_data)
|
||||
return metrics
|
||||
|
||||
|
||||
|
|
|
@ -31,13 +31,28 @@ class NormalSparsityAllocator(SparsityAllocator):
|
|||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
for target_name, target_metric in targets_metric.items():
|
||||
sparsity_rate = wrapper.config['total_sparsity']
|
||||
prune_num = int(sparsity_rate * target_metric.numel())
|
||||
if prune_num != 0:
|
||||
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
|
||||
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
|
||||
else:
|
||||
# target_metric should have the same size as shrinked_mask
|
||||
shrinked_mask = torch.ones_like(target_metric)
|
||||
flatten_metric = target_metric.reshape(-1)
|
||||
kept_num = flatten_metric.numel() - int(sparsity_rate * flatten_metric.numel())
|
||||
kept_indices = torch.topk(flatten_metric, kept_num).indices
|
||||
shrinked_mask = torch.zeros_like(flatten_metric).scatter(0, kept_indices, 1.0).reshape_as(target_metric)
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
|
||||
return masks
|
||||
|
||||
|
||||
class ThresholdSparsityAllocator(SparsityAllocator):
|
||||
"""
|
||||
Note: This allocator is an experimental allocator.
|
||||
It takes 'total_sparsity' as threshold to mask the pruning target where metric is lower then threshold.
|
||||
"""
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
masks = {}
|
||||
# TODO: Support more target type in wrapper & config list refactor
|
||||
for module_name, targets_metric in metrics.items():
|
||||
masks[module_name] = {}
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
for target_name, target_metric in targets_metric.items():
|
||||
threshold = wrapper.config['total_sparsity']
|
||||
shrinked_mask = torch.gt(torch.sigmoid(target_metric), threshold).type_as(target_metric)
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
|
||||
return masks
|
||||
|
||||
|
@ -115,10 +130,10 @@ class GlobalSparsityAllocator(SparsityAllocator):
|
|||
assert global_sparsity_rate == wrapper.config['total_sparsity']
|
||||
|
||||
# find the largest metric value among all metrics
|
||||
max_metric_value = list(list(metrics.values())[0].values())[0].max()
|
||||
max_metric_value = list(list(metrics.values())[0].values())[0].max().item()
|
||||
for targets_metric in metrics.values():
|
||||
for target_metric in targets_metric.values():
|
||||
max_metric_value = max_metric_value if max_metric_value >= target_metric.max() else target_metric.max()
|
||||
max_metric_value = max_metric_value if max_metric_value >= target_metric.max().item() else target_metric.max().item()
|
||||
|
||||
# prevent each module from being over-pruned, prevent ratio is 'max_sparsity_per_layer'
|
||||
for module_name, targets_metric in metrics.items():
|
||||
|
@ -127,10 +142,10 @@ class GlobalSparsityAllocator(SparsityAllocator):
|
|||
max_sparsity = wrapper.config.get('max_sparsity_per_layer', {}).get(module_name, 0.99)
|
||||
assert 0 <= max_sparsity <= 1
|
||||
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
|
||||
expand_times = old_target_mask.numel() // target_metric.numel()
|
||||
max_pruning_numel = int(max_sparsity * target_metric.numel()) * expand_times
|
||||
threshold = torch.topk(target_metric.reshape(-1), max_pruning_numel, largest=False)[0].max()
|
||||
metrics[module_name][target_name] = torch.where(target_metric <= threshold, target_metric, max_metric_value)
|
||||
flatten_metric = target_metric.reshape(-1)
|
||||
protected_pruning_numel = target_metric.numel() - int(max_sparsity * target_metric.numel())
|
||||
protected_indices = torch.topk(flatten_metric, protected_pruning_numel).indices
|
||||
metrics[module_name][target_name] = flatten_metric.scatter(0, protected_indices, max_metric_value).reshape_as(target_metric)
|
||||
|
||||
# build the global_matric & calculate global threshold
|
||||
metric_list = []
|
||||
|
@ -207,7 +222,7 @@ class DependencyAwareAllocator(SparsityAllocator):
|
|||
fused_metrics = self._metric_fuse(sub_metrics)
|
||||
|
||||
for target_name, fused_metric in fused_metrics.items():
|
||||
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity'] \
|
||||
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity']
|
||||
for module_name in sub_metrics.keys()}
|
||||
min_sparsity_rate = min(sparsity_rates.values())
|
||||
|
||||
|
|
|
@ -14,8 +14,13 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
try:
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
except ImportError:
|
||||
LightingInstalled = False
|
||||
else:
|
||||
LightingInstalled = True
|
||||
|
||||
from nni.common import is_traceable
|
||||
from .constructor_helper import OptimizerConstructHelper, LRSchedulerConstructHelper
|
||||
|
@ -292,6 +297,7 @@ class LightningEvaluator(Evaluator):
|
|||
|
||||
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
|
||||
dummy_input: Any | None = None):
|
||||
assert LightingInstalled, 'pytorch_lightning is not installed.'
|
||||
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
|
||||
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
|
||||
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
try:
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
T5Config
|
||||
)
|
||||
except ImportError:
|
||||
TRANSFORMERS_INSTALLED = False
|
||||
else:
|
||||
TRANSFORMERS_INSTALLED = True
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.utils.attr import get_nested_attr
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# huggingface transformers pretrained model parser supported: bart, bert, t5
|
||||
def parser_factory(model: Module) -> HuggingfaceModelParser | None:
|
||||
if TRANSFORMERS_INSTALLED and isinstance(model, PreTrainedModel):
|
||||
cls2parser = {
|
||||
BartConfig: HuggingfaceBartParser,
|
||||
BertConfig: HuggingfaceBertParser,
|
||||
T5Config: HuggingfaceT5Parser
|
||||
}
|
||||
type2parser = {
|
||||
'bart': HuggingfaceBartParser,
|
||||
'bert': HuggingfaceBertParser,
|
||||
't5': HuggingfaceT5Parser
|
||||
}
|
||||
|
||||
if hasattr(model, 'config_class'):
|
||||
parser = cls2parser.get(getattr(model, 'config_class'))
|
||||
elif hasattr(model, 'model_type'):
|
||||
parser = type2parser.get(getattr(model, 'model_type'))
|
||||
else:
|
||||
parser = None
|
||||
|
||||
return parser
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class HuggingfaceModelParser:
|
||||
# This class is used to verify that a module name belongs to a specific huggingface transformers pretrained model.
|
||||
# Further, verify that the module with this name is some kind of special layer (QKVO or FFN).
|
||||
TRANSFORMER_PREFIX: str
|
||||
QKV: Tuple[str, ...]
|
||||
QKVO: Tuple[str, ...]
|
||||
FFN1: Tuple[str, ...]
|
||||
FFN2: Tuple[str, ...]
|
||||
ATTENTION: Tuple[str, ...]
|
||||
|
||||
@classmethod
|
||||
def is_huggingface_model(cls, model: Module):
|
||||
return model.__module__.split('.')[0] == 'transformers'
|
||||
|
||||
@classmethod
|
||||
def is_attention(cls, module_name: str, include_output: bool = True) -> bool:
|
||||
patterns = cls.QKVO if include_output else cls.QKV
|
||||
for pattern in patterns:
|
||||
if pattern in module_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_ffn(cls, module_name: str, ffn_num: int = 1) -> bool:
|
||||
if cls.is_attention(module_name):
|
||||
return False
|
||||
if ffn_num == 1:
|
||||
for pattern in cls.FFN1:
|
||||
if pattern in module_name:
|
||||
return True
|
||||
if ffn_num == 2:
|
||||
for pattern in cls.FFN2:
|
||||
if pattern in module_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_num_heads(cls, module_name: str, model: Module) -> int:
|
||||
if cls.is_attention(module_name, include_output=True):
|
||||
for pattern in cls.ATTENTION:
|
||||
match = re.search(pattern, module_name)
|
||||
if match:
|
||||
attention_module_name = module_name[0: match.span()[1]]
|
||||
module = get_nested_attr(model, attention_module_name)
|
||||
if hasattr(module, 'num_attention_heads'):
|
||||
num_heads = module.num_attention_heads
|
||||
elif hasattr(module, 'num_heads'):
|
||||
num_heads = module.num_heads
|
||||
elif hasattr(module, 'n_heads'):
|
||||
num_heads = module.n_heads
|
||||
else:
|
||||
warn_msg = f'Can not get the heads number of attention layer : {attention_module_name}.'
|
||||
_logger.warning(warn_msg)
|
||||
num_heads = 0
|
||||
return num_heads
|
||||
return 0
|
||||
else:
|
||||
warn_msg = f'The layer `{module_name}` might not an (Q|K|V) attention layer.'
|
||||
_logger.warning(warn_msg)
|
||||
return 0
|
||||
|
||||
|
||||
class HuggingfaceBertParser(HuggingfaceModelParser):
|
||||
TRANSFORMER_PREFIX = r'bert\.encoder\.layer\.[0-9]+\.'
|
||||
QKV = ('attention.self.query', 'attention.self.key', 'attention.self.value')
|
||||
QKVO = QKV + ('attention.output.dense',)
|
||||
FFN1 = ('intermediate.dense',)
|
||||
FFN2 = ('output.dense',)
|
||||
ATTENTION = ('attention.self',)
|
||||
|
||||
|
||||
class HuggingfaceBartParser(HuggingfaceModelParser):
|
||||
TRANSFORMER_PREFIX = r'(en|de)coder\.layer\.[0-9]+\.'
|
||||
QKV = ('self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'encoder_attn.q_proj', 'encoder_attn.k_proj', 'encoder_attn.v_proj')
|
||||
QKVO = QKV + ('self_attn.out_proj', 'encoder_attn.out_proj')
|
||||
FFN1 = ('fc1',)
|
||||
FFN2 = ('fc2',)
|
||||
ATTENTION = ('self_attn', 'encoder_attn')
|
||||
|
||||
|
||||
class HuggingfaceT5Parser(HuggingfaceModelParser):
|
||||
TRANSFORMER_PREFIX = r'(en|de)coder\.block\.[0-9]+\.layer\.[0-9]+.'
|
||||
QKV = ('SelfAttention.q', 'SelfAttention.k', 'SelfAttention.v', 'EncDecAttention.q', 'EncDecAttention.k', 'EncDecAttention.v')
|
||||
QKVO = QKV + ('SelfAttention.o', 'EncDecAttention.o')
|
||||
FFN1 = ('DenseReluDense.wi',)
|
||||
FFN2 = ('DenseReluDense.wo',)
|
||||
ATTENTION = ('SelfAttention', 'EncDecAttention')
|
|
@ -122,8 +122,9 @@ class Scaling:
|
|||
permute_dims = [2 * _ for _ in range(len(kernel_size))] + [2 * _ + 1 for _ in range(len(kernel_size))]
|
||||
converted_target = target.reshape(reshape_size).permute(permute_dims).reshape(final_size + [-1])
|
||||
|
||||
# step 2: reduce the converted_target last dim with a certain way, by default is converted_target.sum(-1).
|
||||
result = reduce_func(converted_target) if reduce_func else converted_target.sum(-1)
|
||||
# step 2: reduce the converted_target last dim with a certain way, by default is converted_target.mean(-1).
|
||||
# `sum` does not take into account the metric scale problem, it is better to use `mean` here.
|
||||
result = reduce_func(converted_target) if reduce_func else converted_target.mean(-1)
|
||||
|
||||
# step 3: reduce the dims where kernel_size is -1.
|
||||
# e.g., target size is [10, 40], kernel_size is [-1, 4], result size is [1, 10], then reduce result to size [10].
|
||||
|
|
|
@ -75,7 +75,19 @@ class TorchGraph:
|
|||
if torch.__version__ >= '1.6.0':
|
||||
# only pytorch with version greater than 1.6.0 has the strict option
|
||||
kw_args['strict'] = False
|
||||
self.trace = torch.jit.trace(model, dummy_input, **kw_args)
|
||||
try:
|
||||
import pytorch_lightning as pl
|
||||
except ImportError:
|
||||
is_lightning_module = False
|
||||
else:
|
||||
if isinstance(model, pl.LightningModule):
|
||||
is_lightning_module = True
|
||||
else:
|
||||
is_lightning_module = False
|
||||
if is_lightning_module:
|
||||
self.trace = model.to_torchscript(method="trace", example_inputs=dummy_input, **kw_args)
|
||||
else:
|
||||
self.trace = torch.jit.trace(model, dummy_input, **kw_args)
|
||||
torch._C._jit_pass_inline(self.trace.graph)
|
||||
model.train(training)
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ replace_module = {
|
|||
'SELU': lambda module, masks: no_replace(module, masks),
|
||||
'CELU': lambda module, masks: no_replace(module, masks),
|
||||
'GELU': lambda module, masks: no_replace(module, masks),
|
||||
'GELUActivation': lambda module, masks: no_replace(module, masks),
|
||||
'Sigmoid': lambda module, masks: no_replace(module, masks),
|
||||
'SiLU': lambda module, masks: no_replace(module, masks),
|
||||
'Mish': lambda module, masks: no_replace(module, masks),
|
||||
|
@ -74,6 +75,7 @@ def convert_to_coarse_mask(t_mask, dim):
|
|||
n_dims = len(shape)
|
||||
dim_list = list(range(n_dims))
|
||||
# try to reduce the mask from the dim-th dimension
|
||||
dim = dim if dim >= 0 else n_dims + dim
|
||||
dim_list.remove(dim)
|
||||
|
||||
t_merged = torch.sum(t_mask, dim_list)
|
||||
|
@ -190,12 +192,9 @@ def replace_linear(linear, masks):
|
|||
in_mask = in_masks[0]
|
||||
|
||||
weight_mask = weight_mask['weight']
|
||||
# the input of the linear may have two dimensions(CV models) or three
|
||||
# dimensions(Bert, for example)
|
||||
n_dim = len(in_mask.size())
|
||||
# N C K
|
||||
pruned_in, remained_in = convert_to_coarse_mask(in_mask, n_dim-1)
|
||||
pruned_out, remained_out = convert_to_coarse_mask(output_mask, n_dim-1)
|
||||
pruned_in, remained_in = convert_to_coarse_mask(in_mask, -1)
|
||||
pruned_out, remained_out = convert_to_coarse_mask(output_mask, -1)
|
||||
n_remained_in = weight_mask.size(1) - pruned_in.size(0)
|
||||
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
|
||||
remained_in, remained_out = remained_in.to(
|
||||
|
@ -610,11 +609,29 @@ def replace_layernorm(layernorm, masks):
|
|||
if len(in_masks) != 1:
|
||||
raise InputsNumberError()
|
||||
in_mask = in_masks[0]
|
||||
dense_shape = convert_dense_shape(in_mask)
|
||||
norm_shape = layernorm.normalized_shape
|
||||
dim_n = len(dense_shape) - len(norm_shape)
|
||||
return nn.LayerNorm(dense_shape[dim_n:], layernorm.eps, layernorm.elementwise_affine)
|
||||
|
||||
old_normalized_shape = layernorm.normalized_shape
|
||||
new_normalized_shape = []
|
||||
remained_list = []
|
||||
for i in range(-len(old_normalized_shape), 0):
|
||||
pruned, remained = convert_to_coarse_mask(in_mask, i)
|
||||
new_normalized_shape.append(old_normalized_shape[i] - pruned.size()[0])
|
||||
remained_list.append(remained)
|
||||
|
||||
new_layernorm = nn.LayerNorm(tuple(new_normalized_shape), layernorm.eps, layernorm.elementwise_affine)
|
||||
|
||||
if new_layernorm.elementwise_affine:
|
||||
new_layernorm.to(layernorm.weight.device)
|
||||
# NOTE: should we keep the weight & bias?
|
||||
with torch.no_grad():
|
||||
tmp_weight_data = layernorm.weight.data
|
||||
tmp_bias_data = layernorm.bias.data
|
||||
for i, remained in enumerate(remained_list):
|
||||
tmp_weight_data = torch.index_select(tmp_weight_data, i, remained)
|
||||
tmp_bias_data = torch.index_select(tmp_bias_data, i, remained)
|
||||
new_layernorm.weight.data = tmp_weight_data
|
||||
new_layernorm.bias.data = tmp_bias_data
|
||||
return new_layernorm
|
||||
|
||||
def replace_embedding(embedding, masks):
|
||||
"""
|
||||
|
|
|
@ -45,7 +45,19 @@ def fix_mask_conflict(masks, model, dummy_input, traced=None):
|
|||
if torch.__version__ >= '1.6.0':
|
||||
# only pytorch with version greater than 1.6.0 has the strict option
|
||||
kw_args['strict'] = False
|
||||
traced = torch.jit.trace(model, dummy_input, **kw_args)
|
||||
try:
|
||||
import pytorch_lightning as pl
|
||||
except ImportError:
|
||||
is_lightning_module = False
|
||||
else:
|
||||
if isinstance(model, pl.LightningModule):
|
||||
is_lightning_module = True
|
||||
else:
|
||||
is_lightning_module = False
|
||||
if is_lightning_module:
|
||||
traced = model.to_torchscript(method="trace", example_inputs=dummy_input, **kw_args)
|
||||
else:
|
||||
traced = torch.jit.trace(model, dummy_input, **kw_args)
|
||||
model.train(training)
|
||||
|
||||
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
|
||||
|
|
|
@ -42,10 +42,6 @@ stages:
|
|||
platform: ubuntu-latest-gpu
|
||||
python_env: venv
|
||||
|
||||
- script: |
|
||||
python -m pip install "pytorch-lightning<1.7"
|
||||
displayName: Pin PytorchLightning version
|
||||
|
||||
- template: templates/install-nni.yml
|
||||
|
||||
- template: templates/download-test-data.yml
|
||||
|
|
|
@ -8,7 +8,7 @@ from nni.algorithms.compression.v2.pytorch.utils.scaling import Scaling
|
|||
|
||||
|
||||
def test_scaling():
|
||||
data = torch.tensor([_ for _ in range(100)]).reshape(10, 10)
|
||||
data = torch.tensor([_ for _ in range(100)], dtype=torch.float32).reshape(10, 10)
|
||||
|
||||
scaler = Scaling([5], kernel_padding_mode='front')
|
||||
shrinked_data = scaler.shrink(data)
|
||||
|
|
Загрузка…
Ссылка в новой задаче