This commit is contained in:
J-shang 2023-04-25 16:46:15 +08:00 коммит произвёл GitHub
Родитель 0f0d145cbe
Коммит f6f0f65941
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
38 изменённых файлов: 1621 добавлений и 61 удалений

Просмотреть файл

@ -0,0 +1,107 @@
Major Enhancement of Compression in NNI 3.0
===========================================
To bolster additional compression scenarios and more particular compression configurations,
we have revised the compression application programming interface (API) in NNI 3.0.
If you are a beginner to NNI Compression, you could bypass this document.
Nonetheless, if you have employed NNI Compression before and want to try the latest Compression version,
this document will help you in comprehending the noteworthy alterations in the interface in 3.0.
New compression version import path:
.. code-block:: python
# most new compression related, include pruners, quantizers, distillers, except new pruning speedup
from nni.contrib.compression.xxx import xxx
# new pruning speedup
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
Old compression version import path:
.. code-block:: python
from nni.compression.pytorch.xxx import xxx
Compression Target
------------------
The notion of ``compression target`` is a novel concept introduced in NNI 3.0.
It refers to the specific parts of a module that should be compressed, such as input, output or weights.
In previous versions, NNI assumed that all module types should have parameters named ``weight`` and ``bias``,
and only produced masks for these parameters.
This assumption was suitable for a significant degree of simulation compression.
However, it is undeniable that there are still many modules that do not fit into this assumption,
particularly for customized modules.
Therefore, in NNI 3.0, model compression can configure specifically for the level of input, output, and parameters of the module.
By means of fine-grained configuration, NNI can not only compress module types that were previously uncompressible,
but also achieve better simulation compression.
As a result, the gap in accuracy between simulation compression and real speedup becomes extremely small.
For instance, in previous versions, the operation of ``softmax`` would significantly diminish the effect of simulated pruning,
since 0 as input is also meaningful for ``softmax``.
In NNI 3.0, this can be avoided by setting the input and output masks and ``apply_method``
to ensure that ``softmax`` obtains the correct simulated pruning result.
Please consult the sections on :ref:`target_names` and :ref:`target_settings` for further details.
Compression Mode
----------------
In the previous version of NNI (lower than 3.0), three pruning modes were supported: ``normal``, ``global``, and ``dependency-aware``.
In the ``normal`` mode, each module was required to be assigned a sparse ratio, and the pruner generated masks directly on the weight elements of this ratio.
In the ``global`` mode, a sparse ratio was set for a group of modules, and the pruner generated masks whose overall sparse ratio conformed to the setting,
but the sparsity of each module in the group may differ.
The ``dependency-aware`` mode constrained modules with operational dependencies to generate related masks.
For instance, if the outputs of two modules had an ``add`` relationship, then the two modules would have the same masks in the output dimension.
Different modes were better suited to different compression scenarios to achieve improved compression effects.
Nevertheless, we believe that more flexible combinations should be allowed.
For example, in a compression process, certain modules of similar levels could apply the overall sparse ratio,
while other modules with operational dependencies could generate similar masks at the same time.
Right now in NNI 3.0, users can directly set :ref:`global_group_id` and :ref:`dependency_group_id` to implement ``global`` and ``dependency-aware`` modes.
Additionally, :ref:`align` is supported to generate a mask from another module mask, such as generating a batch normalization mask from a convolution mask.
You can achieve improved performance and exploration by combining these modes by setting the appropriate keys in the configuration list.
Pruning Speedup
---------------
The previous method of pruning speedup relied on ``torch.jit.trace`` to trace the model graph.
However, this method had several limitations and required additional support to perform certain operations.
These limitations resulted in excessive maintenance costs, making it difficult to continue development.
To address these issues, in NNI 3.0, we refactored the pruning speedup based on ``concrete_trace``.
This is a useful utility for tracing a model graph, based on ``torch.fx``.
Unlike ``torch.fx.symbolic_trace``, ``concrete_trace`` executes the entire model, resulting in a more complete graph.
As a result, most operations that couldn't be traced in the previous pruning speedup can now be traced.
In addition to ``concrete_trace``, users who have a good ``torch.fx.GraphModule`` for their traced model can also use the ``torch.fx.GraphModule`` directly.
Furthermore, the new pruning speedup supports customized masks propagation logic and module replacement methods to cope with the speedup of various customized modules.
Distillation
------------
Two distillers is supported in NNI 3.0. By pruning or quantization fused distillation, it can get better compression results and higher precision.
Please refer :doc:`Distiller <../reference/compression_preview/distiller>` for more details.
Fusion Compressoin
------------------
Thanks to the new unified compression framework, it is now possible to perform pruning, quantization, and distillation simultaneously,
without having to apply them one by one.
Please refer :doc:`fusion compression <./fusion_compress>` for more details.

Просмотреть файл

