зеркало из https://github.com/microsoft/nni.git
NAS oneshot (stage 1) - Expression and profiler utils (#5366)
This commit is contained in:
Родитель
7882c628f1
Коммит
e94a81edd0
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Guide the one-shot strategy to sample architecture within a target latency.
|
||||
|
||||
This module converts the profiling results returned by profiler to something
|
||||
that one-shot strategies can understand. For example, a loss or some penalty to the reward.
|
||||
|
||||
This file is experimentally placed in the oneshot package.
|
||||
It might be moved to a more general place in the future.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing_extensions import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from nni.mutable import Sample
|
||||
from nni.nas.profiler import Profiler, ExpressionProfiler
|
||||
|
||||
from .supermodule._expression_utils import expression_expectation
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProfilerFilter:
|
||||
"""Filter the sample based on the result of the profiler.
|
||||
|
||||
Subclass should implement the ``filter`` method that returns true or false
|
||||
to indicate whether the sample is valid.
|
||||
|
||||
Directly call the instance of this class will call the ``filter`` method.
|
||||
"""
|
||||
|
||||
def __init__(self, profiler: Profiler):
|
||||
self.profiler = profiler
|
||||
|
||||
def filter(self, sample: Sample) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(self, sample: Sample) -> bool:
|
||||
return self.filter(sample)
|
||||
|
||||
|
||||
class RangeProfilerFilter(ProfilerFilter):
|
||||
"""Give up the sample if the result of the profiler is out of range.
|
||||
|
||||
``min`` and ``max`` can't be both None.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
profiler
|
||||
The profiler which is used to profile the sample.
|
||||
min
|
||||
The lower bound of the profiler result. None means no minimum.
|
||||
max
|
||||
The upper bound of the profiler result. None means no maximum.
|
||||
"""
|
||||
|
||||
def __init__(self, profiler: Profiler, min: float | None = None, max: float | None = None):
|
||||
super().__init__(profiler)
|
||||
self.min_value = min
|
||||
self.max_value = max
|
||||
if self.min_value is None and self.max_value is None:
|
||||
raise ValueError('min and max can\'t be both None')
|
||||
|
||||
def filter(self, sample: Sample) -> None:
|
||||
value = self.profiler.profile(sample)
|
||||
if self.min_value is not None and value < self.min_value:
|
||||
_logger.debug('Profiler returns %f (smaller than %f) for sample: %s', value, self.min_value, sample)
|
||||
return False
|
||||
if self.max_value is not None and value > self.max_value:
|
||||
_logger.debug('Profiler returns %f (larger than %f) for sample: %s', value, self.max_value, sample)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ProfilerPenalty(nn.Module):
|
||||
r"""
|
||||
Give the loss a penalty with the result on the profiler.
|
||||
|
||||
Latency losses in `TuNAS <https://arxiv.org/pdf/2008.06120.pdf>`__ and `ProxylessNAS <https://arxiv.org/pdf/1812.00332.pdf>`__
|
||||
are its special cases.
|
||||
|
||||
The computation formula is divided into two steps,
|
||||
where we first compute a ``normalized_penalty``, whose zero point is when the penalty meets the baseline,
|
||||
and then we aggregate it with the original loss.
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
\text{normalized_penalty} ={} & \text{nonlinear}(\frac{\text{penalty}}{\text{baseline}} - 1) \\
|
||||
\text{loss} ={} & \text{aggregate}(\text{original_loss}, \text{normalized_penalty})
|
||||
\end{aligned}
|
||||
|
||||
where ``penalty`` here is the result returned by the profiler.
|
||||
|
||||
For example, when ``nonlinear`` is ``positive`` and ``aggregate`` is ``add``, the computation formula is:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{loss} = \text{original_loss} + \text{scale} * (max(\frac{\text{penalty}}{\text{baseline}}, 1) - 1, 0)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
profiler
|
||||
The profiler which is used to profile the sample.
|
||||
scale
|
||||
The scale of the penalty.
|
||||
baseline
|
||||
The baseline of the penalty.
|
||||
nonlinear
|
||||
The nonlinear function to apply to :math:`\frac{\text{penalty}}{\text{baseline}}`.
|
||||
The result is called ``normalized_penalty``.
|
||||
If ``linear``, then keep the original value.
|
||||
If ``positive``, then apply the function :math:`max(0, \cdot)`.
|
||||
If ``negative``, then apply the function :math:`min(0, \cdot)`.
|
||||
If ``absolute``, then apply the function :math:`abs(\cdot)`.
|
||||
aggregate
|
||||
The aggregate function to merge the original loss with the penalty.
|
||||
If ``add``, then the final loss is :math:`\text{original_loss} + \text{scale} * \text{normalized_penalty}`.
|
||||
If ``mul``, then the final loss is :math:`\text{original_loss} * (1 + \text{normalized_penalty})^{\text{scale}}`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
profiler: Profiler,
|
||||
baseline: float,
|
||||
scale: float = 1.,
|
||||
*,
|
||||
nonlinear: Literal['linear', 'positive', 'negative', 'absolute'] = 'linear',
|
||||
aggregate: Literal['add', 'mul'] = 'add'):
|
||||
super().__init__()
|
||||
self.profiler = profiler
|
||||
self.scale = scale
|
||||
self.baseline = baseline
|
||||
self.nonlinear = nonlinear
|
||||
self.aggregate = aggregate
|
||||
|
||||
def forward(self, loss: torch.Tensor, sample: Sample) -> tuple[torch.Tensor, dict]:
|
||||
profiler_result = self.profile(sample)
|
||||
normalized_penalty = self.nonlinear_fn(profiler_result / self.baseline - 1)
|
||||
loss_new = self.aggregate_fn(loss, normalized_penalty)
|
||||
|
||||
details = {
|
||||
'loss_original': loss,
|
||||
'penalty': profiler_result,
|
||||
'normalized_penalty': normalized_penalty,
|
||||
'loss_final': loss_new,
|
||||
}
|
||||
|
||||
return loss_new, details
|
||||
|
||||
def profile(self, sample: Sample) -> float:
|
||||
"""Subclass overrides this to profile the sample."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def aggregate_fn(self, loss: torch.Tensor, normalized_penalty: float) -> torch.Tensor:
|
||||
if self.aggregate == 'add':
|
||||
return loss + self.scale * normalized_penalty
|
||||
if self.aggregate == 'mul':
|
||||
return loss * _pow(normalized_penalty + 1, self.scale)
|
||||
raise ValueError(f'Invalid aggregate: {self.aggregate}')
|
||||
|
||||
def nonlinear_fn(self, normalized_penalty: float) -> float:
|
||||
if self.nonlinear == 'linear':
|
||||
return normalized_penalty
|
||||
if self.nonlinear == 'positive':
|
||||
return _relu(normalized_penalty)
|
||||
if self.nonlinear == 'negative':
|
||||
return -_relu(-normalized_penalty)
|
||||
if self.nonlinear == 'absolute':
|
||||
return _abs(normalized_penalty)
|
||||
raise ValueError(f'Invalid nonlinear: {self.nonlinear}')
|
||||
|
||||
|
||||
class ExpectationProfilerPenalty(ProfilerPenalty):
|
||||
|
||||
def profile(self, sample: Sample) -> float:
|
||||
"""Profile based on a distribution of samples.
|
||||
|
||||
Each value in the sample must be a dict representation a categorical distribution.
|
||||
"""
|
||||
if not isinstance(self.profiler, ExpressionProfiler):
|
||||
raise TypeError('DifferentiableProfilerPenalty only supports ExpressionProfiler.')
|
||||
for key, value in sample.items():
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError('Each value must be a dict representation a categorical distribution, '
|
||||
f'but found {type(value)} for key {key}: {value}')
|
||||
return expression_expectation(self.profiler.expression, sample)
|
||||
|
||||
|
||||
class SampleProfilerPenalty(ProfilerPenalty):
|
||||
|
||||
def profile(self, sample: Sample) -> float:
|
||||
"""Profile based on a single sample."""
|
||||
return self.profiler.profile(sample)
|
||||
|
||||
|
||||
# Operators that work for both simple numbers and tensors
|
||||
|
||||
def _pow(x: float, y: float) -> float:
|
||||
if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor):
|
||||
return torch.pow(x, y)
|
||||
else:
|
||||
return np.power(x, y)
|
||||
|
||||
def _abs(x: float) -> float:
|
||||
if isinstance(x, torch.Tensor):
|
||||
return torch.abs(x)
|
||||
else:
|
||||
return np.abs(x)
|
||||
|
||||
def _relu(x: float) -> float:
|
||||
if isinstance(x, torch.Tensor):
|
||||
return nn.functional.relu(x)
|
||||
else:
|
||||
return np.maximum(x, 0)
|
|
@ -0,0 +1,222 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Utilities to process the value choice compositions,
|
||||
in the way that is most convenient to one-shot algorithms."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, TypeVar, List, cast, Mapping, Sequence, Optional, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from nni.mutable import MutableExpression, Categorical
|
||||
|
||||
|
||||
Choice = Any
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
__all__ = [
|
||||
'expression_expectation',
|
||||
'traverse_all_options',
|
||||
'weighted_sum',
|
||||
'evaluate_constant',
|
||||
]
|
||||
|
||||
|
||||
def expression_expectation(mutable_expr: MutableExpression[T] | Any, weights: dict[str, list[float]]) -> float:
|
||||
"""Compute the expectation of a value choice.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable_expr
|
||||
The value choice to compute expectation.
|
||||
weights
|
||||
The weights of each leaf node.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The expectation.
|
||||
"""
|
||||
if not isinstance(mutable_expr, MutableExpression):
|
||||
return mutable_expr
|
||||
|
||||
# Optimization: E(a + b) = E(a) + E(b)
|
||||
if hasattr(mutable_expr, 'function') and mutable_expr.function == operator.add:
|
||||
return sum(expression_expectation(child, weights) for child in mutable_expr.arguments)
|
||||
# E(a - b) = E(a) - E(b)
|
||||
if hasattr(mutable_expr, 'function') and mutable_expr.function == operator.sub:
|
||||
return expression_expectation(mutable_expr.arguments[0], weights) - expression_expectation(mutable_expr.arguments[1], weights)
|
||||
|
||||
all_options = traverse_all_options(mutable_expr, weights) # [(option, weight), ...]
|
||||
options, weights = zip(*all_options) # ([option, ...], [weight, ...])
|
||||
return weighted_sum(options, weights)
|
||||
|
||||
|
||||
def traverse_all_options(
|
||||
mutable_expr: MutableExpression[T],
|
||||
weights: dict[str, dict[float]] | dict[str, list[float]] | dict[str, np.ndarray] | dict[str, torch.Tensor] | None = None
|
||||
) -> list[tuple[T, float]] | list[T]:
|
||||
"""Traverse all possible computation outcome of a value choice.
|
||||
If ``weights`` is not None, it will also compute the probability of each possible outcome.
|
||||
|
||||
NOTE: This function is very similar to ``MutableExpression.grid``,
|
||||
but it supports specifying weights for each leaf node.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable_expr
|
||||
The value choice to traverse.
|
||||
weights
|
||||
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
|
||||
weights can be provided. The key is label, value are list of float indicating probability.
|
||||
Normally, they should sum up to 1, but we will not check them in this function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Results will be sorted and duplicates will be eliminated.
|
||||
If weights is provided, the return value will be a list of tuple, with option and its weight.
|
||||
Otherwise, it will be a list of options.
|
||||
"""
|
||||
|
||||
# Validation
|
||||
simplified = mutable_expr.simplify()
|
||||
for label, param in simplified.items():
|
||||
if not isinstance(param, Categorical):
|
||||
raise TypeError(f'{param!r} is not a categorical distribution')
|
||||
if weights is not None:
|
||||
if label not in weights:
|
||||
raise KeyError(f'{mutable_expr} depends on a weight with key {label}, but not found in {weights}')
|
||||
if len(param) != len(weights[label]):
|
||||
raise KeyError(f'Expect weights with {label} to be of length {len(param)}, but {len(weights[label])} found')
|
||||
|
||||
# result is a dict from a option to its weight
|
||||
result: dict[T, float] = {}
|
||||
|
||||
sample = {}
|
||||
for sample_res in mutable_expr.grid(memo=sample):
|
||||
probability = 1.
|
||||
if weights is not None:
|
||||
for label, chosen in sample.items():
|
||||
if isinstance(weights[label], dict):
|
||||
# weights[label] is a dict. Choices are used as keys.
|
||||
probability = probability * weights[label][chosen]
|
||||
else:
|
||||
# weights[label] is a list. We need to find the index of currently chosen value.
|
||||
chosen_idx = cast(Categorical, simplified[label]).values.index(chosen)
|
||||
if chosen_idx == -1:
|
||||
raise RuntimeError(f'{chosen} is not a valid value for {label}: {simplified[label]!r}')
|
||||
probability = probability * weights[label][chosen_idx]
|
||||
|
||||
if sample_res in result:
|
||||
result[sample_res] = result[sample_res] + cast(float, probability)
|
||||
else:
|
||||
result[sample_res] = cast(float, probability)
|
||||
|
||||
if weights is None:
|
||||
return sorted(result.keys()) # type: ignore
|
||||
else:
|
||||
return sorted(result.items()) # type: ignore
|
||||
|
||||
|
||||
def evaluate_constant(expr: Any) -> Any:
|
||||
"""Evaluate a value choice expression to a constant. Raise ValueError if it's not a constant."""
|
||||
all_options = traverse_all_options(expr)
|
||||
if len(all_options) > 1:
|
||||
raise ValueError(f'{expr} is not evaluated to a constant. All possible values are: {all_options}')
|
||||
res = all_options[0]
|
||||
return res
|
||||
|
||||
|
||||
def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence[Optional[float]], None)) -> T:
|
||||
"""Return a weighted sum of items.
|
||||
|
||||
Items can be list of tensors, numpy arrays, or nested lists / dicts.
|
||||
|
||||
If ``weights`` is None, this is simply an unweighted sum.
|
||||
"""
|
||||
|
||||
if weights is None:
|
||||
weights = [None] * len(items)
|
||||
|
||||
assert len(items) == len(weights) > 0
|
||||
elem = items[0]
|
||||
unsupported_msg = 'Unsupported element type in weighted sum: {}. Value is: {}'
|
||||
|
||||
if isinstance(elem, str):
|
||||
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
|
||||
raise TypeError(unsupported_msg.format(type(elem), elem))
|
||||
|
||||
try:
|
||||
if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)):
|
||||
if weights[0] is None:
|
||||
res = elem
|
||||
else:
|
||||
res = elem * weights[0]
|
||||
for it, weight in zip(items[1:], weights[1:]):
|
||||
if type(it) != type(elem):
|
||||
raise TypeError(f'Expect type {type(elem)} but found {type(it)}. Can not be summed')
|
||||
|
||||
if weight is None:
|
||||
res = res + it # type: ignore
|
||||
else:
|
||||
res = res + it * weight # type: ignore
|
||||
return cast(T, res)
|
||||
|
||||
if isinstance(elem, Mapping):
|
||||
for item in items:
|
||||
if not isinstance(item, Mapping):
|
||||
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
|
||||
if set(item) != set(elem):
|
||||
raise KeyError(f'Expect keys {list(elem)} but found {list(item)}')
|
||||
return cast(T, {
|
||||
key: weighted_sum(cast(List[dict], [cast(Mapping, d)[key] for d in items]), weights) for key in elem
|
||||
})
|
||||
if isinstance(elem, Sequence):
|
||||
for item in items:
|
||||
if not isinstance(item, Sequence):
|
||||
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
|
||||
if len(item) != len(elem):
|
||||
raise ValueError(f'Expect length {len(item)} but found {len(elem)}')
|
||||
transposed = cast(Iterable[list], zip(*items)) # type: ignore
|
||||
return cast(T, [weighted_sum(column, weights) for column in transposed])
|
||||
except (TypeError, ValueError, RuntimeError, KeyError):
|
||||
raise ValueError(
|
||||
'Error when summing items. Value format / shape does not match. See full traceback for details.' +
|
||||
''.join([
|
||||
f'\n {idx}: {_summarize_elem_format(it)}' for idx, it in enumerate(items)
|
||||
])
|
||||
)
|
||||
|
||||
# Dealing with all unexpected types.
|
||||
raise TypeError(unsupported_msg)
|
||||
|
||||
|
||||
def _summarize_elem_format(elem: Any) -> Any:
|
||||
# Get a summary of one elem
|
||||
# Helps generate human-readable error messages
|
||||
|
||||
class _repr_object:
|
||||
# empty object is only repr
|
||||
def __init__(self, representation):
|
||||
self.representation = representation
|
||||
|
||||
def __repr__(self):
|
||||
return self.representation
|
||||
|
||||
if isinstance(elem, torch.Tensor):
|
||||
return _repr_object('torch.Tensor(' + ', '.join(map(str, elem.shape)) + ')')
|
||||
if isinstance(elem, np.ndarray):
|
||||
return _repr_object('np.array(' + ', '.join(map(str, elem.shape)) + ')')
|
||||
if isinstance(elem, Mapping):
|
||||
return {key: _summarize_elem_format(value) for key, value in elem.items()}
|
||||
if isinstance(elem, Sequence):
|
||||
return [_summarize_elem_format(value) for value in elem]
|
||||
|
||||
# fallback to original, for cases like float, int, ...
|
||||
return elem
|
|
@ -0,0 +1,94 @@
|
|||
import pytest
|
||||
|
||||
import torch
|
||||
from torch.nn import Conv2d
|
||||
|
||||
from nni.nas.nn import ModelSpace, LayerChoice
|
||||
from nni.nas.oneshot.pytorch.profiler import RangeProfilerFilter, ExpectationProfilerPenalty, SampleProfilerPenalty
|
||||
from nni.nas.profiler.pytorch.flops import FlopsProfiler
|
||||
|
||||
|
||||
class Net(ModelSpace):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = LayerChoice([
|
||||
Conv2d(1, 10, 3, padding=1, bias=False),
|
||||
Conv2d(1, 10, 5, padding=2, bias=False),
|
||||
Conv2d(1, 10, 7, padding=3, bias=False)
|
||||
], label='conv')
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def profiler():
|
||||
net = Net()
|
||||
profiler = FlopsProfiler(net, torch.randn(1, 1, 10, 10))
|
||||
# (9k, 25k, 49k)
|
||||
return profiler
|
||||
|
||||
|
||||
def test_range_filter(profiler):
|
||||
filter = RangeProfilerFilter(profiler, min=10000)
|
||||
assert filter({'conv': 1})
|
||||
assert filter({'conv': 2})
|
||||
assert not filter({'conv': 0})
|
||||
|
||||
filter = RangeProfilerFilter(profiler, max=30000)
|
||||
assert filter({'conv': 0})
|
||||
assert filter({'conv': 1})
|
||||
assert not filter({'conv': 2})
|
||||
|
||||
filter = RangeProfilerFilter(profiler, min=10000, max=30000)
|
||||
assert not filter({'conv': 0})
|
||||
assert filter({'conv': 1})
|
||||
assert not filter({'conv': 2})
|
||||
|
||||
with pytest.raises(ValueError, match='both None'):
|
||||
RangeProfilerFilter(profiler, min=None, max=None)
|
||||
|
||||
|
||||
def test_expectation_penalty(profiler):
|
||||
penalty = ExpectationProfilerPenalty(profiler, 20000, 2.)
|
||||
loss, details = penalty(42., {'conv': {0: 0.2, 1: 0.3, 2: 0.5}})
|
||||
assert details['loss_original'] == 42.
|
||||
assert details['penalty'] == 33800
|
||||
assert details['normalized_penalty'] == 0.69
|
||||
assert loss == 42. + 0.69 * 2.
|
||||
|
||||
prob = torch.tensor([0.2, 0.3, 0.5])
|
||||
loss, details = penalty(torch.tensor(42.), {'conv': {i: prob[i] for i in range(3)}})
|
||||
assert abs(loss.item() - (42. + 0.69 * 2.)) < 1e-4
|
||||
|
||||
penalty = ExpectationProfilerPenalty(profiler, 40000, -2., nonlinear='negative')
|
||||
loss, details = penalty(42., {'conv': {0: 0.2, 1: 0.3, 2: 0.5}})
|
||||
assert details['normalized_penalty'] == 33800 / 40000 - 1
|
||||
assert loss == 42. + 0.155 * 2.
|
||||
|
||||
loss, details = penalty(42., {'conv': {0: 0., 1: 0., 2: 1.}})
|
||||
assert details['normalized_penalty'] == 0
|
||||
assert loss == 42.
|
||||
|
||||
penalty = ExpectationProfilerPenalty(profiler, 40000, 2., nonlinear='absolute', aggregate='mul')
|
||||
loss, details = penalty(42., {'conv': {0: 0.2, 1: 0.3, 2: 0.5}})
|
||||
assert details['normalized_penalty'] == abs(33800 / 40000 - 1)
|
||||
assert loss == 42. * (1 + 0.155) ** 2
|
||||
|
||||
penalty = ExpectationProfilerPenalty(profiler, 30000, 2., nonlinear='positive', aggregate='mul')
|
||||
loss, details = penalty(42., {'conv': {0: 0.2, 1: 0.3, 2: 0.5}})
|
||||
assert details['normalized_penalty'] == 33800 / 30000 - 1
|
||||
assert loss == 42. * (33800 / 30000) ** 2
|
||||
|
||||
loss, details = penalty(42., {'conv': {0: 1., 1: 0., 2: 0.}})
|
||||
assert details['normalized_penalty'] == 0
|
||||
assert loss == 42.
|
||||
|
||||
|
||||
def test_sample_penalty(profiler):
|
||||
penalty = SampleProfilerPenalty(profiler, 20000, 2.)
|
||||
loss, details = penalty(42., {'conv': 1})
|
||||
assert details['loss_original'] == 42.
|
||||
assert details['penalty'] == 25000
|
||||
assert details['normalized_penalty'] == 0.25
|
||||
assert loss == 42. + 0.25 * 2.
|
|
@ -0,0 +1,130 @@
|
|||
import math
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import pytorch_lightning
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
pytestmark = pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs')
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, size, length):
|
||||
self.len = length
|
||||
self.data = torch.randn(length, size)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
|
||||
class BoringModel(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.Linear(32, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self(batch).sum()
|
||||
self.log('train_loss', loss)
|
||||
return {'loss': loss}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss = self(batch).sum()
|
||||
self.log('valid_loss', loss)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
loss = self(batch).sum()
|
||||
self.log('test_loss', loss)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
|
||||
|
||||
def test_concat_loader():
|
||||
from nni.nas.oneshot.pytorch.dataloader import ConcatLoader
|
||||
|
||||
loaders = {
|
||||
'a': DataLoader(range(10), batch_size=4),
|
||||
'b': DataLoader(range(20), batch_size=5),
|
||||
}
|
||||
dataloader = ConcatLoader(loaders)
|
||||
assert len(dataloader) == 7
|
||||
for i, (data, label) in enumerate(dataloader):
|
||||
if i < 3:
|
||||
assert len(data) <= 4
|
||||
assert label == 'a'
|
||||
else:
|
||||
assert len(data) <= 5
|
||||
assert label == 'b'
|
||||
|
||||
|
||||
def test_concat_loader_nested():
|
||||
from nni.nas.oneshot.pytorch.dataloader import ConcatLoader
|
||||
|
||||
loaders = {
|
||||
'a': [DataLoader(range(10), batch_size=4), DataLoader(range(20), batch_size=6)],
|
||||
'b': DataLoader(range(20), batch_size=5),
|
||||
}
|
||||
dataloader = ConcatLoader(loaders)
|
||||
assert len(dataloader) == 7
|
||||
for i, (data, label) in enumerate(dataloader):
|
||||
if i < 3:
|
||||
assert isinstance(data, list) and len(data) == 2
|
||||
assert label == 'a'
|
||||
else:
|
||||
assert label == 'b'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('replace_sampler_ddp', [False, True])
|
||||
@pytest.mark.parametrize('is_min_size_mode', [True])
|
||||
@pytest.mark.parametrize('num_devices', ['auto', 1, 3, 10])
|
||||
def test_concat_loader_with_ddp(
|
||||
replace_sampler_ddp: bool, is_min_size_mode: bool, num_devices: Union[int, str]
|
||||
):
|
||||
"""Inspired by tests/trainer/test_supporters.py in lightning."""
|
||||
from nni.nas.oneshot.pytorch.dataloader import ConcatLoader
|
||||
|
||||
mode = 'min_size' if is_min_size_mode else 'max_size_cycle'
|
||||
dim = 3
|
||||
n1 = 8
|
||||
n2 = 6
|
||||
n3 = 9
|
||||
dataloader = ConcatLoader({
|
||||
'a': {
|
||||
'a1': DataLoader(RandomDataset(dim, n1), batch_size=1),
|
||||
'a2': DataLoader(RandomDataset(dim, n2), batch_size=1),
|
||||
},
|
||||
'b': DataLoader(RandomDataset(dim, n3), batch_size=1),
|
||||
}, mode=mode)
|
||||
expected_length_before_ddp = n3 + (min(n1, n2) if is_min_size_mode else max(n1, n2))
|
||||
print(len(dataloader))
|
||||
assert len(dataloader) == expected_length_before_ddp
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
strategy='ddp',
|
||||
accelerator='cpu',
|
||||
devices=num_devices,
|
||||
replace_sampler_ddp=replace_sampler_ddp,
|
||||
)
|
||||
trainer._data_connector.attach_data(
|
||||
model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
|
||||
)
|
||||
expected_length_after_ddp = (
|
||||
math.ceil(n3 / trainer.num_devices) + \
|
||||
math.ceil((min(n1, n2) if is_min_size_mode else max(n1, n2)) / trainer.num_devices)
|
||||
if replace_sampler_ddp
|
||||
else expected_length_before_ddp
|
||||
)
|
||||
print('Num devices =', trainer.num_devices)
|
||||
trainer.reset_train_dataloader(model=model)
|
||||
assert trainer.train_dataloader is not None
|
||||
assert trainer.train_dataloader.mode == mode
|
||||
|
||||
assert trainer.num_training_batches == expected_length_after_ddp
|
Загрузка…
Ссылка в новой задаче