зеркало из https://github.com/microsoft/nni.git
[Compression] Quantization Preview Doc (#5516)
Co-authored-by: J-shang <shangning128@163.com>
This commit is contained in:
Родитель
f6f0f65941
Коммит
3344d38b79
|
@ -7,5 +7,4 @@ format for model weights is 32-bit float, or FP32. Many research works have demo
|
|||
can be represented using 8-bit integers without significant loss in accuracy. Even lower bit-widths, such as 4/2/1 bits,
|
||||
is an active field of research.
|
||||
|
||||
A quantizer is a quantization algorithm implementation in NNI.
|
||||
You can also :doc:`create your own quantizer <../tutorials/quantization_customize>` using NNI model compression interface.
|
||||
A quantizer is a quantization algorithm implementation in NNI.
|
|
@ -89,6 +89,11 @@ As a result, most operations that couldn't be traced in the previous pruning spe
|
|||
In addition to ``concrete_trace``, users who have a good ``torch.fx.GraphModule`` for their traced model can also use the ``torch.fx.GraphModule`` directly.
|
||||
Furthermore, the new pruning speedup supports customized masks propagation logic and module replacement methods to cope with the speedup of various customized modules.
|
||||
|
||||
Model Fusion
|
||||
------------
|
||||
|
||||
Model fusion is supported in NNI 3.0. You can use it easily by setting ``fuse_names`` in each configure in the config_list.
|
||||
Please refer :doc:`Module Fusion <./module_fusion>` for more details.
|
||||
|
||||
Distillation
|
||||
------------
|
||||
|
|
|
@ -278,6 +278,13 @@ quant_scheme
|
|||
``affine`` or ``symmetric``. If this key is not set, the quantization scheme will be choosen by quantizer,
|
||||
most quantizer will apply ``symmetric`` quantization.
|
||||
|
||||
fuse_names
|
||||
^^^^^^^^^^
|
||||
|
||||
``List[(str,)]``. Optional parameter, each tuple defines the module and modules that need to be fused in the first module.
|
||||
Each element in the tuple is the module name in the model.
|
||||
Note that the first module name in each tuple should be in the ``op_name`` or ``op_name_re``.
|
||||
|
||||
granularity
|
||||
^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
Module Fusion
|
||||
=============
|
||||
|
||||
Module fusion is a new feature in the quantizatizer of NNI 3.0. This feature can fuse the specified
|
||||
sub-models in the simulated quantization process to align with the inference stage of model deployment,
|
||||
reducing the error between the simulated quantization and inference stages.
|
||||
|
||||
Users can use this feature by directly defining ``fuse_names`` in each configure of config_list.
|
||||
``fuse_names`` is an optional parameter of type ``List[(str,)]``. Each tuple specifies the name of the module
|
||||
to be fused in the current configure in the model. Meanwhile, each tuple has 2 or 3 elements, and the first module
|
||||
in each tuple is the fused module, which contains all the operations of all the modules in the tuple.
|
||||
The rest of the modules will be replaced by ``Identity`` during the quantization process. Here is an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# define the Mnist Model
|
||||
class Mnist(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
|
||||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
|
||||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
|
||||
self.fc2 = torch.nn.Linear(500, 10)
|
||||
self.relu1 = torch.nn.ReLU6()
|
||||
self.relu2 = torch.nn.ReLU6()
|
||||
self.relu3 = torch.nn.ReLU6()
|
||||
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
|
||||
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
|
||||
self.batchnorm1 = torch.nn.BatchNorm2d(20)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(self.batchnorm1(self.conv1(x)))
|
||||
x = self.max_pool1(x)
|
||||
x = self.relu2(self.conv2(x))
|
||||
x = self.max_pool2(x)
|
||||
x = x.view(-1, 4 * 4 * 50)
|
||||
x = self.relu3(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
# define the config list
|
||||
config_list = [
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'op_names': ['conv1'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
'fuse_names': [("conv1", "batchnorm1")]
|
||||
}]
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
Overview of NNI Model Quantization
|
||||
==================================
|
||||
|
||||
Quantization refers to compressing models by reducing the number of bits required to represent weights or activations,
|
||||
which can reduce the computations and the inference time. In the context of deep neural networks, the major numerical
|
||||
format for model weights is 32-bit float, or FP32. Many research works have demonstrated that weights and activations
|
||||
can be represented using 8-bit integers without significant loss in accuracy. Even lower bit-widths, such as 4/2/1 bits,
|
||||
is an active field of research.
|
||||
|
||||
A quantizer is a quantization algorithm implementation in NNI.
|
||||
You can also :doc:`create your own quantizer <../tutorials/quantization_customize>` using NNI model compression interface.
|
|
@ -0,0 +1,8 @@
|
|||
Quickstart
|
||||
==========
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 2
|
||||
|
||||
Quantization Quickstart </tutorials/quantization_quick_start>
|
|
@ -0,0 +1,22 @@
|
|||
Quantizer in NNI
|
||||
================
|
||||
|
||||
NNI implements the main part of the quantizaiton algorithm as quantizer. All quantizers are implemented as close as possible to what is described in the paper (if it has).
|
||||
The following table provides a brief introduction to the quantizers implemented in nni, click the link in table to view a more detailed introduction and use cases.
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
:widths: auto
|
||||
|
||||
* - Name
|
||||
- Brief Introduction of Algorithm
|
||||
* - :ref:`NewQATQuantizer`
|
||||
- Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. `Reference Paper <http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf>`__
|
||||
* - :ref:`NewDorefaQuantizer`
|
||||
- DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. `Reference Paper <https://arxiv.org/abs/1606.06160>`__
|
||||
* - :ref:`NewBNNQuantizer`
|
||||
- Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1. `Reference Paper <https://arxiv.org/abs/1602.02830>`__
|
||||
* - :ref:`NewLsqQuantizer`
|
||||
- Learned step size quantization. `Reference Paper <https://arxiv.org/pdf/1902.08153.pdf>`__
|
||||
* - :ref:`NewPtqQuantizer`
|
||||
- Post training quantizaiton. Collect quantization information during calibration with observers.
|
|
@ -9,6 +9,9 @@ Compression (Preview)
|
|||
Enhancement <changes>
|
||||
Config Specification <config_list>
|
||||
Pruning <toctree_pruning>
|
||||
Quantization <toctree_quantization>
|
||||
Evaluator <evaluator>
|
||||
Customize Setting <setting>
|
||||
Fusion Compression <fusion_compress>
|
||||
Module Fusion <module_fusion>
|
||||
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
Quantization
|
||||
============
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 2
|
||||
|
||||
Overview <quantization>
|
||||
Quickstart <quantization_quick_start>
|
||||
Quantizer <quantizer>
|
||||
SpeedUp </tutorials/quantization_speedup>
|
|
@ -0,0 +1,37 @@
|
|||
Quantizer
|
||||
=========
|
||||
|
||||
.. _NewQATQuantizer:
|
||||
|
||||
QAT Quantizer
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: nni.contrib.compression.quantization.QATQuantizer
|
||||
|
||||
.. _NewDorefaQuantizer:
|
||||
|
||||
DoReFa Quantizer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: nni.contrib.compression.quantization.DoReFaQuantizer
|
||||
|
||||
.. _NewBNNQuantizer:
|
||||
|
||||
BNN Quantizer
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: nni.contrib.compression.quantization.BNNQuantizer
|
||||
|
||||
.. _NewLsqQuantizer:
|
||||
|
||||
LSQ Quantizer
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: nni.contrib.compression.quantization.LsqQuantizer
|
||||
|
||||
.. _NewPtqQuantizer:
|
||||
|
||||
PTQ Quantizer
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: nni.contrib.compression.quantization.PtqQuantizer
|
|
@ -9,3 +9,4 @@ Compression API Reference (Preview)
|
|||
Distiller <distiller>
|
||||
Evaluator <evaluator>
|
||||
Compression Utilities <utils>
|
||||
Quantizer <quantizer>
|
||||
|
|
Двоичные данные
docs/source/tutorials/images/thumb/sphx_glr_quantization_quick_start_thumb.png
Normal file
Двоичные данные
docs/source/tutorials/images/thumb/sphx_glr_quantization_quick_start_thumb.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 18 KiB |
|
@ -95,6 +95,23 @@ Tutorials
|
|||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Here is a four-minute video to get you started with model quantization.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_thumb.png
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_quick_start.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Quantization Quickstart</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Quantization algorithms quantize a deep learning model usually in a simulated way. That is, to ...">
|
||||
|
@ -210,6 +227,7 @@ Tutorials
|
|||
/tutorials/pruning_quick_start_mnist
|
||||
/tutorials/quantization_customize
|
||||
/tutorials/nasbench_as_dataset
|
||||
/tutorials/quantization_quick_start
|
||||
/tutorials/quantization_speedup
|
||||
/tutorials/hello_nas
|
||||
/tutorials/quantization_bert_glue
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n# Quantization Quickstart\n\nHere is a four-minute video to get you started with model quantization.\n\n.. youtube:: MSfV7AyfiA4\n :align: center\n\nQuantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.\n\nIn NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported.\nHere we use `QATQuantizer` as an example to show the usage of quantization in NNI.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Preparation\n\nIn this tutorial, we use a simple model and pre-train on MNIST dataset.\nIf you are familiar with defining a model and training in pytorch, you can skip directly to `Quantizing Model`_.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import functools\nimport time\nfrom typing import Callable, Union, List, Dict, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.optim import Optimizer, SGD\nfrom torch.utils.data import DataLoader\nfrom torch import Tensor\n\nfrom nni.common.types import SCHEDULER"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define the model\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Mnist(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)\n self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)\n self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)\n self.fc2 = torch.nn.Linear(500, 10)\n self.relu1 = torch.nn.ReLU6()\n self.relu2 = torch.nn.ReLU6()\n self.relu3 = torch.nn.ReLU6()\n self.max_pool1 = torch.nn.MaxPool2d(2, 2)\n self.max_pool2 = torch.nn.MaxPool2d(2, 2)\n self.batchnorm1 = torch.nn.BatchNorm2d(20)\n\n def forward(self, x):\n x = self.relu1(self.batchnorm1(self.conv1(x)))\n x = self.max_pool1(x)\n x = self.relu2(self.conv2(x))\n x = self.max_pool2(x)\n x = x.view(-1, 4 * 4 * 50)\n x = self.relu3(self.fc1(x))\n x = self.fc2(x)\n return F.log_softmax(x, dim=1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create training and evaluation dataloader\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.data import DataLoader\nfrom torchvision import transforms\nfrom torchvision.datasets import MNIST\n\nMNIST(root='data/mnist', train=True, download=True)\nMNIST(root='data/mnist', train=False, download=True)\ntransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\nmnist_train = MNIST(root='data/mnist', train=True, transform=transform)\ntrain_dataloader = DataLoader(mnist_train, batch_size=64)\nmnist_test = MNIST(root='data/mnist', train=False, transform=transform)\ntest_dataloader = DataLoader(mnist_test, batch_size=1000)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define training and evaluation functions\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n\ndef training_step(batch, model) -> Tensor:\n x, y = batch[0].to(device), batch[1].to(device)\n logits = model(x)\n loss: torch.Tensor = F.nll_loss(logits, y)\n return loss\n\n\ndef training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,\n max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):\n model.train()\n max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100\n current_steps = 0\n\n # training\n for epoch in range(max_epochs):\n print(f'Epoch {epoch} start!')\n for batch in train_dataloader:\n optimizer.zero_grad()\n loss = training_step(batch, model)\n loss.backward()\n optimizer.step()\n current_steps += 1\n if max_steps and current_steps == max_steps:\n return\n if scheduler is not None:\n scheduler.step()\n\n\ndef evaluating_model(model: torch.nn.Module):\n model.eval()\n # testing\n correct = 0\n with torch.no_grad():\n for x, y in test_dataloader:\n x, y = x.to(device), y.to(device)\n logits = model(x)\n preds = torch.argmax(logits, dim=1)\n correct += preds.eq(y.view_as(preds)).sum().item()\n return correct / len(mnist_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Pre-train and evaluate the model on MNIST dataset\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = Mnist().to(device)\noptimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n\nstart = time.time()\ntraining_model(model, optimizer, training_step, None, None, 5)\nprint(f'pure training 5 epochs: {time.time() - start}s')\nstart = time.time()\nacc = evaluating_model(model)\nprint(f'pure evaluating: {time.time() - start}s Acc.: {acc}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Quantizing Model\n\nInitialize a `config_list`.\nDetailed about how to write ``config_list`` please refer :doc:`Config Specification <../compression_preview/config_list>`.\n\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import nni\nfrom nni.contrib.compression.quantization import QATQuantizer\nfrom nni.contrib.compression.utils import TorchEvaluator\n\n\noptimizer = nni.trace(SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\nevaluator = TorchEvaluator(training_model, optimizer, training_step) # type: ignore\n\nconfig_list = [{\n 'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],\n 'target_names': ['_input_', 'weight', '_output_'],\n 'quant_dtype': 'int8',\n 'quant_scheme': 'affine',\n 'granularity': 'default',\n},{\n 'op_names': ['relu1', 'relu2', 'relu3'],\n 'target_names': ['_output_'],\n 'quant_dtype': 'int8',\n 'quant_scheme': 'affine',\n 'granularity': 'default',\n}]\n\nquantizer = QATQuantizer(model, config_list, evaluator, len(train_dataloader))\nreal_input = next(iter(train_dataloader))[0].to(device)\nquantizer.track_forward(real_input)\n\nstart = time.time()\n_, calibration_config = quantizer.compress(None, max_epochs=5)\nprint(f'pure training 5 epochs: {time.time() - start}s')\n\nprint(calibration_config)\nstart = time.time()\nacc = evaluating_model(model)\nprint(f'quantization evaluating: {time.time() - start}s Acc.: {acc}')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
|
@ -0,0 +1,177 @@
|
|||
"""
|
||||
Quantization Quickstart
|
||||
=======================
|
||||
|
||||
Here is a four-minute video to get you started with model quantization.
|
||||
|
||||
.. youtube:: MSfV7AyfiA4
|
||||
:align: center
|
||||
|
||||
Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.
|
||||
|
||||
In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported.
|
||||
Here we use `QATQuantizer` as an example to show the usage of quantization in NNI.
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Preparation
|
||||
# -----------
|
||||
#
|
||||
# In this tutorial, we use a simple model and pre-train on MNIST dataset.
|
||||
# If you are familiar with defining a model and training in pytorch, you can skip directly to `Quantizing Model`_.
|
||||
|
||||
import functools
|
||||
import time
|
||||
from typing import Callable, Union, List, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer, SGD
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
# %%
|
||||
# Define the model
|
||||
class Mnist(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
|
||||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
|
||||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
|
||||
self.fc2 = torch.nn.Linear(500, 10)
|
||||
self.relu1 = torch.nn.ReLU6()
|
||||
self.relu2 = torch.nn.ReLU6()
|
||||
self.relu3 = torch.nn.ReLU6()
|
||||
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
|
||||
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
|
||||
self.batchnorm1 = torch.nn.BatchNorm2d(20)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(self.batchnorm1(self.conv1(x)))
|
||||
x = self.max_pool1(x)
|
||||
x = self.relu2(self.conv2(x))
|
||||
x = self.max_pool2(x)
|
||||
x = x.view(-1, 4 * 4 * 50)
|
||||
x = self.relu3(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
# %%
|
||||
# Create training and evaluation dataloader
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
MNIST(root='data/mnist', train=True, download=True)
|
||||
MNIST(root='data/mnist', train=False, download=True)
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
mnist_train = MNIST(root='data/mnist', train=True, transform=transform)
|
||||
train_dataloader = DataLoader(mnist_train, batch_size=64)
|
||||
mnist_test = MNIST(root='data/mnist', train=False, transform=transform)
|
||||
test_dataloader = DataLoader(mnist_test, batch_size=1000)
|
||||
|
||||
|
||||
# %%
|
||||
# Define training and evaluation functions
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def training_step(batch, model) -> Tensor:
|
||||
x, y = batch[0].to(device), batch[1].to(device)
|
||||
logits = model(x)
|
||||
loss: torch.Tensor = F.nll_loss(logits, y)
|
||||
return loss
|
||||
|
||||
|
||||
def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,
|
||||
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
|
||||
model.train()
|
||||
max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100
|
||||
current_steps = 0
|
||||
|
||||
# training
|
||||
for epoch in range(max_epochs):
|
||||
print(f'Epoch {epoch} start!')
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = training_step(batch, model)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
current_steps += 1
|
||||
if max_steps and current_steps == max_steps:
|
||||
return
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
|
||||
def evaluating_model(model: torch.nn.Module):
|
||||
model.eval()
|
||||
# testing
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for x, y in test_dataloader:
|
||||
x, y = x.to(device), y.to(device)
|
||||
logits = model(x)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
correct += preds.eq(y.view_as(preds)).sum().item()
|
||||
return correct / len(mnist_test)
|
||||
|
||||
|
||||
# %%
|
||||
# Pre-train and evaluate the model on MNIST dataset
|
||||
model = Mnist().to(device)
|
||||
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
start = time.time()
|
||||
training_model(model, optimizer, training_step, None, None, 5)
|
||||
print(f'pure training 5 epochs: {time.time() - start}s')
|
||||
start = time.time()
|
||||
acc = evaluating_model(model)
|
||||
print(f'pure evaluating: {time.time() - start}s Acc.: {acc}')
|
||||
|
||||
|
||||
# %%
|
||||
# Quantizing Model
|
||||
# ----------------
|
||||
#
|
||||
# Initialize a `config_list`.
|
||||
# Detailed about how to write ``config_list`` please refer :doc:`Config Specification <../compression_preview/config_list>`.
|
||||
|
||||
import nni
|
||||
from nni.contrib.compression.quantization import QATQuantizer
|
||||
from nni.contrib.compression.utils import TorchEvaluator
|
||||
|
||||
|
||||
optimizer = nni.trace(SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
evaluator = TorchEvaluator(training_model, optimizer, training_step) # type: ignore
|
||||
|
||||
config_list = [{
|
||||
'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],
|
||||
'target_names': ['_input_', 'weight', '_output_'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
},{
|
||||
'op_names': ['relu1', 'relu2', 'relu3'],
|
||||
'target_names': ['_output_'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
}]
|
||||
|
||||
quantizer = QATQuantizer(model, config_list, evaluator, len(train_dataloader))
|
||||
real_input = next(iter(train_dataloader))[0].to(device)
|
||||
quantizer.track_forward(real_input)
|
||||
|
||||
start = time.time()
|
||||
_, calibration_config = quantizer.compress(None, max_epochs=5)
|
||||
print(f'pure training 5 epochs: {time.time() - start}s')
|
||||
|
||||
print(calibration_config)
|
||||
start = time.time()
|
||||
acc = evaluating_model(model)
|
||||
print(f'quantization evaluating: {time.time() - start}s Acc.: {acc}')
|
|
@ -0,0 +1 @@
|
|||
f72305e67164ac9f28472df05bd8c53d
|
|
@ -0,0 +1,326 @@
|
|||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "tutorials/quantization_quick_start.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
:ref:`Go to the end <sphx_glr_download_tutorials_quantization_quick_start.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_tutorials_quantization_quick_start.py:
|
||||
|
||||
|
||||
Quantization Quickstart
|
||||
=======================
|
||||
|
||||
Here is a four-minute video to get you started with model quantization.
|
||||
|
||||
.. youtube:: MSfV7AyfiA4
|
||||
:align: center
|
||||
|
||||
Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.
|
||||
|
||||
In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported.
|
||||
Here we use `QATQuantizer` as an example to show the usage of quantization in NNI.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 17-22
|
||||
|
||||
Preparation
|
||||
-----------
|
||||
|
||||
In this tutorial, we use a simple model and pre-train on MNIST dataset.
|
||||
If you are familiar with defining a model and training in pytorch, you can skip directly to `Quantizing Model`_.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 22-36
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import functools
|
||||
import time
|
||||
from typing import Callable, Union, List, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer, SGD
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 37-38
|
||||
|
||||
Define the model
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 38-63
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
class Mnist(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
|
||||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
|
||||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
|
||||
self.fc2 = torch.nn.Linear(500, 10)
|
||||
self.relu1 = torch.nn.ReLU6()
|
||||
self.relu2 = torch.nn.ReLU6()
|
||||
self.relu3 = torch.nn.ReLU6()
|
||||
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
|
||||
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
|
||||
self.batchnorm1 = torch.nn.BatchNorm2d(20)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(self.batchnorm1(self.conv1(x)))
|
||||
x = self.max_pool1(x)
|
||||
x = self.relu2(self.conv2(x))
|
||||
x = self.max_pool2(x)
|
||||
x = x.view(-1, 4 * 4 * 50)
|
||||
x = self.relu3(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 64-65
|
||||
|
||||
Create training and evaluation dataloader
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 65-78
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
MNIST(root='data/mnist', train=True, download=True)
|
||||
MNIST(root='data/mnist', train=False, download=True)
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
mnist_train = MNIST(root='data/mnist', train=True, transform=transform)
|
||||
train_dataloader = DataLoader(mnist_train, batch_size=64)
|
||||
mnist_test = MNIST(root='data/mnist', train=False, transform=transform)
|
||||
test_dataloader = DataLoader(mnist_test, batch_size=1000)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 79-80
|
||||
|
||||
Define training and evaluation functions
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 80-124
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def training_step(batch, model) -> Tensor:
|
||||
x, y = batch[0].to(device), batch[1].to(device)
|
||||
logits = model(x)
|
||||
loss: torch.Tensor = F.nll_loss(logits, y)
|
||||
return loss
|
||||
|
||||
|
||||
def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,
|
||||
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
|
||||
model.train()
|
||||
max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100
|
||||
current_steps = 0
|
||||
|
||||
# training
|
||||
for epoch in range(max_epochs):
|
||||
print(f'Epoch {epoch} start!')
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = training_step(batch, model)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
current_steps += 1
|
||||
if max_steps and current_steps == max_steps:
|
||||
return
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
|
||||
def evaluating_model(model: torch.nn.Module):
|
||||
model.eval()
|
||||
# testing
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for x, y in test_dataloader:
|
||||
x, y = x.to(device), y.to(device)
|
||||
logits = model(x)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
correct += preds.eq(y.view_as(preds)).sum().item()
|
||||
return correct / len(mnist_test)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 125-126
|
||||
|
||||
Pre-train and evaluate the model on MNIST dataset
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 126-137
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
model = Mnist().to(device)
|
||||
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
start = time.time()
|
||||
training_model(model, optimizer, training_step, None, None, 5)
|
||||
print(f'pure training 5 epochs: {time.time() - start}s')
|
||||
start = time.time()
|
||||
acc = evaluating_model(model)
|
||||
print(f'pure evaluating: {time.time() - start}s Acc.: {acc}')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Epoch 0 start!
|
||||
Epoch 1 start!
|
||||
Epoch 2 start!
|
||||
Epoch 3 start!
|
||||
Epoch 4 start!
|
||||
pure training 5 epochs: 47.914021015167236s
|
||||
pure evaluating: 1.2639274597167969s Acc.: 0.9897
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 138-143
|
||||
|
||||
Quantizing Model
|
||||
----------------
|
||||
|
||||
Initialize a `config_list`.
|
||||
Detailed about how to write ``config_list`` please refer :doc:`Config Specification <../compression_preview/config_list>`.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 143-177
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import nni
|
||||
from nni.contrib.compression.quantization import QATQuantizer
|
||||
from nni.contrib.compression.utils import TorchEvaluator
|
||||
|
||||
|
||||
optimizer = nni.trace(SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
evaluator = TorchEvaluator(training_model, optimizer, training_step) # type: ignore
|
||||
|
||||
config_list = [{
|
||||
'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],
|
||||
'target_names': ['_input_', 'weight', '_output_'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
},{
|
||||
'op_names': ['relu1', 'relu2', 'relu3'],
|
||||
'target_names': ['_output_'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
}]
|
||||
|
||||
quantizer = QATQuantizer(model, config_list, evaluator, len(train_dataloader))
|
||||
real_input = next(iter(train_dataloader))[0].to(device)
|
||||
quantizer.track_forward(real_input)
|
||||
|
||||
start = time.time()
|
||||
_, calibration_config = quantizer.compress(None, max_epochs=5)
|
||||
print(f'pure training 5 epochs: {time.time() - start}s')
|
||||
|
||||
print(calibration_config)
|
||||
start = time.time()
|
||||
acc = evaluating_model(model)
|
||||
print(f'quantization evaluating: {time.time() - start}s Acc.: {acc}')
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Epoch 0 start!
|
||||
Epoch 1 start!
|
||||
Epoch 2 start!
|
||||
Epoch 3 start!
|
||||
Epoch 4 start!
|
||||
pure training 5 epochs: 78.95339393615723s
|
||||
defaultdict(<class 'dict'>, {'fc2': {'weight': {'scale': tensor(0.0017), 'zero_point': tensor(-5.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.2286), 'tracked_min': tensor(-0.2105)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.1543), 'zero_point': tensor(-35.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(25.0385), 'tracked_min': tensor(-14.1545)}}, 'conv2': {'weight': {'scale': tensor(0.0011), 'zero_point': tensor(-19.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.1659), 'tracked_min': tensor(-0.1226)}, '_input_0': {'scale': tensor(0.0230), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(5.8373), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.0971), 'zero_point': tensor(-6.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(12.9122), 'tracked_min': tensor(-11.7522)}}, 'fc1': {'weight': {'scale': tensor(0.0007), 'zero_point': tensor(-3.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.0885), 'tracked_min': tensor(-0.0844)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.0611), 'zero_point': tensor(-7.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(8.2104), 'tracked_min': tensor(-7.3205)}}, 'conv1': {'weight': {'scale': tensor(0.0021), 'zero_point': tensor(-19.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.3130), 'tracked_min': tensor(-0.2318)}, '_input_0': {'scale': tensor(0.0128), 'zero_point': tensor(-94.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(2.8215), 'tracked_min': tensor(-0.4242)}, '_output_0': {'scale': tensor(0.0311), 'zero_point': tensor(13.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(3.5516), 'tracked_min': tensor(-4.3537)}}, 'relu3': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}}, 'relu1': {'_output_0': {'scale': tensor(0.0232), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(5.8952), 'tracked_min': tensor(0.)}}, 'relu2': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}}})
|
||||
quantization evaluating: 1.2496261596679688s Acc.: 0.9902
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 2 minutes 14.073 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_tutorials_quantization_quick_start.py:
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. container:: sphx-glr-footer sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: quantization_quick_start.py <quantization_quick_start.py>`
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: quantization_quick_start.ipynb <quantization_quick_start.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
Двоичные данные
docs/source/tutorials/quantization_quick_start_codeobj.pickle
сгенерированный
Normal file
Двоичные данные
docs/source/tutorials/quantization_quick_start_codeobj.pickle
сгенерированный
Normal file
Двоичный файл не отображается.
|
@ -6,8 +6,12 @@
|
|||
|
||||
Computation times
|
||||
=================
|
||||
**09:40.647** total execution time for **tutorials** files:
|
||||
**02:14.073** total execution time for **tutorials** files:
|
||||
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_quick_start.py` (``quantization_quick_start.py``) | 02:14.073 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_nasbench_as_dataset.py` (``nasbench_as_dataset.py``) | 01:51.444 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_bert_glue.py` (``quantization_bert_glue.py``) | 09:40.647 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
@ -21,6 +25,10 @@ Computation times
|
|||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_new_pruning_bert_glue.py` (``new_pruning_bert_glue.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py` (``pruning_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_speedup.py` (``pruning_speedup.py``) | 00:00.000 | 0.0 MB |
|
||||
|
|
|
@ -20,7 +20,7 @@ from nni.common.types import SCHEDULER
|
|||
|
||||
|
||||
torch.manual_seed(1024)
|
||||
device = 'cuda'
|
||||
device = 'cuda:0'
|
||||
|
||||
|
||||
MNIST(root='data/mnist', train=True, download=True)
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
"""
|
||||
Quantization Quickstart
|
||||
=======================
|
||||
|
||||
Here is a four-minute video to get you started with model quantization.
|
||||
|
||||
.. youtube:: MSfV7AyfiA4
|
||||
:align: center
|
||||
|
||||
Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.
|
||||
|
||||
In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported.
|
||||
Here we use `QATQuantizer` as an example to show the usage of quantization in NNI.
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Preparation
|
||||
# -----------
|
||||
#
|
||||
# In this tutorial, we use a simple model and pre-train on MNIST dataset.
|
||||
# If you are familiar with defining a model and training in pytorch, you can skip directly to `Quantizing Model`_.
|
||||
|
||||
import functools
|
||||
import time
|
||||
from typing import Callable, Union, List, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer, SGD
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
# %%
|
||||
# Define the model
|
||||
class Mnist(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
|
||||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
|
||||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
|
||||
self.fc2 = torch.nn.Linear(500, 10)
|
||||
self.relu1 = torch.nn.ReLU6()
|
||||
self.relu2 = torch.nn.ReLU6()
|
||||
self.relu3 = torch.nn.ReLU6()
|
||||
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
|
||||
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
|
||||
self.batchnorm1 = torch.nn.BatchNorm2d(20)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(self.batchnorm1(self.conv1(x)))
|
||||
x = self.max_pool1(x)
|
||||
x = self.relu2(self.conv2(x))
|
||||
x = self.max_pool2(x)
|
||||
x = x.view(-1, 4 * 4 * 50)
|
||||
x = self.relu3(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
# %%
|
||||
# Create training and evaluation dataloader
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
MNIST(root='data/mnist', train=True, download=True)
|
||||
MNIST(root='data/mnist', train=False, download=True)
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
mnist_train = MNIST(root='data/mnist', train=True, transform=transform)
|
||||
train_dataloader = DataLoader(mnist_train, batch_size=64)
|
||||
mnist_test = MNIST(root='data/mnist', train=False, transform=transform)
|
||||
test_dataloader = DataLoader(mnist_test, batch_size=1000)
|
||||
|
||||
|
||||
# %%
|
||||
# Define training and evaluation functions
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def training_step(batch, model) -> Tensor:
|
||||
x, y = batch[0].to(device), batch[1].to(device)
|
||||
logits = model(x)
|
||||
loss: torch.Tensor = F.nll_loss(logits, y)
|
||||
return loss
|
||||
|
||||
|
||||
def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,
|
||||
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
|
||||
model.train()
|
||||
max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100
|
||||
current_steps = 0
|
||||
|
||||
# training
|
||||
for epoch in range(max_epochs):
|
||||
print(f'Epoch {epoch} start!')
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = training_step(batch, model)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
current_steps += 1
|
||||
if max_steps and current_steps == max_steps:
|
||||
return
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
|
||||
def evaluating_model(model: torch.nn.Module):
|
||||
model.eval()
|
||||
# testing
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for x, y in test_dataloader:
|
||||
x, y = x.to(device), y.to(device)
|
||||
logits = model(x)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
correct += preds.eq(y.view_as(preds)).sum().item()
|
||||
return correct / len(mnist_test)
|
||||
|
||||
|
||||
# %%
|
||||
# Pre-train and evaluate the model on MNIST dataset
|
||||
model = Mnist().to(device)
|
||||
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
start = time.time()
|
||||
training_model(model, optimizer, training_step, None, None, 5)
|
||||
print(f'pure training 5 epochs: {time.time() - start}s')
|
||||
start = time.time()
|
||||
acc = evaluating_model(model)
|
||||
print(f'pure evaluating: {time.time() - start}s Acc.: {acc}')
|
||||
|
||||
|
||||
# %%
|
||||
# Quantizing Model
|
||||
# ----------------
|
||||
#
|
||||
# Initialize a `config_list`.
|
||||
# Detailed about how to write ``config_list`` please refer :doc:`Config Specification <../compression_preview/config_list>`.
|
||||
|
||||
import nni
|
||||
from nni.contrib.compression.quantization import QATQuantizer
|
||||
from nni.contrib.compression.utils import TorchEvaluator
|
||||
|
||||
|
||||
optimizer = nni.trace(SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
evaluator = TorchEvaluator(training_model, optimizer, training_step) # type: ignore
|
||||
|
||||
config_list = [{
|
||||
'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],
|
||||
'target_names': ['_input_', 'weight', '_output_'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
},{
|
||||
'op_names': ['relu1', 'relu2', 'relu3'],
|
||||
'target_names': ['_output_'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
}]
|
||||
|
||||
quantizer = QATQuantizer(model, config_list, evaluator, len(train_dataloader))
|
||||
real_input = next(iter(train_dataloader))[0].to(device)
|
||||
quantizer.track_forward(real_input)
|
||||
|
||||
start = time.time()
|
||||
_, calibration_config = quantizer.compress(None, max_epochs=5)
|
||||
print(f'pure training 5 epochs: {time.time() - start}s')
|
||||
|
||||
print(calibration_config)
|
||||
start = time.time()
|
||||
acc = evaluating_model(model)
|
||||
print(f'quantization evaluating: {time.time() - start}s Acc.: {acc}')
|
Загрузка…
Ссылка в новой задаче