@ -0,0 +1,318 @@
Compression Config Specification
================================
Common Keys in Config
---------------------
op_names
^^^^^^^^
A list of fully-qualified name of modules (e.g., ``['backbone.layers.0.ffn', ...]``) that will be compressed.
If the name referenced module is not existed in the model, it will be ignored.
op_names_re
^^^^^^^^^^^
A list of regular expressions for matching module names by python standard library ``re``.
The matched modules will be selected to be compressed.
op_types
^^^^^^^^
A list of type names of classes that inherit from ``torch.nn.Module``.
Only module types in this list can be selected to be compressed.
If this key is not set, all module types can be selected.
If neither ``op_names`` or ``op_names_re`` are set, all modules satisfied the ``op_types`` are selected.
exclude_op_names
^^^^^^^^^^^^^^^^
A list of fully-qualified name of modules that are excluded.
exclude_op_names_re
^^^^^^^^^^^^^^^^^^^
A list of regular expressions for matching module names.
The matched modules will be removed from the modules that need to be compressed.
exclude_op_types
^^^^^^^^^^^^^^^^
A list of type names of classes that inherit from ``torch.nn.Module``.
The module types in this list are excluded from compression.
.. _target_names:
target_names
^^^^^^^^^^^^
A list of legal compression target name, i.e., usually ``_input_``, ``weight``, ``bias``, ``_output_`` are support to be compressed.
Two kinds of target are supported by design, module inputs/outputs(should be a tensor), module parameters:
- Inputs/Outputs: If the module inputs or outputs is a singal tensor, directly set ``_input_`` for input and ``_output_`` for output.
``_input_{position_index}`` or ``_input_{arg_name}`` can be used to specify the input target,
i.e., for a forward function ``def forward(self, x: Tensor, y: Tensor, z: Any): ...``, ``_input_0`` or ``_input_x`` can be used to specify ``x`` to be compressed,
note that ``self`` will be ignored when counting the position index.
Similarly, ``_output_{position_index}`` can be used to specify the output target if the output is a ``list/tuple``,
``_output_{dict_key}`` can be used to specify the output target if the output is a ``dict``.
- Parameters/Buffers: Directly using the attribute name to specify the target, i.e., ``weight``, ``bias``.
.. _target_settings:
target_settings
^^^^^^^^^^^^^^^
A ``dict`` of target settings, the format is ``{target_name: setting}``. Target setting usually configure how to compress the target.
All other keys(except these eight common keys) in a config will seems as a shortcut of target setting key, and will apply to all targets selected in this config.
For example, consider a model has two ``Linear`` module (linear module names are ``'fc1'`` and ``'fc2'``), the following configs have same effect for pruning.
.. code-block:: python
shorthand_config = {
'op_types': ['Linear'],
'sparse_ratio': 0.8
}
standard_config = {
'op_names': ['fc1', 'fc2'],
'target_names': ['weight', 'bias'],
'target_settings': {
'weight': {
'sparse_ratio': 0.8,
'max_sparse_ratio': None,
'min_sparse_ratio': None,
'sparse_threshold': None,
'global_group_id': None,
'dependency_group_id': None,
'granularity': 'default',
'internal_metric_block': None,
'apply_method': 'mul',
},
'bias': {
'align': {
'target_name': 'weight',
'dims': [0],
},
'apply_method': 'mul',
}
}
}
.. Note:: Each compression target can only be configure once, re-configuration will not take effect.
Pruning Specific Configuration Keys
-----------------------------------
sparse_ratio
^^^^^^^^^^^^
A float number between 0. ~ 1., the sparse ratio of the pruning target or the total sparse ratio of a group of pruning targets.
For example, if the sparse ratio is 0.8, and the pruning target is a Linear module weight, 80% weight value will be masked after pruning.
max_sparse_ratio
^^^^^^^^^^^^^^^^
This key is usually used in combination with ``sparse_threshold`` and ``global_group_id``, limit the maximum sparse ratio of each target.
A float number between 0. ~ 1., for each single pruning target, the sparse ratio after pruning will not be larger than this number,
that means at most masked ``max_sparse_ratio`` pruning target value.
min_sparse_ratio
^^^^^^^^^^^^^^^^
This key is usually used in combination with ``sparse_threshold`` and ``global_group_id``, limit the minimum sparse ratio of each target.
A float number between 0. ~ 1., for each single pruning target, the sparse ratio after pruning will not be lower than this number,
that means at least masked ``min_sparse_ratio`` pruning target value.
sparse_threshold
^^^^^^^^^^^^^^^^
A float number, different from the ``sparse_ratio`` which configures a specific sparsity, ``sparse_threshold`` usually used in some adaptive sparse cases.
``sparse_threshold`` is directly compared to pruning metrics (different in different algorithms) and the positions smaller than the threshold are masked.
The value range is different for different pruning algorithms, please reference the pruner document to see how to configure it.
In general, the higher the threshold, the higher the final sparsity.
.. _global_group_id:
global_group_id
^^^^^^^^^^^^^^^
``global_group_id`` should jointly used with ``sparse_ratio``.
All pruning targets that have same ``global_group_id`` will be treat as a whole, and the ``sparse_ratio`` will be distributed across pruning targets.
That means each pruning target might have different sparse ratio after pruning, but the group sparse ratio will be the configured ``sparse_ratio``.
Note that the ``sparse_ratio`` in the same global group should be the same.
For example, a model has three ``Linear`` modules (``'fc1'``, ``'fc2'``, ``'fc3'``),
and the expected total sparse ratio of these three modules is 0.5, then the config can be:
.. code-block:: python
config_list = [{
'op_names': ['fc1', 'fc2'],
'sparse_ratio': 0.5,
'global_group_id': 'linear_group_1'
}, {
'op_names': ['fc3'],
'sparse_ratio': 0.5,
'global_group_id': 'linear_group_1'
}]
.. _dependency_group_id:
dependency_group_id
^^^^^^^^^^^^^^^^^^^
All pruning targets that have same ``dependency_group_id`` will be treat as a whole, and the positions the targets' pruned will be the same.
For example, layer A and layer B have same ``dependency_group_id``, and they want to be pruned output channels, then A and B will be pruned the same channel indexes.
Note that the ``sparse_ratio`` in the same dependency group should be the same, and the prunable positions (after reduction by ``granularity``) should be same,
for example, pruning targets should have same output channel number when pruning output channel.
This key usually be used on modules with add or mul operation, i.e., skip connection.
If you don't know your model structure well, you could use :ref:`auto_set_denpendency_group_ids` to auto detect the dependency operations and auto set their ``dependency_group_id``.
granularity
^^^^^^^^^^^
Control the granularity of the generated masked.
``default``, ``in_channel``, ``out_channel``, ``per_channel`` and list of integer are supported:
- default: The pruner will auto determine using which kind of granularity, usually consistent with the paper.
- in_channel: The pruner will do pruning on the weight parameters 1 dimension.
- out_channel: The pruner will do pruning on the weight parameters 0 dimension.
- per_channel: The pruner will do pruning on the input/output -1 dimension.
- list of integer: Block sparse will be applied. For example, ``[4, 4]`` will apply 4x4 block sparse on the last two dimensions of the weight parameters.
Note that ``in_channel`` or ``out_channel`` is not supported for input/output targets, please using ``per_channel`` instead.
``torch.nn.Embedding`` is special, it's output dimension on weight is 1, so if want to pruning Embedding output channel, please set ``in_channel`` for its granularity for workaround.
The following is an example for output channel pruning:
.. code-block:: python
config = {
'op_types': ['Conv2d'],
'sparse_ratio': 0.5,
'granularity': 'out_channel' # same as [1, -1, -1, -1]
}
.. _align:
align
^^^^^
``align`` refers to the process where the target mask will not be generated by the pruning algorithm but is created in accordance with another pruning target mask.
A typical scenario occurs in most PyTorch native modules with ``weight`` and ``bias`` attributes.
In this case, the generation of the ``bias`` mask is aligned with the ``weight`` mask generation,
meaning that a ``bias`` value is masked or not depending on whether the related ``weight`` values are all masked or not.
For example, in all pruners, a ``bias`` in a ``Linear`` layer masks the ``i`` position when the ``i`` row values in ``weight`` are all masked.
This can also prove useful for generating activation masks (output of activation modules).
For instance, consider the common pattern in transformers: ``conv-bn-relu``.
Here, the ``bn`` weight or output can be masked in alignment with the convolution weight for a more effective pruning simulation.
.. code-block:: python
config = {
'op_types': ['BatchNorm2d'],
'target_names': ['weight'],
'target_settings': {
'weight': {
'align': {
'module_name': 'conv',
'target_name': 'weight',
'dims': [0],
}
}
}
}
The mentioned configuration implies that the batch normalization layer should align with the ``weight`` of the ``conv`` layer along dimension ``0``.
module_name
"""""""""""
By default, the current configured module. The name of the module that align with.
target_name
"""""""""""
Align with which tagret mask of the specified module.
dims
""""
Align the mask on which dim of the specified target.
apply_method
^^^^^^^^^^^^
By default, ``mul``. ``mul`` and ``add`` is supported to apply mask on pruning target.
``mul`` means the pruning target will be masked by multiply a mask metrix contains 0 and 1, 0 represents masked position, 1 represents unmasked position.
``add`` means the pruning target will be masked by add a mask metrix contains -1000 and 0, -1000 represents masked position, 0 represents unmasked position.
Note that -1000 can be configured in the future. ``add`` usually be used to mask activation module such as Softmax.
Quantization Specific Configuration Keys
----------------------------------------
quant_dtype
^^^^^^^^^^^
By default, ``int8``. Support ``int`` and ``uint`` plus quant bits.
quant_scheme
^^^^^^^^^^^^
``affine`` or ``symmetric``. If this key is not set, the quantization scheme will be choosen by quantizer,
most quantizer will apply ``symmetric`` quantization.
granularity
^^^^^^^^^^^
Used to control the granularity of the target quantization, by default the whole tensor will use the same scale and zero point.
``per_channel`` and list of integer are supported:
- ``per_channel``: Each (ouput) channel will have their independent scales and zero points.
- list of integer: The integer list is the block size. Each block will have their independent scales and zero points.
Each sub-config in the config list is a dict, and the scope of each setting (key) is only internal to each sub-config.
If multiple sub-configs are configured for the same layer, the later ones will overwrite the previous ones.
Distillation Specific Configuration Keys
----------------------------------------
lambda
^^^^^^
A float number. The scale factor of the distillation loss.
The final distil loss for the specific target is ``lambda * distil_loss_func(student_target, teacher_target)``.
link
^^^^
A teacher module name or a list of teacher module names. The student module link to.
apply_method
^^^^^^^^^^^^
``mse`` or ``kl``.
``mse`` means the MSE loss, usually used to distill hidden states.
Please reference `mse_loss <https://pytorch.org/docs/stable/generated/torch.nn.functional.mse_loss.html>`__.
``kl`` means the KL loss, usually used to distill logits.
The implementation is ``kl_div((stu_hs / 2).log_softmax(dim=-1), (tea_hs / 2).softmax(dim=-1), reduction='batchmean') * (2 ** 2)``,
please reference `kl_div <https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html>`__.

Просмотреть файл

@ -0,0 +1,160 @@
Compression Evaluator
=====================
The ``Evaluator`` is used to package the training and evaluation process for a targeted model.
To explain why NNI needs an ``Evaluator``, let's first look at the general process of model compression in NNI.
In model pruning, some algorithms need to prune according to some intermediate variables (gradients, activations, etc.) generated during the training process,
and some algorithms need to gradually increase or adjust the sparsity of different layers during the training process,
or adjust the pruning strategy according to the performance changes of the model during the pruning process.
In model quantization, NNI has quantization-aware training algorithm,
it can adjust the scale and zero point required for model quantization from time to time during the training process,
and may achieve a better performance compare to post-training quantization.
In order to better support the above algorithms' needs and maintain the consistency of the interface,
NNI introduces the ``Evaluator`` as the carrier of the training and evaluation process.
.. note::
For users prior to NNI v2.8: NNI previously provided APIs like ``trainer``, ``traced_optimizer``, ``criterion``, ``finetuner``.
These APIs were maybe tedious in terms of user experience. Users need to exchange the corresponding API frequently if they want to switch compression algorithms.
``Evaluator`` is an alternative to the above interface, users only need to create the evaluator once and it can be used in all compressors.
For users of native PyTorch, :class:`TorchEvaluator <nni.contrib.compression.TorchEvaluator>` requires the user to encapsulate the training process as a function and exposes the specified interface,
which will bring some complexity. But don't worry, in most cases, this will not change too much code.
For users of `PyTorchLightning <https://www.pytorchlightning.ai/>`__, :class:`LightningEvaluator <nni.contrib.compression.LightningEvaluator>` can be created with only a few lines of code based on your original Lightning code.
For users of `Transformers Trainer <https://huggingface.co/docs/transformers/main_classes/trainer>`__, :class:`TransformersEvaluator <nni.contrib.compression.TransformersEvaluator>` can be created with only a few lines of code.
Here we give three examples of how to create an ``Evaluator`` for native PyTorch users, PyTorchLightning users and Huggingface Transformers users.
TorchEvaluator
--------------
:class:`TorchEvaluator <nni.contrib.compression.TorchEvaluator>` is for the users who work in a native PyTorch environment (If you are using PyTorchLightning, please refer `LightningEvaluator`_).
:class:`TorchEvaluator <nni.contrib.compression.TorchEvaluator>` has six initialization parameters ``training_func``, ``optimizers``, ``training_step``, ``lr_schedulers``,
``dummy_input``, ``evaluating_func``.
* ``training_func`` is the training loop to train the compressed model.
It is a callable function with six input parameters ``model``, ``optimizers``,
``training_step``, ``lr_schedulers``, ``max_steps``, ``max_epochs``.
Please make sure each input argument of the ``training_func`` is actually used,
especially ``max_steps`` and ``max_epochs`` can correctly control the duration of training.
* ``optimizers`` is a single / a list of traced optimizer(s),
please make sure using ``nni.trace`` wrapping the ``Optimizer`` class before initializing it / them.
* ``training_step`` A callable function, the first argument of inputs should be ``batch``, and the outputs should contain loss.
Three kinds of outputs are supported: single loss, tuple with the first element is loss, a dict contains a key ``loss``.
* ``lr_schedulers`` is a single / a list of traced scheduler(s), same as ``optimizers``,
please make sure using ``nni.trace`` wrapping the ``_LRScheduler`` class before initializing it / them.
* ``dummy_input`` is used to trace the model, same as ``example_inputs``
in `torch.jit.trace <https://pytorch.org/docs/stable/generated/torch.jit.trace.html?highlight=torch%20jit%20trace#torch.jit.trace>`_.
* ``evaluating_func`` is a callable function to evaluate the compressed model performance. Its input is a compressed model and its output is metric.
The format of metric should be a float number or a dict with key ``default``.
Please refer :class:`TorchEvaluator <nni.contrib.compression.TorchEvaluator>` for more details.
Here is an example of how to initialize a :class:`TorchEvaluator <nni.contrib.compression.TorchEvaluator>`.
.. code-block:: python
def training_step(batch, model, *args, **kwargs):
output = model(batch[0])
loss = F.cross_entropy(output, batch[1])
return loss
def training_func(model, optimizer, training_step, lr_scheduler, max_steps, max_epochs):
assert max_steps is not None or max_epochs is not None
total_steps = max_steps if max_steps else max_epochs * len(train_dataloader)
total_epochs = total_steps // len(train_dataloader) + (0 if total_steps % len(train_dataloader) == 0 else 1)
current_step = 0
for _ in range(total_epochs):
for batch in train_dataloader:
loss = training_step(batch, model)
loss.backward()
optimizer.step()
# if reach the total steps, exit from the training loop
current_step = current_step + 1
if current_step >= total_steps:
return
# if you are using a epoch-wise scheduler, call it here
lr_scheduler.step()
optimizer = nni.trace(torch.optim.Adam)(model.parameters(), lr=0.001)
lr_scheduler = nni.trace(torch.optim.lr_scheduler.LambdaLR)(optimizer, lr_lambda=lambda epoch: 1 / epoch)
evaluator = TorchEvaluator(training_func, optimizer, training_step, lr_scheduler)
.. note::
It is also worth to note that not all the arguments of :class:`TorchEvaluator <nni.contrib.compression.TorchEvaluator>` must be provided.
Some compressors only require ``evaluate_func`` as they do not train the model, some compressors only require ``training_func``.
Please refer to each compressor's doc to check the required arguments.
But, it is fine to provide more arguments than the compressor's need.
A complete example can be found :githublink:`here <examples/compression/evaluator/torch_evaluator.py>`.
LightningEvaluator
------------------
:class:`LightningEvaluator <nni.contrib.compression.LightningEvaluator>` is for the users who work with PyTorchLightning.
Only three parts users need to modify compared with the original pytorch-lightning code:
1. Wrap the ``Optimizer`` and ``LRScheduler`` class with ``nni.trace``.
2. Wrap the ``LightningModule`` class with ``nni.trace``.
3. Wrap the ``LightningDataModule`` class with ``nni.trace``.
Please refer :class:`LightningEvaluator <nni.contrib.compression.LightningEvaluator>` for more details.
Here is an example of how to initialize a :class:`LightningEvaluator <nni.contrib.compression.LightningEvaluator>`.
.. code-block:: python
pl_trainer = nni.trace(pl.Trainer)(...)
pl_data = nni.trace(MyDataModule)(...)
evaluator = LightningEvaluator(pl_trainer, pl_data)
.. note::
In ``LightningModule.configure_optimizers``, user should use traced ``torch.optim.Optimizer`` and traced ``torch.optim._LRScheduler``.
It's for NNI can get the initialization parameters of the optimizers and lr_schedulers.
.. code-block:: python
class SimpleModel(pl.LightningModule):
...
def configure_optimizers(self):
optimizers = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.001)
lr_schedulers = nni.trace(ExponentialLR)(optimizer=optimizers, gamma=0.1)
return optimizers, lr_schedulers
A complete example can be found :githublink:`here <examples/compression/evaluator/lightning_evaluator.py>`.
TransformersEvaluator
---------------------
:class:`TransformersEvaluator <nni.contrib.compression.TransformersEvaluator>` is for the users who work with Huggingface Transformers Trainer.
The only need is using ``nni.trace`` to wrap the Trainer class.
.. code-block:: python
import nni
from transformers.trainer import Trainer
trainer = nni.trace(Trainer)(model, training_args, ...)
from nni.contrib.compression.utils import TransformersEvaluator
evaluator = TransformersEvaluator(trainer)
Moreover, if you are utilizing a personalized optimizer or learning rate scheduler, kindly use ``nni.trace`` to wrap their class as well.
.. code-block:: python
optimizer = nni.trace(torch.optim.Adam)(model.parameters(), lr=0.001)
lr_scheduler = nni.trace(torch.optim.lr_scheduler.LambdaLR)(optimizer, lr_lambda=lambda epoch: 1 / epoch)

Просмотреть файл

@ -0,0 +1,89 @@
Fusion Compression
==================
Fusion compression is a novel experimental feature incorporated into NNI 3.0.
As for now, NNI compressors are principally classified into three categories, namely pruner, quantizer, and distiller.
This new feature enables the compression of a single model by multiple compressors simultaneously.
For instance, users can apply varied pruning algorithms to different modules within the model,
along with training-aware quantization for model quantization.
Additionally, to maintain accuracy, relevant distillation techniques can be introduced.
.. Note::
NNI strives to ensure maximum compatibility among different compressors in fusion compression.
Nevertheless, it is impossible to avoid mutual interference in model modification between different compression algorithms in some individual scenarios.
We encourage users to integrate algorithms after acquiring a comprehensive understanding of the fundamental principles of compression methods.
If you encounter any problems or doubts that cannot be resolved while using fusion compression, you are welcome to raise an issue for discussion.
Main API
--------
To explain how fusion compression worked, we should know that each module in the model has a corresponding wrapper in the compressor.
The wrapper stores the necessary data required for compression.
After wrapping the original module with the wrapper, when need to execute ``module.forward``,
compressor will execute ``Wrapper.forward`` with simulated compression logic instead.
All compressors implement the class method ``from_compressor`` that can initialize a new compressor from the old ones.
The compressor initialized using this API will reuse the existing wrappers and record the preceding compression logic.
Multiple compressors can be initialized sequentially in the following format:
``fusion_compressor = Pruner.from_compressor(Quantizer.from_compressor(Distiller.from_compressor))``.
In general, the arguments of ``Compressor.from_compressor`` are mostly identical to the initialization arguments of the compressor.
The only difference is that the first argument of the initialization function is generally the model,
while the first parameter of ``from_compressor`` is typically one compressor object.
Additionally, if the fused compressor has no configured evaluator, one evaluator must be passed in ``from_compressor``.
However, if the evaluator has already in fused compressor, there is no need for duplicate passed in (it will be ignored if duplicated).
Example
-------
Pruning + Distillation
^^^^^^^^^^^^^^^^^^^^^^
The full example can be found `here <https://github.com/microsoft/nni/tree/master/examples/compression/pqd_fuse.py>`__.
The following code is a common pipeline with pruning first and then distillation.
.. code-block:: python
...
pruner = Pruner(model, config_list, evaluator, ...)
pruner.compress(max_steps, max_epochs)
pruner.unwrap_model()
masks = pruner.get_masks()
model = ModelSpeedup(model, dummy_input, masks).speedup_model()
...
distiller = Distiller(model, config_list, evaluator, teacher_model, teacher_predict, ...)
distiller.compress(max_steps, max_epochs)
When attempting to implement a large sparsity, the reduction in accuracy post-pruning may become more pronounced,
necessitating greater exertion during the fine-tuning phase. The fusion of distillation and pruning can significantly mitigate this issue.
The following code combines the pruner and distiller, resulting in a fusion compression.
.. code-block:: python
...
pruner = Pruner(model, pruning_config_list, evaluator, ...)
distiller = Distiller.from_compressor(pruner, distillation_config_list, teacher_model, teacher_predict, ...)
distiller.compress(max_steps, max_epochs)
masks = pruner.get_masks()
model = ModelSpeedup(model, dummy_input, masks).speedup_model()
Also you could fuse any compressors you like by ``from_compressor``.
.. code-block:: python
...
pruner_a = PrunerA(model, pruning_config_list_a, evaluator, ...)
pruner_b = PrunerB.from_compressor(pruner_a, pruning_config_list_b, ...)
pruner_c = PrunerC.from_compressor(pruner_b, pruning_config_list_c, ...)
distiller_a = DistillerA.from_compressor(pruner_c, distillation_config_list_a, teacher_model, teacher_predict, ...)
distiller_b = DistillerB.from_compressor(distiller_a, distillation_config_list_b, teacher_model, teacher_predict, ...)
distiller_b.compress(max_steps, max_epochs)
masks = pruner_c.get_masks()
model = ModelSpeedup(model, dummy_input, masks).speedup_model()

Просмотреть файл

@ -0,0 +1,16 @@
Overview of NNI Model Compression (Preview)
===========================================
The NNI model compression has undergone a completely new framework design in version 3.0,
seamlessly integrating pruning, quantization, and distillation methods.
Additionally, it provides a more granular model compression configuration,
including compression granularity configuration, input/output compression configuration, and custom module compression.
Furthermore, the model speedup part of pruning uses the graph analysis scheme based on torch.fx,
which supports more op types of sparsity propagation,
as well as custom special op sparsity propagation methods and replacement logic,
further enhancing the generality and robustness of model acceleration.
The current documentation for the new version of compression may not be complete, but there is no need to worry.
The optimizations in the new version are mostly focused on the underlying framework and implementation,
and there are not significant changes to the user interface.
Instead, there are more extensions and compatibility with the configuration of the previous version.

Просмотреть файл

@ -0,0 +1,33 @@
Pruning Algorithm Supported in NNI
==================================
Note that not all pruners from the previous version have been migrated to the new framework yet.
NNI has plans to migrate all pruners that were implemented in NNI 3.2.
If you believe that a certain old pruner has not been implemented or that another pruning algorithm would be valuable,
please feel free to contact us. We will prioritize and expedite support accordingly.
.. list-table::
:header-rows: 1
:widths: auto
* - Name
- Brief Introduction of Algorithm
* - :ref:`new-level-pruner`
- Pruning the specified ratio on each weight element based on absolute value of weight element
* - :ref:`new-l1-norm-pruner`
- Pruning output channels with the smallest L1 norm of weights (Pruning Filters for Efficient Convnets) `Reference Paper <https://arxiv.org/abs/1608.08710>`__
* - :ref:`new-l2-norm-pruner`
- Pruning output channels with the smallest L2 norm of weights
* - :ref:`new-fpgm-pruner`
- Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `Reference Paper <https://arxiv.org/abs/1811.00250>`__
* - :ref:`new-slim-pruner`
- Pruning output channels by pruning scaling factors in BN layers(Learning Efficient Convolutional Networks through Network Slimming) `Reference Paper <https://arxiv.org/abs/1708.06519>`__
* - :ref:`new-taylor-pruner`
- Pruning filters based on the first order taylor expansion on weights(Importance Estimation for Neural Network Pruning) `Reference Paper <http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf>`__
* - :ref:`new-linear-pruner`
- Sparsity ratio increases linearly during each pruning rounds, in each round, using a basic pruner to prune the model.
* - :ref:`new-agp-pruner`
- Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) `Reference Paper <https://arxiv.org/abs/1710.01878>`__
* - :ref:`new-movement-pruner`
- Movement Pruning: Adaptive Sparsity by Fine-Tuning `Reference Paper <https://arxiv.org/abs/2005.07683>`__

Просмотреть файл

@ -0,0 +1,211 @@
Compression Setting
===================
To enhance the compression compatibility of any type of module, NNI 3.0 introduces the notion of ``setting``,
which is a pre-established template utilized to depict compression information.
The primary objective of this pre-established template is to support shorthand writing in the ``config_list``.
The list of NNI default supported module types can be accessed via this `link <https://github.com/microsoft/nni/tree/master/nni/contrib/compression/base/setting.py>`__.
Please review the ``registry`` in ``PruningSetting``, ``QuantizationSetting`` and ``DistillationSetting`` to see the supported module type and its default setting.
It should be noted that ``DistillationSetting`` will automatically register a default output setting for all module types,
which implies that distilling any module output is available by design.
Register Setting
----------------
If you discover that the module type you intend to compress is unavailable, you may register it into the appropriate compression setting.
However, if the registered module type is already present in the registry, the new setting will directly replace the old one.
In the subsequent sections, you will learn how to register settings in various compression scenarios.
You can find all the supported compression setting keys and their explanations :doc:`here <config_list>`.
If you only want to make temporary settings for a specific layer without affecting other default templates of the same module type,
please refer to section.
Assume a customized module type and we want to compress it:
.. code-block:: python
class CustomizedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.p1 = torch.nn.Parameter(200, 100)
self.p2 = torch.nn.Parameter(200, 50)
def forward(self, x, y):
return torch.matmul(self.p1, x) + torch.matmul(self.p2, y)
Pruning Setting
^^^^^^^^^^^^^^^
Here is a default setting for pruning module weight and bias.
.. code-block:: python
default_setting = {
'weight': {
'sparse_ratio': None,
'max_sparse_ratio': None,
'min_sparse_ratio': None,
'sparse_threshold': None,
'global_group_id': None,
'dependency_group_id': None,
'granularity': 'default',
'internal_metric_block': None,
'apply_method': 'mul',
},
'bias': {
'align': {
'module_name': None,
'target_name': 'weight',
'dims': [0],
},
'apply_method': 'mul',
}
}
We can create a setting for ``CustomizedModule`` by following the above default setting.
.. code-block:: python
customized_setting = {
'p1': {
'sparse_ratio': None,
'max_sparse_ratio': None,
'min_sparse_ratio': None,
'sparse_threshold': None,
'global_group_id': None,
'dependency_group_id': None,
'granularity': [1, -1],
'internal_metric_block': None,
'apply_method': 'mul',
},
'p2': {
'sparse_ratio': None,
'max_sparse_ratio': None,
'min_sparse_ratio': None,
'sparse_threshold': None,
'global_group_id': None,
'dependency_group_id': None,
'granularity': [1, -1],
'internal_metric_block': None,
'apply_method': 'mul',
},
'_output_': {
'align': {
'module_name': None,
'target_name': 'p1',
'dims': [0],
},
'apply_method': 'mul',
'granularity': [-1, 1]
}
}
PruningSetting.register('CustomizedModule', customized_setting)
The customized setting means that ``p1`` and ``p2`` will be applied channel-wise masks on the first dim of parameter,
``_output_`` will be applied channel-wise masks on the second dim of output.
Instead of generating masks by pruning algorithms, the output masks is generated by align with ``p1`` masks on the first dim.
Quantization Setting
^^^^^^^^^^^^^^^^^^^^
Here is a default setting for quantizing module inputs, outputs and weight.
.. code-block:: python
default_setting = {
'_input_': {
'quant_dtype': None,
'quant_scheme': None,
'granularity': 'default',
'apply_method': 'clamp_round',
},
'weight': {
'quant_dtype': None,
'quant_scheme': None,
'granularity': 'default',
'apply_method': 'clamp_round',
},
'_output_': {
'quant_dtype': None,
'quant_scheme': None,
'granularity': 'default',
'apply_method': 'clamp_round',
}
}
Just modified the keys and registered it to ``QuantizationSetting`` are all you need for quantizing this module.
.. code-block:: python
customized_setting = {
'_input_': {
'quant_dtype': None,
'quant_scheme': None,
'granularity': [-1, 1],
'apply_method': 'clamp_round',
},
'p1': {
'quant_dtype': None,
'quant_scheme': None,
'granularity': [1, -1],
'apply_method': 'clamp_round',
},
'p2': {
'quant_dtype': None,
'quant_scheme': None,
'granularity': [1, -1],
'apply_method': 'clamp_round',
},
'_output_': {
'quant_dtype': None,
'quant_scheme': None,
'granularity': [-1, 1],
'apply_method': 'clamp_round',
}
}
QuantizationSetting.register('CustomizedModule', customized_setting)
Temporarily Setting Update
--------------------------
Sometimes we just want to temporarily modify the setting template and don't want to make a global change.
Then we could directly write full setting in ``config_list`` to achieve this.
For example, if the compressed model has conv-bn-relu pattern, and for a better pruning simulation and performance,
we want to mask the batchnorm on the channels convolution masked.
Then we could temporarily make batchnorm masks align with convolution layer weight masks.
.. code-block:: python
config_list = [{
'op_names': ['conv1', 'conv2'],
'granularity': 'out_channel',
}, {
'op_names': ['bn1'],
'target_settings': {
'weight': {
'align': {
'module_name': 'conv1',
'target_name': 'weight',
'dims': [0],
},
'granularity': 'out_channel',
}
}
}, {
'op_names': ['bn2'],
'target_settings': {
'weight': {
'align': {
'module_name': 'conv2',
'target_name': 'weight',
'dims': [0],
},
'granularity': 'out_channel',
}
}
}]

Просмотреть файл

@ -5,4 +5,10 @@ Compression (Preview)
:hidden:
:maxdepth: 2
Overview <overview>
Enhancement <changes>
Config Specification <config_list>
Pruning <toctree_pruning>
Evaluator <evaluator>
Customize Setting <setting>
Fusion Compression <fusion_compress>

Просмотреть файл

@ -5,4 +5,5 @@ Pruning
:hidden:
:maxdepth: 2
Pruner <pruner>
Best Practices <best_practices>

Просмотреть файл

@ -116,6 +116,9 @@ linkcheck_ignore = [
r'https://docs\.nvidia\.com/deeplearning/',
r'https://cla\.opensource\.microsoft\.com',
r'https://www\.docker\.com/',
# remove after #5491 merged
r'https://github\.com/microsoft/nni/tree/master/examples/compression/pqd_fuse\.py',
]
# Ignore all links located in release.rst

Просмотреть файл

@ -0,0 +1,12 @@
Distiller (Preview)
===================
DynamicLayerwiseDistiller
-------------------------
.. autoclass:: nni.contrib.compression.distillation.DynamicLayerwiseDistiller
Adaptive1dLayerwiseDistiller
----------------------------
.. autoclass:: nni.contrib.compression.distillation.Adaptive1dLayerwiseDistiller

Просмотреть файл

@ -0,0 +1,23 @@
Evaluator
=========
.. _new-torch-evaluator:
TorchEvaluator
--------------
.. autoclass:: nni.contrib.compression.TorchEvaluator
.. _new-lightning-evaluator:
LightningEvaluator
------------------
.. autoclass:: nni.contrib.compression.LightningEvaluator
.. _new-transformers-evaluator:
TransformersEvaluator
---------------------
.. autoclass:: nni.contrib.compression.TransformersEvaluator

Просмотреть файл

@ -0,0 +1,74 @@
Pruner (Preview)
================
Basic Pruner
------------
.. _new-level-pruner:
Level Pruner
^^^^^^^^^^^^
.. autoclass:: nni.contrib.compression.pruning.LevelPruner
.. _new-l1-norm-pruner:
L1 Norm Pruner
^^^^^^^^^^^^^^
.. autoclass:: nni.contrib.compression.pruning.L1NormPruner
.. _new-l2-norm-pruner:
L2 Norm Pruner
^^^^^^^^^^^^^^
.. autoclass:: nni.contrib.compression.pruning.L2NormPruner
.. _new-fpgm-pruner:
FPGM Pruner
^^^^^^^^^^^
.. autoclass:: nni.contrib.compression.pruning.FPGMPruner
.. _new-slim-pruner:
Slim Pruner
^^^^^^^^^^^
.. autoclass:: nni.contrib.compression.pruning.SlimPruner
.. _new-taylor-pruner:
Taylor FO Weight Pruner
^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: nni.contrib.compression.pruning.TaylorPruner
Scheduled Pruners
-----------------
.. _new-linear-pruner:
Linear Pruner
^^^^^^^^^^^^^
.. autoclass:: nni.contrib.compression.pruning.LinearPruner
.. _new-agp-pruner:
AGP Pruner
^^^^^^^^^^
.. autoclass:: nni.contrib.compression.pruning.AGPPruner
Other Pruner
------------
.. _new-movement-pruner:
Movement Pruner
^^^^^^^^^^^^^^^
.. autoclass:: nni.contrib.compression.pruning.MovementPruner

Просмотреть файл

@ -0,0 +1,5 @@
Pruning Speedup
===============
.. autoclass:: nni.compression.pytorch.speedup.v2.ModelSpeedup
:members:

Просмотреть файл

@ -0,0 +1,11 @@
Compression API Reference (Preview)
===================================
.. toctree::
:maxdepth: 1
Pruner <pruner>
Pruning Speedup <pruning_speedup>
Distiller <distiller>
Evaluator <evaluator>
Compression Utilities <utils>

Просмотреть файл

@ -0,0 +1,10 @@
Compression Utilities
=====================
.. _auto_set_denpendency_group_ids:
auto_set_denpendency_group_ids
------------------------------
.. autoclass:: nni.contrib.compression.utils.auto_set_denpendency_group_ids
:members:

Просмотреть файл

@ -7,6 +7,7 @@ Python API Reference
Hyperparameter Optimization <hpo>
Neural Architecture Search <nas>
Model Compression <compression/toctree>
Model Compression (Preview) <compression_preview/toctree>
Experiment <experiment>
Mutable <mutable>
Others <others>

Просмотреть файл

@ -0,0 +1,7 @@
# Examples
This folder contains the examples of new compression version (NNI 3.0).
## ./evaluator
If you want to view how to initialize a evaluator you need, please refer this example folder.

Просмотреть файл

@ -0,0 +1,87 @@
"""
Create Lightning Evaluator
==========================
To create a lightning evaluator in NNI,
you only need to make minor modifications to your existing code.
Modificatoin in LightningModule
-------------------------------
In ``configure_optimizers``, please using ``nni.trace`` to trace the optimizer and lr scheduler class.
Please set a ``default`` metric in ``validation_step`` or ``test_step`` if it needs,
NNI may use this metric to compare which model is better.
"""
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torch
from torchmetrics.functional import accuracy
import nni
from examples.compression.models import build_resnet18
class MyModule(pl.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = build_resnet18()
self.criterion = torch.nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
self.log("train_loss", loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y, 'multiclass', num_classes=10)
# If NNI need to evaluate the model, "default" metric will be used.
if stage:
self.log(f"default", loss, prog_bar=False)
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
def configure_optimizers(self):
optimizer = nni.trace(torch.optim.Adam)(self.parameters(), lr=0.001)
scheduler_dict = {
"scheduler": nni.trace(torch.optim.lr_scheduler.LambdaLR)(optimizer, lr_lambda=lambda epoch: 1 / epoch),
"interval": "epoch",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
# %%
# Init ``TorchEvaluator``
# -----------------------
#
# Remember using ``nni.trace`` to trace ``Trainer`` and your customized ``LightningDataModule``.
# directly using your original LightningDataModule
class MyDataModule(pl.LightningDataModule):
pass
from nni.contrib.compression import LightningEvaluator
pl_trainer = nni.trace(pl.Trainer)(
accelerator='auto',
devices=1,
max_epochs=3,
logger=TensorBoardLogger('./lightning_logs', name="vgg"),
)
pl_data = nni.trace(MyDataModule)(data_dir='./data')
evaluator = LightningEvaluator(pl_trainer, pl_data)

Просмотреть файл

@ -0,0 +1,95 @@
"""
Create Pytorch Native Evaluator
===============================
If you are using a native pytorch training loop to train your model, this example could help you getting start quickly.
In this example, you will learn how to create a pytorch native evaluator step by step.
Prepare ``training_func``
-------------------------
``training_func`` has six required parameters.
Maybe you don't need some parameters such as ``lr_scheduler``, but you still need to reflect the complete six parameters on the interface.
For some reason, ``dataloader`` is not exposed on ``training_func`` as part of the interface,
so it is necessary to directly create or reference an dataloader in ``training_func`` inner.
Here is an simple ``training_func``.
"""
from typing import Any, Callable
import torch
from examples.compression.models import prepare_dataloader
def training_func(model: torch.nn.Module, optimizer: torch.optim.Optimizer, training_step: Callable[[Any, torch.nn.Module], torch.Tensor],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler, max_steps: int, max_epochs: int):
# create a train dataloader (and test dataloader if needs)
train_dataloader, test_dataloader = prepare_dataloader()
# deal with training duration, NNI prefers to prioritize the largest number of steps
# at least `max_steps` or `max_epochs` will be given
assert max_steps is not None or max_epochs is not None
total_steps = max_steps if max_steps else max_epochs * len(train_dataloader)
total_epochs = total_steps // len(train_dataloader) + (0 if total_steps % len(train_dataloader) == 0 else 1)
# here is a common training loop
current_step = 0
for _ in range(total_epochs):
for batch in train_dataloader:
loss = training_step(batch, model)
loss.backward()
optimizer.step()
# if reach the total steps, exit from the training loop
current_step = current_step + 1
if current_step >= total_steps:
return
# if you are using a epoch-wise scheduler, call it here
lr_scheduler.step()
# %%
# Now we have a basic training function that can generate loss by ``model`` and ``training_step``,
# optimize the model by ``optimizer`` and ``lr_scheduler``, terminate the training loop by ``max_steps`` and ``max_epochs``.
#
# Prepare ``optimizers`` and ``lr_schedulers``
# --------------------------------------------
#
# ``optimizers`` is a required parameter and ``lr_schedulers`` is an optional parameter.
# ``optimizers`` can be a optimizer instance or a list of optimziers and ``lr_schedulers`` can be a lr scheduler instance or a list of lr schedulers or ``None``.
#
# Note that each ``optimizer`` and ``lr_scheduler`` should be a subclass of ``torch.optim.Optimizer`` or ``torch.optim.lr_scheduler._LRScheduler``
# (``torch.optim.lr_scheduler.LRScheduler`` in ``torch >= 2.0``), and the class should be wrapped by ``nni.trace``.
# ``nni.trace`` is important for NNI refreshing the optimizer, because compression will register new module parameters that need to optimize.
import nni
from examples.compression.models import build_resnet18
# create a resnet18 model as an exmaple
model = build_resnet18()
optimizer = nni.trace(torch.optim.Adam)(model.parameters(), lr=0.001)
lr_scheduler = nni.trace(torch.optim.lr_scheduler.LambdaLR)(optimizer, lr_lambda=lambda epoch: 1 / epoch)
# %%
# Now we have a traced optimizer and a traced lr scheduler.
#
# Prepare ``training_step``
# -------------------------
#
# Training step should have two required parameters ``batch`` and ``model``,
# return value is a loss tensor or a list with the first element loss or a dict with key ``loss``.
import torch.nn.functional as F
def training_step(batch: Any, model: torch.nn.Module, *args, **kwargs):
output = model(batch[0])
loss = F.cross_entropy(output, batch[1])
return loss
# Init ``TorchEvaluator``
# -----------------------
from nni.contrib.compression import TorchEvaluator
evaluator = TorchEvaluator(training_func, optimizer, training_step, lr_scheduler)

Просмотреть файл

@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
This script is an exmaple for how to fuse pruning and distillation.
"""
import pickle
import torch
from examples.compression.models import (
build_resnet18,
prepare_dataloader,
prepare_optimizer,
train,
training_step,
evaluate,
device
)
from nni.contrib.compression import TorchEvaluator
from nni.contrib.compression.distillation import DynamicLayerwiseDistiller
from nni.contrib.compression.pruning import TaylorPruner, AGPPruner
from nni.contrib.compression.utils import auto_set_denpendency_group_ids
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
if __name__ == '__main__':
# finetuning resnet18 on Cifar10
model = build_resnet18()
optimizer = prepare_optimizer(model)
train(model, optimizer, training_step, lr_scheduler=None, max_steps=None, max_epochs=30)
_, test_loader = prepare_dataloader()
print('Original model paramater number: ', sum([param.numel() for param in model.parameters()]))
print('Original model after 10 epochs finetuning acc: ', evaluate(model, test_loader), '%')
# build a teacher model
teacher_model = build_resnet18()
teacher_model.load_state_dict(pickle.loads(pickle.dumps(model.state_dict())))
# create pruner
bn_list = [module_name for module_name, module in model.named_modules() if isinstance(module, torch.nn.BatchNorm2d)]
config_list = [{
'op_types': ['Conv2d'],
'sparse_ratio': 0.5
}, *[{
'op_names': [name],
'target_names': ['_output_'],
'target_settings': {
'_output_': {
'align': {
'module_name': name.replace('bn', 'conv') if 'bn' in name else name.replace('downsample.1', 'downsample.0'),
'target_name': 'weight',
'dims': [0],
},
'granularity': 'per_channel'
}
}
} for name in bn_list]]
dummy_input = torch.rand(8, 3, 224, 224).to(device)
config_list = auto_set_denpendency_group_ids(model, config_list, dummy_input)
optimizer = prepare_optimizer(model)
evaluator = TorchEvaluator(train, optimizer, training_step)
sub_pruner = TaylorPruner(model, config_list, evaluator, training_steps=100)
scheduled_pruner = AGPPruner(sub_pruner, interval_steps=100, total_times=30)
# create distiller
def teacher_predict(batch, teacher_model):
return teacher_model(batch[0])
config_list = [{
'op_types': ['Conv2d'],
'op_names_re': ['features.*'],
'lambda': 0.1,
'apply_method': 'mse',
}]
distiller = DynamicLayerwiseDistiller.from_compressor(scheduled_pruner, config_list, teacher_model, teacher_predict, 0.1)
# max_steps contains (30 iterations 100 steps agp taylor pruning, and 3000 steps finetuning)
distiller.compress(max_steps=100 * 60, max_epochs=None)
distiller.unwrap_model()
distiller.unwrap_teacher_model()
# speed up model
masks = scheduled_pruner.get_masks()
model = ModelSpeedup(model, dummy_input, masks).speedup_model()
print('Pruned model paramater number: ', sum([param.numel() for param in model.parameters()]))
print('Pruned model without finetuning acc: ', evaluate(model, test_loader), '%')

