NAS oneshot (stage 1) - Expression and profiler utils (#5366)

This commit is contained in:
Yuge Zhang 2023-03-02 11:05:04 +08:00 коммит произвёл GitHub
Родитель 7882c628f1
Коммит e94a81edd0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 667 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,221 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Guide the one-shot strategy to sample architecture within a target latency.
This module converts the profiling results returned by profiler to something
that one-shot strategies can understand. For example, a loss or some penalty to the reward.
This file is experimentally placed in the oneshot package.
It might be moved to a more general place in the future.
"""
from __future__ import annotations
import logging
from typing_extensions import Literal
import numpy as np
import torch
from torch import nn
from nni.mutable import Sample
from nni.nas.profiler import Profiler, ExpressionProfiler
from .supermodule._expression_utils import expression_expectation
_logger = logging.getLogger(__name__)
class ProfilerFilter:
"""Filter the sample based on the result of the profiler.
Subclass should implement the ``filter`` method that returns true or false
to indicate whether the sample is valid.
Directly call the instance of this class will call the ``filter`` method.
"""
def __init__(self, profiler: Profiler):
self.profiler = profiler
def filter(self, sample: Sample) -> bool:
raise NotImplementedError()
def __call__(self, sample: Sample) -> bool:
return self.filter(sample)
class RangeProfilerFilter(ProfilerFilter):
"""Give up the sample if the result of the profiler is out of range.
``min`` and ``max`` can't be both None.
Parameters
----------
profiler
The profiler which is used to profile the sample.
min
The lower bound of the profiler result. None means no minimum.
max
The upper bound of the profiler result. None means no maximum.
"""
def __init__(self, profiler: Profiler, min: float | None = None, max: float | None = None):
super().__init__(profiler)
self.min_value = min
self.max_value = max
if self.min_value is None and self.max_value is None:
raise ValueError('min and max can\'t be both None')
def filter(self, sample: Sample) -> None:
value = self.profiler.profile(sample)
if self.min_value is not None and value < self.min_value:
_logger.debug('Profiler returns %f (smaller than %f) for sample: %s', value, self.min_value, sample)
return False
if self.max_value is not None and value > self.max_value:
_logger.debug('Profiler returns %f (larger than %f) for sample: %s', value, self.max_value, sample)
return False
return True
class ProfilerPenalty(nn.Module):
r"""
Give the loss a penalty with the result on the profiler.
Latency losses in `TuNAS <https://arxiv.org/pdf/2008.06120.pdf>`__ and `ProxylessNAS <https://arxiv.org/pdf/1812.00332.pdf>`__
are its special cases.
The computation formula is divided into two steps,
where we first compute a ``normalized_penalty``, whose zero point is when the penalty meets the baseline,
and then we aggregate it with the original loss.
.. math::
\begin{aligned}
\text{normalized_penalty} ={} & \text{nonlinear}(\frac{\text{penalty}}{\text{baseline}} - 1) \\
\text{loss} ={} & \text{aggregate}(\text{original_loss}, \text{normalized_penalty})
\end{aligned}
where ``penalty`` here is the result returned by the profiler.
For example, when ``nonlinear`` is ``positive`` and ``aggregate`` is ``add``, the computation formula is:
.. math::
\text{loss} = \text{original_loss} + \text{scale} * (max(\frac{\text{penalty}}{\text{baseline}}, 1) - 1, 0)
Parameters
----------
profiler
The profiler which is used to profile the sample.
scale
The scale of the penalty.
baseline
The baseline of the penalty.
nonlinear
The nonlinear function to apply to :math:`\frac{\text{penalty}}{\text{baseline}}`.
The result is called ``normalized_penalty``.
If ``linear``, then keep the original value.
If ``positive``, then apply the function :math:`max(0, \cdot)`.
If ``negative``, then apply the function :math:`min(0, \cdot)`.
If ``absolute``, then apply the function :math:`abs(\cdot)`.
aggregate
The aggregate function to merge the original loss with the penalty.
If ``add``, then the final loss is :math:`\text{original_loss} + \text{scale} * \text{normalized_penalty}`.
If ``mul``, then the final loss is :math:`\text{original_loss} * (1 + \text{normalized_penalty})^{\text{scale}}`.
"""
def __init__(self,
profiler: Profiler,
baseline: float,
scale: float = 1.,
*,
nonlinear: Literal['linear', 'positive', 'negative', 'absolute'] = 'linear',
aggregate: Literal['add', 'mul'] = 'add'):
super().__init__()
self.profiler = profiler
self.scale = scale
self.baseline = baseline
self.nonlinear = nonlinear
self.aggregate = aggregate
def forward(self, loss: torch.Tensor, sample: Sample) -> tuple[torch.Tensor, dict]:
profiler_result = self.profile(sample)
normalized_penalty = self.nonlinear_fn(profiler_result / self.baseline - 1)
loss_new = self.aggregate_fn(loss, normalized_penalty)
details = {
'loss_original': loss,
'penalty': profiler_result,
'normalized_penalty': normalized_penalty,
'loss_final': loss_new,
}
return loss_new, details
def profile(self, sample: Sample) -> float:
"""Subclass overrides this to profile the sample."""
raise NotImplementedError()
def aggregate_fn(self, loss: torch.Tensor, normalized_penalty: float) -> torch.Tensor:
if self.aggregate == 'add':
return loss + self.scale * normalized_penalty
if self.aggregate == 'mul':
return loss * _pow(normalized_penalty + 1, self.scale)
raise ValueError(f'Invalid aggregate: {self.aggregate}')
def nonlinear_fn(self, normalized_penalty: float) -> float:
if self.nonlinear == 'linear':
return normalized_penalty
if self.nonlinear == 'positive':
return _relu(normalized_penalty)
if self.nonlinear == 'negative':
return -_relu(-normalized_penalty)
if self.nonlinear == 'absolute':
return _abs(normalized_penalty)
raise ValueError(f'Invalid nonlinear: {self.nonlinear}')
class ExpectationProfilerPenalty(ProfilerPenalty):
def profile(self, sample: Sample) -> float:
"""Profile based on a distribution of samples.
Each value in the sample must be a dict representation a categorical distribution.
"""
if not isinstance(self.profiler, ExpressionProfiler):
raise TypeError('DifferentiableProfilerPenalty only supports ExpressionProfiler.')
for key, value in sample.items():
if not isinstance(value, dict):
raise TypeError('Each value must be a dict representation a categorical distribution, '
f'but found {type(value)} for key {key}: {value}')
return expression_expectation(self.profiler.expression, sample)
class SampleProfilerPenalty(ProfilerPenalty):
def profile(self, sample: Sample) -> float:
"""Profile based on a single sample."""
return self.profiler.profile(sample)
# Operators that work for both simple numbers and tensors
def _pow(x: float, y: float) -> float:
if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor):
return torch.pow(x, y)
else:
return np.power(x, y)
def _abs(x: float) -> float:
if isinstance(x, torch.Tensor):
return torch.abs(x)
else:
return np.abs(x)
def _relu(x: float) -> float:
if isinstance(x, torch.Tensor):
return nn.functional.relu(x)
else:
return np.maximum(x, 0)

Просмотреть файл

@ -0,0 +1,222 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms."""
from __future__ import annotations
import itertools
import operator
from typing import Any, TypeVar, List, cast, Mapping, Sequence, Optional, Iterable
import numpy as np
import torch
from nni.mutable import MutableExpression, Categorical
Choice = Any
T = TypeVar('T')
__all__ = [
'expression_expectation',
'traverse_all_options',
'weighted_sum',
'evaluate_constant',
]
def expression_expectation(mutable_expr: MutableExpression[T] | Any, weights: dict[str, list[float]]) -> float:
"""Compute the expectation of a value choice.
Parameters
----------
mutable_expr
The value choice to compute expectation.
weights
The weights of each leaf node.
Returns
-------
float
The expectation.
"""
if not isinstance(mutable_expr, MutableExpression):
return mutable_expr
# Optimization: E(a + b) = E(a) + E(b)
if hasattr(mutable_expr, 'function') and mutable_expr.function == operator.add:
return sum(expression_expectation(child, weights) for child in mutable_expr.arguments)
# E(a - b) = E(a) - E(b)
if hasattr(mutable_expr, 'function') and mutable_expr.function == operator.sub:
return expression_expectation(mutable_expr.arguments[0], weights) - expression_expectation(mutable_expr.arguments[1], weights)
all_options = traverse_all_options(mutable_expr, weights) # [(option, weight), ...]
options, weights = zip(*all_options) # ([option, ...], [weight, ...])
return weighted_sum(options, weights)
def traverse_all_options(
mutable_expr: MutableExpression[T],
weights: dict[str, dict[float]] | dict[str, list[float]] | dict[str, np.ndarray] | dict[str, torch.Tensor] | None = None
) -> list[tuple[T, float]] | list[T]:
"""Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
NOTE: This function is very similar to ``MutableExpression.grid``,
but it supports specifying weights for each leaf node.
Parameters
----------
mutable_expr
The value choice to traverse.
weights
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function.
Returns
-------
Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options.
"""
# Validation
simplified = mutable_expr.simplify()
for label, param in simplified.items():
if not isinstance(param, Categorical):
raise TypeError(f'{param!r} is not a categorical distribution')
if weights is not None:
if label not in weights:
raise KeyError(f'{mutable_expr} depends on a weight with key {label}, but not found in {weights}')
if len(param) != len(weights[label]):
raise KeyError(f'Expect weights with {label} to be of length {len(param)}, but {len(weights[label])} found')
# result is a dict from a option to its weight
result: dict[T, float] = {}
sample = {}
for sample_res in mutable_expr.grid(memo=sample):
probability = 1.
if weights is not None:
for label, chosen in sample.items():
if isinstance(weights[label], dict):
# weights[label] is a dict. Choices are used as keys.
probability = probability * weights[label][chosen]
else:
# weights[label] is a list. We need to find the index of currently chosen value.
chosen_idx = cast(Categorical, simplified[label]).values.index(chosen)
if chosen_idx == -1:
raise RuntimeError(f'{chosen} is not a valid value for {label}: {simplified[label]!r}')
probability = probability * weights[label][chosen_idx]
if sample_res in result:
result[sample_res] = result[sample_res] + cast(float, probability)
else:
result[sample_res] = cast(float, probability)
if weights is None:
return sorted(result.keys()) # type: ignore
else:
return sorted(result.items()) # type: ignore
def evaluate_constant(expr: Any) -> Any:
"""Evaluate a value choice expression to a constant. Raise ValueError if it's not a constant."""
all_options = traverse_all_options(expr)
if len(all_options) > 1:
raise ValueError(f'{expr} is not evaluated to a constant. All possible values are: {all_options}')
res = all_options[0]
return res
def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence[Optional[float]], None)) -> T:
"""Return a weighted sum of items.
Items can be list of tensors, numpy arrays, or nested lists / dicts.
If ``weights`` is None, this is simply an unweighted sum.
"""
if weights is None:
weights = [None] * len(items)
assert len(items) == len(weights) > 0
elem = items[0]
unsupported_msg = 'Unsupported element type in weighted sum: {}. Value is: {}'
if isinstance(elem, str):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise TypeError(unsupported_msg.format(type(elem), elem))
try:
if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)):
if weights[0] is None:
res = elem
else:
res = elem * weights[0]
for it, weight in zip(items[1:], weights[1:]):
if type(it) != type(elem):
raise TypeError(f'Expect type {type(elem)} but found {type(it)}. Can not be summed')
if weight is None:
res = res + it # type: ignore
else:
res = res + it * weight # type: ignore
return cast(T, res)
if isinstance(elem, Mapping):
for item in items:
if not isinstance(item, Mapping):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if set(item) != set(elem):
raise KeyError(f'Expect keys {list(elem)} but found {list(item)}')
return cast(T, {
key: weighted_sum(cast(List[dict], [cast(Mapping, d)[key] for d in items]), weights) for key in elem
})
if isinstance(elem, Sequence):
for item in items:
if not isinstance(item, Sequence):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if len(item) != len(elem):
raise ValueError(f'Expect length {len(item)} but found {len(elem)}')
transposed = cast(Iterable[list], zip(*items)) # type: ignore
return cast(T, [weighted_sum(column, weights) for column in transposed])
except (TypeError, ValueError, RuntimeError, KeyError):
raise ValueError(
'Error when summing items. Value format / shape does not match. See full traceback for details.' +
''.join([
f'\n {idx}: {_summarize_elem_format(it)}' for idx, it in enumerate(items)
])
)
# Dealing with all unexpected types.
raise TypeError(unsupported_msg)
def _summarize_elem_format(elem: Any) -> Any:
# Get a summary of one elem
# Helps generate human-readable error messages
class _repr_object:
# empty object is only repr
def __init__(self, representation):
self.representation = representation
def __repr__(self):
return self.representation
if isinstance(elem, torch.Tensor):
return _repr_object('torch.Tensor(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, np.ndarray):
return _repr_object('np.array(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, Mapping):
return {key: _summarize_elem_format(value) for key, value in elem.items()}
if isinstance(elem, Sequence):
return [_summarize_elem_format(value) for value in elem]
# fallback to original, for cases like float, int, ...
return elem

Просмотреть файл

@ -0,0 +1,94 @@
import pytest
import torch
from torch.nn import Conv2d
from nni.nas.nn import ModelSpace, LayerChoice
from nni.nas.oneshot.pytorch.profiler import RangeProfilerFilter, ExpectationProfilerPenalty, SampleProfilerPenalty
from nni.nas.profiler.pytorch.flops import FlopsProfiler
class Net(ModelSpace):
def __init__(self):
super().__init__()
self.conv = LayerChoice([
Conv2d(1, 10, 3, padding=1, bias=False),
Conv2d(1, 10, 5, padding=2, bias=False),
Conv2d(1, 10, 7, padding=3, bias=False)
], label='conv')
def forward(self, x):
return self.conv(x)
@pytest.fixture
def profiler():
net = Net()
profiler = FlopsProfiler(net, torch.randn(1, 1, 10, 10))
# (9k, 25k, 49k)
return profiler
def test_range_filter(profiler):
filter = RangeProfilerFilter(profiler, min=10000)
assert filter({'conv': 1})
assert filter({'conv': 2})
assert not filter({'conv': 0})
filter = RangeProfilerFilter(profiler, max=30000)
assert filter({'conv': 0})
assert filter({'conv': 1})
assert not filter({'conv': 2})
filter = RangeProfilerFilter(profiler, min=10000, max=30000)
assert not filter({'conv': 0})
assert filter({'conv': 1})
assert not filter({'conv': 2})
with pytest.raises(ValueError, match='both None'):
RangeProfilerFilter(profiler, min=None, max=None)
def test_expectation_penalty(profiler):
penalty = ExpectationProfilerPenalty(profiler, 20000, 2.)
loss, details = penalty(42., {'conv': {0: 0.2, 1: 0.3, 2: 0.5}})
assert details['loss_original'] == 42.
assert details['penalty'] == 33800
assert details['normalized_penalty'] == 0.69
assert loss == 42. + 0.69 * 2.
prob = torch.tensor([0.2, 0.3, 0.5])
loss, details = penalty(torch.tensor(42.), {'conv': {i: prob[i] for i in range(3)}})
assert abs(loss.item() - (42. + 0.69 * 2.)) < 1e-4
penalty = ExpectationProfilerPenalty(profiler, 40000, -2., nonlinear='negative')
loss, details = penalty(42., {'conv': {0: 0.2, 1: 0.3, 2: 0.5}})
assert details['normalized_penalty'] == 33800 / 40000 - 1
assert loss == 42. + 0.155 * 2.
loss, details = penalty(42., {'conv': {0: 0., 1: 0., 2: 1.}})
assert details['normalized_penalty'] == 0
assert loss == 42.
penalty = ExpectationProfilerPenalty(profiler, 40000, 2., nonlinear='absolute', aggregate='mul')
loss, details = penalty(42., {'conv': {0: 0.2, 1: 0.3, 2: 0.5}})
assert details['normalized_penalty'] == abs(33800 / 40000 - 1)
assert loss == 42. * (1 + 0.155) ** 2
penalty = ExpectationProfilerPenalty(profiler, 30000, 2., nonlinear='positive', aggregate='mul')
loss, details = penalty(42., {'conv': {0: 0.2, 1: 0.3, 2: 0.5}})
assert details['normalized_penalty'] == 33800 / 30000 - 1
assert loss == 42. * (33800 / 30000) ** 2
loss, details = penalty(42., {'conv': {0: 1., 1: 0., 2: 0.}})
assert details['normalized_penalty'] == 0
assert loss == 42.
def test_sample_penalty(profiler):
penalty = SampleProfilerPenalty(profiler, 20000, 2.)
loss, details = penalty(42., {'conv': 1})
assert details['loss_original'] == 42.
assert details['penalty'] == 25000
assert details['normalized_penalty'] == 0.25
assert loss == 42. + 0.25 * 2.

Просмотреть файл

@ -0,0 +1,130 @@
import math
from typing import Union
import pytest
import torch
import pytorch_lightning
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
pytestmark = pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs')
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('train_loss', loss)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('test_loss', loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def test_concat_loader():
from nni.nas.oneshot.pytorch.dataloader import ConcatLoader
loaders = {
'a': DataLoader(range(10), batch_size=4),
'b': DataLoader(range(20), batch_size=5),
}
dataloader = ConcatLoader(loaders)
assert len(dataloader) == 7
for i, (data, label) in enumerate(dataloader):
if i < 3:
assert len(data) <= 4
assert label == 'a'
else:
assert len(data) <= 5
assert label == 'b'
def test_concat_loader_nested():
from nni.nas.oneshot.pytorch.dataloader import ConcatLoader
loaders = {
'a': [DataLoader(range(10), batch_size=4), DataLoader(range(20), batch_size=6)],
'b': DataLoader(range(20), batch_size=5),
}
dataloader = ConcatLoader(loaders)
assert len(dataloader) == 7
for i, (data, label) in enumerate(dataloader):
if i < 3:
assert isinstance(data, list) and len(data) == 2
assert label == 'a'
else:
assert label == 'b'
@pytest.mark.parametrize('replace_sampler_ddp', [False, True])
@pytest.mark.parametrize('is_min_size_mode', [True])
@pytest.mark.parametrize('num_devices', ['auto', 1, 3, 10])
def test_concat_loader_with_ddp(
replace_sampler_ddp: bool, is_min_size_mode: bool, num_devices: Union[int, str]
):
"""Inspired by tests/trainer/test_supporters.py in lightning."""
from nni.nas.oneshot.pytorch.dataloader import ConcatLoader
mode = 'min_size' if is_min_size_mode else 'max_size_cycle'
dim = 3
n1 = 8
n2 = 6
n3 = 9
dataloader = ConcatLoader({
'a': {
'a1': DataLoader(RandomDataset(dim, n1), batch_size=1),
'a2': DataLoader(RandomDataset(dim, n2), batch_size=1),
},
'b': DataLoader(RandomDataset(dim, n3), batch_size=1),
}, mode=mode)
expected_length_before_ddp = n3 + (min(n1, n2) if is_min_size_mode else max(n1, n2))
print(len(dataloader))
assert len(dataloader) == expected_length_before_ddp
model = BoringModel()
trainer = Trainer(
strategy='ddp',
accelerator='cpu',
devices=num_devices,
replace_sampler_ddp=replace_sampler_ddp,
)
trainer._data_connector.attach_data(
model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
)
expected_length_after_ddp = (
math.ceil(n3 / trainer.num_devices) + \
math.ceil((min(n1, n2) if is_min_size_mode else max(n1, n2)) / trainer.num_devices)
if replace_sampler_ddp
else expected_length_before_ddp
)
print('Num devices =', trainer.num_devices)
trainer.reset_train_dataloader(model=model)
assert trainer.train_dataloader is not None
assert trainer.train_dataloader.mode == mode
assert trainer.num_training_batches == expected_length_after_ddp