[NAS] Support of PyTorch (and Lightning) 2.0 (#5466)

This commit is contained in:
Yuge Zhang 2023-05-29 10:45:42 +08:00 коммит произвёл GitHub
Родитель f9440cce54
Коммит a58cadda48
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
23 изменённых файлов: 187 добавлений и 224 удалений

10
dependencies/recommended.txt поставляемый
Просмотреть файл

@ -3,11 +3,11 @@
-f https://download.pytorch.org/whl/torch_stable.html
tensorflow >= 2.7.0
tensorboard >= 2.7.0
torch == 1.13.1+cpu ; sys_platform != "darwin"
torch == 1.13.1 ; sys_platform == "darwin"
torchvision == 0.14.1+cpu ; sys_platform != "darwin"
torchvision == 0.14.1 ; sys_platform == "darwin"
pytorch-lightning >= 1.6.1, < 2.0
torch == 2.0.0+cpu ; sys_platform != "darwin"
torch == 2.0.0 ; sys_platform == "darwin"
torchvision == 0.15.0+cpu ; sys_platform != "darwin"
torchvision == 0.15.0 ; sys_platform == "darwin"
pytorch-lightning >= 2.0
torchmetrics
lightgbm
onnx

8
dependencies/recommended_gpu.txt поставляемый
Просмотреть файл

@ -2,12 +2,12 @@
-f https://download.pytorch.org/whl/torch_stable.html
tensorflow
torch == 1.13.1+cu117
torchvision == 0.14.1+cu117
pytorch-lightning >= 1.6.1, < 2.0
torch == 2.0.0+cu117
torchvision == 0.15.0+cu117
pytorch-lightning >= 2.0
# for full-test-compression
-f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
-f https://download.openmmlab.com/mmcv/dist/cu117/torch2.0/index.html
mmcv >= 2.0.0rc4, < 2.1.0
mmdet >= 3.0
mmengine

26
dependencies/recommended_gpu_legacy.txt поставляемый Normal file
Просмотреть файл

@ -0,0 +1,26 @@
# A temporary dependency version for compression.
# Delete it after compression tests are compatible with PyTorch 2.0.
-f https://download.pytorch.org/whl/torch_stable.html
tensorflow
torch == 1.13.1+cu117
torchvision == 0.14.1+cu117
pytorch-lightning >= 1.6.1, < 2.0
# for full-test-compression
-f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
mmcv >= 2.0.0rc4, < 2.1.0
mmdet >= 3.0
mmengine
git+https://github.com/microsoft/nn-Meter.git#egg=nn_meter
lightgbm
onnx
onnxsim
onnxruntime-gpu
peewee
graphviz
gym
sympy
tianshou >= 0.4.1
timm >= 0.5.4

Просмотреть файл

@ -152,7 +152,7 @@ def movement_mul_mask(target: torch.Tensor, target_space: PruningTargetSpace):
assert target_space.mask is not None and target_space.shape is not None
if target_space._scaler is not None:
score = target_space._scaler.expand(score, target_space.shape, keepdim=True, full_expand=False)
return torch.mul(target, _StraightThrough.apply(score, target_space.mask))
return torch.mul(target, _StraightThrough.apply(score, target_space.mask)) # type: ignore
def movement_add_mask(target: torch.Tensor, target_space: PruningTargetSpace):
@ -164,7 +164,7 @@ def movement_add_mask(target: torch.Tensor, target_space: PruningTargetSpace):
trans_mask = torch.where(target_space.mask == 1, torch.zeros_like(target_space.mask), SMALL_MASK_VALUE)
if target_space._scaler is not None:
score = target_space._scaler.expand(score, target_space.shape, keepdim=True, full_expand=False)
return torch.add(target, _StraightThrough.apply(score, trans_mask))
return torch.add(target, _StraightThrough.apply(score, trans_mask)) # type: ignore
def slim_mul_mask(target: torch.Tensor, target_space: PruningTargetSpace):

Просмотреть файл

@ -612,7 +612,7 @@ class LightningEvaluator(Evaluator):
trainer.num_sanity_val_steps = 0
if max_steps:
trainer.fit_loop.max_steps = max_steps
trainer.fit_loop.max_steps = max_steps # type: ignore
if max_epochs:
trainer.fit_loop.max_epochs = max_epochs
trainer.fit(self.model, self.data_module)

Просмотреть файл

@ -137,8 +137,9 @@ class MultiModelTrainer(pytorch_lightning.Trainer):
def __init__(self, use_cgo: bool = True, **trainer_kwargs):
if use_cgo:
if "accelerator" in trainer_kwargs:
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
# Accelerator and strategy can be both set at lightning 2.0.
# if "accelerator" in trainer_kwargs:
# raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
if 'strategy' in trainer_kwargs:
raise ValueError("MultiModelTrainer does not support specifying strategy")

Просмотреть файл

@ -0,0 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import Any, Optional
from pytorch_lightning.utilities.combined_loader import (
CombinedLoader, _CombinationMode, _SUPPORTED_MODES, _Sequential,
_ModeIterator, _tree_flatten
)
_SUPPORTED_MODES['_nni_concat'] = _CombinationMode(fn=sum, iterator=_Sequential)
__all__ = ['ConcatLoader']
class ConcatLoader(CombinedLoader):
"""This is trying to bypass the supported mode checker in PyTorch-Lightning FitLoop.
"""
def __init__(self, iterables: Any) -> None:
self._iterables = iterables
self._flattened, self._spec = _tree_flatten(iterables)
self._mode = '_nni_concat'
self._iterator: Optional[_ModeIterator] = None

Просмотреть файл

@ -1,6 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# type: ignore
from __future__ import annotations
from typing import Any

Просмотреть файл

@ -4,7 +4,7 @@
from __future__ import annotations
import warnings
from typing import Any, Iterable, cast, TYPE_CHECKING
from typing import Any, Iterable, List, cast, TYPE_CHECKING
import torch.optim as optim
import torch.nn as nn
@ -279,9 +279,6 @@ class BaseOneShotLightningModule(LightningModule):
if self.automatic_optimization:
raise ValueError('This method should not be used when automatic optimization is turned on.')
if self.trainer.optimizer_frequencies:
warnings.warn('optimizer_frequencies is not supported in NAS. It will be ignored.', UserWarning)
# Has to be optimizers() here (to get LightningOptimizer)
# instead of trainer.optimizers (raw optimizers),
# because otherwise optim_progress is incorrect.
@ -289,22 +286,54 @@ class BaseOneShotLightningModule(LightningModule):
if not isinstance(optimizers, list):
optimizers = [optimizers]
# Filter out optimizers for architecture parameters.
optimizers = [opt for opt in optimizers if not getattr(opt, 'is_arch_optimizer', False)]
optimizers = cast(List[Optimizer], [opt for opt in optimizers if not getattr(opt, 'is_arch_optimizer', False)])
if hasattr(self.trainer, 'optimizer_frequencies'): # lightning < 2
self._legacy_advance_optimization(loss, batch_idx, optimizers, gradient_clip_val, gradient_clip_algorithm)
else:
if not self.training_module.automatic_optimization:
raise ValueError('Evaluator module with manual optimization is not compatible with one-shot algorithms.')
if len(optimizers) != 1:
raise ValueError('More than one optimizer returned by evaluator. This is not supported in NAS.')
optimizer = optimizers[0]
# There should be many before/after hooks called here, but they are omitted in this implementation.
# 1. zero gradient
self.training_module.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer)
# 2. backward
self.manual_backward(loss)
# 3. grad clip
self.training_module.configure_gradient_clipping(optimizer, gradient_clip_val, gradient_clip_algorithm)
# 4. optimizer step
self.training_module.optimizer_step(self.trainer.current_epoch, batch_idx, optimizer)
self._optimizer_progress += 1
def _legacy_advance_optimization(
self,
loss: Any,
batch_idx: int,
optimizers: list[Optimizer],
gradient_clip_val: int | float | None = None,
gradient_clip_algorithm: str | None = None
):
""":meth:`advance_optimization` for Lightning 1.x."""
if self.trainer.optimizer_frequencies: # type: ignore
warnings.warn('optimizer_frequencies is not supported in NAS. It will be ignored.', UserWarning)
opt_idx = self._optimizer_progress % len(optimizers)
optimizer = cast(Optimizer, optimizers[opt_idx]) # LightningOptimizer has the same interface as Optimizer.
# There should be many before/after hooks called here, but they are omitted in this implementation.
# 1. zero gradient
self.training_module.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.training_module.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) # type: ignore
# 2. backward
self.manual_backward(loss)
# 3. grad clip
self.training_module.configure_gradient_clipping(optimizer, opt_idx, gradient_clip_val, gradient_clip_algorithm)
self.training_module.configure_gradient_clipping(optimizer, opt_idx, gradient_clip_val, gradient_clip_algorithm) # type: ignore
# 4. optimizer step
self.training_module.optimizer_step(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self._optimizer_progress += 1
self.training_module.optimizer_step(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) # type: ignore
def advance_lr_schedulers(self, batch_idx: int):
"""
@ -329,11 +358,18 @@ class BaseOneShotLightningModule(LightningModule):
try:
# lightning >= 1.6
for config in self.trainer.lr_scheduler_configs:
scheduler, opt_idx = config.scheduler, config.opt_idx
if hasattr(config, 'opt_idx'):
# lightning < 2.0
scheduler, opt_idx = config.scheduler, config.opt_idx # type: ignore
else:
scheduler, opt_idx = config.scheduler, None
if config.reduce_on_plateau:
warnings.warn('Reduce-lr-on-plateau is not supported in NAS. It will be ignored.', UserWarning)
if config.interval == interval and current_idx % config.frequency == 0:
self.training_module.lr_scheduler_step(cast(Any, scheduler), cast(int, opt_idx), None)
if opt_idx is not None:
self.training_module.lr_scheduler_step(cast(Any, scheduler), cast(int, opt_idx), None) # type: ignore
else:
self.training_module.lr_scheduler_step(cast(Any, scheduler), None)
except AttributeError:
# lightning < 1.6
for lr_scheduler in self.trainer.lr_schedulers: # type: ignore

Просмотреть файл

@ -1,155 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.skip_target = skip_target
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self, return_prob=False):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field, return_prob=return_prob)
return result
def _initialize(self):
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field, return_prob):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled_dist = F.softmax(logit, dim=-1)
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
sampled_dist = torch.sigmoid(logit)
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
sampled = sampled.detach().cpu().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
if return_prob:
return sampled_dist.flatten().detach().cpu().numpy().tolist()
return sampled

Просмотреть файл

@ -168,8 +168,13 @@ class EnasLightningModule(BaseOneShotLightningModule):
return super().set_model(model)
def training_step(self, batch_packed, batch_idx):
# The received batch is a tuple of (data, "train" | "val")
batch, mode = batch_packed
if len(batch_packed) == 2:
# Legacy (pytorch-lightning 1.x): The received batch is a tuple of (data, "train" | "val")
batch, mode = batch_packed
else:
# New (pytorch-lightning 2.0+): a tuple of data, batch_idx, and dataloader_idx
batch, _, dataloader_idx = batch_packed
mode = 'train' if dataloader_idx == 0 else 'val'
assert self._replay_buffer is not None
@ -223,6 +228,9 @@ class EnasLightningModule(BaseOneShotLightningModule):
self.log_probs(self.export_probs())
if self._trajectory_counter > 0 and self._trajectory_counter % self.batches_per_update == 0:
# Export could be just called.
# The policy must be in train mode to make update work.
self.policy.train()
update_times = self.update_kwargs.get('update_times', 1)
for _ in range(update_times):
self.policy.update(0, self._replay_buffer, **self.update_kwargs)

Просмотреть файл

@ -942,8 +942,11 @@ class ENAS(RandomOneShot):
)
def train_dataloader(self, train_dataloader_fn, val_dataloader_fn):
# Import locally to avoid import error on legacy PL version
from .dataloader import ConcatLoader
import pytorch_lightning
if pytorch_lightning.__version__.startswith('1.'):
from ._dataloader_legacy import ConcatLoader
else:
from ._dataloader import ConcatLoader
return ConcatLoader({
'train': train_dataloader_fn(),
'val': val_dataloader_fn()

Просмотреть файл

@ -280,7 +280,7 @@ def _canonicalize_dims(dims: list[int], n_dims: int, fn: Any) -> list[int]:
return [d if d >= 0 else d + n_dims for d in dims]
def aten_reshape_alias_formula(fn: Any, input: ShapeTensor, size: list[int], stride: list[int]) -> MutableShape:
def aten_reshape_alias_formula(fn: Any, input: ShapeTensor, size: list[int], stride: list[int] | None = None) -> MutableShape:
input_shape = ensure_shape(input)
if input_shape.is_mutable():
raise RuntimeError(f'Cannot infer the shape of {fn} because the input shape is not determined: {input_shape}, '
@ -395,6 +395,7 @@ _safe_register_aten_formula('cat.default', aten_cat_formula)
_safe_register_aten_formula('mean.dim', aten_mean_dim)
_safe_register_aten_formula('_log_softmax.default', keep_first_shape_formula)
_safe_register_aten_formula('_reshape_alias.default', aten_reshape_alias_formula)
_safe_register_aten_formula('view.default', aten_reshape_alias_formula)
_safe_register_aten_formula('add.Tensor', aten_shape_broadcast)
_safe_register_aten_formula('mul.Tensor', aten_shape_broadcast)
_safe_register_aten_formula('slice.Tensor', aten_slice_formula)

Просмотреть файл

@ -47,7 +47,7 @@ stages:
- template: templates/install-dependencies.yml
parameters:
platform: ubuntu-latest-gpu
platform: ubuntu-latest-gpu-gpu-torch1.x
python_env: venv
- template: templates/install-nni.yml

Просмотреть файл

@ -62,6 +62,11 @@ steps:
displayName: (GPU) Activate CUDA dependencies
condition: and(succeeded(), contains('${{ parameters.platform }}', 'gpu'))
- script: |
mv dependencies/recommended_gpu_legacy.txt dependencies/recommended.txt
displayName: (GPU) Activate CUDA legacy dependencies
condition: and(succeeded(), contains('${{ parameters.platform }}', 'gpu-torch1.x'))
- script: |
echo '===== develop ====='
python -m pip install -r dependencies/develop.txt

Просмотреть файл

@ -1,3 +1,5 @@
import pytest
from nni.nas.benchmark import *
from nni.nas.execution import SequentialExecutionEngine
from nni.nas.strategy import *
@ -30,6 +32,7 @@ def test_nasbench201_with_rl():
assert list(strategy.list_models(sort=True, limit=1))[0].metric > 0.7
@pytest.mark.flaky(reruns=2)
def test_nasbench101_with_evo():
pytorch_space = NasBench101()
benchmark = NasBench101Benchmark()

Просмотреть файл

@ -6,7 +6,7 @@ import torch.nn as nn
import torchmetrics
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import seed_everything
import nni
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
@ -114,7 +114,7 @@ class M_2_stem(nn.Module):
return pool2
def create_evaluator(n_models=None):
def create_evaluator(n_models=None, accelerator='gpu'):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=False, transform=transform)
test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=False, transform=transform)
@ -127,7 +127,7 @@ def create_evaluator(n_models=None):
lightning = Lightning(
multi_module,
MultiModelTrainer(max_epochs=1, limit_train_batches=0.25, enable_progress_bar=True),
MultiModelTrainer(max_epochs=1, limit_train_batches=0.25, enable_progress_bar=True, accelerator=accelerator),
train_dataloaders=DataLoader(train_dataset, batch_size=100),
val_dataloaders=DataLoader(test_dataset, batch_size=100)
)
@ -191,7 +191,7 @@ def cgo(request):
def test_multi_model_trainer_cpu(trial_command_channel):
evaluator = create_evaluator(n_models=2)
evaluator = create_evaluator(n_models=2, accelerator='cpu')
evaluator.evaluate(_model_cpu())
result = trial_command_channel.final

Просмотреть файл

@ -257,7 +257,7 @@ class TestOperators(unittest.TestCase, ConvertMixin):
x = torch.randn(3, 4, requires_grad=True)
self.checkExportImport(SimpleOp(), (x, ))
@unittest.skip('No longer works for pytorch 2.0')
def test_basic_max(self):
class SimpleOp(nn.Module):
def forward(self, x, y):

Просмотреть файл

@ -223,7 +223,8 @@ def test_hub_oneshot(space_type, strategy_type):
val_dataloaders=valid_loader,
max_epochs=1,
export_onnx=False,
gpus=1 if torch.cuda.is_available() else 0, # 0 for my debug
accelerator='auto',
devices=1,
logger=False, # disable logging and checkpoint to avoid too much log
enable_checkpointing=False,
enable_model_summary=False,
@ -267,7 +268,8 @@ def test_expectation_profiler():
train_dataloaders=train_loader,
val_dataloaders=valid_loader,
max_epochs=1,
gpus=1 if torch.cuda.is_available() else 0,
accelerator='auto',
devices=1,
num_classes=1000
)

Просмотреть файл

@ -69,8 +69,8 @@ def test_proxyless_bp_hook_once():
trainer = Trainer(
max_epochs=1,
accelerator='cpu', devices=1, num_nodes=1, strategy='ddp',
replace_sampler_ddp=False,
accelerator='cpu', devices=1, num_nodes=1, strategy='ddp_find_unused_parameters_true',
use_distributed_sampler=False,
)
trainer.fit(DistributedModule(), dataloader)
@ -197,7 +197,7 @@ def test_proxyless_repeat_nested():
memo.update(module.resample(memo))
y = repeat(x).sum()
optimizer.zero_grad()
optimizer.zero_grad(set_to_none=False)
y.backward()
optimizer.step()

Просмотреть файл

@ -18,7 +18,6 @@ from nni.nas.nn.pytorch import LayerChoice, ModelSpace
from nni.nas.oneshot.pytorch import DartsLightningModule
from ut.nas.nn.models import MODELS
from .test_utils import RandomDataset
@ -217,8 +216,9 @@ def test_optimizer_lr_scheduler():
def configure_optimizers(self):
opt1 = torch.optim.SGD(self.net.layer1.parameters(), lr=0.1)
opt2 = torch.optim.Adam(self.net.layer2.parameters(), lr=0.2)
return [opt1, opt2], [torch.optim.lr_scheduler.StepLR(opt1, step_size=2, gamma=0.1)]
# no longer supported in lightning 2.x
# opt2 = torch.optim.Adam(self.net.layer2.parameters(), lr=0.2)
return [opt1], [torch.optim.lr_scheduler.StepLR(opt1, step_size=2, gamma=0.1)]
def training_step(self, batch, batch_idx):
loss = self(batch).sum()

Просмотреть файл

@ -5,6 +5,7 @@ import pytest
import torch
import pytorch_lightning
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.trainer.states import RunningStage
from torch.utils.data import DataLoader, Dataset
pytestmark = pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs')
@ -48,7 +49,7 @@ class BoringModel(LightningModule):
def test_concat_loader():
from nni.nas.oneshot.pytorch.dataloader import ConcatLoader
from nni.nas.oneshot.pytorch._dataloader import ConcatLoader
loaders = {
'a': DataLoader(range(10), batch_size=4),
@ -56,40 +57,43 @@ def test_concat_loader():
}
dataloader = ConcatLoader(loaders)
assert len(dataloader) == 7
for i, (data, label) in enumerate(dataloader):
for i, (data, batch_index, loader_index) in enumerate(dataloader):
if i < 3:
assert len(data) <= 4
assert label == 'a'
assert loader_index == 0
assert batch_index == i
else:
assert len(data) <= 5
assert label == 'b'
assert loader_index == 1
assert batch_index == i - 3
def test_concat_loader_nested():
from nni.nas.oneshot.pytorch.dataloader import ConcatLoader
from nni.nas.oneshot.pytorch._dataloader import ConcatLoader
loaders = {
'a': [DataLoader(range(10), batch_size=4), DataLoader(range(20), batch_size=6)],
'b': DataLoader(range(20), batch_size=5),
}
dataloader = ConcatLoader(loaders)
assert len(dataloader) == 7
for i, (data, label) in enumerate(dataloader):
assert len(dataloader) == 11
for i, (data, batch_index, loader_index) in enumerate(dataloader):
if i < 3:
assert isinstance(data, list) and len(data) == 2
assert label == 'a'
assert len(data) in [2, 4]
assert loader_index == 0
assert batch_index == i
else:
assert label == 'b'
assert 1 <= loader_index <= 2
@pytest.mark.parametrize('replace_sampler_ddp', [False, True])
@pytest.mark.parametrize('use_distributed_sampler', [False, True])
@pytest.mark.parametrize('is_min_size_mode', [True])
@pytest.mark.parametrize('num_devices', ['auto', 1, 3, 10])
def test_concat_loader_with_ddp(
replace_sampler_ddp: bool, is_min_size_mode: bool, num_devices: Union[int, str]
use_distributed_sampler: bool, is_min_size_mode: bool, num_devices: Union[int, str]
):
"""Inspired by tests/trainer/test_supporters.py in lightning."""
from nni.nas.oneshot.pytorch.dataloader import ConcatLoader
from nni.nas.oneshot.pytorch._dataloader import ConcatLoader
mode = 'min_size' if is_min_size_mode else 'max_size_cycle'
dim = 3
@ -102,29 +106,29 @@ def test_concat_loader_with_ddp(
'a2': DataLoader(RandomDataset(dim, n2), batch_size=1),
},
'b': DataLoader(RandomDataset(dim, n3), batch_size=1),
}, mode=mode)
expected_length_before_ddp = n3 + (min(n1, n2) if is_min_size_mode else max(n1, n2))
})
print(len(dataloader))
expected_length_before_ddp = n1 + n2 + n3
assert len(dataloader) == expected_length_before_ddp
model = BoringModel()
trainer = Trainer(
strategy='ddp',
accelerator='cpu',
devices=num_devices,
replace_sampler_ddp=replace_sampler_ddp,
use_distributed_sampler=use_distributed_sampler,
)
trainer.strategy.connect(model)
trainer._data_connector.attach_data(
model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
model=model, train_dataloaders=dataloader
)
expected_length_after_ddp = (
math.ceil(n3 / trainer.num_devices) + \
math.ceil((min(n1, n2) if is_min_size_mode else max(n1, n2)) / trainer.num_devices)
if replace_sampler_ddp
math.ceil(n3 / trainer.num_devices) + math.ceil(n1 / trainer.num_devices) + math.ceil(n2 / trainer.num_devices)
if use_distributed_sampler
else expected_length_before_ddp
)
print('Num devices =', trainer.num_devices)
trainer.reset_train_dataloader(model=model)
trainer.state.fn = "fit"
trainer.state.stage = RunningStage.TRAINING
trainer.fit_loop.setup_data()
assert trainer.train_dataloader is not None
assert trainer.train_dataloader.mode == mode
assert trainer.num_training_batches == expected_length_after_ddp

Просмотреть файл

@ -130,7 +130,7 @@ def test_int_proxy():
def test_error_message(caplog):
class Net(nn.Module):
def forward(self, x):
return torch.stft(x, 4)
return torch.stft(x, 4, return_complex=True)
input = ShapeTensor(torch.randn(10, 8), True)
with pytest.raises(RuntimeError, match='Shape inference failed because no shape inference formula'):