[Compression] Add support for torch 2.0 (#5492)

This commit is contained in:
Bonytu 2023-04-07 13:04:08 +08:00 коммит произвёл GitHub
Родитель 3f67d92b67
Коммит b2e2a4d840
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
27 изменённых файлов: 145 добавлений и 125 удалений

27
docs/source/tutorials/index.rst сгенерированный
Просмотреть файл

@ -17,7 +17,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
:alt: Speedup Model with Mask
:alt:
:ref:`sphx_glr_tutorials_pruning_speedup.py`
@ -34,7 +34,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
:alt: Quantization Quickstart
:alt:
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
@ -51,7 +51,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
:alt: Pruning Quickstart
:alt:
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
@ -68,7 +68,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
:alt: Customize a new quantization algorithm
:alt:
:ref:`sphx_glr_tutorials_quantization_customize.py`
@ -85,7 +85,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
:alt: Use NAS Benchmarks as Datasets
:alt:
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
@ -102,7 +102,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
:alt: Speed Up Quantized Model with TensorRT
:alt:
:ref:`sphx_glr_tutorials_quantization_speedup.py`
@ -119,7 +119,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
:alt: Hello, NAS!
:alt:
:ref:`sphx_glr_tutorials_hello_nas.py`
@ -136,7 +136,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_darts_thumb.png
:alt: Searching in DARTS search space
:alt:
:ref:`sphx_glr_tutorials_darts.py`
@ -153,7 +153,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
:alt: Pruning Bert on Task MNLI
:alt:
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
@ -196,7 +196,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with PyTorch
:alt:
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
@ -213,7 +213,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
:alt: Port PyTorch Quickstart to NNI
:alt:
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
@ -242,7 +242,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
:alt: HPO Quickstart with TensorFlow
:alt:
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
@ -259,7 +259,7 @@ Tutorials
.. only:: html
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
:alt: Port TensorFlow Quickstart to NNI
:alt:
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
@ -278,6 +278,7 @@ Tutorials
:hidden:
:includehidden:
/tutorials/hpo_quickstart_pytorch/index.rst
/tutorials/hpo_quickstart_tensorflow/index.rst

6
docs/source/tutorials/pruning_bert_glue.ipynb сгенерированный
Просмотреть файл

@ -80,7 +80,7 @@
},
"outputs": [],
"source": [
"import functools\nimport time\n\nimport torch.nn.functional as F\nfrom datasets import load_metric\nfrom transformers.modeling_outputs import SequenceClassifierOutput\n\n\ndef training(model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,\n max_steps: int = None,\n max_epochs: int = None,\n train_dataloader: DataLoader = None,\n distillation: bool = False,\n teacher_model: torch.nn.Module = None,\n distil_func: Callable = None,\n log_path: str = Path(log_dir) / 'training.log',\n save_best_model: bool = False,\n save_path: str = None,\n evaluation_func: Callable = None,\n eval_per_steps: int = 1000,\n device=None):\n\n assert train_dataloader is not None\n\n model.train()\n if teacher_model is not None:\n teacher_model.eval()\n current_step = 0\n best_result = 0\n\n total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3\n total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)\n\n print(f'Training {total_epochs} epochs, {total_steps} steps...')\n\n for current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if current_step >= total_steps:\n return\n batch.to(device)\n outputs = model(**batch)\n loss = outputs.loss\n\n if distillation:\n assert teacher_model is not None\n with torch.no_grad():\n teacher_outputs = teacher_model(**batch)\n distil_loss = distil_func(outputs, teacher_outputs)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n loss = criterion(loss, None)\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n\n # per step schedule\n if lr_scheduler:\n lr_scheduler.step()\n\n current_step += 1\n\n if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(model) if evaluation_func else None\n with (log_path).open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)\n f.write(msg)\n # if it's the best model, save it.\n if save_best_model and (result is None or best_result < result['default']):\n assert save_path is not None\n torch.save(model.state_dict(), save_path)\n best_result = None if result is None else result['default']\n\n\ndef distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):\n encoder_hidden_state_loss = []\n for i, idx in enumerate(encoder_layer_idxs[:-1]):\n encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))\n logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n\n distil_loss = 0\n for loss in encoder_hidden_state_loss:\n distil_loss += loss\n distil_loss += logits_loss\n return distil_loss\n\n\ndef evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):\n assert validation_dataloaders is not None\n training = model.training\n model.eval()\n\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n result = {}\n default_result = 0\n for val_name, validation_dataloader in validation_dataloaders.items():\n for batch in validation_dataloader:\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result[val_name] = metric.compute()\n default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))\n result['default'] = default_result / len(result)\n\n model.train(training)\n return result\n\n\nevaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)\n\n\ndef fake_criterion(loss, _):\n return loss"
"import functools\nimport time\n\nimport torch.nn.functional as F\nfrom datasets import load_metric\nfrom transformers.modeling_outputs import SequenceClassifierOutput\n\nfrom nni.common.types import SCHEDULER\n\n\ndef training(model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n lr_scheduler: SCHEDULER = None,\n max_steps: int = None,\n max_epochs: int = None,\n train_dataloader: DataLoader = None,\n distillation: bool = False,\n teacher_model: torch.nn.Module = None,\n distil_func: Callable = None,\n log_path: str = Path(log_dir) / 'training.log',\n save_best_model: bool = False,\n save_path: str = None,\n evaluation_func: Callable = None,\n eval_per_steps: int = 1000,\n device=None):\n\n assert train_dataloader is not None\n\n model.train()\n if teacher_model is not None:\n teacher_model.eval()\n current_step = 0\n best_result = 0\n\n total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3\n total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)\n\n print(f'Training {total_epochs} epochs, {total_steps} steps...')\n\n for current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if current_step >= total_steps:\n return\n batch.to(device)\n outputs = model(**batch)\n loss = outputs.loss\n\n if distillation:\n assert teacher_model is not None\n with torch.no_grad():\n teacher_outputs = teacher_model(**batch)\n distil_loss = distil_func(outputs, teacher_outputs)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n loss = criterion(loss, None)\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n\n # per step schedule\n if lr_scheduler:\n lr_scheduler.step()\n\n current_step += 1\n\n if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(model) if evaluation_func else None\n with (log_path).open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)\n f.write(msg)\n # if it's the best model, save it.\n if save_best_model and (result is None or best_result < result['default']):\n assert save_path is not None\n torch.save(model.state_dict(), save_path)\n best_result = None if result is None else result['default']\n\n\ndef distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):\n encoder_hidden_state_loss = []\n for i, idx in enumerate(encoder_layer_idxs[:-1]):\n encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))\n logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n\n distil_loss = 0\n for loss in encoder_hidden_state_loss:\n distil_loss += loss\n distil_loss += logits_loss\n return distil_loss\n\n\ndef evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):\n assert validation_dataloaders is not None\n training = model.training\n model.eval()\n\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n result = {}\n default_result = 0\n for val_name, validation_dataloader in validation_dataloaders.items():\n for batch in validation_dataloader:\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result[val_name] = metric.compute()\n default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))\n result['default'] = default_result / len(result)\n\n model.train(training)\n return result\n\n\nevaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)\n\n\ndef fake_criterion(loss, _):\n return loss"
]
},
{
@ -134,7 +134,7 @@
},
"outputs": [],
"source": [
"attention_pruned_model = create_finetuned_model().to(device)\nattention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')\n\nffn_config_list = []\nlayer_remained_idxs = []\nmodule_list = []\nfor i in range(0, layers_num):\n prefix = f'bert.encoder.layer.{i}.'\n value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']\n head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)\n head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()\n print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')\n if len(head_idxs) != heads_num:\n attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)\n module_list.append(attention_pruned_model.bert.encoder.layer[i])\n # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.\n # This is just an empirical configuration, you can use any other method to determine this sparsity.\n sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5\n # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.\n sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)\n ffn_config_list.append({\n 'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],\n 'sparsity': sparsity_per_iter\n })\n layer_remained_idxs.append(i)\n\nattention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)\ndistil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)"
"attention_pruned_model = create_finetuned_model().to(device)\nattention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')\n\nffn_config_list = []\nlayer_remained_idxs = []\nmodule_list = []\nfor i in range(0, layers_num):\n prefix = f'bert.encoder.layer.{i}.'\n value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']\n head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.).to(\"cpu\")\n head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()\n print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')\n if len(head_idxs) != heads_num:\n attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)\n module_list.append(attention_pruned_model.bert.encoder.layer[i])\n # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.\n # This is just an empirical configuration, you can use any other method to determine this sparsity.\n sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5\n # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.\n sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)\n ffn_config_list.append({\n 'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],\n 'sparsity': sparsity_per_iter\n })\n layer_remained_idxs.append(i)\n\nattention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)\ndistil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)"
]
},
{
@ -197,7 +197,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.9.16"
}
},
"nbformat": 4,

