NAS oneshot (stage 2) - Supernet modules (#5372)

This commit is contained in:
Yuge Zhang 2023-03-03 15:33:32 +08:00 коммит произвёл GitHub
Родитель 4bd3f33a3a
Коммит d40f408a0a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 925 добавлений и 511 удалений

Просмотреть файл

@ -9,59 +9,12 @@ from typing import Any, Dict
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.nas.nn.pytorch import MutableModule
__all__ = ['BaseSuperNetModule', 'sub_state_dict']
__all__ = ['BaseSuperNetModule']
def sub_state_dict(module: Any, destination: Any=None, prefix: str='', keep_vars: bool=False) -> Dict[str, Any]:
"""Returns a dictionary containing a whole state of the BaseSuperNetModule.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Parameters
----------
arch : dict[str, Any]
subnet architecture dict.
destination (dict, optional):
If provided, the state of module will be updated into the dict
and the same object is returned. Otherwise, an ``OrderedDict``
will be created and returned. Default: ``None``.
prefix (str, optional):
a prefix added to parameter and buffer names to compose the keys in state_dict.
Default: ``''``.
keep_vars (bool, optional):
by default the :class:`~torch.Tensor` s returned in the state dict are
detached from autograd. If it's set to ``True``, detaching will not be performed.
Default: ``False``.
Returns
-------
dict
Subnet state dictionary.
"""
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=module._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
if isinstance(module, BaseSuperNetModule):
module._save_to_sub_state_dict(destination, prefix, keep_vars)
else:
module._save_to_state_dict(destination, prefix, keep_vars)
for name, m in module._modules.items():
if m is not None:
sub_state_dict(m, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination
class BaseSuperNetModule(nn.Module):
class BaseSuperNetModule(MutableModule):
"""
Mutated module in super-net.
Usually, the feed-forward of the module itself is undefined.
@ -116,17 +69,6 @@ class BaseSuperNetModule(nn.Module):
"""
raise NotImplementedError()
def search_space_spec(self) -> dict[str, ParameterSpec]:
"""
Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
For example: ::
{"layer1": ParameterSpec(values=["conv", "pool"])}
"""
raise NotImplementedError()
@classmethod
def mutate(cls, module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> \
'BaseSuperNetModule' | bool | tuple['BaseSuperNetModule', bool]:
@ -153,22 +95,3 @@ class BaseSuperNetModule(nn.Module):
See :class:`BaseOneShotLightningModule <nni.retiarii.oneshot.pytorch.base_lightning.BaseOneShotLightningModule>` for details.
"""
raise NotImplementedError()
def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars):
"""Save the params and buffers of the current module to state dict."""
for name, value in itertools.chain(self._parameters.items(), self._buffers.items()): # direct children
if value is None or name in self._non_persistent_buffers_set:
# it won't appear in state dict
continue
destination[prefix + name] = value if keep_vars else value.detach()
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
"""Save the sub-module to state dict."""
for name, module in self._modules.items():
if module is not None:
sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
def _save_to_sub_state_dict(self, destination, prefix, keep_vars):
"""Save to state dict."""
self._save_param_buff_to_state_dict(destination, prefix, keep_vars)
self._save_module_to_state_dict(destination, prefix, keep_vars)

Просмотреть файл

@ -13,15 +13,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.common.hpo_utils import ParameterSpec
from nni.nas.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.mutable import MutableExpression, Mutable, Categorical
from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat
from nni.nas.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy
from .sampling import PathSamplingCell
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, weighted_sum
from ._expression_utils import traverse_all_options, weighted_sum
_logger = logging.getLogger(__name__)
@ -46,7 +45,7 @@ class GumbelSoftmax(nn.Softmax):
return F.gumbel_softmax(inputs, tau=self.tau, hard=self.hard, dim=self.dim)
class DifferentiableMixedLayer(BaseSuperNetModule):
class DifferentiableMixedLayer(LayerChoice, BaseSuperNetModule):
"""
Mixed layer, in which fprop is decided by a weighted sum of several layers.
Proposed in `DARTS: Differentiable Architecture Search <https://arxiv.org/abs/1806.09055>`__.
@ -66,31 +65,18 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
op_names : str
Operator names.
label : str
Name of the choice.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self,
paths: list[tuple[str, nn.Module]],
paths: list[nn.Module] | dict[str, nn.Module],
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__()
self.op_names = []
super().__init__(paths, label=label)
if len(alpha) != len(paths):
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({len(paths)}).')
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self.label = label
self._arch_alpha = alpha
self._softmax = softmax
@ -102,62 +88,41 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
"""Choose the operator with the maximum logit."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
return {self.label: self.names[int(torch.argmax(self._arch_alpha).item())]}
def export_probs(self, memo):
if any(k.startswith(self.label + '/') for k in memo):
return {} # nothing new
if self.label in memo:
return {}
weights = self._softmax(self._arch_alpha).cpu().tolist()
ret = {f'{self.label}/{name}': value for name, value in zip(self.op_names, weights)}
return ret
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
return {self.label: dict(zip(self.names, weights))}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, LayerChoice):
size = len(module)
if module.label in memo:
alpha = memo[module.label]
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
memo[module.label] = alpha
if type(module) is LayerChoice: # must be exactly LayerChoice
if module.label not in memo:
raise KeyError(f'LayerChoice {module.label} not found in memo.')
alpha = memo[module.label]
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label)
return cls(module.candidates, alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
def _reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, *args, **kwargs):
"""The forward of mixed layer accepts same arguments as its sub-layer."""
all_op_results = [getattr(self, op)(*args, **kwargs) for op in self.op_names]
return self.reduction(all_op_results, self._softmax(self._arch_alpha))
all_op_results = [self[op](*args, **kwargs) for op in self.names]
return self._reduction(all_op_results, self._softmax(self._arch_alpha))
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
"""Named parameters excluding architecture parameters."""
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
def arch_parameters(self):
"""Iterate over architecture parameters. Not recursive."""
for name, p in self.named_parameters():
if any(name == par_name for par_name in self._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
yield p
class DifferentiableMixedInput(BaseSuperNetModule):
class DifferentiableMixedInput(InputChoice, BaseSuperNetModule):
"""
Mixed input. Forward returns a weighted sum of candidates.
Implementation is very similar to :class:`DifferentiableMixedLayer`.
@ -174,11 +139,6 @@ class DifferentiableMixedInput(BaseSuperNetModule):
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
label : str
Name of the choice.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
@ -189,18 +149,14 @@ class DifferentiableMixedInput(BaseSuperNetModule):
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__()
self.n_candidates = n_candidates
if len(alpha) != n_candidates:
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({n_candidates}).')
if n_chosen is None:
warnings.warn('Differentiable architecture search does not support choosing multiple inputs. Assuming one.',
RuntimeWarning)
self.n_chosen = 1
self.n_chosen = n_chosen
self.label = label
n_chosen = 1
super().__init__(n_candidates, n_chosen=n_chosen, label=label)
if len(alpha) != n_candidates:
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({n_candidates}).')
self._softmax = softmax
self._arch_alpha = alpha
def resample(self, memo):
@ -212,68 +168,44 @@ class DifferentiableMixedInput(BaseSuperNetModule):
if self.label in memo:
return {} # nothing new to export
chosen = sorted(torch.argsort(-self._arch_alpha).cpu().numpy().tolist()[:self.n_chosen])
if len(chosen) == 1:
chosen = chosen[0]
return {self.label: chosen}
def export_probs(self, memo):
if any(k.startswith(self.label + '/') for k in memo):
return {} # nothing new
if self.label in memo:
return {}
weights = self._softmax(self._arch_alpha).cpu().tolist()
ret = {f'{self.label}/{index}': value for index, value in enumerate(weights)}
return ret
def search_space_spec(self):
return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
}
return {self.label: dict(enumerate(weights))}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, InputChoice):
if type(module) == InputChoice: # must be exactly InputChoice
module = cast(InputChoice, module)
if module.reduction not in ['sum', 'mean']:
raise ValueError('Only input choice of sum/mean reduction is supported.')
size = module.n_candidates
if module.label in memo:
alpha = memo[module.label]
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
memo[module.label] = alpha
if module.label not in memo:
raise KeyError(f'InputChoice {module.label} not found in memo.')
alpha = memo[module.label]
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
def _reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, inputs):
"""Forward takes a list of input candidates."""
return self.reduction(inputs, self._softmax(self._arch_alpha))
return self._reduction(inputs, self._softmax(self._arch_alpha))
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
"""Named parameters excluding architecture parameters."""
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
def arch_parameters(self):
"""Iterate over architecture parameters. Not recursive."""
for name, p in self.named_parameters():
if any(name == par_name for par_name in self._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
yield p
class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
"""Implementes the differentiable sampling in mixed operation.
"""Implements the differentiable sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Thus the ``_arch_alpha`` here is a parameter dict, and ``named_parameters``
@ -293,36 +225,21 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation._arch_alpha = nn.ParameterDict()
for name, spec in operation.search_space_spec().items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
memo[name] = alpha
operation._arch_alpha[name] = alpha
for name in operation.simplify():
if name not in memo:
raise KeyError(f'Argument {name} not found in memo.')
operation._arch_alpha[str(name)] = memo[name]
operation.parameters = functools.partial(self.parameters, module=operation) # bind self
operation.named_parameters = functools.partial(self.named_parameters, module=operation)
operation.arch_parameters = functools.partial(self.arch_parameters, module=operation)
operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
@staticmethod
def parameters(module, *args, **kwargs):
for _, p in module.named_parameters(*args, **kwargs):
yield p
@staticmethod
def named_parameters(module, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super(module.__class__, module).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call
def arch_parameters(module):
"""Iterate over architecture parameters. Not recursive."""
for name, p in module.named_parameters():
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
yield p
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Differentiable. Do nothing in resample."""
@ -331,21 +248,21 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is argmax for each leaf value choice."""
result = {}
for name, spec in operation.search_space_spec().items():
for name, spec in operation.simplify().items():
if name in memo:
continue
chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item())
result[name] = spec.values[chosen_index]
result[name] = cast(Categorical, spec).values[chosen_index]
return result
def export_probs(self, operation: MixedOperation, memo: dict[str, Any]):
"""Export the weight for every leaf value choice."""
ret = {}
for name, spec in operation.search_space_spec().items():
if any(k.startswith(name + '/') for k in memo):
for name, spec in operation.simplify().items():
if name in memo:
continue
weights = operation._softmax(operation._arch_alpha[name]).cpu().tolist() # type: ignore
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
ret.update({name: dict(zip(cast(Categorical, spec).values, weights))})
return ret
def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
@ -357,9 +274,9 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
return operation.init_arguments[name]
class DifferentiableMixedRepeat(BaseSuperNetModule):
class DifferentiableMixedRepeat(Repeat, BaseSuperNetModule):
"""
Implementaion of Repeat in a differentiable supernet.
Implementation of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
If the output is not a single tensor, it will be summed at every independant dimension.
@ -368,27 +285,17 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
_arch_parameter_names: list[str] = ['_arch_alpha']
depth_choice: MutableExpression[int]
def __init__(self,
blocks: list[nn.Module],
depth: ChoiceOf[int],
depth: MutableExpression[int],
softmax: nn.Module,
memo: dict[str, Any]):
super().__init__()
self.blocks = blocks
self.depth = depth
alphas: dict[str, Any]):
assert isinstance(depth, Mutable)
super().__init__(blocks, depth)
self._softmax = softmax
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._arch_alpha = nn.ParameterDict()
for name, spec in self._space_spec.items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
memo[name] = alpha
self._arch_alpha[name] = alpha
self._arch_alpha = nn.ParameterDict(alphas)
def resample(self, memo):
"""Do nothing."""
@ -397,48 +304,43 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
def export(self, memo):
"""Choose argmax for each leaf value choice."""
result = {}
for name, spec in self._space_spec.items():
for name, spec in self.depth_choice.simplify().items():
if name in memo:
continue
chosen_index = int(torch.argmax(self._arch_alpha[name]).item())
result[name] = spec.values[chosen_index]
result[name] = cast(Categorical, spec).values[chosen_index]
return result
def export_probs(self, memo):
"""Export the weight for every leaf value choice."""
ret = {}
for name, spec in self.search_space_spec().items():
if any(k.startswith(name + '/') for k in memo):
for name, spec in self.depth_choice.simplify().items():
if name in memo:
continue
weights = self._softmax(self._arch_alpha[name]).cpu().tolist()
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
ret.update({name: dict(zip(cast(Categorical, spec).values, weights))})
return ret
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
if type(module) == Repeat and isinstance(module.depth_choice, Mutable): # Repeat and depth is mutable
# Only interesting when depth is mutable
module = cast(Repeat, module)
alphas = {}
for name in cast(Mutable, module.depth_choice).simplify():
if name not in memo:
raise KeyError(f'Mutable depth "{name}" not found in memo')
alphas[name] = memo[name]
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(cast(List[nn.Module], module.blocks), module.depth_choice, softmax, memo)
return cls(list(module.blocks), cast(MutableExpression[int], module.depth_choice), softmax, alphas)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def arch_parameters(self):
"""Iterate over architecture parameters. Not recursive."""
for name, p in self.named_parameters():
if any(name.startswith(par_name) for par_name in self._arch_parameter_names):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
def reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any:
def _reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
@ -447,7 +349,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
weights: dict[str, torch.Tensor] = {
label: self._softmax(alpha) for label, alpha in self._arch_alpha.items()
}
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights)))
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth_choice, weights=weights)))
res: list[torch.Tensor] = []
weight_list: list[float] = []
@ -459,7 +361,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
res.append(x)
depths.append(i)
return self.reduction(res, weight_list, depths)
return self._reduction(res, weight_list, depths)
class DifferentiableMixedCell(PathSamplingCell):
@ -473,6 +375,8 @@ class DifferentiableMixedCell(PathSamplingCell):
# TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code.
# Possibly need another refactor here.
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(
self, op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor, concat_dim,
@ -486,10 +390,13 @@ class DifferentiableMixedCell(PathSamplingCell):
self._arch_alpha = nn.ParameterDict()
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for j in range(i):
edge_label = f'{label}/{i}_{j}'
edge_label = f'{self.label}/{i}_{j}'
# Some parameters still need to be created here inside.
# We should avoid conflict with "outside parameters".
memo_label = edge_label + '/in_cell'
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
if edge_label in memo:
alpha = memo[edge_label]
if memo_label in memo:
alpha = memo[memo_label]
if len(alpha) != len(op) + 1:
if len(alpha) != len(op):
raise ValueError(
@ -504,7 +411,7 @@ class DifferentiableMixedCell(PathSamplingCell):
else:
# +1 to emulate the input choice.
alpha = nn.Parameter(torch.randn(len(op) + 1) * 1E-3)
memo[edge_label] = alpha
memo[memo_label] = alpha
self._arch_alpha[edge_label] = alpha
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
@ -517,10 +424,10 @@ class DifferentiableMixedCell(PathSamplingCell):
"""When export probability, we follow the structure in arch alpha."""
ret = {}
for name, parameter in self._arch_alpha.items():
if any(k.startswith(name + '/') for k in memo):
if name in memo:
continue
weights = self._softmax(parameter).cpu().tolist()
ret.update({f'{name}/{value}': weight for value, weight in zip(self.op_names, weights)})
ret.update({name: dict(zip(self.op_names, weights))})
return ret
def export(self, memo):
@ -565,7 +472,7 @@ class DifferentiableMixedCell(PathSamplingCell):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
_, j, op_name = all_weights[k % len(all_weights)]
exported[f'{self.label}/op_{i}_{k}'] = op_name
exported[f'{self.label}/input_{i}_{k}'] = j
exported[f'{self.label}/input_{i}_{k}'] = [j]
return exported
@ -579,10 +486,10 @@ class DifferentiableMixedCell(PathSamplingCell):
op_results = torch.stack([op(states[j]) for op in ops[j].values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) # (-1, 1, 1, 1, 1, ...)
op_weights = self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}'])
if len(op_weights) == len(op_results) + 1:
if op_weights.size(0) == op_results.size(0) + 1:
# concatenate with a zero operation, indicating this path is not chosen at all.
op_results = torch.cat((op_results, torch.zeros_like(op_results[:1])), 0)
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
edge_sum = torch.sum(op_results * op_weights.view(*alpha_shape), 0)
current_state.append(edge_sum)
states.append(sum(current_state)) # type: ignore
@ -591,16 +498,8 @@ class DifferentiableMixedCell(PathSamplingCell):
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
def arch_parameters(self):
"""Iterate over architecture parameters. Not recursive."""
for name, p in self.named_parameters():
if any(name.startswith(par_name) for par_name in self._arch_parameter_names):
yield p

Просмотреть файл

@ -18,13 +18,15 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import nni.nas.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.mutable import MutableExpression
from nni.nas.nn.pytorch import (
ParametrizedModule,
MutableConv2d, MutableLinear, MutableBatchNorm2d, MutableLayerNorm, MutableMultiheadAttention
)
from .base import BaseSuperNetModule, sub_state_dict
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant
from .base import BaseSuperNetModule
from ._expression_utils import traverse_all_options, evaluate_constant
from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict
T = TypeVar('T')
@ -40,7 +42,7 @@ __all__ = [
'NATIVE_MIXED_OPERATIONS',
]
_diff_not_compatible_error = 'To be compatible with differentiable one-shot strategy, {} in {} must not be ValueChoice.'
_diff_not_compatible_error = 'To be compatible with differentiable one-shot strategy, {} in {} must not be mutable.'
class MixedOperationSamplingPolicy:
@ -84,10 +86,10 @@ class MixedOperationSamplingPolicy:
class MixedOperation(BaseSuperNetModule):
"""This is the base class for all mixed operations.
It's what you should inherit to support a new operation with ValueChoice.
It's what you should inherit to support a new operation with mutable.
It contains commonly used utilities that will ease the effort to write customized mixed oeprations,
i.e., operations with ValueChoice in its arguments.
It contains commonly used utilities that will ease the effort to write customized mixed operations,
i.e., operations with mutable in its arguments.
To customize, please write your own mixed operation, and add the hook into ``mutation_hooks`` parameter when using the strategy.
By design, for a mixed operation to work in a specific algorithm,
@ -116,14 +118,14 @@ class MixedOperation(BaseSuperNetModule):
sampling_policy: MixedOperationSamplingPolicy
def super_init_argument(self, name: str, value_choice: ValueChoiceX) -> Any:
def super_init_argument(self, name: str, value_choice: MutableExpression) -> Any:
"""Get the initialization argument when constructing super-kernel, i.e., calling ``super().__init__()``.
This is often related to specific operator, rather than algo.
For example::
def super_init_argument(self, name, value_choice):
return max(value_choice.candidates)
return max(value_choice.grid())
"""
raise NotImplementedError()
@ -138,8 +140,8 @@ class MixedOperation(BaseSuperNetModule):
def __init__(self, module_kwargs: dict[str, Any]) -> None:
# Concerned arguments
self.mutable_arguments: dict[str, ValueChoiceX] = {}
# Useful when retrieving arguments without ValueChoice
self.mutable_arguments: dict[str, MutableExpression] = {}
# Useful when retrieving arguments without mutable
self.init_arguments: dict[str, Any] = {**module_kwargs}
self._fill_missing_init_arguments()
@ -147,21 +149,37 @@ class MixedOperation(BaseSuperNetModule):
super_init_kwargs = {}
for key, value in module_kwargs.items():
if isinstance(value, ValueChoiceX):
if isinstance(value, MutableExpression):
if key not in self.argument_list:
raise TypeError(f'Unsupported value choice on argument of {self.bound_type}: {key}')
raise TypeError(f'Unsupported mutable argument of "{self.bound_type}": {key}')
super_init_kwargs[key] = self.super_init_argument(key, value)
self.mutable_arguments[key] = value
else:
super_init_kwargs[key] = value
# get all inner leaf value choices
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices(list(self.mutable_arguments.values()))
super().__init__(**super_init_kwargs)
for mutable in self.mutable_arguments.values():
self.add_mutable(mutable)
self.__post_init__()
def freeze(self, sample) -> Any:
"""Freeze the mixed operation to a specific operation.
Weights will be copied from the mixed operation to the frozen operation.
The returned operation will be of the ``bound_type``.
"""
arguments = {**self.init_arguments}
for name, mutable in self.mutable_arguments.items():
arguments[name] = mutable.freeze(sample)
operation = self.bound_type(**arguments)
# copy weights
state_dict = self.freeze_weight(**arguments)
operation.load_state_dict(state_dict)
return operation
def resample(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return self.sampling_policy.resample(self, memo)
@ -174,21 +192,12 @@ class MixedOperation(BaseSuperNetModule):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return self.sampling_policy.export(self, memo)
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
"""Find value choice in module's arguments and replace the whole module"""
has_valuechoice = False
if isinstance(module, cls.bound_type) and is_traceable(module):
for arg in itertools.chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
if isinstance(arg, ValueChoiceX):
has_valuechoice = True
if has_valuechoice:
if isinstance(module, cls.bound_type) and isinstance(module, ParametrizedModule):
if module.trace_args:
raise ValueError('ValueChoice on class arguments cannot appear together with ``trace_args``. '
raise ValueError('Mutable on class arguments cannot appear together with ``trace_args``. '
'Please enable ``kw_only`` on nni.trace.')
# save type and kwargs
@ -232,21 +241,12 @@ class MixedOperation(BaseSuperNetModule):
if param.default is not param.empty and param.name not in self.init_arguments:
self.init_arguments[param.name] = param.default
def slice_param(self, **kwargs):
def freeze_weight(self, **kwargs):
"""Slice the params and buffers for subnet forward and state dict.
When there is a `mapping=True` in kwargs, the return result will be wrapped in dict.
"""
raise NotImplementedError()
def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars):
kwargs = {name: self.forward_argument(name) for name in self.argument_list}
params_mapping: dict[str, Any] = self.slice_param(**kwargs)
for name, value in itertools.chain(self._parameters.items(), self._buffers.items()): # direct children
if value is None or name in self._non_persistent_buffers_set:
# it won't appear in state dict
continue
value = params_mapping.get(name, value)
destination[prefix + name] = value if keep_vars else value.detach()
The arguments are same as the arguments passed to ``__init__``.
"""
raise NotImplementedError('freeze_weight is not implemented.')
class MixedLinear(MixedOperation, nn.Linear):
@ -260,13 +260,13 @@ class MixedLinear(MixedOperation, nn.Linear):
Prefix of weight and bias will be sliced.
"""
bound_type = nas_nn.Linear
bound_type = MutableLinear
argument_list = ['in_features', 'out_features']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def super_init_argument(self, name: str, value_choice: MutableExpression):
return max(value_choice.grid())
def slice_param(self, in_features: int_or_int_dict, out_features: int_or_int_dict, **kwargs) -> Any:
def freeze_weight(self, in_features: int_or_int_dict, out_features: int_or_int_dict, **kwargs) -> Any:
in_features_ = _W(in_features)
out_features_ = _W(out_features)
@ -281,7 +281,7 @@ class MixedLinear(MixedOperation, nn.Linear):
out_features: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
params_mapping = self.slice_param(in_features, out_features)
params_mapping = self.freeze_weight(in_features, out_features)
weight, bias = [params_mapping.get(name) for name in ['weight', 'bias']]
return F.linear(inputs, weight, bias)
@ -321,7 +321,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
"""
bound_type = nas_nn.Conv2d
bound_type = MutableConv2d
argument_list = [
'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'dilation', 'groups'
]
@ -332,12 +332,12 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
return (value, value)
return value
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
def super_init_argument(self, name: str, mutable_expr: MutableExpression):
if name not in ['in_channels', 'out_channels', 'groups', 'stride', 'kernel_size', 'padding', 'dilation']:
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
if name == ['kernel_size', 'padding']:
all_sizes = set(traverse_all_options(value_choice))
all_sizes = set(traverse_all_options(mutable_expr))
if any(isinstance(sz, tuple) for sz in all_sizes):
# maximum kernel should be calculated on every dimension
return (
@ -351,28 +351,37 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
if 'in_channels' in self.mutable_arguments:
# If the ratio is constant, we don't need to try the maximum groups.
try:
constant = evaluate_constant(self.mutable_arguments['in_channels'] / value_choice)
return max(cast(List[float], traverse_all_options(value_choice))) // int(constant)
constant = evaluate_constant(self.mutable_arguments['in_channels'] / mutable_expr)
return max(cast(List[float], traverse_all_options(mutable_expr))) // int(constant)
except ValueError:
warnings.warn(
'Both input channels and groups are ValueChoice in a convolution, and their relative ratio is not a constant. '
'Both input channels and groups are mutable in a convolution, and their relative ratio is not a constant. '
'This can be problematic for most one-shot algorithms. Please check whether this is your intention.',
RuntimeWarning
)
# minimum groups, maximum kernel
return min(traverse_all_options(value_choice))
return min(traverse_all_options(mutable_expr))
else:
return max(traverse_all_options(value_choice))
return max(traverse_all_options(mutable_expr))
def slice_param(self,
def freeze_weight(self,
in_channels: int_or_int_dict,
out_channels: int_or_int_dict,
kernel_size: scalar_or_scalar_dict[_int_or_tuple],
groups: int_or_int_dict,
**kwargs
) -> Any:
**kwargs) -> Any:
rv = self._freeze_weight_impl(in_channels, out_channels, kernel_size, groups)
rv.pop('in_channels_per_group', None)
return rv
def _freeze_weight_impl(self,
in_channels: int_or_int_dict,
out_channels: int_or_int_dict,
kernel_size: scalar_or_scalar_dict[_int_or_tuple],
groups: int_or_int_dict,
**kwargs) -> Any:
in_channels_ = _W(in_channels)
out_channels_ = _W(out_channels)
@ -386,8 +395,8 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
in_channels_per_group = None
else:
assert 'groups' in self.mutable_arguments
err_message = 'For differentiable one-shot strategy, when groups is a ValueChoice, ' \
'in_channels and out_channels should also be a ValueChoice. ' \
err_message = 'For differentiable one-shot strategy, when groups is a mutable, ' \
'in_channels and out_channels should also be a mutable. ' \
'Also, the ratios of in_channels divided by groups, and out_channels divided by groups ' \
'should be constants.'
if 'in_channels' not in self.mutable_arguments or 'out_channels' not in self.mutable_arguments:
@ -412,7 +421,6 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
return {'weight': weight, 'bias': bias, 'in_channels_per_group': in_channels_per_group}
def forward_with_args(self,
in_channels: int_or_int_dict,
out_channels: int_or_int_dict,
@ -426,7 +434,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
if any(isinstance(arg, dict) for arg in [stride, dilation]):
raise ValueError(_diff_not_compatible_error.format('stride, dilation', 'Conv2d'))
params_mapping = self.slice_param(in_channels, out_channels, kernel_size, groups)
params_mapping = self._freeze_weight_impl(in_channels, out_channels, kernel_size, groups)
weight, bias, in_channels_per_group = [
params_mapping.get(name)
for name in ['weight', 'bias', 'in_channels_per_group']
@ -478,13 +486,13 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
PyTorch BatchNorm supports a case where momentum can be none, which is not supported here.
"""
bound_type = nas_nn.BatchNorm2d
bound_type = MutableBatchNorm2d
argument_list = ['num_features', 'eps', 'momentum']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def super_init_argument(self, name: str, mutable_expr: MutableExpression):
return max(traverse_all_options(mutable_expr))
def slice_param(self, num_features: int_or_int_dict, **kwargs) -> Any:
def freeze_weight(self, num_features: int_or_int_dict, **kwargs) -> Any:
if isinstance(num_features, dict):
num_features = self.num_features
weight, bias = self.weight, self.bias
@ -508,7 +516,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
if any(isinstance(arg, dict) for arg in [eps, momentum]):
raise ValueError(_diff_not_compatible_error.format('eps and momentum', 'BatchNorm2d'))
params_mapping = self.slice_param(num_features)
params_mapping = self.freeze_weight(num_features)
weight, bias, running_mean, running_var = [
params_mapping.get(name)
for name in ['weight', 'bias', 'running_mean', 'running_var']
@ -547,7 +555,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
eps is required to be float.
"""
bound_type = nas_nn.LayerNorm
bound_type = MutableLayerNorm
argument_list = ['normalized_shape', 'eps']
@staticmethod
@ -556,10 +564,10 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
return (value, value)
return value
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
def super_init_argument(self, name: str, mutable_expr: MutableExpression):
if name not in ['normalized_shape']:
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
all_sizes = set(traverse_all_options(value_choice))
all_sizes = set(traverse_all_options(mutable_expr))
if any(isinstance(sz, (tuple, list)) for sz in all_sizes):
# transpose
all_sizes = list(zip(*all_sizes))
@ -568,7 +576,12 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
else:
return max(all_sizes)
def slice_param(self, normalized_shape, **kwargs) -> Any:
def freeze_weight(self, normalized_shape, **kwargs) -> Any:
rv = self._freeze_weight_impl(normalized_shape)
rv.pop('normalized_shape')
return rv
def _freeze_weight_impl(self, normalized_shape, **kwargs) -> Any:
if isinstance(normalized_shape, dict):
normalized_shape = self.normalized_shape
@ -595,7 +608,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
if any(isinstance(arg, dict) for arg in [eps]):
raise ValueError(_diff_not_compatible_error.format('eps', 'LayerNorm'))
params_mapping = self.slice_param(normalized_shape)
params_mapping = self._freeze_weight_impl(normalized_shape)
weight, bias, normalized_shape = [
params_mapping.get(name)
for name in ['weight', 'bias', 'normalized_shape']
@ -633,7 +646,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``.
"""
bound_type = nas_nn.MultiheadAttention
bound_type = MutableMultiheadAttention
argument_list = ['embed_dim', 'num_heads', 'kdim', 'vdim', 'dropout']
def __post_init__(self):
@ -671,8 +684,17 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def super_init_argument(self, name: str, mutable_expr: MutableExpression):
return max(traverse_all_options(mutable_expr))
def freeze_weight(self, embed_dim, kdim, vdim, **kwargs):
rv = self._freeze_weight_impl(embed_dim, kdim, vdim, **kwargs)
# pop flags and nones, as they won't show in state dict
rv.pop('qkv_same_embed_dim')
for k in list(rv):
if rv[k] is None:
rv.pop(k)
return rv
def _to_proj_slice(self, embed_dim: _W) -> list[slice]:
# slice three parts, corresponding to q, k, v respectively
@ -682,7 +704,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
slice(self.embed_dim * 2, self.embed_dim * 2 + embed_dim)
]
def slice_param(self, embed_dim, kdim, vdim, **kwargs):
def _freeze_weight_impl(self, embed_dim, kdim, vdim, **kwargs):
# by default, kdim, vdim can be none
if kdim is None:
kdim = embed_dim
@ -723,31 +745,6 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
'qkv_same_embed_dim': qkv_same_embed_dim
}
def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars):
kwargs = {name: self.forward_argument(name) for name in self.argument_list}
params_mapping = self.slice_param(**kwargs, mapping=True)
for name, value in itertools.chain(self._parameters.items(), self._buffers.items()):
if value is None or name in self._non_persistent_buffers_set:
continue
value = params_mapping.get(name, value)
destination[prefix + name] = value if keep_vars else value.detach()
# params of out_proj is handled in ``MixedMultiHeadAttention`` rather than
# ``NonDynamicallyQuantizableLinear`` sub-module. We also convert it to state dict here.
for name in ["out_proj.weight", "out_proj.bias"]:
value = params_mapping.get(name, None)
if value is None:
continue
destination[prefix + name] = value if keep_vars else value.detach()
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
for name, module in self._modules.items():
# the weights of ``NonDynamicallyQuantizableLinear`` has been handled in `_save_param_buff_to_state_dict`.
if isinstance(module, nn.modules.linear.NonDynamicallyQuantizableLinear):
continue
if module is not None:
sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
def forward_with_args(
self,
embed_dim: int_or_int_dict, num_heads: int,
@ -770,7 +767,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
else:
used_embed_dim = embed_dim
params_mapping = self.slice_param(embed_dim, kdim, vdim)
params_mapping = self._freeze_weight_impl(embed_dim, kdim, vdim)
in_proj_bias, in_proj_weight, bias_k, bias_v, \
out_proj_weight, out_proj_bias, q_proj, k_proj, v_proj, qkv_same_embed_dim = [
params_mapping.get(name)

Просмотреть файл

@ -6,19 +6,19 @@
from __future__ import annotations
import copy
import functools
import random
from typing import Any, List, Dict, Sequence, cast
import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.mutable import MutableExpression, label_scope, Mutable, Categorical, CategoricalMultiple
from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, Cell
from nni.nas.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule, sub_state_dict
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum
from .base import BaseSuperNetModule
from ._expression_utils import weighted_sum
from .operation import MixedOperationSamplingPolicy, MixedOperation
__all__ = [
@ -28,7 +28,7 @@ __all__ = [
]
class PathSamplingLayer(BaseSuperNetModule):
class PathSamplingLayer(LayerChoice, BaseSuperNetModule):
"""
Mixed layer, in which fprop is decided by exactly one inner layer or sum of multiple (sampled) layers.
If multiple modules are selected, the result will be summed and returned.
@ -41,62 +41,43 @@ class PathSamplingLayer(BaseSuperNetModule):
Name of the choice.
"""
def __init__(self, paths: list[tuple[str, nn.Module]], label: str):
super().__init__()
self.op_names = []
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
def __init__(self, paths: dict[str, nn.Module] | list[nn.Module], label: str):
super().__init__(paths, label=label)
self._sampled: list[str] | str | None = None # sampled can be either a list of indices or an index
self.label = label
def resample(self, memo):
"""Random choose one path if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
else:
self._sampled = random.choice(self.op_names)
self._sampled = self.choice.random()
return {self.label: self._sampled}
def export(self, memo):
"""Random choose one name if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: random.choice(self.op_names)}
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
return {self.label: self.choice.random()}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, LayerChoice):
return cls(list(module.named_children()), module.label)
if type(module) is LayerChoice:
return cls(module.candidates, module.label)
def reduction(self, items: list[Any], sampled: list[Any]):
def _reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
for samp in sampled:
module = getattr(self, str(samp))
if module is not None:
sub_state_dict(module, destination=destination, prefix=prefix, keep_vars=keep_vars)
def forward(self, *args, **kwargs):
if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
# str(samp) is needed here because samp can sometimes be integers, but attr are always str
res = [getattr(self, str(samp))(*args, **kwargs) for samp in sampled]
return self.reduction(res, sampled)
res = [self[samp](*args, **kwargs) for samp in sampled]
return self._reduction(res, sampled)
class PathSamplingInput(BaseSuperNetModule):
class PathSamplingInput(InputChoice, BaseSuperNetModule):
"""
Mixed input. Take a list of tensor as input, select some of them and return the sum.
@ -106,22 +87,9 @@ class PathSamplingInput(BaseSuperNetModule):
Sampled input indices.
"""
def __init__(self, n_candidates: int, n_chosen: int, reduction_type: str, label: str):
super().__init__()
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction_type = reduction_type
def __init__(self, n_candidates: int, n_chosen: int, reduction: str, label: str):
super().__init__(n_candidates, n_chosen=n_chosen, reduction=reduction, label=label)
self._sampled: list[int] | int | None = None
self.label = label
def _random_choose_n(self):
sampling = list(range(self.n_candidates))
random.shuffle(sampling)
sampling = sorted(sampling[:self.n_chosen])
if len(sampling) == 1:
return sampling[0]
else:
return sampling
def resample(self, memo):
"""Random choose one path / multiple paths if label is not found in memo.
@ -131,42 +99,36 @@ class PathSamplingInput(BaseSuperNetModule):
if self.label in memo:
self._sampled = memo[self.label]
else:
self._sampled = self._random_choose_n()
self._sampled = self.choice.random()
return {self.label: self._sampled}
def export(self, memo):
"""Random choose one name if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self._random_choose_n()}
def search_space_spec(self):
return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
}
return {self.label: self.choice.random()}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, InputChoice):
if type(module) is InputChoice:
if module.reduction not in ['sum', 'mean', 'concat']:
raise ValueError('Only input choice of sum/mean/concat reduction is supported.')
if module.n_chosen is None:
raise ValueError('n_chosen is None is not supported yet.')
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
def reduction(self, items: list[Any], sampled: list[Any]) -> Any:
def _reduction(self, items: list[Any], sampled: list[Any]) -> Any:
"""Override this to implement customized reduction."""
if len(items) == 1:
return items[0]
else:
if self.reduction_type == 'sum':
if self.reduction == 'sum':
return sum(items)
elif self.reduction_type == 'mean':
elif self.reduction == 'mean':
return sum(items) / len(items)
elif self.reduction_type == 'concat':
elif self.reduction == 'concat':
return torch.cat(items, 1)
raise ValueError(f'Unsupported reduction type: {self.reduction_type}')
raise ValueError(f'Unsupported reduction type: {self.reduction}')
def forward(self, input_tensors):
if self._sampled is None:
@ -175,7 +137,7 @@ class PathSamplingInput(BaseSuperNetModule):
raise ValueError(f'Expect {self.n_candidates} input tensors, found {len(input_tensors)}.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = [input_tensors[samp] for samp in sampled]
return self.reduction(res, sampled)
return self._reduction(res, sampled)
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
@ -193,28 +155,28 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Random sample for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
for label in space_spec:
space_spec = operation.simplify()
for label, mutable in space_spec.items():
if label in memo:
result[label] = memo[label]
else:
result[label] = random.choice(space_spec[label].values)
result[label] = mutable.random()
# composits to kwargs
# composites to kwargs
# example: result = {"exp_ratio": 3}, self._sampled = {"in_channels": 48, "out_channels": 96}
self._sampled = {}
for key, value in operation.mutable_arguments.items():
self._sampled[key] = evaluate_value_choice_with_dict(value, result)
self._sampled[key] = value.freeze(result)
return result
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
for label in space_spec:
space_spec = operation.simplify()
for label, mutable in space_spec.items():
if label not in memo:
result[label] = random.choice(space_spec[label].values)
result[label] = mutable.random()
return result
def forward_argument(self, operation: MixedOperation, name: str) -> Any:
@ -226,9 +188,9 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return operation.init_arguments[name]
class PathSamplingRepeat(BaseSuperNetModule):
class PathSamplingRepeat(Repeat, BaseSuperNetModule):
"""
Implementaion of Repeat in a path-sampling supernet.
Implementation of Repeat in a path-sampling supernet.
Samples one / some of the prefixes of the repeated blocks.
Attributes
@ -237,56 +199,43 @@ class PathSamplingRepeat(BaseSuperNetModule):
Sampled depth.
"""
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int]):
super().__init__()
self.blocks: Any = blocks
self.depth = depth
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
def __init__(self, blocks: list[nn.Module], depth: MutableExpression[int]):
super().__init__(blocks, depth)
self._sampled: list[int] | int | None = None
def resample(self, memo):
"""Since depth is based on ValueChoice, we only need to randomly sample every leaf value choices."""
result = {}
for label in self._space_spec:
assert isinstance(self.depth_choice, Mutable)
for label, mutable in self.depth_choice.simplify().items():
if label in memo:
result[label] = memo[label]
else:
result[label] = random.choice(self._space_spec[label].values)
result[label] = mutable.random()
self._sampled = evaluate_value_choice_with_dict(self.depth, result)
self._sampled = self.depth_choice.freeze(result)
return result
def export(self, memo):
"""Random choose one if every choice not in memo."""
result = {}
for label in self._space_spec:
assert isinstance(self.depth_choice, Mutable)
for label, mutable in self.depth_choice.simplify().items():
if label not in memo:
result[label] = random.choice(self._space_spec[label].values)
result[label] = mutable.random()
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
if type(module) == Repeat and isinstance(module.depth_choice, MutableExpression):
# Only interesting when depth is mutable
return cls(cast(List[nn.Module], module.blocks), module.depth_choice)
return cls(list(module.blocks), module.depth_choice)
def reduction(self, items: list[Any], sampled: list[Any]):
def _reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
sampled: Any = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
for cur_depth, (name, module) in enumerate(self.blocks.named_children(), start=1):
if module is not None:
sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
if not any(d > cur_depth for d in sampled):
break
def forward(self, x):
if self._sampled is None:
raise RuntimeError('At least one depth needs to be sampled before fprop.')
@ -299,7 +248,7 @@ class PathSamplingRepeat(BaseSuperNetModule):
res.append(x)
if not any(d > cur_depth for d in sampled):
break
return self.reduction(res, sampled)
return self._reduction(res, sampled)
class PathSamplingCell(BaseSuperNetModule):
@ -345,39 +294,46 @@ class PathSamplingCell(BaseSuperNetModule):
# InputChoice is implicit in this graph.
for i in self.output_node_indices:
self.ops.append(nn.ModuleList())
for k in range(i + self.num_predecessors):
for k in range(i):
# Second argument in (i, **0**, k) is always 0.
# One-shot strategy can't handle the cases where op spec is dependent on `op_index`.
ops, _ = create_cell_op_candidates(op_factory, i, 0, k)
self.op_names = list(ops.keys())
cast(nn.ModuleList, self.ops[-1]).append(nn.ModuleDict(ops))
self.label = label
with label_scope(label) as self.label_scope:
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for k in range(self.num_ops_per_node):
op_label = f'op_{i}_{k}'
input_label = f'input_{i}_{k}'
self.add_mutable(Categorical(self.op_names, label=op_label))
# Need multiple here to align with the original cell.
self.add_mutable(CategoricalMultiple(range(i), n_chosen=1, label=input_label))
self._sampled: dict[str, str | int] = {}
def search_space_spec(self) -> dict[str, ParameterSpec]:
# TODO: Recreating the space here.
# The spec should be moved to definition of Cell itself.
space_spec = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for k in range(self.num_ops_per_node):
op_label = f'{self.label}/op_{i}_{k}'
input_label = f'{self.label}/input_{i}_{k}'
space_spec[op_label] = ParameterSpec(op_label, 'choice', self.op_names, (op_label,), True, size=len(self.op_names))
space_spec[input_label] = ParameterSpec(input_label, 'choice', list(range(i)), (input_label, ), True, size=i)
return space_spec
@property
def label(self) -> str:
return self.label_scope.name
def freeze(self, sample):
raise NotImplementedError('PathSamplingCell does not support freeze.')
def resample(self, memo):
"""Random choose one path if label is not found in memo."""
self._sampled = {}
new_sampled = {}
for label, param_spec in self.search_space_spec().items():
for label, param_spec in self.simplify().items():
if label in memo:
assert not isinstance(memo[label], list), 'Multi-path sampling is currently unsupported on cell.'
if isinstance(memo[label], list) and len(memo[label]) > 1:
raise ValueError(f'Multi-path sampling is currently unsupported on cell: {memo[label]}')
self._sampled[label] = memo[label]
else:
self._sampled[label] = new_sampled[label] = random.choice(param_spec.values)
if isinstance(param_spec, Categorical):
self._sampled[label] = new_sampled[label] = random.choice(param_spec.values)
elif isinstance(param_spec, CategoricalMultiple):
assert param_spec.n_chosen == 1
self._sampled[label] = new_sampled[label] = [random.choice(param_spec.values)]
return new_sampled
def export(self, memo):
@ -392,7 +348,7 @@ class PathSamplingCell(BaseSuperNetModule):
for k in range(self.num_ops_per_node):
# Select op list based on the input chosen
input_index = self._sampled[f'{self.label}/input_{i}_{k}']
input_index = self._sampled[f'{self.label}/input_{i}_{k}'][0] # [0] because it's a list and n_chosen=1
op_candidates = ops[cast(int, input_index)]
# Select op from op list based on the op chosen
op_index = self._sampled[f'{self.label}/op_{i}_{k}']
@ -411,7 +367,7 @@ class PathSamplingCell(BaseSuperNetModule):
Mutate only handles cells of specific configurations (e.g., with loose end).
Fallback to the default mutate if the cell is not handled here.
"""
if isinstance(module, Cell):
if type(module) is Cell:
op_factory = None # not all the cells need to be replaced
if module.op_candidates_factory is not None:
op_factory = module.op_candidates_factory
@ -420,10 +376,16 @@ class PathSamplingCell(BaseSuperNetModule):
elif module.merge_op == 'loose_end':
op_candidates_lc = module.ops[-1][-1] # type: ignore
assert isinstance(op_candidates_lc, LayerChoice)
op_factory = { # create a factory
name: lambda _, __, ___: copy.deepcopy(op_candidates_lc[name])
for name in op_candidates_lc.names
}
candidates = op_candidates_lc.candidates
def _copy(_, __, ___, op):
return copy.deepcopy(op)
if isinstance(candidates, list):
op_factory = [functools.partial(_copy, op=op) for op in candidates]
elif isinstance(candidates, dict):
op_factory = {name: functools.partial(_copy, op=op) for name, op in candidates.items()}
else:
raise ValueError(f'Unsupported type of candidates: {type(candidates)}')
if op_factory is not None:
return cls(
op_factory,

Просмотреть файл

@ -0,0 +1,633 @@
import pytest
import numpy as np
import torch
import torch.nn as nn
import nni
from nni.mutable import MutableList, frozen
from nni.nas.nn.pytorch import LayerChoice, ModelSpace, Cell, MutableConv2d, MutableBatchNorm2d, MutableLayerNorm, MutableLinear, MutableMultiheadAttention
from nni.nas.oneshot.pytorch.differentiable import DartsLightningModule
from nni.nas.oneshot.pytorch.strategy import RandomOneShot, DARTS
from nni.nas.oneshot.pytorch.supermodule.base import BaseSuperNetModule
from nni.nas.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax,
DifferentiableMixedRepeat, DifferentiableMixedCell
)
from nni.nas.oneshot.pytorch.supermodule.sampling import (
MixedOpPathSamplingPolicy, PathSamplingLayer, PathSamplingInput, PathSamplingRepeat, PathSamplingCell
)
from nni.nas.oneshot.pytorch.supermodule.operation import MixedConv2d, NATIVE_MIXED_OPERATIONS
from nni.nas.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput
from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W
from nni.nas.oneshot.pytorch.supermodule._expression_utils import *
from ut.nas.nn.models import (
CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory
)
@pytest.fixture(autouse=True)
def context():
frozen._ENSURE_FROZEN_STRICT = False
yield
frozen._ENSURE_FROZEN_STRICT = True
def test_slice():
weight = np.ones((3, 7, 24, 23))
assert S(weight)[:, 1:3, :, 9:13].shape == (3, 2, 24, 4)
assert S(weight)[:, 1:W(3)*2+1, :, 9:13].shape == (3, 6, 24, 4)
assert S(weight)[:, 1:W(3)*2+1].shape == (3, 6, 24, 23)
# Ellipsis
assert S(weight)[..., 9:13].shape == (3, 7, 24, 4)
assert S(weight)[:2, ..., 1:W(3)+1].shape == (2, 7, 24, 3)
assert S(weight)[..., 1:W(3)*2+1].shape == (3, 7, 24, 6)
assert S(weight)[..., :10, 1:W(3)*2+1].shape == (3, 7, 10, 6)
# no effect
assert S(weight)[:] is weight
# list
assert S(weight)[[slice(1), slice(2, 3)]].shape == (2, 7, 24, 23)
assert S(weight)[[slice(1), slice(2, W(2) + 1)], W(2):].shape == (2, 5, 24, 23)
# weighted
weight = S(weight)[:W({1: 0.5, 2: 0.3, 3: 0.2})]
weight = weight[:, 0, 0, 0]
assert weight[0] == 1 and weight[1] == 0.5 and weight[2] == 0.2
weight = np.ones((3, 6, 6))
value = W({1: 0.5, 3: 0.5})
weight = S(weight)[:, 3 - value:3 + value, 3 - value:3 + value]
for i in range(0, 6):
for j in range(0, 6):
if 2 <= i <= 3 and 2 <= j <= 3:
assert weight[0, i, j] == 1
else:
assert weight[1, i, j] == 0.5
# weighted + list
value = W({1: 0.5, 3: 0.5})
weight = np.ones((8, 4))
weight = S(weight)[[slice(value), slice(4, value + 4)]]
assert weight.sum(1).tolist() == [4, 2, 2, 0, 4, 2, 2, 0]
with pytest.raises(ValueError, match='one distinct'):
# has to be exactly the same instance, equal is not enough
weight = S(weight)[:W({1: 0.5}), : W({1: 0.5})]
def test_valuechoice_utils():
chosen = {"exp": 3, "add": 1}
vc0 = nni.choice('exp', [3, 4, 6]) * 2 + nni.choice('add', [0, 1])
assert vc0.freeze(chosen) == 7
vc = vc0 + nni.choice('exp', [3, 4, 6])
assert vc.freeze(chosen) == 10
assert list(MutableList([vc0, vc]).simplify().keys()) == ['exp', 'add']
assert traverse_all_options(vc) == [9, 10, 12, 13, 18, 19]
weights = dict(traverse_all_options(vc, weights={'exp': [0.5, 0.3, 0.2], 'add': [0.4, 0.6]}))
ans = dict([(9, 0.2), (10, 0.3), (12, 0.12), (13, 0.18), (18, 0.08), (19, 0.12)])
assert len(weights) == len(ans)
for value, weight in ans.items():
assert abs(weight - weights[value]) < 1e-6
assert evaluate_constant(nni.choice('x', [3, 4, 6]) - nni.choice('x', [3, 4, 6])) == 0
with pytest.raises(ValueError):
evaluate_constant(nni.choice('x', [3, 4, 6]) - nni.choice('y', [3, 4, 6]))
assert evaluate_constant(nni.choice('x', [3, 4, 6]) * 2 / nni.choice('x', [3, 4, 6])) == 2
def test_expectation():
vc = nni.choice('exp', [3, 4, 6]) * 2 + nni.choice('add', [0, 1])
assert expression_expectation(vc, {'exp': [0.5, 0.3, 0.2], 'add': [0.4, 0.6]}) == 8.4
vc = sum([nni.choice(f'e{i}', [0, 1]) for i in range(100)])
assert expression_expectation(vc, {f'e{i}': [0.5] * 2 for i in range(100)}) == 50
vc = nni.choice('a', [1, 2, 3]) * nni.choice('b', [1, 2, 3]) - nni.choice('c', [1, 2, 3])
probs1 = [0.2, 0.3, 0.5]
probs2 = [0.1, 0.2, 0.7]
probs3 = [0.3, 0.4, 0.3]
expect = sum(
(i * j - k) * p1 * p2 * p3
for i, p1 in enumerate(probs1, 1)
for j, p2 in enumerate(probs2, 1)
for k, p3 in enumerate(probs3, 1)
)
assert abs(expression_expectation(vc, {'a': probs1, 'b': probs2, 'c': probs3}) - expect) < 1e-12
vc = nni.choice('a', [1, 2, 3]) + 1
assert expression_expectation(vc, {'a': [0.2, 0.3, 0.5]}) == 3.3
def test_weighted_sum():
weights = [0.1, 0.2, 0.7]
items = [1, 2, 3]
assert abs(weighted_sum(items, weights) - 2.6) < 1e-6
assert weighted_sum(items) == 6
with pytest.raises(TypeError, match='Unsupported'):
weighted_sum(['a', 'b', 'c'], weights)
assert abs(weighted_sum(np.arange(3), weights).item() - 1.6) < 1e-6
items = [torch.full((2, 3, 5), i) for i in items]
assert abs(weighted_sum(items, weights).flatten()[0].item() - 2.6) < 1e-6
items = [torch.randn(2, 3, i) for i in [1, 2, 3]]
with pytest.raises(ValueError, match=r'does not match.*\n.*torch\.Tensor\(2, 3, 1\)'):
weighted_sum(items, weights)
items = [(1, 2), (3, 4), (5, 6)]
res = weighted_sum(items, weights)
assert len(res) == 2 and abs(res[0] - 4.2) < 1e-6 and abs(res[1] - 5.2) < 1e-6
items = [(1, 2), (3, 4), (5, 6, 7)]
with pytest.raises(ValueError):
weighted_sum(items, weights)
items = [{"a": i, "b": np.full((2, 3, 5), i)} for i in [1, 2, 3]]
res = weighted_sum(items, weights)
assert res['b'].shape == (2, 3, 5)
assert abs(res['b'][0][0][0] - res['a']) < 1e-6
assert abs(res['a'] - 2.6) < 1e-6
def test_pathsampling_valuechoice():
orig_conv = MutableConv2d(3, nni.choice('123', [3, 5, 7]), kernel_size=3)
conv = MixedConv2d.mutate(orig_conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
conv.resample(memo={'123': 5})
assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 5
conv.resample(memo={'123': 7})
assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 7
assert conv.export({})['123'] in [3, 5, 7]
def test_differentiable_valuechoice():
orig_conv = MutableConv2d(3, nni.choice('456', [3, 5, 7]),
kernel_size=nni.choice('123', [3, 5, 7]),
padding=nni.choice('123', [3, 5, 7]) // 2
)
memo = {
'123': nn.Parameter(torch.zeros(3)),
'456': nn.Parameter(torch.zeros(3)),
}
conv = MixedConv2d.mutate(orig_conv, 'dummy', memo, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
assert conv(torch.zeros((1, 3, 7, 7))).size(2) == 7
assert set(conv.export({}).keys()) == {'123', '456'}
def test_differentiable_layerchoice_dedup():
layerchoice1 = LayerChoice([MutableConv2d(3, 3, 3), MutableConv2d(3, 3, 3)], label='a')
layerchoice2 = LayerChoice([MutableConv2d(3, 3, 3), MutableConv2d(3, 3, 3)], label='a')
memo = {'a': nn.Parameter(torch.zeros(2))}
DifferentiableMixedLayer.mutate(layerchoice1, 'x', memo, {})
DifferentiableMixedLayer.mutate(layerchoice2, 'x', memo, {})
assert len(memo) == 1 and 'a' in memo
def _mutate_op_path_sampling_policy(operation):
for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation):
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
break
return mutate_op
def _mixed_operation_sampling_sanity_check(operation, memo, *input):
mutate_op = _mutate_op_path_sampling_policy(operation)
mutate_op.resample(memo=memo)
return mutate_op(*input)
def _mixed_operation_state_dict_sanity_check(operation, model, memo, *input):
mutate_op = _mutate_op_path_sampling_policy(operation)
mutate_op.resample(memo=memo)
frozen_op = mutate_op.freeze(memo)
return frozen_op(*input), mutate_op(*input)
def _mixed_operation_differentiable_sanity_check(operation, *input):
memo = {k: nn.Parameter(torch.zeros(len(v))) for k, v in operation.simplify().items()}
for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation):
mutate_op = native_op.mutate(operation, 'dummy', memo, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
break
mutate_op(*input)
mutate_op.export({})
mutate_op.export_probs({})
def test_mixed_linear():
linear = MutableLinear(nni.choice('shared', [3, 6, 9]), nni.choice('xx', [2, 4, 8]))
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
_mixed_operation_sampling_sanity_check(linear, {'shared': 9}, torch.randn(2, 9))
_mixed_operation_differentiable_sanity_check(linear, torch.randn(2, 9))
linear = MutableLinear(nni.choice('shared', [3, 6, 9]), nni.choice('xx', [2, 4, 8]), bias=False)
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
with pytest.raises(TypeError):
linear = MutableLinear(nni.choice('shared', [3, 6, 9]), nni.choice('xx', [2, 4, 8]), bias=nni.choice('yy', [False, True]))
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
linear = MutableLinear(nni.choice('in_features', [3, 6, 9]), nni.choice('out_features', [2, 4, 8]), bias=True)
kwargs = {'in_features': 6, 'out_features': 4}
out1, out2 = _mixed_operation_state_dict_sanity_check(linear, MutableLinear(**kwargs), kwargs, torch.randn(2, 6))
assert torch.allclose(out1, out2)
def test_mixed_conv2d():
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('out', [2, 4, 8]) * 2, 1)
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'out': 4}, torch.randn(2, 3, 9, 9)).size(1) == 8
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
# stride
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('out', [2, 4, 8]), 1, stride=nni.choice('stride', [1, 2]))
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 2}, torch.randn(2, 3, 10, 10)).size(2) == 5
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 1}, torch.randn(2, 3, 10, 10)).size(2) == 10
with pytest.raises(ValueError, match='must not be mutable'):
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 10, 10))
# groups, dw conv
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9]))
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10])
# groups, invalid case
conv = MutableConv2d(nni.choice('in', [9, 6, 3]), nni.choice('in', [9, 6, 3]), 1, groups=9)
with pytest.raises(RuntimeError):
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10))
# groups, differentiable
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('out', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9]))
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9]))
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
with pytest.raises(ValueError):
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('groups', [3, 9]))
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
with pytest.raises(RuntimeError):
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9]) // 3)
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 10, 3, 3))
# make sure kernel is sliced correctly
conv = MutableConv2d(1, 1, nni.choice('k', [1, 3]), bias=False)
conv = MixedConv2d.mutate(conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
with torch.no_grad():
conv.weight.zero_()
# only center is 1, must pick center to pass this test
conv.weight[0, 0, 1, 1] = 1
conv.resample({'k': 1})
assert conv(torch.ones((1, 1, 3, 3))).sum().item() == 9
# only `in_channels`, `out_channels`, `kernel_size`, and `groups` influence state_dict
conv = MutableConv2d(
nni.choice('in_channels', [2, 4, 8]), nni.choice('out_channels', [6, 12, 24]),
kernel_size=nni.choice('kernel_size', [3, 5, 7]), groups=nni.choice('groups', [1, 2])
)
kwargs = {
'in_channels': 8, 'out_channels': 12,
'kernel_size': 5, 'groups': 2
}
out1, out2 = _mixed_operation_state_dict_sanity_check(conv, MutableConv2d(**kwargs), kwargs, torch.randn(2, 8, 16, 16))
assert torch.allclose(out1, out2)
def test_mixed_batchnorm2d():
bn = MutableBatchNorm2d(nni.choice('dim', [32, 64]))
assert _mixed_operation_sampling_sanity_check(bn, {'dim': 32}, torch.randn(2, 32, 3, 3)).size(1) == 32
assert _mixed_operation_sampling_sanity_check(bn, {'dim': 64}, torch.randn(2, 64, 3, 3)).size(1) == 64
_mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3))
bn = MutableBatchNorm2d(nni.choice('num_features', [32, 48, 64]))
kwargs = {'num_features': 48}
out1, out2 = _mixed_operation_state_dict_sanity_check(bn, MutableBatchNorm2d(**kwargs), kwargs, torch.randn(2, 48, 3, 3))
assert torch.allclose(out1, out2)
def test_mixed_layernorm():
ln = MutableLayerNorm(nni.choice('normalized_shape', [32, 64]), elementwise_affine=True)
assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 32}, torch.randn(2, 16, 32)).size(-1) == 32
assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 64}, torch.randn(2, 16, 64)).size(-1) == 64
_mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 16, 64))
import itertools
ln = MutableLayerNorm(nni.choice('normalized_shape', list(itertools.product([16, 32, 64], [8, 16]))))
assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (16, 8)}, torch.randn(2, 16, 8)).shape[-2:]) == [16, 8]
assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (64, 16)}, torch.randn(2, 64, 16)).shape[-2:]) == [64, 16]
_mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 64, 16))
ln = MutableLayerNorm(nni.choice('normalized_shape', [32, 48, 64]))
kwargs = {'normalized_shape': 48}
out1, out2 = _mixed_operation_state_dict_sanity_check(ln, MutableLayerNorm(**kwargs), kwargs, torch.randn(2, 8, 48))
assert torch.allclose(out1, out2)
def test_mixed_mhattn():
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 4)
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8},
torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))[0].size(-1) == 8
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), nni.choice('heads', [2, 3, 4]))
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 2},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
with pytest.raises(AssertionError, match='divisible'):
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 3},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 4, kdim=nni.choice('kdim', [5, 7]))
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7},
torch.randn(7, 2, 4), torch.randn(7, 2, 7), torch.randn(7, 2, 4))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 5},
torch.randn(7, 2, 8), torch.randn(7, 2, 5), torch.randn(7, 2, 8))[0].size(-1) == 8
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 4, vdim=nni.choice('vdim', [5, 8]))
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'vdim': 8},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 8))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'vdim': 5},
torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 5))[0].size(-1) == 8
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(5, 3, 8), torch.randn(5, 3, 8), torch.randn(5, 3, 8))
mhattn = MutableMultiheadAttention(embed_dim=nni.choice('embed_dim', [4, 8, 16]), num_heads=nni.choice('num_heads', [1, 2, 4]),
kdim=nni.choice('kdim', [4, 8, 16]), vdim=nni.choice('vdim', [4, 8, 16]))
kwargs = {'embed_dim': 16, 'num_heads': 2, 'kdim': 4, 'vdim': 8}
(out1, _), (out2, _) = _mixed_operation_state_dict_sanity_check(mhattn, MutableMultiheadAttention(**kwargs), kwargs, torch.randn(7, 2, 16), torch.randn(7, 2, 4), torch.randn(7, 2, 8))
assert torch.allclose(out1, out2)
@pytest.mark.skipif(torch.__version__.startswith('1.7'), reason='batch_first is not supported for legacy PyTorch')
def test_mixed_mhattn_batch_first():
# batch_first is not supported for legacy pytorch versions
# mark 1.7 because 1.7 is used on legacy pipeline
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 2, kdim=(nni.choice('kdim', [3, 7])), vdim=nni.choice('vdim', [5, 8]),
bias=False, add_bias_kv=True, batch_first=True)
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7, 'vdim': 8},
torch.randn(2, 7, 4), torch.randn(2, 7, 7), torch.randn(2, 7, 8))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 3, 'vdim': 5},
torch.randn(2, 7, 8), torch.randn(2, 7, 3), torch.randn(2, 7, 5))[0].size(-1) == 8
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(1, 7, 8), torch.randn(1, 7, 7), torch.randn(1, 7, 8))
def test_pathsampling_layer_input():
op = PathSamplingLayer({'a': MutableLinear(2, 3, bias=False), 'b': MutableLinear(2, 3, bias=True)}, label='ccc')
with pytest.raises(RuntimeError, match='sample'):
op(torch.randn(4, 2))
op.resample({})
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.simplify()['ccc'].values == ['a', 'b']
assert op.export({})['ccc'] in ['a', 'b']
input = PathSamplingInput(5, 2, 'concat', 'ddd')
sample = input.resample({})
assert 'ddd' in sample
assert len(sample['ddd']) == 2
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 4
assert len(input.export({})['ddd']) == 2
def test_differentiable_layer_input():
op = DifferentiableMixedLayer({'a': MutableLinear(2, 3, bias=False), 'b': MutableLinear(2, 3, bias=True)}, nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b']
probs = op.export_probs({})
assert len(probs) == 1
assert len(probs['eee']) == 2
assert abs(probs['eee']['a'] + probs['eee']['b'] - 1) < 1e-4
assert len(list(op.parameters())) == 4
assert len(list(op.arch_parameters())) == 1
with pytest.raises(ValueError):
op = DifferentiableMixedLayer({'a': MutableLinear(2, 3), 'b': MutableLinear(2, 4)}, nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
op(torch.randn(4, 2))
input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
assert len(input.export({})['ddd']) == 2
assert len(input.export_probs({})) == 1
assert len(input.export_probs({})['ddd']) == 5
assert 3 in input.export_probs({})['ddd']
def test_proxyless_layer_input():
op = ProxylessMixedLayer({'a': MutableLinear(2, 3, bias=False), 'b': MutableLinear(2, 3, bias=True)}, nn.Parameter(torch.randn(2)),
nn.Softmax(-1), 'eee')
assert op.resample({})['eee'] in ['a', 'b']
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b']
assert len(list(op.parameters())) == 4
assert len(list(op.arch_parameters())) == 1
input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert all(x in list(range(5)) for x in input.resample({})['ddd'])
assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2])
exported = input.export({})['ddd']
assert len(exported) == 2 and all(e in list(range(5)) for e in exported)
def test_pathsampling_repeat():
op = PathSamplingRepeat([MutableLinear(16, 16), MutableLinear(16, 8), MutableLinear(8, 4)], nni.choice('ccc', [1, 2, 3]))
sample = op.resample({})
assert sample['ccc'] in [1, 2, 3]
for i in range(1, 4):
op.resample({'ccc': i})
out = op(torch.randn(2, 16))
assert out.shape[1] == [16, 8, 4][i - 1]
op = PathSamplingRepeat([MutableLinear(i + 1, i + 2) for i in range(7)], 2 * nni.choice('ddd', [1, 2, 3]) + 1)
sample = op.resample({})
assert sample['ddd'] in [1, 2, 3]
for i in range(1, 4):
op.resample({'ddd': i})
out = op(torch.randn(2, 1))
assert out.shape[1] == (2 * i + 1) + 1
def test_differentiable_repeat():
op = DifferentiableMixedRepeat(
[MutableLinear(8 if i == 0 else 16, 16) for i in range(4)],
nni.choice('ccc', [0, 1]) * 2 + 1,
GumbelSoftmax(-1),
{'ccc': nn.Parameter(torch.randn(2))},
)
op.resample({})
assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
sample = op.export({})
assert 'ccc' in sample and sample['ccc'] in [0, 1]
assert sorted(op.export_probs({})['ccc'].keys()) == [0, 1]
class TupleModule(nn.Module):
def __init__(self, num):
super().__init__()
self.num = num
def forward(self, *args, **kwargs):
return torch.full((2, 3), self.num), torch.full((3, 5), self.num), {'a': 7, 'b': [self.num] * 11}
class CustomSoftmax(nn.Softmax):
def forward(self, *args, **kwargs):
return [0.3, 0.3, 0.4]
op = DifferentiableMixedRepeat(
[TupleModule(i + 1) for i in range(4)],
nni.choice('ccc', [1, 2, 4]),
CustomSoftmax(),
{'ccc': nn.Parameter(torch.randn(3))},
)
op.resample({})
res = op(None)
assert len(res) == 3
assert res[0].shape == (2, 3) and res[0][0][0].item() == 2.5
assert res[2]['a'] == 7
assert len(res[2]['b']) == 11 and res[2]['b'][-1] == 2.5
def test_pathsampling_cell():
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
model = cell_cls()
strategy = RandomOneShot()
model = strategy.mutate_model(model)
nas_modules = [m for m in model.modules() if isinstance(m, BaseSuperNetModule)]
result = {}
for module in nas_modules:
result.update(module.resample(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
result = {}
for module in nas_modules:
result.update(module.export(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
if cell_cls in [CellLooseEnd, CellOpFactory]:
assert isinstance(model.cell, PathSamplingCell)
else:
assert not isinstance(model.cell, PathSamplingCell)
inputs = {
CellSimple: (torch.randn(2, 16), torch.randn(2, 16)),
CellDefaultArgs: (torch.randn(2, 16),),
CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)),
CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)),
CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)),
}[cell_cls]
output = model(*inputs)
if cell_cls == CellCustomProcessor:
assert isinstance(output, tuple) and len(output) == 2 and \
output[1].shape == torch.Size([2, 16 * model.cell.num_nodes])
else:
# no loose-end support for now
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
def test_differentiable_cell():
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
model = cell_cls()
strategy = DARTS()
model = strategy.mutate_model(model)
nas_modules = [m for m in model.modules() if isinstance(m, BaseSuperNetModule)]
result = {}
for module in nas_modules:
result.update(module.export(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
for k, v in result.items():
if 'input' in k:
assert isinstance(v, list) and len(v) == 1
result_prob = {}
for module in nas_modules:
result_prob.update(module.export_probs(memo=result_prob))
ctrl_params = []
for m in nas_modules:
ctrl_params += list(m.arch_parameters())
if cell_cls in [CellLooseEnd, CellOpFactory]:
assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2
assert len(result_prob) == len(ctrl_params) # len(op_names) == 2
for v in result_prob.values():
assert len(v) == 2
assert isinstance(model.cell, DifferentiableMixedCell)
else:
assert not isinstance(model.cell, DifferentiableMixedCell)
inputs = {
CellSimple: (torch.randn(2, 16), torch.randn(2, 16)),
CellDefaultArgs: (torch.randn(2, 16),),
CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)),
CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)),
CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)),
}[cell_cls]
output = model(*inputs)
if cell_cls == CellCustomProcessor:
assert isinstance(output, tuple) and len(output) == 2 and \
output[1].shape == torch.Size([2, 16 * model.cell.num_nodes])
else:
# no loose-end support for now
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
def test_memo_sharing():
class TestModelSpace(ModelSpace):
def __init__(self):
super().__init__()
self.linear1 = Cell(
[nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=3, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end',
label='cell'
)
self.linear2 = Cell(
[nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=3, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end',
label='cell'
)
strategy = DARTS()
model = strategy.mutate_model(TestModelSpace())
assert model.linear1._arch_alpha['cell/2_0'] is model.linear2._arch_alpha['cell/2_0']
def test_parameters():
class Model(ModelSpace):
def __init__(self):
super().__init__()
self.op = DifferentiableMixedLayer(
{
'a': MutableLinear(2, 3, bias=False),
'b': MutableLinear(2, 3, bias=True)
},
nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'abc'
)
def forward(self, x):
return self.op(x)
model = Model()
assert len(list(model.parameters())) == 4
assert len(list(model.op.arch_parameters())) == 1
optimizer = torch.optim.SGD(model.parameters(), 0.1)
assert len(DartsLightningModule(model).arch_parameters()) == 1
optimizer = DartsLightningModule(model).postprocess_weight_optimizers(optimizer)
assert len(optimizer.param_groups[0]['params']) == 3