Просмотреть файл

@ -0,0 +1,157 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
This script is an exmaple for how to fuse pruning and distillation.
"""
import pickle
import torch
from examples.compression.models import (
build_resnet18,
prepare_dataloader,
prepare_optimizer,
train,
training_step,
evaluate,
device
)
from nni.contrib.compression import TorchEvaluator
from nni.contrib.compression.base.compressor import Quantizer
from nni.contrib.compression.distillation import DynamicLayerwiseDistiller
from nni.contrib.compression.pruning import TaylorPruner, AGPPruner
from nni.contrib.compression.quantization import QATQuantizer
from nni.contrib.compression.utils import auto_set_denpendency_group_ids
from nni.compression.pytorch.speedup.v2 import ModelSpeedup
if __name__ == '__main__':
# finetuning resnet18 on Cifar10
model = build_resnet18()
optimizer = prepare_optimizer(model)
train(model, optimizer, training_step, lr_scheduler=None, max_steps=None, max_epochs=30)
_, test_loader = prepare_dataloader()
print('Original model paramater number: ', sum([param.numel() for param in model.parameters()]))
print('Original model after 10 epochs finetuning acc: ', evaluate(model, test_loader), '%')
# build a teacher model
teacher_model = build_resnet18()
teacher_model.load_state_dict(pickle.loads(pickle.dumps(model.state_dict())))
# create pruner
bn_list = [module_name for module_name, module in model.named_modules() if isinstance(module, torch.nn.BatchNorm2d)]
p_config_list = [{
'op_types': ['Conv2d'],
'sparse_ratio': 0.5
}, *[{
'op_names': [name],
'target_names': ['_output_'],
'target_settings': {
'_output_': {
'align': {
'module_name': name.replace('bn', 'conv') if 'bn' in name else name.replace('downsample.1', 'downsample.0'),
'target_name': 'weight',
'dims': [0],
},
'granularity': 'per_channel'
}
}
} for name in bn_list]]
dummy_input = torch.rand(8, 3, 224, 224).to(device)
p_config_list = auto_set_denpendency_group_ids(model, p_config_list, dummy_input)
optimizer = prepare_optimizer(model)
evaluator = TorchEvaluator(train, optimizer, training_step)
sub_pruner = TaylorPruner(model, p_config_list, evaluator, training_steps=100)
scheduled_pruner = AGPPruner(sub_pruner, interval_steps=100, total_times=30)
# create quantizer
q_config_list = [{
'op_types': ['Conv2d'],
'quant_dtype': 'int8',
'target_names': ['_input_'],
'granularity': 'per_channel'
}, {
'op_types': ['Conv2d'],
'quant_dtype': 'int8',
'target_names': ['weight'],
'granularity': 'out_channel'
}, {
'op_types': ['BatchNorm2d'],
'quant_dtype': 'int8',
'target_names': ['_output_'],
'granularity': 'per_channel'
}]
quantizer = QATQuantizer.from_compressor(scheduled_pruner, q_config_list, quant_start_step=100)
# create distiller
def teacher_predict(batch, teacher_model):
return teacher_model(batch[0])
d_config_list = [{
'op_types': ['Conv2d'],
'op_names_re': ['features.*'],
'lambda': 0.1,
'apply_method': 'mse',
}]
distiller = DynamicLayerwiseDistiller.from_compressor(quantizer, d_config_list, teacher_model, teacher_predict, 0.1)
# max_steps contains (30 iterations 100 steps agp taylor pruning, and 3000 steps finetuning)
distiller.compress(max_steps=100 * 60, max_epochs=None)
distiller.unwrap_model()
distiller.unwrap_teacher_model()
# speed up model
masks = scheduled_pruner.get_masks()
speedup = ModelSpeedup(model, dummy_input, masks)
model = speedup.speedup_model()
print('Compressed model paramater number: ', sum([param.numel() for param in model.parameters()]))
print('Compressed model without finetuning & qsim acc: ', evaluate(model, test_loader), '%')
# simulate quantization
calibration_config = quantizer.get_calibration_config()
def trans(calibration_config, speedup: ModelSpeedup):
for node, node_info in speedup.node_infos.items():
if node.op == 'call_module' and node.target in calibration_config:
# assume the module only has one input and one output
input_mask = speedup.node_infos[node.args[0]].output_masks
param_mask = node_info.param_masks
output_mask = node_info.output_masks
module_cali_config = calibration_config[node.target]
if '_input_0' in module_cali_config:
reduce_dims = list(range(len(input_mask.shape)))
reduce_dims.remove(1)
idxs = torch.nonzero(input_mask.sum(reduce_dims), as_tuple=True)[0].cpu()
module_cali_config['_input_0']['scale'] = module_cali_config['_input_0']['scale'].index_select(1, idxs)
module_cali_config['_input_0']['zero_point'] = module_cali_config['_input_0']['zero_point'].index_select(1, idxs)
if '_output_0' in module_cali_config:
reduce_dims = list(range(len(output_mask.shape)))
reduce_dims.remove(1)
idxs = torch.nonzero(output_mask.sum(reduce_dims), as_tuple=True)[0].cpu()
module_cali_config['_output_0']['scale'] = module_cali_config['_output_0']['scale'].index_select(1, idxs)
module_cali_config['_output_0']['zero_point'] = module_cali_config['_output_0']['zero_point'].index_select(1, idxs)
if 'weight' in module_cali_config:
reduce_dims = list(range(len(param_mask['weight'].shape)))
reduce_dims.remove(0)
idxs = torch.nonzero(param_mask['weight'].sum(reduce_dims), as_tuple=True)[0].cpu()
module_cali_config['weight']['scale'] = module_cali_config['weight']['scale'].index_select(0, idxs)
module_cali_config['weight']['zero_point'] = module_cali_config['weight']['zero_point'].index_select(0, idxs)
if 'bias' in module_cali_config:
idxs = torch.nonzero(param_mask['bias'], as_tuple=True)[0].cpu()
module_cali_config['bias']['scale'] = module_cali_config['bias']['scale'].index_select(0, idxs)
module_cali_config['bias']['zero_point'] = module_cali_config['bias']['zero_point'].index_select(0, idxs)
return calibration_config
calibration_config = trans(calibration_config, speedup)
sim_quantizer = Quantizer(model, q_config_list)
sim_quantizer.update_calibration_config(calibration_config)
print('Compressed model without finetuning acc: ', evaluate(model, test_loader), '%')

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.optim import Adam
@ -28,23 +30,23 @@ def build_resnet18():
return model.to(device)
def prepare_dataloader():
def prepare_dataloader(batch_size: int = 128):
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
datasets.CIFAR10(Path(__file__).parent / 'data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True, num_workers=8)
batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
datasets.CIFAR10(Path(__file__).parent / 'data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False, num_workers=8)
batch_size=batch_size, shuffle=False, num_workers=8)
return train_loader, test_loader
@ -67,7 +69,7 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, training_ste
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
loss = training_step(model, (data, target))
loss = training_step((data, target), model)
loss.backward()
optimizer.step()
count_steps += 1
@ -91,7 +93,7 @@ def evaluate(model: torch.nn.Module, test_loader):
return 100 * correct / len(test_loader.dataset)
def training_step(model: torch.nn.Module, batch):
def training_step(batch, model: torch.nn.Module):
output = model(batch[0])
loss = F.cross_entropy(output, batch[1])
return loss

Просмотреть файл

@ -3,7 +3,7 @@
import torch
from examples.compression.pruning.models import (
from examples.compression.models import (
build_resnet18,
prepare_dataloader,
prepare_optimizer,

Просмотреть файл

@ -3,7 +3,7 @@
import torch
from examples.compression.pruning.models import (
from examples.compression.models import (
build_resnet18,
prepare_dataloader,
prepare_optimizer,

Просмотреть файл

@ -3,7 +3,7 @@
import torch
from examples.compression.pruning.models import (
from examples.compression.models import (
build_resnet18,
prepare_dataloader,
prepare_optimizer,

Просмотреть файл

@ -3,7 +3,7 @@
import torch
from examples.compression.pruning.models import (
from examples.compression.models import (
build_resnet18,
prepare_dataloader,
prepare_optimizer,

Просмотреть файл

@ -0,0 +1,7 @@
# Examples
This folder contains a large number of examples of old versions of compression.
If you find that some examples are invalid, please contact us.
This folder will be deleted around NNI 3.2.
The new version examples is under `examples/compression`.

Просмотреть файл

@ -209,9 +209,8 @@ class ModelSpeedup(torch.fx.Interpreter):
def propagate_originally(self):
"""
Propagate normally to get informations of intermediate variables such as shape, dtype of tensors.
Default action:
execute and store output to node_info.output_origin(intermediate variables when assigned),
and node_info.output_inplace(intermediate variables after in-place ops)
Default action: execute and store output to node_info.output_origin(intermediate variables when assigned),
and node_info.output_inplace(intermediate variables after in-place ops).
"""
self.logger.info("Propagate original variables")
for node in self.graph_module.graph.nodes:

Просмотреть файл

@ -350,6 +350,31 @@ class Quantizer(Compressor):
return calibration_config
def update_calibration_config(self, calibration_config: Dict[str, Dict[str, Dict[str, torch.Tensor | Any]]]):
for module_name, target_configs in calibration_config.items():
for target_name, config in target_configs.items():
assert module_name in self._module_wrappers and target_name in self._module_wrappers[module_name].quantization_target_spaces
wrapper = self._module_wrappers[module_name]
target_space = wrapper.quantization_target_spaces[target_name]
# NOTE: try to auto get the device of the current module
try:
device = next(wrapper.parameters()).device
except StopIteration:
try:
device = next(wrapper.buffers()).device
except StopIteration:
if target_space.scale is not None:
device = target_space.scale.device
else:
# NOTE: this will have risk in model parallel
device = next(self.bound_model.parameters()).device
config = tree_map(lambda t: t.to(device) if isinstance(t, torch.Tensor) else t, config)
target_space.scale = config['scale']
target_space.zero_point = config['zero_point']
assert target_space.quant_bits == config['quant_bits']
assert target_space.quant_dtype == config['quant_dtype']
assert target_space.quant_scheme == config['quant_scheme']
def patch_optimizer_param_group(self):
module_name_param_dict = {}
for module_name, _ in self._target_spaces.items():

Просмотреть файл

@ -166,20 +166,23 @@ class DynamicLayerwiseDistiller(TeacherModelBasedDistiller):
config_list
Config list to configure how to distill.
Common keys please refer :doc:`Compression Config Specification </compression/compression_config_list>`.
Specific keys:
- 'lambda': By default, 1.
This is a scaling factor to control the loss scale, the final loss used during training is
``(origin_loss_lambda * origin_loss + sum(lambda_i * distill_loss_i))``.
Here ``i`` represents the ``i-th`` distillation target.
The higher the value of lambda, the greater the contribution of the corresponding distillation target to the loss.
- 'link': By default, 'auto'.
'auto' or a teacher module name or a list of teacher module names,
the module name(s) of teacher module(s) will align with student module(s) configured in this config.
If 'auto' is set, will use student module name as the link,
usually requires the teacher model and the student model to be isomorphic.
- 'apply_method': By default, 'mse'.
'mse' and 'kl' are supported right now. 'mse' means the MSE loss, usually used to distill hidden states.
'kl' means the KL loss, usually used to distill logits.
* 'lambda': By default, 1.
This is a scaling factor to control the loss scale, the final loss used during training is
``(origin_loss_lambda * origin_loss + sum(lambda_i * distill_loss_i))``.
Here ``i`` represents the ``i-th`` distillation target.
The higher the value of lambda, the greater the contribution of the corresponding distillation target to the loss.
* 'link': By default, 'auto'.
'auto' or a teacher module name or a list of teacher module names,
the module name(s) of teacher module(s) will align with student module(s) configured in this config.
If 'auto' is set, will use student module name as the link,
usually requires the teacher model and the student model to be isomorphic.
* 'apply_method': By default, 'mse'.
'mse' and 'kl' are supported right now. 'mse' means the MSE loss, usually used to distill hidden states.
'kl' means the KL loss, usually used to distill logits.
evaluator
{evaluator_docstring}
teacher_model
@ -239,20 +242,23 @@ class Adaptive1dLayerwiseDistiller(TeacherModelBasedDistiller):
config_list
Config list to configure how to distill.
Common keys please refer :doc:`Compression Config Specification </compression/compression_config_list>`.
Specific keys:
- 'lambda': By default, 1.
This is a scaling factor to control the loss scale, the final loss used during training is
``(origin_loss_lambda * origin_loss + sum(lambda_i * distill_loss_i))``.
Here ``i`` represents the ``i-th`` distillation target.
The higher the value of lambda, the greater the contribution of the corresponding distillation target to the loss.
- 'link': By default, 'auto'.
'auto' or a teacher module name or a list of teacher module names,
the module name(s) of teacher module(s) will align with student module(s) configured in this config.
If 'auto' is set, will use student module name as the link,
usually requires the teacher model and the student model to be isomorphic.
- 'apply_method': By default, 'mse'.
'mse' and 'kl' are supported right now. 'mse' means the MSE loss, usually used to distill hidden states.
'kl' means the KL loss, usually used to distill logits.
* 'lambda': By default, 1.
This is a scaling factor to control the loss scale, the final loss used during training is
``(origin_loss_lambda * origin_loss + sum(lambda_i * distill_loss_i))``.
Here ``i`` represents the ``i-th`` distillation target.
The higher the value of lambda, the greater the contribution of the corresponding distillation target to the loss.
* 'link': By default, 'auto'.
'auto' or a teacher module name or a list of teacher module names,
the module name(s) of teacher module(s) will align with student module(s) configured in this config.
If 'auto' is set, will use student module name as the link,
usually requires the teacher model and the student model to be isomorphic.
* 'apply_method': By default, 'mse'.
'mse' and 'kl' are supported right now. 'mse' means the MSE loss, usually used to distill hidden states.
'kl' means the KL loss, usually used to distill logits.
evaluator
{evaluator_docstring}
teacher_model

Просмотреть файл

@ -114,10 +114,6 @@ class LevelPruner(_NormPruner):
config_list
A list of dict, each dict configure which module need to be pruned, and how to prune.
Please refer :doc:`Compression Config Specification </compression/compression_config_list>` for more information.
Examples
--------
TODO
"""
p = 1
@ -144,7 +140,8 @@ class L1NormPruner(_NormPruner):
Examples
--------
TODO
Please refer to
:githublink:`examples/compression/pruning/norm_pruning.py <examples/compression/pruning/norm_pruning.py>`.
"""
p = 1
@ -165,7 +162,8 @@ class L2NormPruner(_NormPruner):
Examples
--------
TODO
Please refer to
:githublink:`examples/compression/pruning/norm_pruning.py <examples/compression/pruning/norm_pruning.py>`.
"""
p = 2
@ -188,7 +186,8 @@ class FPGMPruner(_NormPruner):
Examples
--------
TODO
Please refer to
:githublink:`examples/compression/pruning/norm_pruning.py <examples/compression/pruning/norm_pruning.py>`.
"""
p = 2

Просмотреть файл

@ -62,7 +62,8 @@ class MovementPruner(ScheduledPruner):
Examples
--------
TODO
Please refer to
:githublink:`examples/tutorials/new_pruning_bert_glue.py <examples/tutorials/new_pruning_bert_glue.py>`.
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload

Просмотреть файл

@ -149,7 +149,8 @@ class LinearPruner(_ComboPruner):
Examples
--------
TODO
Please refer to
:githublink:`examples/compression/pruning/scheduled_pruning.py <examples/compression/pruning/scheduled_pruning.py>`.
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def update_sparse_goals(self, current_times: int):
@ -180,7 +181,8 @@ class AGPPruner(_ComboPruner):
Examples
--------
TODO
Please refer to
:githublink:`examples/compression/pruning/scheduled_pruning.py <examples/compression/pruning/scheduled_pruning.py>`.
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def update_sparse_goals(self, current_times: int):

Просмотреть файл

@ -43,6 +43,11 @@ class SlimPruner(Pruner):
An integer to control steps of training the model and scale factors. Masks will be generated after ``training_steps``.
regular_scale
``regular_scale`` controls the scale factors' penalty.
Examples
--------
Please refer to
:githublink:`examples/compression/pruning/slim_pruning.py <examples/compression/pruning/slim_pruning.py>`.
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload

Просмотреть файл

@ -43,7 +43,8 @@ class TaylorPruner(Pruner):
Examples
--------
TODO
Please refer to
:githublink:`examples/compression/pruning/taylor_pruning.py <examples/compression/pruning/taylor_pruning.py>`.
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload

Просмотреть файл

@ -6,11 +6,11 @@ _EVALUATOR_DOCSTRING = r"""NNI will use the evaluator to intervene in the model
so as to perform training-aware model compression.
All training-aware model compression will use the evaluator as the entry for intervention training in the future.
Usually you just need to wrap some classes with ``nni.trace`` or package the training process as a function to initialize the evaluator.
Please refer :doc:`/compression/compression_evaluator` for a full tutorial on how to initialize a ``evaluator``.
Please refer :doc:`/compression_preview/evaluator` for a full tutorial on how to initialize a ``evaluator``.
The following are two simple examples, if you use native pytorch, please refer to :class:`nni.contrib.compression.TorchEvaluator`,
if you use pytorch_lightning, please refer to :class:`nni.contrib.compression.LightningEvaluator`,
if you use huggingface transformer trainer, please refer to :class:`nni.contrib.compression.TransformersEvaluator`::
The following are two simple examples, if you use native pytorch, please refer to :ref:`new-torch-evaluator`,
if you use pytorch_lightning, please refer to :ref:`new-lightning-evaluator`,
if you use huggingface transformer trainer, please refer to :ref:`new-transformers-evaluator`::
# LightningEvaluator example
import pytorch_lightning

Просмотреть файл

@ -701,19 +701,13 @@ class TorchEvaluator(Evaluator):
training_step: Callable[[Any, Any], torch.Tensor],
lr_schedulers: _LRScheduler | None = None, max_steps: int | None = None,
max_epochs: int | None = None, *args, **kwargs):
...
total_epochs = max_epochs if max_epochs else 20
total_steps = max_steps if max_steps else 1000000
current_steps = 0
...
for epoch in range(total_epochs):
...
if current_steps >= total_steps:
return
@ -732,13 +726,16 @@ class TorchEvaluator(Evaluator):
training_step
A callable function, the first argument of inputs should be ``batch``, and the outputs should contain loss.
Three kinds of outputs are supported: single loss, tuple with the first element is loss, a dict contains a key ``loss``.
.. code-block:: python
def training_step(batch, model, ...):
inputs, labels = batch
output = model(inputs)
...
loss = loss_func(output, labels)
return loss
lr_schedulers
Optional. A single traced lr_scheduler instance or a list of traced lr_schedulers by ``nni.trace``.
For the same reason with ``optimizers``, NNI needs the traced lr_scheduler to re-initialize it.