6
docs/source/tutorials/pruning_bert_glue.py сгенерированный
Просмотреть файл

@ -159,11 +159,13 @@ import torch.nn.functional as F
from datasets import load_metric
from transformers.modeling_outputs import SequenceClassifierOutput
from nni.common.types import SCHEDULER
def training(model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
lr_scheduler: SCHEDULER = None,
max_steps: int = None,
max_epochs: int = None,
train_dataloader: DataLoader = None,
@ -399,7 +401,7 @@ module_list = []
for i in range(0, layers_num):
prefix = f'bert.encoder.layer.{i}.'
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.).to("cpu")
head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
if len(head_idxs) != heads_num:

2
docs/source/tutorials/pruning_bert_glue.py.md5 сгенерированный
Просмотреть файл

@ -1 +1 @@
099d745e7809d57227bb42086ecd581c
822d1933bb3b99080589c0cdf89cf89e

71
docs/source/tutorials/pruning_bert_glue.rst сгенерированный

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Двоичные данные
docs/source/tutorials/pruning_bert_glue_codeobj.pickle сгенерированный

Двоичный файл не отображается.

9
docs/source/tutorials/sg_execution_times.rst сгенерированный
Просмотреть файл

@ -3,12 +3,13 @@
.. _sphx_glr_tutorials_sg_execution_times:
Computation times
=================
**06:12.568** total execution time for **tutorials** files:
**00:29.846** total execution time for **tutorials** files:
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 06:12.568 | 0.0 MB |
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:29.846 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
@ -16,8 +17,6 @@ Computation times
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_nasbench_as_dataset.py` (``nasbench_as_dataset.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py` (``pruning_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_speedup.py` (``pruning_speedup.py``) | 00:00.000 | 0.0 MB |
@ -26,3 +25,5 @@ Computation times
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+

Просмотреть файл

@ -25,10 +25,11 @@ from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.utils import count_flops_params
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
from nni.compression.pytorch.utils import TorchEvaluator
from nni.common.types import SCHEDULER
############# Create dataloaders, optimizer, training and evaluation function ############
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
@ -74,7 +75,7 @@ def training(
model: nn.Module,
optimizer: torch.optim.Optimizer,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
lr_scheduler: SCHEDULER = None,
max_steps: int = None, max_epochs: int = None,
local_rank: int = -1,
save_best_model: bool = False, save_path: str = None,

Просмотреть файл

@ -2,12 +2,13 @@ from __future__ import annotations
from typing import Callable, Any
import torch
from torch.optim.lr_scheduler import StepLR, _LRScheduler
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import nni
from nni.compression.pytorch import TorchEvaluator
from nni.common.types import SCHEDULER
import sys
from pathlib import Path
@ -21,7 +22,7 @@ model: torch.nn.Module = VGG().to(device)
def training_func(model: torch.nn.Module, optimizers: torch.optim.Optimizer,
criterion: Callable[[Any, Any], torch.Tensor],
lr_schedulers: _LRScheduler | None = None, max_steps: int | None = None,
lr_schedulers: SCHEDULER | None = None, max_steps: int | None = None,
max_epochs: int | None = None, *args, **kwargs):
model.train()
# prepare data

Просмотреть файл

@ -8,12 +8,12 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import _LRScheduler
from torchvision import datasets, transforms
import nni
from nni.contrib.compression.quantization import BNNQuantizer
from nni.contrib.compression.utils import TorchEvaluator
from nni.common.types import SCHEDULER
torch.manual_seed(0)
@ -108,7 +108,7 @@ def test(model: nn.Module):
return acc
def train(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[_LRScheduler, None] = None,
def train(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = 400):
best_top1 = 0
max_epochs = max_epochs or (40 if max_steps is None else 400)
@ -120,7 +120,7 @@ def train(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable,
loss = training_step(batch, model)
loss.backward()
optimizer.step()
if isinstance(scheduler, _LRScheduler):
if isinstance(scheduler, SCHEDULER):
scheduler.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))

Просмотреть файл

@ -7,7 +7,6 @@ from typing import Callable, Union, List, Dict, Tuple
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch import Tensor
@ -16,6 +15,8 @@ from torchvision import datasets, transforms
import nni
from nni.contrib.compression.quantization import DoReFaQuantizer
from nni.contrib.compression.utils import TorchEvaluator
from nni.common.types import SCHEDULER
torch.manual_seed(0)
device = 'cuda'
@ -61,7 +62,7 @@ def training_step(batch, model) -> Tensor:
def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Optimizer]], \
training_step: _TRAINING_STEP, scheduler: Union[None, _LRScheduler, List[_LRScheduler]] = None,
training_step: _TRAINING_STEP, scheduler: Union[None, SCHEDULER, List[SCHEDULER]] = None,
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
model.train()
max_epochs = max_epochs or (10 if max_steps is None else 100)
@ -85,9 +86,9 @@ def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Opti
elif isinstance(optimizer, List) and all(isinstance(_, Optimizer) for _ in optimizer):
for opt in optimizer:
opt.step()
if isinstance(scheduler, _LRScheduler):
if isinstance(scheduler, SCHEDULER):
scheduler.step()
if isinstance(scheduler, List) and all(isinstance(_, _LRScheduler) for _ in scheduler):
if isinstance(scheduler, List) and all(isinstance(_, SCHEDULER) for _ in scheduler):
for sch in scheduler:
sch.step()
current_steps += 1

Просмотреть файл

@ -7,7 +7,6 @@ from typing import Callable, Union, List, Dict, Tuple
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch import Tensor
@ -16,6 +15,7 @@ from torchvision import datasets, transforms
import nni
from nni.contrib.compression.quantization import LsqQuantizer
from nni.contrib.compression.utils import TorchEvaluator
from nni.common.types import SCHEDULER
torch.manual_seed(0)
device = 'cuda'
@ -62,7 +62,7 @@ def training_step(batch, model) -> Tensor:
def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Optimizer]], \
training_step: _TRAINING_STEP, scheduler: Union[None, _LRScheduler, List[_LRScheduler]] = None,
training_step: _TRAINING_STEP, scheduler: Union[None, SCHEDULER, List[SCHEDULER]] = None,
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
model.train()
max_epochs = max_epochs or (40 if max_steps is None else 100)
@ -86,9 +86,9 @@ def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Opti
elif isinstance(optimizer, List) and all(isinstance(_, Optimizer) for _ in optimizer):
for opt in optimizer:
opt.step()
if isinstance(scheduler, _LRScheduler):
if isinstance(scheduler, SCHEDULER):
scheduler.step()
if isinstance(scheduler, List) and all(isinstance(_, _LRScheduler) for _ in scheduler):
if isinstance(scheduler, List) and all(isinstance(_, SCHEDULER) for _ in scheduler):
for sch in scheduler:
sch.step()
current_steps += 1

Просмотреть файл

@ -7,7 +7,6 @@ from typing import Callable, Union, List, Tuple, Any, Dict
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch import Tensor
from torchvision import datasets, transforms
@ -15,6 +14,7 @@ from torchvision import datasets, transforms
import nni
from nni.contrib.compression.quantization import PtqQuantizer
from nni.contrib.compression.utils import TorchEvaluator
from nni.common.types import SCHEDULER
torch.manual_seed(0)
@ -63,7 +63,7 @@ def training_step(batch, model) -> Tensor:
def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Optimizer]], \
training_step: _TRAINING_STEP, scheduler: Union[None, _LRScheduler, List[_LRScheduler]] = None,
training_step: _TRAINING_STEP, scheduler: Union[None, SCHEDULER, List[SCHEDULER]] = None,
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
model.train()
max_epochs = max_epochs if max_epochs else 10 if max_steps is None else 100
@ -87,9 +87,9 @@ def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Opti
elif isinstance(optimizer, List) and all(isinstance(_, Optimizer) for _ in optimizer):
for opt in optimizer:
opt.step()
if isinstance(scheduler, _LRScheduler):
if isinstance(scheduler, SCHEDULER):
scheduler.step()
if isinstance(scheduler, List) and all(isinstance(_, _LRScheduler) for _ in scheduler):
if isinstance(scheduler, List) and all(isinstance(_, SCHEDULER) for _ in scheduler):
for sch in scheduler:
sch.step()
current_steps += 1

Просмотреть файл

@ -9,7 +9,6 @@ from typing import Callable
import torch
import torch.nn.functional as F
from torch.optim import Optimizer, SGD
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
@ -17,6 +16,8 @@ from torchvision.datasets import MNIST
import nni
from nni.contrib.compression.quantization import QATQuantizer
from nni.contrib.compression.utils import TorchEvaluator
from nni.common.types import SCHEDULER
torch.manual_seed(1024)
device = 'cuda'
@ -62,7 +63,7 @@ def training_step(batch, model):
return loss
def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: _LRScheduler | None = None,
def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: SCHEDULER | None = None,
max_steps: int | None = None, max_epochs: int | None = None):
model.train()
max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100

Просмотреть файл

@ -159,11 +159,13 @@ import torch.nn.functional as F
from datasets import load_metric
from transformers.modeling_outputs import SequenceClassifierOutput
from nni.common.types import SCHEDULER
def training(model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
lr_scheduler: SCHEDULER = None,
max_steps: int = None,
max_epochs: int = None,
train_dataloader: DataLoader = None,
@ -399,7 +401,7 @@ module_list = []
for i in range(0, layers_num):
prefix = f'bert.encoder.layer.{i}.'
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.).to("cpu")
head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
if len(head_idxs) != heads_num:

11
nni/common/types.py Normal file
Просмотреть файл

@ -0,0 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.version import torch_version_is_2
if torch_version_is_2():
from torch.optim.lr_scheduler import LRScheduler # type: ignore
SCHEDULER = LRScheduler
else:
from torch.optim.lr_scheduler import _LRScheduler
SCHEDULER = _LRScheduler

Просмотреть файл

@ -9,9 +9,9 @@ from typing import Callable, Dict, List, Type
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from nni.common.serializer import is_traceable
from nni.common.types import SCHEDULER
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper']
@ -91,7 +91,7 @@ class OptimizerConstructHelper(ConstructHelper):
class LRSchedulerConstructHelper(ConstructHelper):
def __init__(self, lr_scheduler_class: Type[_LRScheduler], *args, **kwargs):
def __init__(self, lr_scheduler_class: Type[SCHEDULER], *args, **kwargs):
args = list(args)
if 'optimizer' in kwargs:
kwargs['optimizer'] = None
@ -99,7 +99,7 @@ class LRSchedulerConstructHelper(ConstructHelper):
args[0] = None
super().__init__(lr_scheduler_class, *args, **kwargs)
def call(self, optimizer: Optimizer) -> _LRScheduler:
def call(self, optimizer: Optimizer) -> SCHEDULER:
args = deepcopy(self.args)
kwargs = deepcopy(self.kwargs)
@ -111,10 +111,10 @@ class LRSchedulerConstructHelper(ConstructHelper):
return self.callable_obj(*args, **kwargs)
@staticmethod
def from_trace(lr_scheduler_trace: _LRScheduler):
def from_trace(lr_scheduler_trace: SCHEDULER):
assert is_traceable(lr_scheduler_trace), \
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
assert isinstance(lr_scheduler_trace, _LRScheduler), \
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
assert isinstance(lr_scheduler_trace, SCHEDULER), \
f'It is not an instance of torch.nn.lr_scheduler.{SCHEDULER}.'
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, # type: ignore
**lr_scheduler_trace.trace_kwargs) # type: ignore

Просмотреть файл

@ -11,7 +11,6 @@ from typing import Dict, List, Tuple, Union, Any, Callable, Optional
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.hooks import RemovableHandle
try:
@ -31,9 +30,11 @@ else:
import nni
from nni.common import is_traceable
from nni.common.types import SCHEDULER
from .constructor_helper import OptimizerConstructHelper, LRSchedulerConstructHelper
from .check_ddp import check_ddp_model, reset_ddp_model
_logger = logging.getLogger(__name__)
@ -553,7 +554,7 @@ class LightningEvaluator(Evaluator):
_OPTIMIZERS = Union[Optimizer, List[Optimizer]]
_CRITERION = Callable[[Any, Any], Any]
_SCHEDULERS = Union[None, _LRScheduler, List[_LRScheduler]]
_SCHEDULERS = Union[None, SCHEDULER, List[SCHEDULER]]
_EVALUATING_FUNC = Callable[[Module], Union[float, Dict]]
_TRAINING_FUNC = Callable[[Module, _OPTIMIZERS, _CRITERION, _SCHEDULERS, Optional[int], Optional[int]], None]
@ -656,7 +657,7 @@ class TorchEvaluator(Evaluator):
"""
def __init__(self, training_func: _TRAINING_FUNC, optimizers: Optimizer | List[Optimizer], criterion: _CRITERION,
lr_schedulers: _LRScheduler | List[_LRScheduler] | None = None, dummy_input: Any | None = None,
lr_schedulers: SCHEDULER | List[SCHEDULER] | None = None, dummy_input: Any | None = None,
evaluating_func: _EVALUATING_FUNC | None = None):
self.training_func = training_func
self._ori_criterion = criterion
@ -665,11 +666,11 @@ class TorchEvaluator(Evaluator):
self.evaluating_func = evaluating_func
self._train_with_single_optimizer = isinstance(optimizers, Optimizer)
self._train_with_single_scheduler = isinstance(lr_schedulers, _LRScheduler)
self._train_with_single_scheduler = isinstance(lr_schedulers, SCHEDULER)
self.model: Module | None = None
self._optimizers: List[Optimizer] | None = None
self._lr_schedulers: List[_LRScheduler] | None = None
self._lr_schedulers: List[SCHEDULER] | None = None
self._first_optimizer_step: Callable | None = None
self._param_names_map: Dict[str, str] | None = None
@ -677,7 +678,7 @@ class TorchEvaluator(Evaluator):
self._tmp_optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
assert all(isinstance(optimizer, Optimizer) and is_traceable(optimizer) for optimizer in self._tmp_optimizers)
self._tmp_lr_schedulers = lr_schedulers if isinstance(lr_schedulers, (list, tuple)) else [lr_schedulers] if lr_schedulers else []
assert all(isinstance(lr_scheduler, _LRScheduler) and is_traceable(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers)
assert all(isinstance(lr_scheduler, SCHEDULER) and is_traceable(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers)
self._initialization_complete = False
def _init_optimizer_helpers(self, pure_model: Module):

Просмотреть файл

@ -9,9 +9,10 @@ from typing import Callable, Dict, List, Type
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from nni.common.serializer import is_traceable
from nni.common.types import SCHEDULER
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper']
@ -91,7 +92,7 @@ class OptimizerConstructHelper(ConstructHelper):
class LRSchedulerConstructHelper(ConstructHelper):
def __init__(self, lr_scheduler_class: Type[_LRScheduler], *args, **kwargs):
def __init__(self, lr_scheduler_class: Type[SCHEDULER], *args, **kwargs): # type: ignore
args = list(args)
if 'optimizer' in kwargs:
kwargs['optimizer'] = None
@ -99,7 +100,7 @@ class LRSchedulerConstructHelper(ConstructHelper):
args[0] = None
super().__init__(lr_scheduler_class, *args, **kwargs)
def call(self, optimizer: Optimizer) -> _LRScheduler:
def call(self, optimizer: Optimizer) -> SCHEDULER: # type: ignore
args = deepcopy(self.args)
kwargs = deepcopy(self.kwargs)
@ -111,10 +112,10 @@ class LRSchedulerConstructHelper(ConstructHelper):
return self.callable_obj(*args, **kwargs)
@staticmethod
def from_trace(lr_scheduler_trace: _LRScheduler):
def from_trace(lr_scheduler_trace: SCHEDULER): # type: ignore
assert is_traceable(lr_scheduler_trace), \
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
assert isinstance(lr_scheduler_trace, _LRScheduler), \
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, # type: ignore
**lr_scheduler_trace.trace_kwargs) # type: ignore
assert isinstance(lr_scheduler_trace, SCHEDULER), \
f'It is not an instance of torch.nn.lr_scheduler.{SCHEDULER}.'
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, # type: ignore
**lr_scheduler_trace.trace_kwargs) # type: ignore

Просмотреть файл

@ -13,7 +13,6 @@ import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.hooks import RemovableHandle
try:
@ -32,9 +31,11 @@ else:
import nni
from nni.common import is_traceable
from nni.common.types import SCHEDULER
from .constructor_helper import OptimizerConstructHelper, LRSchedulerConstructHelper
from .check_ddp import check_ddp_model, reset_ddp_model
_logger = logging.getLogger(__name__)
@ -218,7 +219,7 @@ class Evaluator:
raise TypeError("optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param))
if not optimizer.defaults.get('differentiable', None) \
and not (param.is_leaf or param.retains_grad): # type: ignore
and not (param.is_leaf or param.retains_grad): # type: ignore
raise ValueError("can't optimize a non-leaf Tensor")
target_param_group['params'].append(param)
@ -486,27 +487,27 @@ class LightningEvaluator(Evaluator):
if self._opt_returned_dicts:
def new_configure_optimizers(_): # type: ignore
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
optimizers = [opt_lrs_dict['optimizer'] for opt_lrs_dict in optimizers_lr_schedulers]
# add param group
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers) # type: ignore
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers) # type: ignore
return optimizers_lr_schedulers
elif self._lr_scheduler_helpers:
def new_configure_optimizers(_): # type: ignore
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
optimizers, lr_schedulers = optimizers_lr_schedulers
# add param_group
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers) # type: ignore
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers) # type: ignore
return optimizers, lr_schedulers
else:
def new_configure_optimizers(_):
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
# add param_group
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers_lr_schedulers) # type: ignore
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers_lr_schedulers) # type: ignore
return optimizers_lr_schedulers
@ -656,7 +657,7 @@ class LightningEvaluator(Evaluator):
_OPTIMIZERS = Union[Optimizer, List[Optimizer]]
_TRAINING_STEP = Callable[..., Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]]
_SCHEDULERS = Union[None, _LRScheduler, List[_LRScheduler]]
_SCHEDULERS = Union[None, SCHEDULER, List[SCHEDULER]]
_EVALUATING_FUNC = Callable[[Module], Union[float, Dict]]
_TRAINING_FUNC = Callable[[Module, _OPTIMIZERS, _TRAINING_STEP, Optional[_SCHEDULERS], Optional[int], Optional[int]], None]
@ -765,7 +766,7 @@ class TorchEvaluator(Evaluator):
"""
def __init__(self, training_func: _TRAINING_FUNC, optimizers: Optimizer | List[Optimizer], training_step: _TRAINING_STEP,
lr_schedulers: _LRScheduler | List[_LRScheduler] | None = None, dummy_input: Any | None = None,
lr_schedulers: SCHEDULER | List[SCHEDULER] | None = None, dummy_input: Any | None = None, # type: ignore
evaluating_func: _EVALUATING_FUNC | None = None):
self.training_func = training_func
self._ori_training_step = training_step
@ -774,11 +775,11 @@ class TorchEvaluator(Evaluator):
self.evaluating_func = evaluating_func
self._train_with_single_optimizer = isinstance(optimizers, Optimizer)
self._train_with_single_scheduler = isinstance(lr_schedulers, _LRScheduler)
self._train_with_single_scheduler = isinstance(lr_schedulers, SCHEDULER)
self.model: Module | None = None
self._optimizers: List[Optimizer] | None = None
self._lr_schedulers: List[_LRScheduler] | None = None
self._lr_schedulers: List[SCHEDULER] | None = None # type: ignore
self._first_optimizer_step: Callable | None = None
self._param_names_map: Dict[str, str] | None = None
@ -786,7 +787,7 @@ class TorchEvaluator(Evaluator):
self._tmp_optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
assert all(isinstance(optimizer, Optimizer) and is_traceable(optimizer) for optimizer in self._tmp_optimizers)
self._tmp_lr_schedulers = lr_schedulers if isinstance(lr_schedulers, (list, tuple)) else [lr_schedulers] if lr_schedulers else []
assert all(isinstance(lr_scheduler, _LRScheduler) and is_traceable(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers)
assert all(isinstance(lr_scheduler, SCHEDULER) and is_traceable(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers)
self._initialization_complete = False
def _init_optimizer_helpers(self, pure_model: Module):
@ -829,7 +830,7 @@ class TorchEvaluator(Evaluator):
def patch_optim_param_group(self, module_name_param_dict: Dict[str, List[Tensor]]):
assert isinstance(self.model, Module)
assert module_name_param_dict is not None
self._optimizer_add_param_group(self.model, module_name_param_dict, self._optimizers) # type: ignore
self._optimizer_add_param_group(self.model, module_name_param_dict, self._optimizers) # type: ignore
def unbind_model(self):
if self.model:

Просмотреть файл

@ -10,13 +10,14 @@ import torch
from torch.nn import Module
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from ..device import device
from nni.common.types import SCHEDULER
class SimpleTorchModel(torch.nn.Module):
def __init__(self):
@ -37,7 +38,7 @@ class SimpleTorchModel(torch.nn.Module):
return F.log_softmax(x, -1)
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: _LRScheduler = None,
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: SCHEDULER = None,
max_steps: int | None = None, max_epochs: int | None = None, device: torch.device = device):
model.train()

Просмотреть файл

@ -9,13 +9,14 @@ import torch
from torch.nn import Module
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from ..device import device
from nni.common.types import SCHEDULER
class SimpleTorchModel(torch.nn.Module):
def __init__(self):
@ -43,7 +44,7 @@ def training_step(batch: Tuple, model: Module, device: torch.device = device):
return loss
def training_model(model: Module, optimizer: Optimizer, training_step: Callable, scheduler: _LRScheduler = None,
def training_model(model: Module, optimizer: Optimizer, training_step: Callable, scheduler: SCHEDULER = None,
max_steps: int | None = None, max_epochs: int | None = None, device: torch.device = device):
model.train()

Просмотреть файл

@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch import Tensor

Просмотреть файл

@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch import Tensor

Просмотреть файл

@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch import Tensor

Просмотреть файл

@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch import Tensor

Просмотреть файл

@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch import Tensor