зеркало из https://github.com/microsoft/nni.git
[Compression] Add support for torch 2.0 (#5492)
This commit is contained in:
Родитель
3f67d92b67
Коммит
b2e2a4d840
|
@ -17,7 +17,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
|
||||
:alt: Speedup Model with Mask
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_speedup.py`
|
||||
|
||||
|
@ -34,7 +34,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_mnist_thumb.png
|
||||
:alt: Quantization Quickstart
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py`
|
||||
|
||||
|
@ -51,7 +51,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_mnist_thumb.png
|
||||
:alt: Pruning Quickstart
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py`
|
||||
|
||||
|
@ -68,7 +68,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_customize_thumb.png
|
||||
:alt: Customize a new quantization algorithm
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_customize.py`
|
||||
|
||||
|
@ -85,7 +85,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
|
||||
:alt: Use NAS Benchmarks as Datasets
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
|
||||
|
||||
|
@ -102,7 +102,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
|
||||
:alt: Speed Up Quantized Model with TensorRT
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_speedup.py`
|
||||
|
||||
|
@ -119,7 +119,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
|
||||
:alt: Hello, NAS!
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hello_nas.py`
|
||||
|
||||
|
@ -136,7 +136,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_darts_thumb.png
|
||||
:alt: Searching in DARTS search space
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_darts.py`
|
||||
|
||||
|
@ -153,7 +153,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
|
||||
:alt: Pruning Bert on Task MNLI
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
|
||||
|
||||
|
@ -196,7 +196,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with PyTorch
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
|
||||
|
||||
|
@ -213,7 +213,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port PyTorch Quickstart to NNI
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
|
||||
|
||||
|
@ -242,7 +242,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with TensorFlow
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
|
||||
|
||||
|
@ -259,7 +259,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port TensorFlow Quickstart to NNI
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
|
||||
|
||||
|
@ -278,6 +278,7 @@ Tutorials
|
|||
:hidden:
|
||||
:includehidden:
|
||||
|
||||
|
||||
/tutorials/hpo_quickstart_pytorch/index.rst
|
||||
/tutorials/hpo_quickstart_tensorflow/index.rst
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import functools\nimport time\n\nimport torch.nn.functional as F\nfrom datasets import load_metric\nfrom transformers.modeling_outputs import SequenceClassifierOutput\n\n\ndef training(model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,\n max_steps: int = None,\n max_epochs: int = None,\n train_dataloader: DataLoader = None,\n distillation: bool = False,\n teacher_model: torch.nn.Module = None,\n distil_func: Callable = None,\n log_path: str = Path(log_dir) / 'training.log',\n save_best_model: bool = False,\n save_path: str = None,\n evaluation_func: Callable = None,\n eval_per_steps: int = 1000,\n device=None):\n\n assert train_dataloader is not None\n\n model.train()\n if teacher_model is not None:\n teacher_model.eval()\n current_step = 0\n best_result = 0\n\n total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3\n total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)\n\n print(f'Training {total_epochs} epochs, {total_steps} steps...')\n\n for current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if current_step >= total_steps:\n return\n batch.to(device)\n outputs = model(**batch)\n loss = outputs.loss\n\n if distillation:\n assert teacher_model is not None\n with torch.no_grad():\n teacher_outputs = teacher_model(**batch)\n distil_loss = distil_func(outputs, teacher_outputs)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n loss = criterion(loss, None)\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n\n # per step schedule\n if lr_scheduler:\n lr_scheduler.step()\n\n current_step += 1\n\n if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(model) if evaluation_func else None\n with (log_path).open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)\n f.write(msg)\n # if it's the best model, save it.\n if save_best_model and (result is None or best_result < result['default']):\n assert save_path is not None\n torch.save(model.state_dict(), save_path)\n best_result = None if result is None else result['default']\n\n\ndef distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):\n encoder_hidden_state_loss = []\n for i, idx in enumerate(encoder_layer_idxs[:-1]):\n encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))\n logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n\n distil_loss = 0\n for loss in encoder_hidden_state_loss:\n distil_loss += loss\n distil_loss += logits_loss\n return distil_loss\n\n\ndef evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):\n assert validation_dataloaders is not None\n training = model.training\n model.eval()\n\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n result = {}\n default_result = 0\n for val_name, validation_dataloader in validation_dataloaders.items():\n for batch in validation_dataloader:\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result[val_name] = metric.compute()\n default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))\n result['default'] = default_result / len(result)\n\n model.train(training)\n return result\n\n\nevaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)\n\n\ndef fake_criterion(loss, _):\n return loss"
|
||||
"import functools\nimport time\n\nimport torch.nn.functional as F\nfrom datasets import load_metric\nfrom transformers.modeling_outputs import SequenceClassifierOutput\n\nfrom nni.common.types import SCHEDULER\n\n\ndef training(model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n lr_scheduler: SCHEDULER = None,\n max_steps: int = None,\n max_epochs: int = None,\n train_dataloader: DataLoader = None,\n distillation: bool = False,\n teacher_model: torch.nn.Module = None,\n distil_func: Callable = None,\n log_path: str = Path(log_dir) / 'training.log',\n save_best_model: bool = False,\n save_path: str = None,\n evaluation_func: Callable = None,\n eval_per_steps: int = 1000,\n device=None):\n\n assert train_dataloader is not None\n\n model.train()\n if teacher_model is not None:\n teacher_model.eval()\n current_step = 0\n best_result = 0\n\n total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3\n total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)\n\n print(f'Training {total_epochs} epochs, {total_steps} steps...')\n\n for current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if current_step >= total_steps:\n return\n batch.to(device)\n outputs = model(**batch)\n loss = outputs.loss\n\n if distillation:\n assert teacher_model is not None\n with torch.no_grad():\n teacher_outputs = teacher_model(**batch)\n distil_loss = distil_func(outputs, teacher_outputs)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n loss = criterion(loss, None)\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n\n # per step schedule\n if lr_scheduler:\n lr_scheduler.step()\n\n current_step += 1\n\n if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(model) if evaluation_func else None\n with (log_path).open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)\n f.write(msg)\n # if it's the best model, save it.\n if save_best_model and (result is None or best_result < result['default']):\n assert save_path is not None\n torch.save(model.state_dict(), save_path)\n best_result = None if result is None else result['default']\n\n\ndef distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):\n encoder_hidden_state_loss = []\n for i, idx in enumerate(encoder_layer_idxs[:-1]):\n encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))\n logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n\n distil_loss = 0\n for loss in encoder_hidden_state_loss:\n distil_loss += loss\n distil_loss += logits_loss\n return distil_loss\n\n\ndef evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):\n assert validation_dataloaders is not None\n training = model.training\n model.eval()\n\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n result = {}\n default_result = 0\n for val_name, validation_dataloader in validation_dataloaders.items():\n for batch in validation_dataloader:\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result[val_name] = metric.compute()\n default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))\n result['default'] = default_result / len(result)\n\n model.train(training)\n return result\n\n\nevaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)\n\n\ndef fake_criterion(loss, _):\n return loss"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -134,7 +134,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"attention_pruned_model = create_finetuned_model().to(device)\nattention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')\n\nffn_config_list = []\nlayer_remained_idxs = []\nmodule_list = []\nfor i in range(0, layers_num):\n prefix = f'bert.encoder.layer.{i}.'\n value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']\n head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)\n head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()\n print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')\n if len(head_idxs) != heads_num:\n attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)\n module_list.append(attention_pruned_model.bert.encoder.layer[i])\n # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.\n # This is just an empirical configuration, you can use any other method to determine this sparsity.\n sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5\n # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.\n sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)\n ffn_config_list.append({\n 'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],\n 'sparsity': sparsity_per_iter\n })\n layer_remained_idxs.append(i)\n\nattention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)\ndistil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)"
|
||||
"attention_pruned_model = create_finetuned_model().to(device)\nattention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')\n\nffn_config_list = []\nlayer_remained_idxs = []\nmodule_list = []\nfor i in range(0, layers_num):\n prefix = f'bert.encoder.layer.{i}.'\n value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']\n head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.).to(\"cpu\")\n head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()\n print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')\n if len(head_idxs) != heads_num:\n attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)\n module_list.append(attention_pruned_model.bert.encoder.layer[i])\n # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.\n # This is just an empirical configuration, you can use any other method to determine this sparsity.\n sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5\n # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.\n sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)\n ffn_config_list.append({\n 'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],\n 'sparsity': sparsity_per_iter\n })\n layer_remained_idxs.append(i)\n\nattention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)\ndistil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -197,7 +197,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -159,11 +159,13 @@ import torch.nn.functional as F
|
|||
from datasets import load_metric
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
def training(model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
lr_scheduler: SCHEDULER = None,
|
||||
max_steps: int = None,
|
||||
max_epochs: int = None,
|
||||
train_dataloader: DataLoader = None,
|
||||
|
@ -399,7 +401,7 @@ module_list = []
|
|||
for i in range(0, layers_num):
|
||||
prefix = f'bert.encoder.layer.{i}.'
|
||||
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
|
||||
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
|
||||
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.).to("cpu")
|
||||
head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
|
||||
print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
|
||||
if len(head_idxs) != heads_num:
|
||||
|
|
|
@ -1 +1 @@
|
|||
099d745e7809d57227bb42086ecd581c
|
||||
822d1933bb3b99080589c0cdf89cf89e
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Двоичный файл не отображается.
|
@ -3,12 +3,13 @@
|
|||
|
||||
.. _sphx_glr_tutorials_sg_execution_times:
|
||||
|
||||
|
||||
Computation times
|
||||
=================
|
||||
**06:12.568** total execution time for **tutorials** files:
|
||||
**00:29.846** total execution time for **tutorials** files:
|
||||
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 06:12.568 | 0.0 MB |
|
||||
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:29.846 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
@ -16,8 +17,6 @@ Computation times
|
|||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_nasbench_as_dataset.py` (``nasbench_as_dataset.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py` (``pruning_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_pruning_speedup.py` (``pruning_speedup.py``) | 00:00.000 | 0.0 MB |
|
||||
|
@ -26,3 +25,5 @@ Computation times
|
|||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 00:00.000 | 0.0 MB |
|
||||
+-----------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
|
|
@ -25,10 +25,11 @@ from nni.compression.pytorch import ModelSpeedup
|
|||
from nni.compression.pytorch.utils import count_flops_params
|
||||
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
|
||||
from nni.compression.pytorch.utils import TorchEvaluator
|
||||
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
############# Create dataloaders, optimizer, training and evaluation function ############
|
||||
|
||||
|
||||
class Mnist(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -74,7 +75,7 @@ def training(
|
|||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
lr_scheduler: SCHEDULER = None,
|
||||
max_steps: int = None, max_epochs: int = None,
|
||||
local_rank: int = -1,
|
||||
save_best_model: bool = False, save_path: str = None,
|
||||
|
|
|
@ -2,12 +2,13 @@ from __future__ import annotations
|
|||
from typing import Callable, Any
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import StepLR, _LRScheduler
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
import nni
|
||||
from nni.compression.pytorch import TorchEvaluator
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
@ -21,7 +22,7 @@ model: torch.nn.Module = VGG().to(device)
|
|||
|
||||
def training_func(model: torch.nn.Module, optimizers: torch.optim.Optimizer,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
lr_schedulers: _LRScheduler | None = None, max_steps: int | None = None,
|
||||
lr_schedulers: SCHEDULER | None = None, max_steps: int | None = None,
|
||||
max_epochs: int | None = None, *args, **kwargs):
|
||||
model.train()
|
||||
# prepare data
|
||||
|
|
|
@ -8,12 +8,12 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
import nni
|
||||
from nni.contrib.compression.quantization import BNNQuantizer
|
||||
from nni.contrib.compression.utils import TorchEvaluator
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
@ -108,7 +108,7 @@ def test(model: nn.Module):
|
|||
return acc
|
||||
|
||||
|
||||
def train(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[_LRScheduler, None] = None,
|
||||
def train(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,
|
||||
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = 400):
|
||||
best_top1 = 0
|
||||
max_epochs = max_epochs or (40 if max_steps is None else 400)
|
||||
|
@ -120,7 +120,7 @@ def train(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable,
|
|||
loss = training_step(batch, model)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if isinstance(scheduler, _LRScheduler):
|
||||
if isinstance(scheduler, SCHEDULER):
|
||||
scheduler.step()
|
||||
if batch_idx % 100 == 0:
|
||||
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import Callable, Union, List, Dict, Tuple
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -16,6 +15,8 @@ from torchvision import datasets, transforms
|
|||
import nni
|
||||
from nni.contrib.compression.quantization import DoReFaQuantizer
|
||||
from nni.contrib.compression.utils import TorchEvaluator
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
device = 'cuda'
|
||||
|
@ -61,7 +62,7 @@ def training_step(batch, model) -> Tensor:
|
|||
|
||||
|
||||
def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Optimizer]], \
|
||||
training_step: _TRAINING_STEP, scheduler: Union[None, _LRScheduler, List[_LRScheduler]] = None,
|
||||
training_step: _TRAINING_STEP, scheduler: Union[None, SCHEDULER, List[SCHEDULER]] = None,
|
||||
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
|
||||
model.train()
|
||||
max_epochs = max_epochs or (10 if max_steps is None else 100)
|
||||
|
@ -85,9 +86,9 @@ def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Opti
|
|||
elif isinstance(optimizer, List) and all(isinstance(_, Optimizer) for _ in optimizer):
|
||||
for opt in optimizer:
|
||||
opt.step()
|
||||
if isinstance(scheduler, _LRScheduler):
|
||||
if isinstance(scheduler, SCHEDULER):
|
||||
scheduler.step()
|
||||
if isinstance(scheduler, List) and all(isinstance(_, _LRScheduler) for _ in scheduler):
|
||||
if isinstance(scheduler, List) and all(isinstance(_, SCHEDULER) for _ in scheduler):
|
||||
for sch in scheduler:
|
||||
sch.step()
|
||||
current_steps += 1
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import Callable, Union, List, Dict, Tuple
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -16,6 +15,7 @@ from torchvision import datasets, transforms
|
|||
import nni
|
||||
from nni.contrib.compression.quantization import LsqQuantizer
|
||||
from nni.contrib.compression.utils import TorchEvaluator
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
torch.manual_seed(0)
|
||||
device = 'cuda'
|
||||
|
@ -62,7 +62,7 @@ def training_step(batch, model) -> Tensor:
|
|||
|
||||
|
||||
def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Optimizer]], \
|
||||
training_step: _TRAINING_STEP, scheduler: Union[None, _LRScheduler, List[_LRScheduler]] = None,
|
||||
training_step: _TRAINING_STEP, scheduler: Union[None, SCHEDULER, List[SCHEDULER]] = None,
|
||||
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
|
||||
model.train()
|
||||
max_epochs = max_epochs or (40 if max_steps is None else 100)
|
||||
|
@ -86,9 +86,9 @@ def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Opti
|
|||
elif isinstance(optimizer, List) and all(isinstance(_, Optimizer) for _ in optimizer):
|
||||
for opt in optimizer:
|
||||
opt.step()
|
||||
if isinstance(scheduler, _LRScheduler):
|
||||
if isinstance(scheduler, SCHEDULER):
|
||||
scheduler.step()
|
||||
if isinstance(scheduler, List) and all(isinstance(_, _LRScheduler) for _ in scheduler):
|
||||
if isinstance(scheduler, List) and all(isinstance(_, SCHEDULER) for _ in scheduler):
|
||||
for sch in scheduler:
|
||||
sch.step()
|
||||
current_steps += 1
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import Callable, Union, List, Tuple, Any, Dict
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
from torchvision import datasets, transforms
|
||||
|
@ -15,6 +14,7 @@ from torchvision import datasets, transforms
|
|||
import nni
|
||||
from nni.contrib.compression.quantization import PtqQuantizer
|
||||
from nni.contrib.compression.utils import TorchEvaluator
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
@ -63,7 +63,7 @@ def training_step(batch, model) -> Tensor:
|
|||
|
||||
|
||||
def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Optimizer]], \
|
||||
training_step: _TRAINING_STEP, scheduler: Union[None, _LRScheduler, List[_LRScheduler]] = None,
|
||||
training_step: _TRAINING_STEP, scheduler: Union[None, SCHEDULER, List[SCHEDULER]] = None,
|
||||
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
|
||||
model.train()
|
||||
max_epochs = max_epochs if max_epochs else 10 if max_steps is None else 100
|
||||
|
@ -87,9 +87,9 @@ def training_model(model: torch.nn.Module, optimizer: Union[Optimizer, List[Opti
|
|||
elif isinstance(optimizer, List) and all(isinstance(_, Optimizer) for _ in optimizer):
|
||||
for opt in optimizer:
|
||||
opt.step()
|
||||
if isinstance(scheduler, _LRScheduler):
|
||||
if isinstance(scheduler, SCHEDULER):
|
||||
scheduler.step()
|
||||
if isinstance(scheduler, List) and all(isinstance(_, _LRScheduler) for _ in scheduler):
|
||||
if isinstance(scheduler, List) and all(isinstance(_, SCHEDULER) for _ in scheduler):
|
||||
for sch in scheduler:
|
||||
sch.step()
|
||||
current_steps += 1
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import Callable
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer, SGD
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
@ -17,6 +16,8 @@ from torchvision.datasets import MNIST
|
|||
import nni
|
||||
from nni.contrib.compression.quantization import QATQuantizer
|
||||
from nni.contrib.compression.utils import TorchEvaluator
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
torch.manual_seed(1024)
|
||||
device = 'cuda'
|
||||
|
@ -62,7 +63,7 @@ def training_step(batch, model):
|
|||
return loss
|
||||
|
||||
|
||||
def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: _LRScheduler | None = None,
|
||||
def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: SCHEDULER | None = None,
|
||||
max_steps: int | None = None, max_epochs: int | None = None):
|
||||
model.train()
|
||||
max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100
|
||||
|
|
|
@ -159,11 +159,13 @@ import torch.nn.functional as F
|
|||
from datasets import load_metric
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
def training(model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
lr_scheduler: SCHEDULER = None,
|
||||
max_steps: int = None,
|
||||
max_epochs: int = None,
|
||||
train_dataloader: DataLoader = None,
|
||||
|
@ -399,7 +401,7 @@ module_list = []
|
|||
for i in range(0, layers_num):
|
||||
prefix = f'bert.encoder.layer.{i}.'
|
||||
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
|
||||
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
|
||||
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.).to("cpu")
|
||||
head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
|
||||
print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
|
||||
if len(head_idxs) != heads_num:
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from nni.common.version import torch_version_is_2
|
||||
|
||||
if torch_version_is_2():
|
||||
from torch.optim.lr_scheduler import LRScheduler # type: ignore
|
||||
SCHEDULER = LRScheduler
|
||||
else:
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
SCHEDULER = _LRScheduler
|
|
@ -9,9 +9,9 @@ from typing import Callable, Dict, List, Type
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from nni.common.serializer import is_traceable
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper']
|
||||
|
||||
|
@ -91,7 +91,7 @@ class OptimizerConstructHelper(ConstructHelper):
|
|||
|
||||
|
||||
class LRSchedulerConstructHelper(ConstructHelper):
|
||||
def __init__(self, lr_scheduler_class: Type[_LRScheduler], *args, **kwargs):
|
||||
def __init__(self, lr_scheduler_class: Type[SCHEDULER], *args, **kwargs):
|
||||
args = list(args)
|
||||
if 'optimizer' in kwargs:
|
||||
kwargs['optimizer'] = None
|
||||
|
@ -99,7 +99,7 @@ class LRSchedulerConstructHelper(ConstructHelper):
|
|||
args[0] = None
|
||||
super().__init__(lr_scheduler_class, *args, **kwargs)
|
||||
|
||||
def call(self, optimizer: Optimizer) -> _LRScheduler:
|
||||
def call(self, optimizer: Optimizer) -> SCHEDULER:
|
||||
args = deepcopy(self.args)
|
||||
kwargs = deepcopy(self.kwargs)
|
||||
|
||||
|
@ -111,10 +111,10 @@ class LRSchedulerConstructHelper(ConstructHelper):
|
|||
return self.callable_obj(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_trace(lr_scheduler_trace: _LRScheduler):
|
||||
def from_trace(lr_scheduler_trace: SCHEDULER):
|
||||
assert is_traceable(lr_scheduler_trace), \
|
||||
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
|
||||
assert isinstance(lr_scheduler_trace, _LRScheduler), \
|
||||
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
|
||||
assert isinstance(lr_scheduler_trace, SCHEDULER), \
|
||||
f'It is not an instance of torch.nn.lr_scheduler.{SCHEDULER}.'
|
||||
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, # type: ignore
|
||||
**lr_scheduler_trace.trace_kwargs) # type: ignore
|
||||
|
|
|
@ -11,7 +11,6 @@ from typing import Dict, List, Tuple, Union, Any, Callable, Optional
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
try:
|
||||
|
@ -31,9 +30,11 @@ else:
|
|||
|
||||
import nni
|
||||
from nni.common import is_traceable
|
||||
from nni.common.types import SCHEDULER
|
||||
from .constructor_helper import OptimizerConstructHelper, LRSchedulerConstructHelper
|
||||
from .check_ddp import check_ddp_model, reset_ddp_model
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -553,7 +554,7 @@ class LightningEvaluator(Evaluator):
|
|||
|
||||
_OPTIMIZERS = Union[Optimizer, List[Optimizer]]
|
||||
_CRITERION = Callable[[Any, Any], Any]
|
||||
_SCHEDULERS = Union[None, _LRScheduler, List[_LRScheduler]]
|
||||
_SCHEDULERS = Union[None, SCHEDULER, List[SCHEDULER]]
|
||||
_EVALUATING_FUNC = Callable[[Module], Union[float, Dict]]
|
||||
_TRAINING_FUNC = Callable[[Module, _OPTIMIZERS, _CRITERION, _SCHEDULERS, Optional[int], Optional[int]], None]
|
||||
|
||||
|
@ -656,7 +657,7 @@ class TorchEvaluator(Evaluator):
|
|||
"""
|
||||
|
||||
def __init__(self, training_func: _TRAINING_FUNC, optimizers: Optimizer | List[Optimizer], criterion: _CRITERION,
|
||||
lr_schedulers: _LRScheduler | List[_LRScheduler] | None = None, dummy_input: Any | None = None,
|
||||
lr_schedulers: SCHEDULER | List[SCHEDULER] | None = None, dummy_input: Any | None = None,
|
||||
evaluating_func: _EVALUATING_FUNC | None = None):
|
||||
self.training_func = training_func
|
||||
self._ori_criterion = criterion
|
||||
|
@ -665,11 +666,11 @@ class TorchEvaluator(Evaluator):
|
|||
self.evaluating_func = evaluating_func
|
||||
|
||||
self._train_with_single_optimizer = isinstance(optimizers, Optimizer)
|
||||
self._train_with_single_scheduler = isinstance(lr_schedulers, _LRScheduler)
|
||||
self._train_with_single_scheduler = isinstance(lr_schedulers, SCHEDULER)
|
||||
|
||||
self.model: Module | None = None
|
||||
self._optimizers: List[Optimizer] | None = None
|
||||
self._lr_schedulers: List[_LRScheduler] | None = None
|
||||
self._lr_schedulers: List[SCHEDULER] | None = None
|
||||
self._first_optimizer_step: Callable | None = None
|
||||
self._param_names_map: Dict[str, str] | None = None
|
||||
|
||||
|
@ -677,7 +678,7 @@ class TorchEvaluator(Evaluator):
|
|||
self._tmp_optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
|
||||
assert all(isinstance(optimizer, Optimizer) and is_traceable(optimizer) for optimizer in self._tmp_optimizers)
|
||||
self._tmp_lr_schedulers = lr_schedulers if isinstance(lr_schedulers, (list, tuple)) else [lr_schedulers] if lr_schedulers else []
|
||||
assert all(isinstance(lr_scheduler, _LRScheduler) and is_traceable(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers)
|
||||
assert all(isinstance(lr_scheduler, SCHEDULER) and is_traceable(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers)
|
||||
self._initialization_complete = False
|
||||
|
||||
def _init_optimizer_helpers(self, pure_model: Module):
|
||||
|
|
|
@ -9,9 +9,10 @@ from typing import Callable, Dict, List, Type
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from nni.common.serializer import is_traceable
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper']
|
||||
|
||||
|
@ -91,7 +92,7 @@ class OptimizerConstructHelper(ConstructHelper):
|
|||
|
||||
|
||||
class LRSchedulerConstructHelper(ConstructHelper):
|
||||
def __init__(self, lr_scheduler_class: Type[_LRScheduler], *args, **kwargs):
|
||||
def __init__(self, lr_scheduler_class: Type[SCHEDULER], *args, **kwargs): # type: ignore
|
||||
args = list(args)
|
||||
if 'optimizer' in kwargs:
|
||||
kwargs['optimizer'] = None
|
||||
|
@ -99,7 +100,7 @@ class LRSchedulerConstructHelper(ConstructHelper):
|
|||
args[0] = None
|
||||
super().__init__(lr_scheduler_class, *args, **kwargs)
|
||||
|
||||
def call(self, optimizer: Optimizer) -> _LRScheduler:
|
||||
def call(self, optimizer: Optimizer) -> SCHEDULER: # type: ignore
|
||||
args = deepcopy(self.args)
|
||||
kwargs = deepcopy(self.kwargs)
|
||||
|
||||
|
@ -111,10 +112,10 @@ class LRSchedulerConstructHelper(ConstructHelper):
|
|||
return self.callable_obj(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_trace(lr_scheduler_trace: _LRScheduler):
|
||||
def from_trace(lr_scheduler_trace: SCHEDULER): # type: ignore
|
||||
assert is_traceable(lr_scheduler_trace), \
|
||||
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
|
||||
assert isinstance(lr_scheduler_trace, _LRScheduler), \
|
||||
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
|
||||
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, # type: ignore
|
||||
**lr_scheduler_trace.trace_kwargs) # type: ignore
|
||||
assert isinstance(lr_scheduler_trace, SCHEDULER), \
|
||||
f'It is not an instance of torch.nn.lr_scheduler.{SCHEDULER}.'
|
||||
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, # type: ignore
|
||||
**lr_scheduler_trace.trace_kwargs) # type: ignore
|
||||
|
|
|
@ -13,7 +13,6 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
try:
|
||||
|
@ -32,9 +31,11 @@ else:
|
|||
|
||||
import nni
|
||||
from nni.common import is_traceable
|
||||
from nni.common.types import SCHEDULER
|
||||
from .constructor_helper import OptimizerConstructHelper, LRSchedulerConstructHelper
|
||||
from .check_ddp import check_ddp_model, reset_ddp_model
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -218,7 +219,7 @@ class Evaluator:
|
|||
raise TypeError("optimizer can only optimize Tensors, "
|
||||
"but one of the params is " + torch.typename(param))
|
||||
if not optimizer.defaults.get('differentiable', None) \
|
||||
and not (param.is_leaf or param.retains_grad): # type: ignore
|
||||
and not (param.is_leaf or param.retains_grad): # type: ignore
|
||||
raise ValueError("can't optimize a non-leaf Tensor")
|
||||
target_param_group['params'].append(param)
|
||||
|
||||
|
@ -486,27 +487,27 @@ class LightningEvaluator(Evaluator):
|
|||
|
||||
if self._opt_returned_dicts:
|
||||
def new_configure_optimizers(_): # type: ignore
|
||||
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
|
||||
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
|
||||
optimizers = [opt_lrs_dict['optimizer'] for opt_lrs_dict in optimizers_lr_schedulers]
|
||||
# add param group
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers) # type: ignore
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers) # type: ignore
|
||||
|
||||
return optimizers_lr_schedulers
|
||||
|
||||
elif self._lr_scheduler_helpers:
|
||||
def new_configure_optimizers(_): # type: ignore
|
||||
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
|
||||
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
|
||||
optimizers, lr_schedulers = optimizers_lr_schedulers
|
||||
# add param_group
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers) # type: ignore
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers) # type: ignore
|
||||
|
||||
return optimizers, lr_schedulers
|
||||
|
||||
else:
|
||||
def new_configure_optimizers(_):
|
||||
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
|
||||
optimizers_lr_schedulers: Any = old_configure_optimizers() # type: ignore
|
||||
# add param_group
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers_lr_schedulers) # type: ignore
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, optimizers_lr_schedulers) # type: ignore
|
||||
|
||||
return optimizers_lr_schedulers
|
||||
|
||||
|
@ -656,7 +657,7 @@ class LightningEvaluator(Evaluator):
|
|||
|
||||
_OPTIMIZERS = Union[Optimizer, List[Optimizer]]
|
||||
_TRAINING_STEP = Callable[..., Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]]
|
||||
_SCHEDULERS = Union[None, _LRScheduler, List[_LRScheduler]]
|
||||
_SCHEDULERS = Union[None, SCHEDULER, List[SCHEDULER]]
|
||||
_EVALUATING_FUNC = Callable[[Module], Union[float, Dict]]
|
||||
_TRAINING_FUNC = Callable[[Module, _OPTIMIZERS, _TRAINING_STEP, Optional[_SCHEDULERS], Optional[int], Optional[int]], None]
|
||||
|
||||
|
@ -765,7 +766,7 @@ class TorchEvaluator(Evaluator):
|
|||
"""
|
||||
|
||||
def __init__(self, training_func: _TRAINING_FUNC, optimizers: Optimizer | List[Optimizer], training_step: _TRAINING_STEP,
|
||||
lr_schedulers: _LRScheduler | List[_LRScheduler] | None = None, dummy_input: Any | None = None,
|
||||
lr_schedulers: SCHEDULER | List[SCHEDULER] | None = None, dummy_input: Any | None = None, # type: ignore
|
||||
evaluating_func: _EVALUATING_FUNC | None = None):
|
||||
self.training_func = training_func
|
||||
self._ori_training_step = training_step
|
||||
|
@ -774,11 +775,11 @@ class TorchEvaluator(Evaluator):
|
|||
self.evaluating_func = evaluating_func
|
||||
|
||||
self._train_with_single_optimizer = isinstance(optimizers, Optimizer)
|
||||
self._train_with_single_scheduler = isinstance(lr_schedulers, _LRScheduler)
|
||||
self._train_with_single_scheduler = isinstance(lr_schedulers, SCHEDULER)
|
||||
|
||||
self.model: Module | None = None
|
||||
self._optimizers: List[Optimizer] | None = None
|
||||
self._lr_schedulers: List[_LRScheduler] | None = None
|
||||
self._lr_schedulers: List[SCHEDULER] | None = None # type: ignore
|
||||
self._first_optimizer_step: Callable | None = None
|
||||
self._param_names_map: Dict[str, str] | None = None
|
||||
|
||||
|
@ -786,7 +787,7 @@ class TorchEvaluator(Evaluator):
|
|||
self._tmp_optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
|
||||
assert all(isinstance(optimizer, Optimizer) and is_traceable(optimizer) for optimizer in self._tmp_optimizers)
|
||||
self._tmp_lr_schedulers = lr_schedulers if isinstance(lr_schedulers, (list, tuple)) else [lr_schedulers] if lr_schedulers else []
|
||||
assert all(isinstance(lr_scheduler, _LRScheduler) and is_traceable(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers)
|
||||
assert all(isinstance(lr_scheduler, SCHEDULER) and is_traceable(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers)
|
||||
self._initialization_complete = False
|
||||
|
||||
def _init_optimizer_helpers(self, pure_model: Module):
|
||||
|
@ -829,7 +830,7 @@ class TorchEvaluator(Evaluator):
|
|||
def patch_optim_param_group(self, module_name_param_dict: Dict[str, List[Tensor]]):
|
||||
assert isinstance(self.model, Module)
|
||||
assert module_name_param_dict is not None
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, self._optimizers) # type: ignore
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, self._optimizers) # type: ignore
|
||||
|
||||
def unbind_model(self):
|
||||
if self.model:
|
||||
|
|
|
@ -10,13 +10,14 @@ import torch
|
|||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
|
||||
from ..device import device
|
||||
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
class SimpleTorchModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -37,7 +38,7 @@ class SimpleTorchModel(torch.nn.Module):
|
|||
return F.log_softmax(x, -1)
|
||||
|
||||
|
||||
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: _LRScheduler = None,
|
||||
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: SCHEDULER = None,
|
||||
max_steps: int | None = None, max_epochs: int | None = None, device: torch.device = device):
|
||||
model.train()
|
||||
|
||||
|
|
|
@ -9,13 +9,14 @@ import torch
|
|||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
|
||||
from ..device import device
|
||||
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
class SimpleTorchModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -43,7 +44,7 @@ def training_step(batch: Tuple, model: Module, device: torch.device = device):
|
|||
return loss
|
||||
|
||||
|
||||
def training_model(model: Module, optimizer: Optimizer, training_step: Callable, scheduler: _LRScheduler = None,
|
||||
def training_model(model: Module, optimizer: Optimizer, training_step: Callable, scheduler: SCHEDULER = None,
|
||||
max_steps: int | None = None, max_epochs: int | None = None, device: torch.device = device):
|
||||
model.train()
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ from typing import Callable, Union, List, Dict, Tuple
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче