зеркало из https://github.com/microsoft/nni.git
[BugFix] fix compression bugs (#5140)
This commit is contained in:
Родитель
56c6cfeaed
Коммит
a966834755
|
@ -1,4 +1,4 @@
|
|||
.. a6a9f0292afa81c7796304ae7da5afcd
|
||||
.. e6c000f46f269ea88861ca2cd3b597ae
|
||||
|
||||
Web 界面
|
||||
========
|
||||
|
|
|
@ -8,7 +8,7 @@ msgid ""
|
|||
msgstr ""
|
||||
"Project-Id-Version: NNI \n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2022-05-27 16:52+0800\n"
|
||||
"POT-Creation-Date: 2022-10-18 19:27+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||
|
@ -127,12 +127,6 @@ msgstr ""
|
|||
#: ../../source/tutorials/hello_nas.rst:564
|
||||
#: ../../source/tutorials/hpo_quickstart_pytorch/main.rst:244
|
||||
#: ../../source/tutorials/hpo_quickstart_pytorch/main.rst:281
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:70
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:112
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:177
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:223
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:260
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:288
|
||||
msgid "Out:"
|
||||
msgstr ""
|
||||
|
||||
|
@ -347,7 +341,7 @@ msgstr ""
|
|||
|
||||
#: ../../source/tutorials/hello_nas.rst:625
|
||||
#: ../../source/tutorials/hpo_quickstart_pytorch/main.rst:335
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:362
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:343
|
||||
msgid "`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_"
|
||||
msgstr ""
|
||||
|
||||
|
@ -690,11 +684,11 @@ msgid ""
|
|||
"can skip directly to `Pruning Model`_."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:126
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:122
|
||||
msgid "Pruning Model"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:128
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:124
|
||||
msgid ""
|
||||
"Using L1NormPruner to prune the model and generate the masks. Usually, a "
|
||||
"pruner requires original model and ``config_list`` as its inputs. "
|
||||
|
@ -703,7 +697,7 @@ msgid ""
|
|||
"<../compression/compression_config_list>`."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:132
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:128
|
||||
msgid ""
|
||||
"The following `config_list` means all layers whose type is `Linear` or "
|
||||
"`Conv2d` will be pruned, except the layer named `fc3`, because `fc3` is "
|
||||
|
@ -711,11 +705,11 @@ msgid ""
|
|||
"named `fc3` will not be pruned."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:158
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:154
|
||||
msgid "Pruners usually require `model` and `config_list` as input arguments."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:237
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:229
|
||||
msgid ""
|
||||
"Speedup the original model with masks, note that `ModelSpeedup` requires "
|
||||
"an unwrapped model. The model becomes smaller after speedup, and reaches "
|
||||
|
@ -723,32 +717,32 @@ msgid ""
|
|||
"across layers."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:274
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:262
|
||||
msgid "the model will become real smaller after speedup"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:312
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:298
|
||||
msgid "Fine-tuning Compacted Model"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:313
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:299
|
||||
msgid ""
|
||||
"Note that if the model has been sped up, you need to re-initialize a new "
|
||||
"optimizer for fine-tuning. Because speedup will replace the masked big "
|
||||
"layers with dense small ones."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:334
|
||||
msgid "**Total running time of the script:** ( 1 minutes 30.730 seconds)"
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:320
|
||||
msgid "**Total running time of the script:** ( 1 minutes 0.810 seconds)"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:349
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:332
|
||||
msgid ""
|
||||
":download:`Download Python source code: pruning_quick_start_mnist.py "
|
||||
"<pruning_quick_start_mnist.py>`"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:355
|
||||
#: ../../source/tutorials/pruning_quick_start_mnist.rst:336
|
||||
msgid ""
|
||||
":download:`Download Jupyter notebook: pruning_quick_start_mnist.ipynb "
|
||||
"<pruning_quick_start_mnist.ipynb>`"
|
||||
|
@ -778,3 +772,6 @@ msgstr ""
|
|||
#~ msgid "**Total running time of the script:** ( 0 minutes 58.337 seconds)"
|
||||
#~ msgstr ""
|
||||
|
||||
#~ msgid "**Total running time of the script:** ( 1 minutes 30.730 seconds)"
|
||||
#~ msgstr ""
|
||||
|
||||
|
|
|
@ -1,24 +1,184 @@
|
|||
:orphan:
|
||||
|
||||
|
||||
|
||||
.. _sphx_glr_tutorials:
|
||||
|
||||
Tutorials
|
||||
=========
|
||||
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnails">
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Introduction ------------">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
|
||||
:alt: Speedup Model with Mask
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
|
||||
:alt: Speedup Model with Mask
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_speedup.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Speedup Model with Mask</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip=" Introduction ------------">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
|
||||
:alt: SpeedUp Model with Calibration Config
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_speedup.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">SpeedUp Model with Calibration Config</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Here is a four-minute video to get you started with model quantization.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
|
||||
:alt: Quantization Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Quantization Quickstart</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Here is a three-minute video to get you started with model pruning.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
|
||||
:alt: Pruning Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Pruning Quickstart</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="To write a new quantization algorithm, you can write a class that inherits nni.compression.pyto...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
|
||||
:alt: Customize a new quantization algorithm
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_customize.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Customize a new quantization algorithm</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we show how to use NAS Benchmarks as datasets. For research purposes we somet...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
|
||||
:alt: Use NAS Benchmarks as Datasets
|
||||
|
||||
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Use NAS Benchmarks as Datasets</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Users can easily customize a basic pruner in NNI. A large number of basic modules have been pro...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_customize_thumb.png
|
||||
:alt: Customize Basic Pruner
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_customize.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Customize Basic Pruner</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="This is the 101 tutorial of Neural Architecture Search (NAS) on NNI. In this tutorial, we will ...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
|
||||
:alt: Hello, NAS!
|
||||
|
||||
:ref:`sphx_glr_tutorials_hello_nas.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Hello, NAS!</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we demonstrate how to search in the famous model space proposed in `DARTS`_.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_darts_thumb.png
|
||||
:alt: Searching in DARTS search space
|
||||
|
||||
:ref:`sphx_glr_tutorials_darts.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Searching in DARTS search space</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
|
||||
:alt: Pruning Bert on Task MNLI
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Pruning Bert on Task MNLI</div>
|
||||
</div>
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_speedup.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
|
@ -29,204 +189,22 @@ Tutorials
|
|||
:hidden:
|
||||
|
||||
/tutorials/pruning_speedup
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip=" Introduction ------------">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
|
||||
:alt: SpeedUp Model with Calibration Config
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_speedup.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/quantization_speedup
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Here is a four-minute video to get you started with model quantization.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
|
||||
:alt: Quantization Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/quantization_quick_start_mnist
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Here is a three-minute video to get you started with model pruning.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
|
||||
:alt: Pruning Quickstart
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/pruning_quick_start_mnist
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="To write a new quantization algorithm, you can write a class that inherits nni.compression.pyto...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
|
||||
:alt: Customize a new quantization algorithm
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_customize.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/quantization_customize
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we show how to use NAS Benchmarks as datasets. For research purposes we somet...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
|
||||
:alt: Use NAS Benchmarks as Datasets
|
||||
|
||||
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/nasbench_as_dataset
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Users can easily customize a basic pruner in NNI. A large number of basic modules have been pro...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_customize_thumb.png
|
||||
:alt: Customize Basic Pruner
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_customize.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/pruning_customize
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="This is the 101 tutorial of Neural Architecture Search (NAS) on NNI. In this tutorial, we will ...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
|
||||
:alt: Hello, NAS!
|
||||
|
||||
:ref:`sphx_glr_tutorials_hello_nas.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hello_nas
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In this tutorial, we demonstrate how to search in the famous model space proposed in `DARTS`_.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_darts_thumb.png
|
||||
:alt: Searching in DARTS search space
|
||||
|
||||
:ref:`sphx_glr_tutorials_darts.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/darts
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Workable Pruning Process ------------------------">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
|
||||
:alt: Pruning Bert on Task MNLI
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/pruning_bert_glue
|
||||
|
||||
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
|
||||
|
||||
|
||||
.. _sphx_glr_tutorials_hpo_quickstart_pytorch:
|
||||
|
||||
|
||||
<div class="sphx-glr-thumbnails">
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
@ -235,50 +213,44 @@ Tutorials
|
|||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with PyTorch
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with PyTorch
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">HPO Quickstart with PyTorch</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hpo_quickstart_pytorch/main
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port PyTorch Quickstart to NNI
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port PyTorch Quickstart to NNI
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Port PyTorch Quickstart to NNI</div>
|
||||
</div>
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hpo_quickstart_pytorch/model
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
|
||||
|
||||
|
||||
.. _sphx_glr_tutorials_hpo_quickstart_tensorflow:
|
||||
|
||||
|
||||
<div class="sphx-glr-thumbnails">
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
@ -287,31 +259,33 @@ Tutorials
|
|||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with TensorFlow
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with TensorFlow
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">HPO Quickstart with TensorFlow</div>
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hpo_quickstart_tensorflow/main
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="It can be run directly and will have the exact same result as original version.">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port TensorFlow Quickstart to NNI
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port TensorFlow Quickstart to NNI
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbnail-title">Port TensorFlow Quickstart to NNI</div>
|
||||
</div>
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
|
@ -320,11 +294,10 @@ Tutorials
|
|||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:includehidden:
|
||||
|
||||
/tutorials/hpo_quickstart_tensorflow/model
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
/tutorials/hpo_quickstart_pytorch/index.rst
|
||||
/tutorials/hpo_quickstart_tensorflow/index.rst
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nimport torch.nn.functional as F\nfrom torch.optim import SGD\n\nfrom scripts.compression_mnist_model import TorchModel, trainer, evaluator, device\n\n# define the model\nmodel = TorchModel().to(device)\n\n# show the model structure, note that pruner will wrap the model layer.\nprint(model)"
|
||||
"import torch\nimport torch.nn.functional as F\nfrom torch.optim import SGD\n\nfrom nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device\n\n# define the model\nmodel = TorchModel().to(device)\n\n# show the model structure, note that pruner will wrap the model layer.\nprint(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -29,7 +29,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
|
||||
from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device
|
||||
from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device
|
||||
|
||||
# define the model
|
||||
model = TorchModel().to(device)
|
||||
|
|
|
@ -1 +1 @@
|
|||
33781311d6344b4aebb94db94a96dfd3
|
||||
e7c8d40b9d497d59db95ffcedfc1c450
|
|
@ -53,7 +53,7 @@ If you are familiar with defining a model and training in pytorch, you can skip
|
|||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
|
||||
from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device
|
||||
from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device
|
||||
|
||||
# define the model
|
||||
model = TorchModel().to(device)
|
||||
|
@ -67,8 +67,6 @@ If you are familiar with defining a model and training in pytorch, you can skip
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
TorchModel(
|
||||
|
@ -109,13 +107,11 @@ If you are familiar with defining a model and training in pytorch, you can skip
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Average test loss: 0.4925, Accuracy: 8414/10000 (84%)
|
||||
Average test loss: 0.2626, Accuracy: 9214/10000 (92%)
|
||||
Average test loss: 0.2006, Accuracy: 9369/10000 (94%)
|
||||
Average test loss: 1.3409, Accuracy: 6494/10000 (65%)
|
||||
Average test loss: 0.3263, Accuracy: 9003/10000 (90%)
|
||||
Average test loss: 0.2029, Accuracy: 9388/10000 (94%)
|
||||
|
||||
|
||||
|
||||
|
@ -174,8 +170,6 @@ Pruners usually require `model` and `config_list` as input arguments.
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
TorchModel(
|
||||
|
@ -220,8 +214,6 @@ Pruners usually require `model` and `config_list` as input arguments.
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
conv1 sparsity : 0.5
|
||||
|
@ -257,12 +249,8 @@ and reaches a higher sparsity ratio because `ModelSpeedup` will propagate the ma
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
aten::log_softmax is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~
|
||||
Note: .aten::log_softmax.12 does not have corresponding mask inference object
|
||||
/home/ningshang/anaconda3/envs/nni-dev/lib/python3.8/site-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:417.)
|
||||
return self._grad
|
||||
|
||||
|
@ -285,8 +273,6 @@ the model will become real smaller after speedup
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
TorchModel(
|
||||
|
@ -331,28 +317,23 @@ Because speedup will replace the masked big layers with dense small ones.
|
|||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 1 minutes 30.730 seconds)
|
||||
**Total running time of the script:** ( 1 minutes 0.810 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_tutorials_pruning_quick_start_mnist.py:
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
.. container:: sphx-glr-footer sphx-glr-footer-example
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
:download:`Download Python source code: pruning_quick_start_mnist.py <pruning_quick_start_mnist.py>`
|
||||
|
||||
:download:`Download Python source code: pruning_quick_start_mnist.py <pruning_quick_start_mnist.py>`
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: pruning_quick_start_mnist.ipynb <pruning_quick_start_mnist.ipynb>`
|
||||
:download:`Download Jupyter notebook: pruning_quick_start_mnist.ipynb <pruning_quick_start_mnist.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
|
Двоичный файл не отображается.
|
@ -1,4 +1,4 @@
|
|||
.. b743ab67f64dd0a0688a8cb184e0e947
|
||||
.. f2006d635ba8b91cd9cd311c1bd844f3
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nfrom scripts.compression_mnist_model import TorchModel, device\n\nmodel = TorchModel().to(device)\n# masks = {layer_name: {'weight': weight_mask, 'bias': bias_mask}}\nconv1_mask = torch.ones_like(model.conv1.weight.data)\n# mask the first three output channels in conv1\nconv1_mask[0: 3] = 0\nmasks = {'conv1': {'weight': conv1_mask}}"
|
||||
"import torch\nfrom nni_assets.compression.mnist_model import TorchModel, device\n\nmodel = TorchModel().to(device)\n# masks = {layer_name: {'weight': weight_mask, 'bias': bias_mask}}\nconv1_mask = torch.ones_like(model.conv1.weight.data)\n# mask the first three output channels in conv1\nconv1_mask[0: 3] = 0\nmasks = {'conv1': {'weight': conv1_mask}}"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -112,7 +112,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For combining usage of ``Pruner`` masks generation with ``ModelSpeedup``,\nplease refer to :doc:`Pruning Quick Start <pruning_quick_start_mnist>`.\n\nNOTE: The current implementation supports PyTorch 1.3.1 or newer.\n\n## Limitations\n\nFor PyTorch we can only replace modules, if functions in ``forward`` should be replaced,\nour current implementation does not work. One workaround is make the function a PyTorch module.\n\nIf you want to speedup your own model which cannot supported by the current implementation,\nyou need implement the replace function for module replacement, welcome to contribute.\n\n## Speedup Results of Examples\n\nThe code of these experiments can be found :githublink:`here <examples/model_compress/pruning/legacy/speedup/model_speedup.py>`.\n\nThese result are tested on the `legacy pruning framework <https://nni.readthedocs.io/en/v2.6/Compression/pruning.html>`_, new results will coming soon.\n\n### slim pruner example\n\non one V100 GPU,\ninput tensor: ``torch.randn(64, 3, 32, 32)``\n\n.. list-table::\n :header-rows: 1\n :widths: auto\n\n * - Times\n - Mask Latency\n - Speedup Latency\n * - 1\n - 0.01197\n - 0.005107\n * - 2\n - 0.02019\n - 0.008769\n * - 4\n - 0.02733\n - 0.014809\n * - 8\n - 0.04310\n - 0.027441\n * - 16\n - 0.07731\n - 0.05008\n * - 32\n - 0.14464\n - 0.10027\n\n### fpgm pruner example\n\non cpu,\ninput tensor: ``torch.randn(64, 1, 28, 28)``\\ ,\ntoo large variance\n\n.. list-table::\n :header-rows: 1\n :widths: auto\n\n * - Times\n - Mask Latency\n - Speedup Latency\n * - 1\n - 0.01383\n - 0.01839\n * - 2\n - 0.01167\n - 0.003558\n * - 4\n - 0.01636\n - 0.01088\n * - 40\n - 0.14412\n - 0.08268\n * - 40\n - 1.29385\n - 0.14408\n * - 40\n - 0.41035\n - 0.46162\n * - 400\n - 6.29020\n - 5.82143\n\n### l1filter pruner example\n\non one V100 GPU,\ninput tensor: ``torch.randn(64, 3, 32, 32)``\n\n.. list-table::\n :header-rows: 1\n :widths: auto\n\n * - Times\n - Mask Latency\n - Speedup Latency\n * - 1\n - 0.01026\n - 0.003677\n * - 2\n - 0.01657\n - 0.008161\n * - 4\n - 0.02458\n - 0.020018\n * - 8\n - 0.03498\n - 0.025504\n * - 16\n - 0.06757\n - 0.047523\n * - 32\n - 0.10487\n - 0.086442\n\n### APoZ pruner example\n\non one V100 GPU,\ninput tensor: ``torch.randn(64, 3, 32, 32)``\n\n.. list-table::\n :header-rows: 1\n :widths: auto\n\n * - Times\n - Mask Latency\n - Speedup Latency\n * - 1\n - 0.01389\n - 0.004208\n * - 2\n - 0.01628\n - 0.008310\n * - 4\n - 0.02521\n - 0.014008\n * - 8\n - 0.03386\n - 0.023923\n * - 16\n - 0.06042\n - 0.046183\n * - 32\n - 0.12421\n - 0.087113\n\n### SimulatedAnnealing pruner example\n\nIn this experiment, we use SimulatedAnnealing pruner to prune the resnet18 on the cifar10 dataset.\nWe measure the latencies and accuracies of the pruned model under different sparsity ratios, as shown in the following figure.\nThe latency is measured on one V100 GPU and the input tensor is ``torch.randn(128, 3, 32, 32)``.\n\n<img src=\"file://../../img/SA_latency_accuracy.png\">\n\n"
|
||||
"For combining usage of ``Pruner`` masks generation with ``ModelSpeedup``,\nplease refer to :doc:`Pruning Quick Start <pruning_quick_start_mnist>`.\n\nNOTE: The current implementation supports PyTorch 1.3.1 or newer.\n\n## Limitations\n\nFor PyTorch we can only replace modules, if functions in ``forward`` should be replaced,\nour current implementation does not work. One workaround is make the function a PyTorch module.\n\nIf you want to speedup your own model which cannot supported by the current implementation,\nyou need implement the replace function for module replacement, welcome to contribute.\n\n## Speedup Results of Examples\n\nThe code of these experiments can be found :githublink:`here <examples/model_compress/pruning/legacy/speedup/model_speedup.py>`.\n\nThese result are tested on the [legacy pruning framework](https://nni.readthedocs.io/en/v2.6/Compression/pruning.html), new results will coming soon.\n\n### slim pruner example\n\non one V100 GPU,\ninput tensor: ``torch.randn(64, 3, 32, 32)``\n\n.. list-table::\n :header-rows: 1\n :widths: auto\n\n * - Times\n - Mask Latency\n - Speedup Latency\n * - 1\n - 0.01197\n - 0.005107\n * - 2\n - 0.02019\n - 0.008769\n * - 4\n - 0.02733\n - 0.014809\n * - 8\n - 0.04310\n - 0.027441\n * - 16\n - 0.07731\n - 0.05008\n * - 32\n - 0.14464\n - 0.10027\n\n### fpgm pruner example\n\non cpu,\ninput tensor: ``torch.randn(64, 1, 28, 28)``\\ ,\ntoo large variance\n\n.. list-table::\n :header-rows: 1\n :widths: auto\n\n * - Times\n - Mask Latency\n - Speedup Latency\n * - 1\n - 0.01383\n - 0.01839\n * - 2\n - 0.01167\n - 0.003558\n * - 4\n - 0.01636\n - 0.01088\n * - 40\n - 0.14412\n - 0.08268\n * - 40\n - 1.29385\n - 0.14408\n * - 40\n - 0.41035\n - 0.46162\n * - 400\n - 6.29020\n - 5.82143\n\n### l1filter pruner example\n\non one V100 GPU,\ninput tensor: ``torch.randn(64, 3, 32, 32)``\n\n.. list-table::\n :header-rows: 1\n :widths: auto\n\n * - Times\n - Mask Latency\n - Speedup Latency\n * - 1\n - 0.01026\n - 0.003677\n * - 2\n - 0.01657\n - 0.008161\n * - 4\n - 0.02458\n - 0.020018\n * - 8\n - 0.03498\n - 0.025504\n * - 16\n - 0.06757\n - 0.047523\n * - 32\n - 0.10487\n - 0.086442\n\n### APoZ pruner example\n\non one V100 GPU,\ninput tensor: ``torch.randn(64, 3, 32, 32)``\n\n.. list-table::\n :header-rows: 1\n :widths: auto\n\n * - Times\n - Mask Latency\n - Speedup Latency\n * - 1\n - 0.01389\n - 0.004208\n * - 2\n - 0.01628\n - 0.008310\n * - 4\n - 0.02521\n - 0.014008\n * - 8\n - 0.03386\n - 0.023923\n * - 16\n - 0.06042\n - 0.046183\n * - 32\n - 0.12421\n - 0.087113\n\n### SimulatedAnnealing pruner example\n\nIn this experiment, we use SimulatedAnnealing pruner to prune the resnet18 on the cifar10 dataset.\nWe measure the latencies and accuracies of the pruned model under different sparsity ratios, as shown in the following figure.\nThe latency is measured on one V100 GPU and the input tensor is ``torch.randn(128, 3, 32, 32)``.\n\n<img src=\"file://../../img/SA_latency_accuracy.png\">\n\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -132,7 +132,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -43,7 +43,7 @@ Usage
|
|||
# But in fact ``ModelSpeedup`` is a relatively independent tool, so you can use it independently.
|
||||
|
||||
import torch
|
||||
from scripts.compression_mnist_model import TorchModel, device
|
||||
from nni_assets.compression.mnist_model import TorchModel, device
|
||||
|
||||
model = TorchModel().to(device)
|
||||
# masks = {layer_name: {'weight': weight_mask, 'bias': bias_mask}}
|
||||
|
|
|
@ -1 +1 @@
|
|||
dc5c2369666206591238118f0f746e46
|
||||
db22d0e9ae78d8e7910b77e3b4541dd5
|
|
@ -66,7 +66,7 @@ But in fact ``ModelSpeedup`` is a relatively independent tool, so you can use it
|
|||
|
||||
|
||||
import torch
|
||||
from scripts.compression_mnist_model import TorchModel, device
|
||||
from nni_assets.compression.mnist_model import TorchModel, device
|
||||
|
||||
model = TorchModel().to(device)
|
||||
# masks = {layer_name: {'weight': weight_mask, 'bias': bias_mask}}
|
||||
|
@ -98,8 +98,6 @@ Show the original model structure.
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
TorchModel(
|
||||
|
@ -138,11 +136,9 @@ Roughly test the original model inference speed.
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Original Model - Elapsed Time : 0.5094916820526123
|
||||
Original Model - Elapsed Time : 0.1178426742553711
|
||||
|
||||
|
||||
|
||||
|
@ -165,13 +161,9 @@ Speedup the model and show the model structure after speedup.
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
aten::log_softmax is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~
|
||||
Note: .aten::log_softmax.12 does not have corresponding mask inference object
|
||||
/home/nishang/anaconda3/envs/MCM/lib/python3.9/site-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /opt/conda/conda-bld/pytorch_1640811803361/work/build/aten/src/ATen/core/TensorBody.h:417.)
|
||||
/home/ningshang/anaconda3/envs/nni-dev/lib/python3.8/site-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:417.)
|
||||
return self._grad
|
||||
TorchModel(
|
||||
(conv1): Conv2d(1, 3, kernel_size=(5, 5), stride=(1, 1))
|
||||
|
@ -208,11 +200,9 @@ Roughly test the model after speedup inference speed.
|
|||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Speedup Model - Elapsed Time : 0.006000041961669922
|
||||
Speedup Model - Elapsed Time : 0.003069639205932617
|
||||
|
||||
|
||||
|
||||
|
@ -384,28 +374,23 @@ The latency is measured on one V100 GPU and the input tensor is ``torch.randn(1
|
|||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 4.528 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 15.253 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_tutorials_pruning_speedup.py:
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
.. container:: sphx-glr-footer sphx-glr-footer-example
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
:download:`Download Python source code: pruning_speedup.py <pruning_speedup.py>`
|
||||
|
||||
:download:`Download Python source code: pruning_speedup.py <pruning_speedup.py>`
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: pruning_speedup.ipynb <pruning_speedup.ipynb>`
|
||||
:download:`Download Jupyter notebook: pruning_speedup.ipynb <pruning_speedup.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
|
Двоичный файл не отображается.
|
@ -33,7 +33,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nimport torch.nn.functional as F\nfrom torch.optim import SGD\n\nfrom scripts.compression_mnist_model import TorchModel, trainer, evaluator, device, test_trt\n\n# define the model\nmodel = TorchModel().to(device)\n\n# define the optimizer and criterion for pre-training\n\noptimizer = SGD(model.parameters(), 1e-2)\ncriterion = F.nll_loss\n\n# pre-train and evaluate the model on MNIST dataset\nfor epoch in range(3):\n trainer(model, optimizer, criterion)\n evaluator(model)"
|
||||
"import torch\nimport torch.nn.functional as F\nfrom torch.optim import SGD\n\nfrom nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device, test_trt\n\n# define the model\nmodel = TorchModel().to(device)\n\n# define the optimizer and criterion for pre-training\n\noptimizer = SGD(model.parameters(), 1e-2)\ncriterion = F.nll_loss\n\n# pre-train and evaluate the model on MNIST dataset\nfor epoch in range(3):\n trainer(model, optimizer, criterion)\n evaluator(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
|
||||
from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device, test_trt
|
||||
from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device, test_trt
|
||||
|
||||
# define the model
|
||||
model = TorchModel().to(device)
|
||||
|
|
|
@ -1 +1 @@
|
|||
2995cef94c5c6c66a6dfa4b5ff28baea
|
||||
e502fe5ce56078aeca7926e830bb9cae
|
|
@ -48,7 +48,7 @@ If you are familiar with defining a model and training in pytorch, you can skip
|
|||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
|
||||
from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device, test_trt
|
||||
from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device, test_trt
|
||||
|
||||
# define the model
|
||||
model = TorchModel().to(device)
|
||||
|
@ -73,9 +73,9 @@ If you are familiar with defining a model and training in pytorch, you can skip
|
|||
|
||||
.. code-block:: none
|
||||
|
||||
Average test loss: 0.5901, Accuracy: 8293/10000 (83%)
|
||||
Average test loss: 0.2469, Accuracy: 9245/10000 (92%)
|
||||
Average test loss: 0.1586, Accuracy: 9531/10000 (95%)
|
||||
Average test loss: 0.8954, Accuracy: 6995/10000 (70%)
|
||||
Average test loss: 0.3259, Accuracy: 9046/10000 (90%)
|
||||
Average test loss: 0.2125, Accuracy: 9354/10000 (94%)
|
||||
|
||||
|
||||
|
||||
|
@ -195,9 +195,9 @@ QAT is a training-aware quantizer, it will update scale and zero point during tr
|
|||
|
||||
.. code-block:: none
|
||||
|
||||
Average test loss: 0.1333, Accuracy: 9587/10000 (96%)
|
||||
Average test loss: 0.1076, Accuracy: 9660/10000 (97%)
|
||||
Average test loss: 0.0957, Accuracy: 9702/10000 (97%)
|
||||
Average test loss: 0.1858, Accuracy: 9438/10000 (94%)
|
||||
Average test loss: 0.1420, Accuracy: 9564/10000 (96%)
|
||||
Average test loss: 0.1213, Accuracy: 9632/10000 (96%)
|
||||
|
||||
|
||||
|
||||
|
@ -226,9 +226,7 @@ export model and get calibration_config
|
|||
|
||||
.. code-block:: none
|
||||
|
||||
INFO:nni.compression.pytorch.compressor:Model state_dict saved to ./log/mnist_model.pth
|
||||
INFO:nni.compression.pytorch.compressor:Mask dict saved to ./log/mnist_calibration.pth
|
||||
calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0029], device='cuda:0'), 'weight_zero_point': tensor([96.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0017], device='cuda:0'), 'weight_zero_point': tensor([101.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 10.014460563659668}, 'fc1': {'weight_bits': 8, 'weight_scale': tensor([0.0012], device='cuda:0'), 'weight_zero_point': tensor([118.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 25.994585037231445}, 'fc2': {'weight_bits': 8, 'weight_scale': tensor([0.0012], device='cuda:0'), 'weight_zero_point': tensor([120.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 21.589195251464844}, 'relu1': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 10.066218376159668}, 'relu2': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 26.317869186401367}, 'relu3': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 21.97711944580078}, 'relu4': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 17.56885528564453}}
|
||||
calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0039], device='cuda:0'), 'weight_zero_point': tensor([82.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0019], device='cuda:0'), 'weight_zero_point': tensor([127.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 18.87591552734375}, 'fc1': {'weight_bits': 8, 'weight_scale': tensor([0.0010], device='cuda:0'), 'weight_zero_point': tensor([123.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 26.67470932006836}, 'fc2': {'weight_bits': 8, 'weight_scale': tensor([0.0012], device='cuda:0'), 'weight_zero_point': tensor([129.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 21.60409164428711}, 'relu1': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 18.998125076293945}, 'relu2': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 27.000442504882812}, 'relu3': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 22.2519588470459}, 'relu4': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 17.8553524017334}}
|
||||
|
||||
|
||||
|
||||
|
@ -257,8 +255,8 @@ build tensorRT engine to make a real speedup, for more information about speedup
|
|||
|
||||
.. code-block:: none
|
||||
|
||||
Loss: 0.09545102081298829 Accuracy: 96.98%
|
||||
Inference elapsed_time (whole dataset): 0.03549933433532715s
|
||||
Loss: 0.12193695755004882 Accuracy: 96.38%
|
||||
Inference elapsed_time (whole dataset): 0.036092281341552734s
|
||||
|
||||
|
||||
|
||||
|
@ -266,7 +264,7 @@ build tensorRT engine to make a real speedup, for more information about speedup
|
|||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 1 minutes 45.743 seconds)
|
||||
**Total running time of the script:** ( 1 minutes 39.686 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_tutorials_quantization_quick_start_mnist.py:
|
||||
|
|
Двоичный файл не отображается.
|
@ -26,7 +26,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\nimport torch.nn.functional as F\nfrom torch.optim import SGD\nfrom scripts.compression_mnist_model import TorchModel, device, trainer, evaluator, test_trt\n\nconfig_list = [{\n 'quant_types': ['input', 'weight'],\n 'quant_bits': {'input': 8, 'weight': 8},\n 'op_types': ['Conv2d']\n}, {\n 'quant_types': ['output'],\n 'quant_bits': {'output': 8},\n 'op_types': ['ReLU']\n}, {\n 'quant_types': ['input', 'weight'],\n 'quant_bits': {'input': 8, 'weight': 8},\n 'op_names': ['fc1', 'fc2']\n}]\n\nmodel = TorchModel().to(device)\noptimizer = SGD(model.parameters(), lr=0.01, momentum=0.5)\ncriterion = F.nll_loss\ndummy_input = torch.rand(32, 1, 28, 28).to(device)\n\nfrom nni.algorithms.compression.pytorch.quantization import QAT_Quantizer\nquantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input)\nquantizer.compress()"
|
||||
"import torch\nimport torch.nn.functional as F\nfrom torch.optim import SGD\nfrom nni_assets.compression.mnist_model import TorchModel, device, trainer, evaluator, test_trt\n\nconfig_list = [{\n 'quant_types': ['input', 'weight'],\n 'quant_bits': {'input': 8, 'weight': 8},\n 'op_types': ['Conv2d']\n}, {\n 'quant_types': ['output'],\n 'quant_bits': {'output': 8},\n 'op_types': ['ReLU']\n}, {\n 'quant_types': ['input', 'weight'],\n 'quant_bits': {'input': 8, 'weight': 8},\n 'op_names': ['fc1', 'fc2']\n}]\n\nmodel = TorchModel().to(device)\noptimizer = SGD(model.parameters(), lr=0.01, momentum=0.5)\ncriterion = F.nll_loss\ndummy_input = torch.rand(32, 1, 28, 28).to(device)\n\nfrom nni.algorithms.compression.pytorch.quantization import QAT_Quantizer\nquantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input)\nquantizer.compress()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -107,7 +107,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -64,7 +64,7 @@ Usage
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
from scripts.compression_mnist_model import TorchModel, device, trainer, evaluator, test_trt
|
||||
from nni_assets.compression.mnist_model import TorchModel, device, trainer, evaluator, test_trt
|
||||
|
||||
config_list = [{
|
||||
'quant_types': ['input', 'weight'],
|
||||
|
|
|
@ -1 +1 @@
|
|||
2404b8d0c3958a0191b77bbe882456e4
|
||||
06c37bd5c886478ae20a1fc552af729a
|
|
@ -84,7 +84,7 @@ Usage
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
from scripts.compression_mnist_model import TorchModel, device, trainer, evaluator, test_trt
|
||||
from nni_assets.compression.mnist_model import TorchModel, device, trainer, evaluator, test_trt
|
||||
|
||||
config_list = [{
|
||||
'quant_types': ['input', 'weight'],
|
||||
|
@ -174,9 +174,9 @@ finetuning the model by using QAT
|
|||
|
||||
.. code-block:: none
|
||||
|
||||
Average test loss: 0.5386, Accuracy: 8619/10000 (86%)
|
||||
Average test loss: 0.1553, Accuracy: 9521/10000 (95%)
|
||||
Average test loss: 0.1001, Accuracy: 9686/10000 (97%)
|
||||
Average test loss: 0.6058, Accuracy: 8534/10000 (85%)
|
||||
Average test loss: 0.1585, Accuracy: 9508/10000 (95%)
|
||||
Average test loss: 0.0920, Accuracy: 9717/10000 (97%)
|
||||
|
||||
|
||||
|
||||
|
@ -207,7 +207,7 @@ export model and get calibration_config
|
|||
|
||||
.. code-block:: none
|
||||
|
||||
calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0029], device='cuda:0'), 'weight_zero_point': tensor([98.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0017], device='cuda:0'), 'weight_zero_point': tensor([124.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 8.848002433776855}, 'fc1': {'weight_bits': 8, 'weight_scale': tensor([0.0010], device='cuda:0'), 'weight_zero_point': tensor([134.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 14.64758586883545}, 'fc2': {'weight_bits': 8, 'weight_scale': tensor([0.0013], device='cuda:0'), 'weight_zero_point': tensor([121.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 15.807988166809082}, 'relu1': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 9.041301727294922}, 'relu2': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 15.143928527832031}, 'relu3': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 16.151935577392578}, 'relu4': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 11.749024391174316}}
|
||||
calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0029], device='cuda:0'), 'weight_zero_point': tensor([97.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0017], device='cuda:0'), 'weight_zero_point': tensor([115.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 7.800363063812256}, 'fc1': {'weight_bits': 8, 'weight_scale': tensor([0.0010], device='cuda:0'), 'weight_zero_point': tensor([121.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 13.914573669433594}, 'fc2': {'weight_bits': 8, 'weight_scale': tensor([0.0012], device='cuda:0'), 'weight_zero_point': tensor([125.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 11.657418251037598}, 'relu1': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 7.897384166717529}, 'relu2': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 14.337020874023438}, 'relu3': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 11.884227752685547}, 'relu4': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 9.330422401428223}}
|
||||
|
||||
|
||||
|
||||
|
@ -237,8 +237,8 @@ build tensorRT engine to make a real speedup
|
|||
|
||||
.. code-block:: none
|
||||
|
||||
Loss: 0.10061546401977539 Accuracy: 96.83%
|
||||
Inference elapsed_time (whole dataset): 0.04322671890258789s
|
||||
Loss: 0.09235906448364258 Accuracy: 97.19%
|
||||
Inference elapsed_time (whole dataset): 0.03632998466491699s
|
||||
|
||||
|
||||
|
||||
|
@ -300,7 +300,7 @@ input tensor: ``torch.randn(128, 3, 32, 32)``
|
|||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 1 minutes 4.509 seconds)
|
||||
**Total running time of the script:** ( 1 minutes 13.658 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_tutorials_quantization_speedup.py:
|
||||
|
|
Двоичный файл не отображается.
|
@ -5,17 +5,19 @@
|
|||
|
||||
Computation times
|
||||
=================
|
||||
**00:20.822** total execution time for **tutorials** files:
|
||||
**01:39.686** total execution time for **tutorials** files:
|
||||
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:20.822 | 0.0 MB |
|
||||
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 01:39.686 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 01:51.710 | 0.0 MB |
|
||||
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_nasbench_as_dataset.py` (``nasbench_as_dataset.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_customize.py` (``pruning_customize.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py` (``pruning_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
|
||||
|
@ -24,7 +26,5 @@ Computation times
|
|||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_customize.py` (``quantization_customize.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
|
|
@ -217,7 +217,7 @@ def main(args):
|
|||
}]
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
|
||||
quantizer = QAT_Quantizer(model, config_list, optimizer)
|
||||
quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input)
|
||||
quantizer.compress()
|
||||
|
||||
# Step6. Quantization Aware Training
|
||||
|
|
|
@ -134,11 +134,11 @@ def main():
|
|||
'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5']
|
||||
}]
|
||||
|
||||
quantizer = BNNQuantizer(model, configure_list)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
quantizer = BNNQuantizer(model, configure_list, optimizer)
|
||||
model = quantizer.compress()
|
||||
|
||||
print('=' * 10 + 'train' + '=' * 10)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
best_top1 = 0
|
||||
for epoch in range(400):
|
||||
print('# Epoch {} #'.format(epoch))
|
||||
|
|
|
@ -29,7 +29,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
|
||||
from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device
|
||||
from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device
|
||||
|
||||
# define the model
|
||||
model = TorchModel().to(device)
|
||||
|
|
|
@ -43,7 +43,7 @@ Usage
|
|||
# But in fact ``ModelSpeedup`` is a relatively independent tool, so you can use it independently.
|
||||
|
||||
import torch
|
||||
from scripts.compression_mnist_model import TorchModel, device
|
||||
from nni_assets.compression.mnist_model import TorchModel, device
|
||||
|
||||
model = TorchModel().to(device)
|
||||
# masks = {layer_name: {'weight': weight_mask, 'bias': bias_mask}}
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
|
||||
from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device, test_trt
|
||||
from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device, test_trt
|
||||
|
||||
# define the model
|
||||
model = TorchModel().to(device)
|
||||
|
|
|
@ -64,7 +64,7 @@ Usage
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
from scripts.compression_mnist_model import TorchModel, device, trainer, evaluator, test_trt
|
||||
from nni_assets.compression.mnist_model import TorchModel, device, trainer, evaluator, test_trt
|
||||
|
||||
config_list = [{
|
||||
'quant_types': ['input', 'weight'],
|
||||
|
|
|
@ -152,8 +152,12 @@ class ChannelDependency(Dependency):
|
|||
parent_layers = []
|
||||
queue = []
|
||||
queue.append(node)
|
||||
visited_set = set()
|
||||
while queue:
|
||||
curnode = queue.pop(0)
|
||||
if curnode in visited_set:
|
||||
continue
|
||||
visited_set.add(curnode)
|
||||
if curnode.op_type in self.target_types:
|
||||
# find the first met conv
|
||||
parent_layers.append(curnode.name)
|
||||
|
@ -164,6 +168,8 @@ class ChannelDependency(Dependency):
|
|||
parents = self.graph.find_predecessors(curnode.unique_name)
|
||||
parents = [self.graph.name_to_node[name] for name in parents]
|
||||
for parent in parents:
|
||||
if parent in visited_set:
|
||||
continue
|
||||
queue.append(parent)
|
||||
|
||||
return parent_layers
|
||||
|
|
|
@ -56,7 +56,7 @@ def rand_like_with_shape(shape, ori_t):
|
|||
higher_bound = torch.max(ori_t)
|
||||
|
||||
if dtype in [torch.uint8, torch.int16, torch.short, torch.int16, torch.long, torch.bool]:
|
||||
return torch.randint(lower_bound, higher_bound+1, shape, dtype=dtype, device=device)
|
||||
return torch.randint(lower_bound.long(), higher_bound.long() + 1, shape, dtype=dtype, device=device)
|
||||
else:
|
||||
return torch.rand(shape, dtype=dtype, device=device, requires_grad=require_grad)
|
||||
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent
|
||||
|
||||
# define the model
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -38,13 +36,13 @@ device = torch.device("cuda" if use_cuda else "cpu")
|
|||
from torchvision import datasets, transforms
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(root_path / 'data', train=True, download=True, transform=transforms.Compose([
|
||||
datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])), batch_size=128, shuffle=True)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(root_path / 'data', train=False, transform=transforms.Compose([
|
||||
datasets.MNIST('data', train=False, transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])), batch_size=1000, shuffle=True)
|
Загрузка…
Ссылка в новой задаче