From d40f408a0a4fb57d758bd070ab3fa3c1a6ca59f9 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 3 Mar 2023 15:33:32 +0800 Subject: [PATCH] NAS oneshot (stage 2) - Supernet modules (#5372) --- nni/nas/oneshot/pytorch/supermodule/base.py | 83 +-- .../pytorch/supermodule/differentiable.py | 311 +++------ .../oneshot/pytorch/supermodule/operation.py | 201 +++--- .../oneshot/pytorch/supermodule/sampling.py | 208 +++--- test/algo/nas/oneshot/test_supermodules.py | 633 ++++++++++++++++++ 5 files changed, 925 insertions(+), 511 deletions(-) create mode 100644 test/algo/nas/oneshot/test_supermodules.py diff --git a/nni/nas/oneshot/pytorch/supermodule/base.py b/nni/nas/oneshot/pytorch/supermodule/base.py index c688cc19d..8f807e61d 100644 --- a/nni/nas/oneshot/pytorch/supermodule/base.py +++ b/nni/nas/oneshot/pytorch/supermodule/base.py @@ -9,59 +9,12 @@ from typing import Any, Dict import torch.nn as nn -from nni.common.hpo_utils import ParameterSpec +from nni.nas.nn.pytorch import MutableModule -__all__ = ['BaseSuperNetModule', 'sub_state_dict'] +__all__ = ['BaseSuperNetModule'] -def sub_state_dict(module: Any, destination: Any=None, prefix: str='', keep_vars: bool=False) -> Dict[str, Any]: - """Returns a dictionary containing a whole state of the BaseSuperNetModule. - - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. - Parameters and buffers set to ``None`` are not included. - - Parameters - ---------- - arch : dict[str, Any] - subnet architecture dict. - destination (dict, optional): - If provided, the state of module will be updated into the dict - and the same object is returned. Otherwise, an ``OrderedDict`` - will be created and returned. Default: ``None``. - prefix (str, optional): - a prefix added to parameter and buffer names to compose the keys in state_dict. - Default: ``''``. - keep_vars (bool, optional): - by default the :class:`~torch.Tensor` s returned in the state dict are - detached from autograd. If it's set to ``True``, detaching will not be performed. - Default: ``False``. - - Returns - ------- - dict - Subnet state dictionary. - """ - if destination is None: - destination = OrderedDict() - destination._metadata = OrderedDict() - - local_metadata = dict(version=module._version) - if hasattr(destination, "_metadata"): - destination._metadata[prefix[:-1]] = local_metadata - - if isinstance(module, BaseSuperNetModule): - module._save_to_sub_state_dict(destination, prefix, keep_vars) - else: - module._save_to_state_dict(destination, prefix, keep_vars) - for name, m in module._modules.items(): - if m is not None: - sub_state_dict(m, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) - - return destination - - -class BaseSuperNetModule(nn.Module): +class BaseSuperNetModule(MutableModule): """ Mutated module in super-net. Usually, the feed-forward of the module itself is undefined. @@ -116,17 +69,6 @@ class BaseSuperNetModule(nn.Module): """ raise NotImplementedError() - def search_space_spec(self) -> dict[str, ParameterSpec]: - """ - Space specification (sample points). - Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export. - - For example: :: - - {"layer1": ParameterSpec(values=["conv", "pool"])} - """ - raise NotImplementedError() - @classmethod def mutate(cls, module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> \ 'BaseSuperNetModule' | bool | tuple['BaseSuperNetModule', bool]: @@ -153,22 +95,3 @@ class BaseSuperNetModule(nn.Module): See :class:`BaseOneShotLightningModule ` for details. """ raise NotImplementedError() - - def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars): - """Save the params and buffers of the current module to state dict.""" - for name, value in itertools.chain(self._parameters.items(), self._buffers.items()): # direct children - if value is None or name in self._non_persistent_buffers_set: - # it won't appear in state dict - continue - destination[prefix + name] = value if keep_vars else value.detach() - - def _save_module_to_state_dict(self, destination, prefix, keep_vars): - """Save the sub-module to state dict.""" - for name, module in self._modules.items(): - if module is not None: - sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) - - def _save_to_sub_state_dict(self, destination, prefix, keep_vars): - """Save to state dict.""" - self._save_param_buff_to_state_dict(destination, prefix, keep_vars) - self._save_module_to_state_dict(destination, prefix, keep_vars) diff --git a/nni/nas/oneshot/pytorch/supermodule/differentiable.py b/nni/nas/oneshot/pytorch/supermodule/differentiable.py index 9bf939a26..efa2599ef 100644 --- a/nni/nas/oneshot/pytorch/supermodule/differentiable.py +++ b/nni/nas/oneshot/pytorch/supermodule/differentiable.py @@ -13,15 +13,14 @@ import torch import torch.nn as nn import torch.nn.functional as F -from nni.common.hpo_utils import ParameterSpec -from nni.nas.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat -from nni.nas.nn.pytorch.choice import ValueChoiceX +from nni.mutable import MutableExpression, Mutable, Categorical +from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat from nni.nas.nn.pytorch.cell import preprocess_cell_inputs from .base import BaseSuperNetModule from .operation import MixedOperation, MixedOperationSamplingPolicy from .sampling import PathSamplingCell -from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, weighted_sum +from ._expression_utils import traverse_all_options, weighted_sum _logger = logging.getLogger(__name__) @@ -46,7 +45,7 @@ class GumbelSoftmax(nn.Softmax): return F.gumbel_softmax(inputs, tau=self.tau, hard=self.hard, dim=self.dim) -class DifferentiableMixedLayer(BaseSuperNetModule): +class DifferentiableMixedLayer(LayerChoice, BaseSuperNetModule): """ Mixed layer, in which fprop is decided by a weighted sum of several layers. Proposed in `DARTS: Differentiable Architecture Search `__. @@ -66,31 +65,18 @@ class DifferentiableMixedLayer(BaseSuperNetModule): Customizable softmax function. Usually ``nn.Softmax(-1)``. label : str Name of the choice. - - Attributes - ---------- - op_names : str - Operator names. - label : str - Name of the choice. """ _arch_parameter_names: list[str] = ['_arch_alpha'] def __init__(self, - paths: list[tuple[str, nn.Module]], + paths: list[nn.Module] | dict[str, nn.Module], alpha: torch.Tensor, softmax: nn.Module, label: str): - super().__init__() - self.op_names = [] + super().__init__(paths, label=label) if len(alpha) != len(paths): raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({len(paths)}).') - for name, module in paths: - self.add_module(name, module) - self.op_names.append(name) - assert self.op_names, 'There has to be at least one op to choose from.' - self.label = label self._arch_alpha = alpha self._softmax = softmax @@ -102,62 +88,41 @@ class DifferentiableMixedLayer(BaseSuperNetModule): """Choose the operator with the maximum logit.""" if self.label in memo: return {} # nothing new to export - return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]} + return {self.label: self.names[int(torch.argmax(self._arch_alpha).item())]} def export_probs(self, memo): - if any(k.startswith(self.label + '/') for k in memo): - return {} # nothing new + if self.label in memo: + return {} weights = self._softmax(self._arch_alpha).cpu().tolist() - ret = {f'{self.label}/{name}': value for name, value in zip(self.op_names, weights)} - return ret - - def search_space_spec(self): - return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ), - True, size=len(self.op_names))} + return {self.label: dict(zip(self.names, weights))} @classmethod def mutate(cls, module, name, memo, mutate_kwargs): - if isinstance(module, LayerChoice): - size = len(module) - if module.label in memo: - alpha = memo[module.label] - if len(alpha) != size: - raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}') - else: - alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later - memo[module.label] = alpha - + if type(module) is LayerChoice: # must be exactly LayerChoice + if module.label not in memo: + raise KeyError(f'LayerChoice {module.label} not found in memo.') + alpha = memo[module.label] softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) - return cls(list(module.named_children()), alpha, softmax, module.label) + return cls(module.candidates, alpha, softmax, module.label) - def reduction(self, items: list[Any], weights: list[float]) -> Any: + def _reduction(self, items: list[Any], weights: list[float]) -> Any: """Override this for customized reduction.""" # Use weighted_sum to handle complex cases where sequential output is not a single tensor return weighted_sum(items, weights) def forward(self, *args, **kwargs): """The forward of mixed layer accepts same arguments as its sub-layer.""" - all_op_results = [getattr(self, op)(*args, **kwargs) for op in self.op_names] - return self.reduction(all_op_results, self._softmax(self._arch_alpha)) + all_op_results = [self[op](*args, **kwargs) for op in self.names] + return self._reduction(all_op_results, self._softmax(self._arch_alpha)) - def parameters(self, *args, **kwargs): - """Parameters excluding architecture parameters.""" - for _, p in self.named_parameters(*args, **kwargs): - yield p - - def named_parameters(self, *args, **kwargs): - """Named parameters excluding architecture parameters.""" - arch = kwargs.pop('arch', False) - for name, p in super().named_parameters(*args, **kwargs): + def arch_parameters(self): + """Iterate over architecture parameters. Not recursive.""" + for name, p in self.named_parameters(): if any(name == par_name for par_name in self._arch_parameter_names): - if arch: - yield name, p - else: - if not arch: - yield name, p + yield p -class DifferentiableMixedInput(BaseSuperNetModule): +class DifferentiableMixedInput(InputChoice, BaseSuperNetModule): """ Mixed input. Forward returns a weighted sum of candidates. Implementation is very similar to :class:`DifferentiableMixedLayer`. @@ -174,11 +139,6 @@ class DifferentiableMixedInput(BaseSuperNetModule): Customizable softmax function. Usually ``nn.Softmax(-1)``. label : str Name of the choice. - - Attributes - ---------- - label : str - Name of the choice. """ _arch_parameter_names: list[str] = ['_arch_alpha'] @@ -189,18 +149,14 @@ class DifferentiableMixedInput(BaseSuperNetModule): alpha: torch.Tensor, softmax: nn.Module, label: str): - super().__init__() - self.n_candidates = n_candidates - if len(alpha) != n_candidates: - raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({n_candidates}).') if n_chosen is None: warnings.warn('Differentiable architecture search does not support choosing multiple inputs. Assuming one.', RuntimeWarning) - self.n_chosen = 1 - self.n_chosen = n_chosen - self.label = label + n_chosen = 1 + super().__init__(n_candidates, n_chosen=n_chosen, label=label) + if len(alpha) != n_candidates: + raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({n_candidates}).') self._softmax = softmax - self._arch_alpha = alpha def resample(self, memo): @@ -212,68 +168,44 @@ class DifferentiableMixedInput(BaseSuperNetModule): if self.label in memo: return {} # nothing new to export chosen = sorted(torch.argsort(-self._arch_alpha).cpu().numpy().tolist()[:self.n_chosen]) - if len(chosen) == 1: - chosen = chosen[0] return {self.label: chosen} def export_probs(self, memo): - if any(k.startswith(self.label + '/') for k in memo): - return {} # nothing new + if self.label in memo: + return {} weights = self._softmax(self._arch_alpha).cpu().tolist() - ret = {f'{self.label}/{index}': value for index, value in enumerate(weights)} - return ret - - def search_space_spec(self): - return { - self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)), - (self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen) - } + return {self.label: dict(enumerate(weights))} @classmethod def mutate(cls, module, name, memo, mutate_kwargs): - if isinstance(module, InputChoice): + if type(module) == InputChoice: # must be exactly InputChoice + module = cast(InputChoice, module) if module.reduction not in ['sum', 'mean']: raise ValueError('Only input choice of sum/mean reduction is supported.') - size = module.n_candidates - if module.label in memo: - alpha = memo[module.label] - if len(alpha) != size: - raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}') - else: - alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later - memo[module.label] = alpha - + if module.label not in memo: + raise KeyError(f'InputChoice {module.label} not found in memo.') + alpha = memo[module.label] softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label) - def reduction(self, items: list[Any], weights: list[float]) -> Any: + def _reduction(self, items: list[Any], weights: list[float]) -> Any: """Override this for customized reduction.""" # Use weighted_sum to handle complex cases where sequential output is not a single tensor return weighted_sum(items, weights) def forward(self, inputs): """Forward takes a list of input candidates.""" - return self.reduction(inputs, self._softmax(self._arch_alpha)) + return self._reduction(inputs, self._softmax(self._arch_alpha)) - def parameters(self, *args, **kwargs): - """Parameters excluding architecture parameters.""" - for _, p in self.named_parameters(*args, **kwargs): - yield p - - def named_parameters(self, *args, **kwargs): - """Named parameters excluding architecture parameters.""" - arch = kwargs.pop('arch', False) - for name, p in super().named_parameters(*args, **kwargs): + def arch_parameters(self): + """Iterate over architecture parameters. Not recursive.""" + for name, p in self.named_parameters(): if any(name == par_name for par_name in self._arch_parameter_names): - if arch: - yield name, p - else: - if not arch: - yield name, p + yield p class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): - """Implementes the differentiable sampling in mixed operation. + """Implements the differentiable sampling in mixed operation. One mixed operation can have multiple value choices in its arguments. Thus the ``_arch_alpha`` here is a parameter dict, and ``named_parameters`` @@ -293,36 +225,21 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None: # Sampling arguments. This should have the same keys with `operation.mutable_arguments` operation._arch_alpha = nn.ParameterDict() - for name, spec in operation.search_space_spec().items(): - if name in memo: - alpha = memo[name] - if len(alpha) != spec.size: - raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}') - else: - alpha = nn.Parameter(torch.randn(spec.size) * 1E-3) - memo[name] = alpha - operation._arch_alpha[name] = alpha + for name in operation.simplify(): + if name not in memo: + raise KeyError(f'Argument {name} not found in memo.') + operation._arch_alpha[str(name)] = memo[name] - operation.parameters = functools.partial(self.parameters, module=operation) # bind self - operation.named_parameters = functools.partial(self.named_parameters, module=operation) + operation.arch_parameters = functools.partial(self.arch_parameters, module=operation) operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) @staticmethod - def parameters(module, *args, **kwargs): - for _, p in module.named_parameters(*args, **kwargs): - yield p - - @staticmethod - def named_parameters(module, *args, **kwargs): - arch = kwargs.pop('arch', False) - for name, p in super(module.__class__, module).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call + def arch_parameters(module): + """Iterate over architecture parameters. Not recursive.""" + for name, p in module.named_parameters(): if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names): - if arch: - yield name, p - else: - if not arch: - yield name, p + yield p def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]: """Differentiable. Do nothing in resample.""" @@ -331,21 +248,21 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]: """Export is argmax for each leaf value choice.""" result = {} - for name, spec in operation.search_space_spec().items(): + for name, spec in operation.simplify().items(): if name in memo: continue chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item()) - result[name] = spec.values[chosen_index] + result[name] = cast(Categorical, spec).values[chosen_index] return result def export_probs(self, operation: MixedOperation, memo: dict[str, Any]): """Export the weight for every leaf value choice.""" ret = {} - for name, spec in operation.search_space_spec().items(): - if any(k.startswith(name + '/') for k in memo): + for name, spec in operation.simplify().items(): + if name in memo: continue weights = operation._softmax(operation._arch_alpha[name]).cpu().tolist() # type: ignore - ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)}) + ret.update({name: dict(zip(cast(Categorical, spec).values, weights))}) return ret def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any: @@ -357,9 +274,9 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): return operation.init_arguments[name] -class DifferentiableMixedRepeat(BaseSuperNetModule): +class DifferentiableMixedRepeat(Repeat, BaseSuperNetModule): """ - Implementaion of Repeat in a differentiable supernet. + Implementation of Repeat in a differentiable supernet. Result is a weighted sum of possible prefixes, sliced by possible depths. If the output is not a single tensor, it will be summed at every independant dimension. @@ -368,27 +285,17 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): _arch_parameter_names: list[str] = ['_arch_alpha'] + depth_choice: MutableExpression[int] + def __init__(self, blocks: list[nn.Module], - depth: ChoiceOf[int], + depth: MutableExpression[int], softmax: nn.Module, - memo: dict[str, Any]): - super().__init__() - self.blocks = blocks - self.depth = depth + alphas: dict[str, Any]): + assert isinstance(depth, Mutable) + super().__init__(blocks, depth) self._softmax = softmax - self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth]) - self._arch_alpha = nn.ParameterDict() - - for name, spec in self._space_spec.items(): - if name in memo: - alpha = memo[name] - if len(alpha) != spec.size: - raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}') - else: - alpha = nn.Parameter(torch.randn(spec.size) * 1E-3) - memo[name] = alpha - self._arch_alpha[name] = alpha + self._arch_alpha = nn.ParameterDict(alphas) def resample(self, memo): """Do nothing.""" @@ -397,48 +304,43 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): def export(self, memo): """Choose argmax for each leaf value choice.""" result = {} - for name, spec in self._space_spec.items(): + for name, spec in self.depth_choice.simplify().items(): if name in memo: continue chosen_index = int(torch.argmax(self._arch_alpha[name]).item()) - result[name] = spec.values[chosen_index] + result[name] = cast(Categorical, spec).values[chosen_index] return result def export_probs(self, memo): """Export the weight for every leaf value choice.""" ret = {} - for name, spec in self.search_space_spec().items(): - if any(k.startswith(name + '/') for k in memo): + for name, spec in self.depth_choice.simplify().items(): + if name in memo: continue weights = self._softmax(self._arch_alpha[name]).cpu().tolist() - ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)}) + ret.update({name: dict(zip(cast(Categorical, spec).values, weights))}) return ret - def search_space_spec(self): - return self._space_spec - @classmethod def mutate(cls, module, name, memo, mutate_kwargs): - if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX): + if type(module) == Repeat and isinstance(module.depth_choice, Mutable): # Repeat and depth is mutable # Only interesting when depth is mutable + module = cast(Repeat, module) + alphas = {} + for name in cast(Mutable, module.depth_choice).simplify(): + if name not in memo: + raise KeyError(f'Mutable depth "{name}" not found in memo') + alphas[name] = memo[name] softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) - return cls(cast(List[nn.Module], module.blocks), module.depth_choice, softmax, memo) + return cls(list(module.blocks), cast(MutableExpression[int], module.depth_choice), softmax, alphas) - def parameters(self, *args, **kwargs): - for _, p in self.named_parameters(*args, **kwargs): - yield p + def arch_parameters(self): + """Iterate over architecture parameters. Not recursive.""" + for name, p in self.named_parameters(): + if any(name.startswith(par_name) for par_name in self._arch_parameter_names): + yield p - def named_parameters(self, *args, **kwargs): - arch = kwargs.pop('arch', False) - for name, p in super().named_parameters(*args, **kwargs): - if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names): - if arch: - yield name, p - else: - if not arch: - yield name, p - - def reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any: + def _reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any: """Override this for customized reduction.""" # Use weighted_sum to handle complex cases where sequential output is not a single tensor return weighted_sum(items, weights) @@ -447,7 +349,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): weights: dict[str, torch.Tensor] = { label: self._softmax(alpha) for label, alpha in self._arch_alpha.items() } - depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights))) + depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth_choice, weights=weights))) res: list[torch.Tensor] = [] weight_list: list[float] = [] @@ -459,7 +361,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): res.append(x) depths.append(i) - return self.reduction(res, weight_list, depths) + return self._reduction(res, weight_list, depths) class DifferentiableMixedCell(PathSamplingCell): @@ -473,6 +375,8 @@ class DifferentiableMixedCell(PathSamplingCell): # TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code. # Possibly need another refactor here. + _arch_parameter_names: list[str] = ['_arch_alpha'] + def __init__( self, op_factory, num_nodes, num_ops_per_node, num_predecessors, preprocessor, postprocessor, concat_dim, @@ -486,10 +390,13 @@ class DifferentiableMixedCell(PathSamplingCell): self._arch_alpha = nn.ParameterDict() for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors): for j in range(i): - edge_label = f'{label}/{i}_{j}' + edge_label = f'{self.label}/{i}_{j}' + # Some parameters still need to be created here inside. + # We should avoid conflict with "outside parameters". + memo_label = edge_label + '/in_cell' op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j] - if edge_label in memo: - alpha = memo[edge_label] + if memo_label in memo: + alpha = memo[memo_label] if len(alpha) != len(op) + 1: if len(alpha) != len(op): raise ValueError( @@ -504,7 +411,7 @@ class DifferentiableMixedCell(PathSamplingCell): else: # +1 to emulate the input choice. alpha = nn.Parameter(torch.randn(len(op) + 1) * 1E-3) - memo[edge_label] = alpha + memo[memo_label] = alpha self._arch_alpha[edge_label] = alpha self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) @@ -517,10 +424,10 @@ class DifferentiableMixedCell(PathSamplingCell): """When export probability, we follow the structure in arch alpha.""" ret = {} for name, parameter in self._arch_alpha.items(): - if any(k.startswith(name + '/') for k in memo): + if name in memo: continue weights = self._softmax(parameter).cpu().tolist() - ret.update({f'{name}/{value}': weight for value, weight in zip(self.op_names, weights)}) + ret.update({name: dict(zip(self.op_names, weights))}) return ret def export(self, memo): @@ -565,7 +472,7 @@ class DifferentiableMixedCell(PathSamplingCell): # all_weights could be too short in case ``num_ops_per_node`` is too large. _, j, op_name = all_weights[k % len(all_weights)] exported[f'{self.label}/op_{i}_{k}'] = op_name - exported[f'{self.label}/input_{i}_{k}'] = j + exported[f'{self.label}/input_{i}_{k}'] = [j] return exported @@ -579,10 +486,10 @@ class DifferentiableMixedCell(PathSamplingCell): op_results = torch.stack([op(states[j]) for op in ops[j].values()]) alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) # (-1, 1, 1, 1, 1, ...) op_weights = self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']) - if len(op_weights) == len(op_results) + 1: + if op_weights.size(0) == op_results.size(0) + 1: # concatenate with a zero operation, indicating this path is not chosen at all. op_results = torch.cat((op_results, torch.zeros_like(op_results[:1])), 0) - edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0) + edge_sum = torch.sum(op_results * op_weights.view(*alpha_shape), 0) current_state.append(edge_sum) states.append(sum(current_state)) # type: ignore @@ -591,16 +498,8 @@ class DifferentiableMixedCell(PathSamplingCell): this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim) return self.postprocessor(this_cell, processed_inputs) - def parameters(self, *args, **kwargs): - for _, p in self.named_parameters(*args, **kwargs): - yield p - - def named_parameters(self, *args, **kwargs): - arch = kwargs.pop('arch', False) - for name, p in super().named_parameters(*args, **kwargs): - if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names): - if arch: - yield name, p - else: - if not arch: - yield name, p + def arch_parameters(self): + """Iterate over architecture parameters. Not recursive.""" + for name, p in self.named_parameters(): + if any(name.startswith(par_name) for par_name in self._arch_parameter_names): + yield p diff --git a/nni/nas/oneshot/pytorch/supermodule/operation.py b/nni/nas/oneshot/pytorch/supermodule/operation.py index d6dddcf66..b79948ad7 100644 --- a/nni/nas/oneshot/pytorch/supermodule/operation.py +++ b/nni/nas/oneshot/pytorch/supermodule/operation.py @@ -18,13 +18,15 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -import nni.nas.nn.pytorch as nas_nn -from nni.common.hpo_utils import ParameterSpec from nni.common.serializer import is_traceable -from nni.nas.nn.pytorch.choice import ValueChoiceX +from nni.mutable import MutableExpression +from nni.nas.nn.pytorch import ( + ParametrizedModule, + MutableConv2d, MutableLinear, MutableBatchNorm2d, MutableLayerNorm, MutableMultiheadAttention +) -from .base import BaseSuperNetModule, sub_state_dict -from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant +from .base import BaseSuperNetModule +from ._expression_utils import traverse_all_options, evaluate_constant from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict T = TypeVar('T') @@ -40,7 +42,7 @@ __all__ = [ 'NATIVE_MIXED_OPERATIONS', ] -_diff_not_compatible_error = 'To be compatible with differentiable one-shot strategy, {} in {} must not be ValueChoice.' +_diff_not_compatible_error = 'To be compatible with differentiable one-shot strategy, {} in {} must not be mutable.' class MixedOperationSamplingPolicy: @@ -84,10 +86,10 @@ class MixedOperationSamplingPolicy: class MixedOperation(BaseSuperNetModule): """This is the base class for all mixed operations. - It's what you should inherit to support a new operation with ValueChoice. + It's what you should inherit to support a new operation with mutable. - It contains commonly used utilities that will ease the effort to write customized mixed oeprations, - i.e., operations with ValueChoice in its arguments. + It contains commonly used utilities that will ease the effort to write customized mixed operations, + i.e., operations with mutable in its arguments. To customize, please write your own mixed operation, and add the hook into ``mutation_hooks`` parameter when using the strategy. By design, for a mixed operation to work in a specific algorithm, @@ -116,14 +118,14 @@ class MixedOperation(BaseSuperNetModule): sampling_policy: MixedOperationSamplingPolicy - def super_init_argument(self, name: str, value_choice: ValueChoiceX) -> Any: + def super_init_argument(self, name: str, value_choice: MutableExpression) -> Any: """Get the initialization argument when constructing super-kernel, i.e., calling ``super().__init__()``. This is often related to specific operator, rather than algo. For example:: def super_init_argument(self, name, value_choice): - return max(value_choice.candidates) + return max(value_choice.grid()) """ raise NotImplementedError() @@ -138,8 +140,8 @@ class MixedOperation(BaseSuperNetModule): def __init__(self, module_kwargs: dict[str, Any]) -> None: # Concerned arguments - self.mutable_arguments: dict[str, ValueChoiceX] = {} - # Useful when retrieving arguments without ValueChoice + self.mutable_arguments: dict[str, MutableExpression] = {} + # Useful when retrieving arguments without mutable self.init_arguments: dict[str, Any] = {**module_kwargs} self._fill_missing_init_arguments() @@ -147,21 +149,37 @@ class MixedOperation(BaseSuperNetModule): super_init_kwargs = {} for key, value in module_kwargs.items(): - if isinstance(value, ValueChoiceX): + if isinstance(value, MutableExpression): if key not in self.argument_list: - raise TypeError(f'Unsupported value choice on argument of {self.bound_type}: {key}') + raise TypeError(f'Unsupported mutable argument of "{self.bound_type}": {key}') super_init_kwargs[key] = self.super_init_argument(key, value) self.mutable_arguments[key] = value else: super_init_kwargs[key] = value - # get all inner leaf value choices - self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices(list(self.mutable_arguments.values())) - super().__init__(**super_init_kwargs) + for mutable in self.mutable_arguments.values(): + self.add_mutable(mutable) + self.__post_init__() + def freeze(self, sample) -> Any: + """Freeze the mixed operation to a specific operation. + Weights will be copied from the mixed operation to the frozen operation. + + The returned operation will be of the ``bound_type``. + """ + arguments = {**self.init_arguments} + for name, mutable in self.mutable_arguments.items(): + arguments[name] = mutable.freeze(sample) + operation = self.bound_type(**arguments) + + # copy weights + state_dict = self.freeze_weight(**arguments) + operation.load_state_dict(state_dict) + return operation + def resample(self, memo): """Delegates to :meth:`MixedOperationSamplingPolicy.resample`.""" return self.sampling_policy.resample(self, memo) @@ -174,21 +192,12 @@ class MixedOperation(BaseSuperNetModule): """Delegates to :meth:`MixedOperationSamplingPolicy.export`.""" return self.sampling_policy.export(self, memo) - def search_space_spec(self): - return self._space_spec - @classmethod def mutate(cls, module, name, memo, mutate_kwargs): """Find value choice in module's arguments and replace the whole module""" - has_valuechoice = False - if isinstance(module, cls.bound_type) and is_traceable(module): - for arg in itertools.chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()): - if isinstance(arg, ValueChoiceX): - has_valuechoice = True - - if has_valuechoice: + if isinstance(module, cls.bound_type) and isinstance(module, ParametrizedModule): if module.trace_args: - raise ValueError('ValueChoice on class arguments cannot appear together with ``trace_args``. ' + raise ValueError('Mutable on class arguments cannot appear together with ``trace_args``. ' 'Please enable ``kw_only`` on nni.trace.') # save type and kwargs @@ -232,21 +241,12 @@ class MixedOperation(BaseSuperNetModule): if param.default is not param.empty and param.name not in self.init_arguments: self.init_arguments[param.name] = param.default - def slice_param(self, **kwargs): + def freeze_weight(self, **kwargs): """Slice the params and buffers for subnet forward and state dict. - When there is a `mapping=True` in kwargs, the return result will be wrapped in dict. - """ - raise NotImplementedError() - def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars): - kwargs = {name: self.forward_argument(name) for name in self.argument_list} - params_mapping: dict[str, Any] = self.slice_param(**kwargs) - for name, value in itertools.chain(self._parameters.items(), self._buffers.items()): # direct children - if value is None or name in self._non_persistent_buffers_set: - # it won't appear in state dict - continue - value = params_mapping.get(name, value) - destination[prefix + name] = value if keep_vars else value.detach() + The arguments are same as the arguments passed to ``__init__``. + """ + raise NotImplementedError('freeze_weight is not implemented.') class MixedLinear(MixedOperation, nn.Linear): @@ -260,13 +260,13 @@ class MixedLinear(MixedOperation, nn.Linear): Prefix of weight and bias will be sliced. """ - bound_type = nas_nn.Linear + bound_type = MutableLinear argument_list = ['in_features', 'out_features'] - def super_init_argument(self, name: str, value_choice: ValueChoiceX): - return max(traverse_all_options(value_choice)) + def super_init_argument(self, name: str, value_choice: MutableExpression): + return max(value_choice.grid()) - def slice_param(self, in_features: int_or_int_dict, out_features: int_or_int_dict, **kwargs) -> Any: + def freeze_weight(self, in_features: int_or_int_dict, out_features: int_or_int_dict, **kwargs) -> Any: in_features_ = _W(in_features) out_features_ = _W(out_features) @@ -281,7 +281,7 @@ class MixedLinear(MixedOperation, nn.Linear): out_features: int_or_int_dict, inputs: torch.Tensor) -> torch.Tensor: - params_mapping = self.slice_param(in_features, out_features) + params_mapping = self.freeze_weight(in_features, out_features) weight, bias = [params_mapping.get(name) for name in ['weight', 'bias']] return F.linear(inputs, weight, bias) @@ -321,7 +321,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d): □ □ □ □ □ □ □ □ □ □ """ - bound_type = nas_nn.Conv2d + bound_type = MutableConv2d argument_list = [ 'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'dilation', 'groups' ] @@ -332,12 +332,12 @@ class MixedConv2d(MixedOperation, nn.Conv2d): return (value, value) return value - def super_init_argument(self, name: str, value_choice: ValueChoiceX): + def super_init_argument(self, name: str, mutable_expr: MutableExpression): if name not in ['in_channels', 'out_channels', 'groups', 'stride', 'kernel_size', 'padding', 'dilation']: raise NotImplementedError(f'Unsupported value choice on argument: {name}') if name == ['kernel_size', 'padding']: - all_sizes = set(traverse_all_options(value_choice)) + all_sizes = set(traverse_all_options(mutable_expr)) if any(isinstance(sz, tuple) for sz in all_sizes): # maximum kernel should be calculated on every dimension return ( @@ -351,28 +351,37 @@ class MixedConv2d(MixedOperation, nn.Conv2d): if 'in_channels' in self.mutable_arguments: # If the ratio is constant, we don't need to try the maximum groups. try: - constant = evaluate_constant(self.mutable_arguments['in_channels'] / value_choice) - return max(cast(List[float], traverse_all_options(value_choice))) // int(constant) + constant = evaluate_constant(self.mutable_arguments['in_channels'] / mutable_expr) + return max(cast(List[float], traverse_all_options(mutable_expr))) // int(constant) except ValueError: warnings.warn( - 'Both input channels and groups are ValueChoice in a convolution, and their relative ratio is not a constant. ' + 'Both input channels and groups are mutable in a convolution, and their relative ratio is not a constant. ' 'This can be problematic for most one-shot algorithms. Please check whether this is your intention.', RuntimeWarning ) # minimum groups, maximum kernel - return min(traverse_all_options(value_choice)) + return min(traverse_all_options(mutable_expr)) else: - return max(traverse_all_options(value_choice)) + return max(traverse_all_options(mutable_expr)) - def slice_param(self, + def freeze_weight(self, in_channels: int_or_int_dict, out_channels: int_or_int_dict, kernel_size: scalar_or_scalar_dict[_int_or_tuple], groups: int_or_int_dict, - **kwargs - ) -> Any: + **kwargs) -> Any: + rv = self._freeze_weight_impl(in_channels, out_channels, kernel_size, groups) + rv.pop('in_channels_per_group', None) + return rv + + def _freeze_weight_impl(self, + in_channels: int_or_int_dict, + out_channels: int_or_int_dict, + kernel_size: scalar_or_scalar_dict[_int_or_tuple], + groups: int_or_int_dict, + **kwargs) -> Any: in_channels_ = _W(in_channels) out_channels_ = _W(out_channels) @@ -386,8 +395,8 @@ class MixedConv2d(MixedOperation, nn.Conv2d): in_channels_per_group = None else: assert 'groups' in self.mutable_arguments - err_message = 'For differentiable one-shot strategy, when groups is a ValueChoice, ' \ - 'in_channels and out_channels should also be a ValueChoice. ' \ + err_message = 'For differentiable one-shot strategy, when groups is a mutable, ' \ + 'in_channels and out_channels should also be a mutable. ' \ 'Also, the ratios of in_channels divided by groups, and out_channels divided by groups ' \ 'should be constants.' if 'in_channels' not in self.mutable_arguments or 'out_channels' not in self.mutable_arguments: @@ -412,7 +421,6 @@ class MixedConv2d(MixedOperation, nn.Conv2d): return {'weight': weight, 'bias': bias, 'in_channels_per_group': in_channels_per_group} - def forward_with_args(self, in_channels: int_or_int_dict, out_channels: int_or_int_dict, @@ -426,7 +434,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d): if any(isinstance(arg, dict) for arg in [stride, dilation]): raise ValueError(_diff_not_compatible_error.format('stride, dilation', 'Conv2d')) - params_mapping = self.slice_param(in_channels, out_channels, kernel_size, groups) + params_mapping = self._freeze_weight_impl(in_channels, out_channels, kernel_size, groups) weight, bias, in_channels_per_group = [ params_mapping.get(name) for name in ['weight', 'bias', 'in_channels_per_group'] @@ -478,13 +486,13 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): PyTorch BatchNorm supports a case where momentum can be none, which is not supported here. """ - bound_type = nas_nn.BatchNorm2d + bound_type = MutableBatchNorm2d argument_list = ['num_features', 'eps', 'momentum'] - def super_init_argument(self, name: str, value_choice: ValueChoiceX): - return max(traverse_all_options(value_choice)) + def super_init_argument(self, name: str, mutable_expr: MutableExpression): + return max(traverse_all_options(mutable_expr)) - def slice_param(self, num_features: int_or_int_dict, **kwargs) -> Any: + def freeze_weight(self, num_features: int_or_int_dict, **kwargs) -> Any: if isinstance(num_features, dict): num_features = self.num_features weight, bias = self.weight, self.bias @@ -508,7 +516,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): if any(isinstance(arg, dict) for arg in [eps, momentum]): raise ValueError(_diff_not_compatible_error.format('eps and momentum', 'BatchNorm2d')) - params_mapping = self.slice_param(num_features) + params_mapping = self.freeze_weight(num_features) weight, bias, running_mean, running_var = [ params_mapping.get(name) for name in ['weight', 'bias', 'running_mean', 'running_var'] @@ -547,7 +555,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm): eps is required to be float. """ - bound_type = nas_nn.LayerNorm + bound_type = MutableLayerNorm argument_list = ['normalized_shape', 'eps'] @staticmethod @@ -556,10 +564,10 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm): return (value, value) return value - def super_init_argument(self, name: str, value_choice: ValueChoiceX): + def super_init_argument(self, name: str, mutable_expr: MutableExpression): if name not in ['normalized_shape']: raise NotImplementedError(f'Unsupported value choice on argument: {name}') - all_sizes = set(traverse_all_options(value_choice)) + all_sizes = set(traverse_all_options(mutable_expr)) if any(isinstance(sz, (tuple, list)) for sz in all_sizes): # transpose all_sizes = list(zip(*all_sizes)) @@ -568,7 +576,12 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm): else: return max(all_sizes) - def slice_param(self, normalized_shape, **kwargs) -> Any: + def freeze_weight(self, normalized_shape, **kwargs) -> Any: + rv = self._freeze_weight_impl(normalized_shape) + rv.pop('normalized_shape') + return rv + + def _freeze_weight_impl(self, normalized_shape, **kwargs) -> Any: if isinstance(normalized_shape, dict): normalized_shape = self.normalized_shape @@ -595,7 +608,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm): if any(isinstance(arg, dict) for arg in [eps]): raise ValueError(_diff_not_compatible_error.format('eps', 'LayerNorm')) - params_mapping = self.slice_param(normalized_shape) + params_mapping = self._freeze_weight_impl(normalized_shape) weight, bias, normalized_shape = [ params_mapping.get(name) for name in ['weight', 'bias', 'normalized_shape'] @@ -633,7 +646,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``. """ - bound_type = nas_nn.MultiheadAttention + bound_type = MutableMultiheadAttention argument_list = ['embed_dim', 'num_heads', 'kdim', 'vdim', 'dropout'] def __post_init__(self): @@ -671,8 +684,17 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): nn.init.xavier_uniform_(self.k_proj_weight) nn.init.xavier_uniform_(self.v_proj_weight) - def super_init_argument(self, name: str, value_choice: ValueChoiceX): - return max(traverse_all_options(value_choice)) + def super_init_argument(self, name: str, mutable_expr: MutableExpression): + return max(traverse_all_options(mutable_expr)) + + def freeze_weight(self, embed_dim, kdim, vdim, **kwargs): + rv = self._freeze_weight_impl(embed_dim, kdim, vdim, **kwargs) + # pop flags and nones, as they won't show in state dict + rv.pop('qkv_same_embed_dim') + for k in list(rv): + if rv[k] is None: + rv.pop(k) + return rv def _to_proj_slice(self, embed_dim: _W) -> list[slice]: # slice three parts, corresponding to q, k, v respectively @@ -682,7 +704,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): slice(self.embed_dim * 2, self.embed_dim * 2 + embed_dim) ] - def slice_param(self, embed_dim, kdim, vdim, **kwargs): + def _freeze_weight_impl(self, embed_dim, kdim, vdim, **kwargs): # by default, kdim, vdim can be none if kdim is None: kdim = embed_dim @@ -723,31 +745,6 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): 'qkv_same_embed_dim': qkv_same_embed_dim } - def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars): - kwargs = {name: self.forward_argument(name) for name in self.argument_list} - params_mapping = self.slice_param(**kwargs, mapping=True) - for name, value in itertools.chain(self._parameters.items(), self._buffers.items()): - if value is None or name in self._non_persistent_buffers_set: - continue - value = params_mapping.get(name, value) - destination[prefix + name] = value if keep_vars else value.detach() - - # params of out_proj is handled in ``MixedMultiHeadAttention`` rather than - # ``NonDynamicallyQuantizableLinear`` sub-module. We also convert it to state dict here. - for name in ["out_proj.weight", "out_proj.bias"]: - value = params_mapping.get(name, None) - if value is None: - continue - destination[prefix + name] = value if keep_vars else value.detach() - - def _save_module_to_state_dict(self, destination, prefix, keep_vars): - for name, module in self._modules.items(): - # the weights of ``NonDynamicallyQuantizableLinear`` has been handled in `_save_param_buff_to_state_dict`. - if isinstance(module, nn.modules.linear.NonDynamicallyQuantizableLinear): - continue - if module is not None: - sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) - def forward_with_args( self, embed_dim: int_or_int_dict, num_heads: int, @@ -770,7 +767,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): else: used_embed_dim = embed_dim - params_mapping = self.slice_param(embed_dim, kdim, vdim) + params_mapping = self._freeze_weight_impl(embed_dim, kdim, vdim) in_proj_bias, in_proj_weight, bias_k, bias_v, \ out_proj_weight, out_proj_bias, q_proj, k_proj, v_proj, qkv_same_embed_dim = [ params_mapping.get(name) diff --git a/nni/nas/oneshot/pytorch/supermodule/sampling.py b/nni/nas/oneshot/pytorch/supermodule/sampling.py index 29a3592cf..d8d0512c3 100644 --- a/nni/nas/oneshot/pytorch/supermodule/sampling.py +++ b/nni/nas/oneshot/pytorch/supermodule/sampling.py @@ -6,19 +6,19 @@ from __future__ import annotations import copy +import functools import random from typing import Any, List, Dict, Sequence, cast import torch import torch.nn as nn -from nni.common.hpo_utils import ParameterSpec -from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell -from nni.nas.nn.pytorch.choice import ValueChoiceX +from nni.mutable import MutableExpression, label_scope, Mutable, Categorical, CategoricalMultiple +from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, Cell from nni.nas.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs -from .base import BaseSuperNetModule, sub_state_dict -from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum +from .base import BaseSuperNetModule +from ._expression_utils import weighted_sum from .operation import MixedOperationSamplingPolicy, MixedOperation __all__ = [ @@ -28,7 +28,7 @@ __all__ = [ ] -class PathSamplingLayer(BaseSuperNetModule): +class PathSamplingLayer(LayerChoice, BaseSuperNetModule): """ Mixed layer, in which fprop is decided by exactly one inner layer or sum of multiple (sampled) layers. If multiple modules are selected, the result will be summed and returned. @@ -41,62 +41,43 @@ class PathSamplingLayer(BaseSuperNetModule): Name of the choice. """ - def __init__(self, paths: list[tuple[str, nn.Module]], label: str): - super().__init__() - self.op_names = [] - for name, module in paths: - self.add_module(name, module) - self.op_names.append(name) - assert self.op_names, 'There has to be at least one op to choose from.' + def __init__(self, paths: dict[str, nn.Module] | list[nn.Module], label: str): + super().__init__(paths, label=label) self._sampled: list[str] | str | None = None # sampled can be either a list of indices or an index - self.label = label def resample(self, memo): """Random choose one path if label is not found in memo.""" if self.label in memo: self._sampled = memo[self.label] else: - self._sampled = random.choice(self.op_names) + self._sampled = self.choice.random() return {self.label: self._sampled} def export(self, memo): """Random choose one name if label isn't found in memo.""" if self.label in memo: return {} # nothing new to export - return {self.label: random.choice(self.op_names)} - - def search_space_spec(self): - return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ), - True, size=len(self.op_names))} + return {self.label: self.choice.random()} @classmethod def mutate(cls, module, name, memo, mutate_kwargs): - if isinstance(module, LayerChoice): - return cls(list(module.named_children()), module.label) + if type(module) is LayerChoice: + return cls(module.candidates, module.label) - def reduction(self, items: list[Any], sampled: list[Any]): + def _reduction(self, items: list[Any], sampled: list[Any]): """Override this to implement customized reduction.""" return weighted_sum(items) - def _save_module_to_state_dict(self, destination, prefix, keep_vars): - sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled - - for samp in sampled: - module = getattr(self, str(samp)) - if module is not None: - sub_state_dict(module, destination=destination, prefix=prefix, keep_vars=keep_vars) - def forward(self, *args, **kwargs): if self._sampled is None: raise RuntimeError('At least one path needs to be sampled before fprop.') sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled - # str(samp) is needed here because samp can sometimes be integers, but attr are always str - res = [getattr(self, str(samp))(*args, **kwargs) for samp in sampled] - return self.reduction(res, sampled) + res = [self[samp](*args, **kwargs) for samp in sampled] + return self._reduction(res, sampled) -class PathSamplingInput(BaseSuperNetModule): +class PathSamplingInput(InputChoice, BaseSuperNetModule): """ Mixed input. Take a list of tensor as input, select some of them and return the sum. @@ -106,22 +87,9 @@ class PathSamplingInput(BaseSuperNetModule): Sampled input indices. """ - def __init__(self, n_candidates: int, n_chosen: int, reduction_type: str, label: str): - super().__init__() - self.n_candidates = n_candidates - self.n_chosen = n_chosen - self.reduction_type = reduction_type + def __init__(self, n_candidates: int, n_chosen: int, reduction: str, label: str): + super().__init__(n_candidates, n_chosen=n_chosen, reduction=reduction, label=label) self._sampled: list[int] | int | None = None - self.label = label - - def _random_choose_n(self): - sampling = list(range(self.n_candidates)) - random.shuffle(sampling) - sampling = sorted(sampling[:self.n_chosen]) - if len(sampling) == 1: - return sampling[0] - else: - return sampling def resample(self, memo): """Random choose one path / multiple paths if label is not found in memo. @@ -131,42 +99,36 @@ class PathSamplingInput(BaseSuperNetModule): if self.label in memo: self._sampled = memo[self.label] else: - self._sampled = self._random_choose_n() + self._sampled = self.choice.random() return {self.label: self._sampled} def export(self, memo): """Random choose one name if label isn't found in memo.""" if self.label in memo: return {} # nothing new to export - return {self.label: self._random_choose_n()} - - def search_space_spec(self): - return { - self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)), - (self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen) - } + return {self.label: self.choice.random()} @classmethod def mutate(cls, module, name, memo, mutate_kwargs): - if isinstance(module, InputChoice): + if type(module) is InputChoice: if module.reduction not in ['sum', 'mean', 'concat']: raise ValueError('Only input choice of sum/mean/concat reduction is supported.') if module.n_chosen is None: raise ValueError('n_chosen is None is not supported yet.') return cls(module.n_candidates, module.n_chosen, module.reduction, module.label) - def reduction(self, items: list[Any], sampled: list[Any]) -> Any: + def _reduction(self, items: list[Any], sampled: list[Any]) -> Any: """Override this to implement customized reduction.""" if len(items) == 1: return items[0] else: - if self.reduction_type == 'sum': + if self.reduction == 'sum': return sum(items) - elif self.reduction_type == 'mean': + elif self.reduction == 'mean': return sum(items) / len(items) - elif self.reduction_type == 'concat': + elif self.reduction == 'concat': return torch.cat(items, 1) - raise ValueError(f'Unsupported reduction type: {self.reduction_type}') + raise ValueError(f'Unsupported reduction type: {self.reduction}') def forward(self, input_tensors): if self._sampled is None: @@ -175,7 +137,7 @@ class PathSamplingInput(BaseSuperNetModule): raise ValueError(f'Expect {self.n_candidates} input tensors, found {len(input_tensors)}.') sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled res = [input_tensors[samp] for samp in sampled] - return self.reduction(res, sampled) + return self._reduction(res, sampled) class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): @@ -193,28 +155,28 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]: """Random sample for each leaf value choice.""" result = {} - space_spec = operation.search_space_spec() - for label in space_spec: + space_spec = operation.simplify() + for label, mutable in space_spec.items(): if label in memo: result[label] = memo[label] else: - result[label] = random.choice(space_spec[label].values) + result[label] = mutable.random() - # composits to kwargs + # composites to kwargs # example: result = {"exp_ratio": 3}, self._sampled = {"in_channels": 48, "out_channels": 96} self._sampled = {} for key, value in operation.mutable_arguments.items(): - self._sampled[key] = evaluate_value_choice_with_dict(value, result) + self._sampled[key] = value.freeze(result) return result def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]: """Export is also random for each leaf value choice.""" result = {} - space_spec = operation.search_space_spec() - for label in space_spec: + space_spec = operation.simplify() + for label, mutable in space_spec.items(): if label not in memo: - result[label] = random.choice(space_spec[label].values) + result[label] = mutable.random() return result def forward_argument(self, operation: MixedOperation, name: str) -> Any: @@ -226,9 +188,9 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): return operation.init_arguments[name] -class PathSamplingRepeat(BaseSuperNetModule): +class PathSamplingRepeat(Repeat, BaseSuperNetModule): """ - Implementaion of Repeat in a path-sampling supernet. + Implementation of Repeat in a path-sampling supernet. Samples one / some of the prefixes of the repeated blocks. Attributes @@ -237,56 +199,43 @@ class PathSamplingRepeat(BaseSuperNetModule): Sampled depth. """ - def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int]): - super().__init__() - self.blocks: Any = blocks - self.depth = depth - self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth]) + def __init__(self, blocks: list[nn.Module], depth: MutableExpression[int]): + super().__init__(blocks, depth) self._sampled: list[int] | int | None = None def resample(self, memo): """Since depth is based on ValueChoice, we only need to randomly sample every leaf value choices.""" result = {} - for label in self._space_spec: + assert isinstance(self.depth_choice, Mutable) + for label, mutable in self.depth_choice.simplify().items(): if label in memo: result[label] = memo[label] else: - result[label] = random.choice(self._space_spec[label].values) + result[label] = mutable.random() - self._sampled = evaluate_value_choice_with_dict(self.depth, result) + self._sampled = self.depth_choice.freeze(result) return result def export(self, memo): """Random choose one if every choice not in memo.""" result = {} - for label in self._space_spec: + assert isinstance(self.depth_choice, Mutable) + for label, mutable in self.depth_choice.simplify().items(): if label not in memo: - result[label] = random.choice(self._space_spec[label].values) + result[label] = mutable.random() return result - def search_space_spec(self): - return self._space_spec - @classmethod def mutate(cls, module, name, memo, mutate_kwargs): - if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX): + if type(module) == Repeat and isinstance(module.depth_choice, MutableExpression): # Only interesting when depth is mutable - return cls(cast(List[nn.Module], module.blocks), module.depth_choice) + return cls(list(module.blocks), module.depth_choice) - def reduction(self, items: list[Any], sampled: list[Any]): + def _reduction(self, items: list[Any], sampled: list[Any]): """Override this to implement customized reduction.""" return weighted_sum(items) - def _save_module_to_state_dict(self, destination, prefix, keep_vars): - sampled: Any = [self._sampled] if not isinstance(self._sampled, list) else self._sampled - - for cur_depth, (name, module) in enumerate(self.blocks.named_children(), start=1): - if module is not None: - sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) - if not any(d > cur_depth for d in sampled): - break - def forward(self, x): if self._sampled is None: raise RuntimeError('At least one depth needs to be sampled before fprop.') @@ -299,7 +248,7 @@ class PathSamplingRepeat(BaseSuperNetModule): res.append(x) if not any(d > cur_depth for d in sampled): break - return self.reduction(res, sampled) + return self._reduction(res, sampled) class PathSamplingCell(BaseSuperNetModule): @@ -345,39 +294,46 @@ class PathSamplingCell(BaseSuperNetModule): # InputChoice is implicit in this graph. for i in self.output_node_indices: self.ops.append(nn.ModuleList()) - for k in range(i + self.num_predecessors): + for k in range(i): # Second argument in (i, **0**, k) is always 0. # One-shot strategy can't handle the cases where op spec is dependent on `op_index`. ops, _ = create_cell_op_candidates(op_factory, i, 0, k) self.op_names = list(ops.keys()) cast(nn.ModuleList, self.ops[-1]).append(nn.ModuleDict(ops)) - self.label = label + with label_scope(label) as self.label_scope: + for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors): + for k in range(self.num_ops_per_node): + op_label = f'op_{i}_{k}' + input_label = f'input_{i}_{k}' + self.add_mutable(Categorical(self.op_names, label=op_label)) + # Need multiple here to align with the original cell. + self.add_mutable(CategoricalMultiple(range(i), n_chosen=1, label=input_label)) self._sampled: dict[str, str | int] = {} - def search_space_spec(self) -> dict[str, ParameterSpec]: - # TODO: Recreating the space here. - # The spec should be moved to definition of Cell itself. - space_spec = {} - for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors): - for k in range(self.num_ops_per_node): - op_label = f'{self.label}/op_{i}_{k}' - input_label = f'{self.label}/input_{i}_{k}' - space_spec[op_label] = ParameterSpec(op_label, 'choice', self.op_names, (op_label,), True, size=len(self.op_names)) - space_spec[input_label] = ParameterSpec(input_label, 'choice', list(range(i)), (input_label, ), True, size=i) - return space_spec + @property + def label(self) -> str: + return self.label_scope.name + + def freeze(self, sample): + raise NotImplementedError('PathSamplingCell does not support freeze.') def resample(self, memo): """Random choose one path if label is not found in memo.""" self._sampled = {} new_sampled = {} - for label, param_spec in self.search_space_spec().items(): + for label, param_spec in self.simplify().items(): if label in memo: - assert not isinstance(memo[label], list), 'Multi-path sampling is currently unsupported on cell.' + if isinstance(memo[label], list) and len(memo[label]) > 1: + raise ValueError(f'Multi-path sampling is currently unsupported on cell: {memo[label]}') self._sampled[label] = memo[label] else: - self._sampled[label] = new_sampled[label] = random.choice(param_spec.values) + if isinstance(param_spec, Categorical): + self._sampled[label] = new_sampled[label] = random.choice(param_spec.values) + elif isinstance(param_spec, CategoricalMultiple): + assert param_spec.n_chosen == 1 + self._sampled[label] = new_sampled[label] = [random.choice(param_spec.values)] return new_sampled def export(self, memo): @@ -392,7 +348,7 @@ class PathSamplingCell(BaseSuperNetModule): for k in range(self.num_ops_per_node): # Select op list based on the input chosen - input_index = self._sampled[f'{self.label}/input_{i}_{k}'] + input_index = self._sampled[f'{self.label}/input_{i}_{k}'][0] # [0] because it's a list and n_chosen=1 op_candidates = ops[cast(int, input_index)] # Select op from op list based on the op chosen op_index = self._sampled[f'{self.label}/op_{i}_{k}'] @@ -411,7 +367,7 @@ class PathSamplingCell(BaseSuperNetModule): Mutate only handles cells of specific configurations (e.g., with loose end). Fallback to the default mutate if the cell is not handled here. """ - if isinstance(module, Cell): + if type(module) is Cell: op_factory = None # not all the cells need to be replaced if module.op_candidates_factory is not None: op_factory = module.op_candidates_factory @@ -420,10 +376,16 @@ class PathSamplingCell(BaseSuperNetModule): elif module.merge_op == 'loose_end': op_candidates_lc = module.ops[-1][-1] # type: ignore assert isinstance(op_candidates_lc, LayerChoice) - op_factory = { # create a factory - name: lambda _, __, ___: copy.deepcopy(op_candidates_lc[name]) - for name in op_candidates_lc.names - } + candidates = op_candidates_lc.candidates + def _copy(_, __, ___, op): + return copy.deepcopy(op) + + if isinstance(candidates, list): + op_factory = [functools.partial(_copy, op=op) for op in candidates] + elif isinstance(candidates, dict): + op_factory = {name: functools.partial(_copy, op=op) for name, op in candidates.items()} + else: + raise ValueError(f'Unsupported type of candidates: {type(candidates)}') if op_factory is not None: return cls( op_factory, diff --git a/test/algo/nas/oneshot/test_supermodules.py b/test/algo/nas/oneshot/test_supermodules.py new file mode 100644 index 000000000..adced554c --- /dev/null +++ b/test/algo/nas/oneshot/test_supermodules.py @@ -0,0 +1,633 @@ +import pytest + +import numpy as np +import torch +import torch.nn as nn + +import nni +from nni.mutable import MutableList, frozen +from nni.nas.nn.pytorch import LayerChoice, ModelSpace, Cell, MutableConv2d, MutableBatchNorm2d, MutableLayerNorm, MutableLinear, MutableMultiheadAttention +from nni.nas.oneshot.pytorch.differentiable import DartsLightningModule +from nni.nas.oneshot.pytorch.strategy import RandomOneShot, DARTS +from nni.nas.oneshot.pytorch.supermodule.base import BaseSuperNetModule +from nni.nas.oneshot.pytorch.supermodule.differentiable import ( + MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax, + DifferentiableMixedRepeat, DifferentiableMixedCell +) +from nni.nas.oneshot.pytorch.supermodule.sampling import ( + MixedOpPathSamplingPolicy, PathSamplingLayer, PathSamplingInput, PathSamplingRepeat, PathSamplingCell +) +from nni.nas.oneshot.pytorch.supermodule.operation import MixedConv2d, NATIVE_MIXED_OPERATIONS +from nni.nas.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput +from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W +from nni.nas.oneshot.pytorch.supermodule._expression_utils import * + +from ut.nas.nn.models import ( + CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory +) + +@pytest.fixture(autouse=True) +def context(): + frozen._ENSURE_FROZEN_STRICT = False + yield + frozen._ENSURE_FROZEN_STRICT = True + + +def test_slice(): + weight = np.ones((3, 7, 24, 23)) + assert S(weight)[:, 1:3, :, 9:13].shape == (3, 2, 24, 4) + assert S(weight)[:, 1:W(3)*2+1, :, 9:13].shape == (3, 6, 24, 4) + assert S(weight)[:, 1:W(3)*2+1].shape == (3, 6, 24, 23) + + # Ellipsis + assert S(weight)[..., 9:13].shape == (3, 7, 24, 4) + assert S(weight)[:2, ..., 1:W(3)+1].shape == (2, 7, 24, 3) + assert S(weight)[..., 1:W(3)*2+1].shape == (3, 7, 24, 6) + assert S(weight)[..., :10, 1:W(3)*2+1].shape == (3, 7, 10, 6) + + # no effect + assert S(weight)[:] is weight + + # list + assert S(weight)[[slice(1), slice(2, 3)]].shape == (2, 7, 24, 23) + assert S(weight)[[slice(1), slice(2, W(2) + 1)], W(2):].shape == (2, 5, 24, 23) + + # weighted + weight = S(weight)[:W({1: 0.5, 2: 0.3, 3: 0.2})] + weight = weight[:, 0, 0, 0] + assert weight[0] == 1 and weight[1] == 0.5 and weight[2] == 0.2 + + weight = np.ones((3, 6, 6)) + value = W({1: 0.5, 3: 0.5}) + weight = S(weight)[:, 3 - value:3 + value, 3 - value:3 + value] + for i in range(0, 6): + for j in range(0, 6): + if 2 <= i <= 3 and 2 <= j <= 3: + assert weight[0, i, j] == 1 + else: + assert weight[1, i, j] == 0.5 + + # weighted + list + value = W({1: 0.5, 3: 0.5}) + weight = np.ones((8, 4)) + weight = S(weight)[[slice(value), slice(4, value + 4)]] + assert weight.sum(1).tolist() == [4, 2, 2, 0, 4, 2, 2, 0] + + with pytest.raises(ValueError, match='one distinct'): + # has to be exactly the same instance, equal is not enough + weight = S(weight)[:W({1: 0.5}), : W({1: 0.5})] + + +def test_valuechoice_utils(): + chosen = {"exp": 3, "add": 1} + vc0 = nni.choice('exp', [3, 4, 6]) * 2 + nni.choice('add', [0, 1]) + + assert vc0.freeze(chosen) == 7 + vc = vc0 + nni.choice('exp', [3, 4, 6]) + assert vc.freeze(chosen) == 10 + + assert list(MutableList([vc0, vc]).simplify().keys()) == ['exp', 'add'] + + assert traverse_all_options(vc) == [9, 10, 12, 13, 18, 19] + weights = dict(traverse_all_options(vc, weights={'exp': [0.5, 0.3, 0.2], 'add': [0.4, 0.6]})) + ans = dict([(9, 0.2), (10, 0.3), (12, 0.12), (13, 0.18), (18, 0.08), (19, 0.12)]) + assert len(weights) == len(ans) + for value, weight in ans.items(): + assert abs(weight - weights[value]) < 1e-6 + + assert evaluate_constant(nni.choice('x', [3, 4, 6]) - nni.choice('x', [3, 4, 6])) == 0 + with pytest.raises(ValueError): + evaluate_constant(nni.choice('x', [3, 4, 6]) - nni.choice('y', [3, 4, 6])) + + assert evaluate_constant(nni.choice('x', [3, 4, 6]) * 2 / nni.choice('x', [3, 4, 6])) == 2 + + +def test_expectation(): + vc = nni.choice('exp', [3, 4, 6]) * 2 + nni.choice('add', [0, 1]) + assert expression_expectation(vc, {'exp': [0.5, 0.3, 0.2], 'add': [0.4, 0.6]}) == 8.4 + + vc = sum([nni.choice(f'e{i}', [0, 1]) for i in range(100)]) + assert expression_expectation(vc, {f'e{i}': [0.5] * 2 for i in range(100)}) == 50 + + vc = nni.choice('a', [1, 2, 3]) * nni.choice('b', [1, 2, 3]) - nni.choice('c', [1, 2, 3]) + probs1 = [0.2, 0.3, 0.5] + probs2 = [0.1, 0.2, 0.7] + probs3 = [0.3, 0.4, 0.3] + expect = sum( + (i * j - k) * p1 * p2 * p3 + for i, p1 in enumerate(probs1, 1) + for j, p2 in enumerate(probs2, 1) + for k, p3 in enumerate(probs3, 1) + ) + assert abs(expression_expectation(vc, {'a': probs1, 'b': probs2, 'c': probs3}) - expect) < 1e-12 + + vc = nni.choice('a', [1, 2, 3]) + 1 + assert expression_expectation(vc, {'a': [0.2, 0.3, 0.5]}) == 3.3 + + +def test_weighted_sum(): + weights = [0.1, 0.2, 0.7] + items = [1, 2, 3] + assert abs(weighted_sum(items, weights) - 2.6) < 1e-6 + + assert weighted_sum(items) == 6 + + with pytest.raises(TypeError, match='Unsupported'): + weighted_sum(['a', 'b', 'c'], weights) + + assert abs(weighted_sum(np.arange(3), weights).item() - 1.6) < 1e-6 + + items = [torch.full((2, 3, 5), i) for i in items] + assert abs(weighted_sum(items, weights).flatten()[0].item() - 2.6) < 1e-6 + + items = [torch.randn(2, 3, i) for i in [1, 2, 3]] + with pytest.raises(ValueError, match=r'does not match.*\n.*torch\.Tensor\(2, 3, 1\)'): + weighted_sum(items, weights) + + items = [(1, 2), (3, 4), (5, 6)] + res = weighted_sum(items, weights) + assert len(res) == 2 and abs(res[0] - 4.2) < 1e-6 and abs(res[1] - 5.2) < 1e-6 + + items = [(1, 2), (3, 4), (5, 6, 7)] + with pytest.raises(ValueError): + weighted_sum(items, weights) + + items = [{"a": i, "b": np.full((2, 3, 5), i)} for i in [1, 2, 3]] + res = weighted_sum(items, weights) + assert res['b'].shape == (2, 3, 5) + assert abs(res['b'][0][0][0] - res['a']) < 1e-6 + assert abs(res['a'] - 2.6) < 1e-6 + + +def test_pathsampling_valuechoice(): + orig_conv = MutableConv2d(3, nni.choice('123', [3, 5, 7]), kernel_size=3) + conv = MixedConv2d.mutate(orig_conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy}) + conv.resample(memo={'123': 5}) + assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 5 + conv.resample(memo={'123': 7}) + assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 7 + assert conv.export({})['123'] in [3, 5, 7] + + +def test_differentiable_valuechoice(): + orig_conv = MutableConv2d(3, nni.choice('456', [3, 5, 7]), + kernel_size=nni.choice('123', [3, 5, 7]), + padding=nni.choice('123', [3, 5, 7]) // 2 + ) + memo = { + '123': nn.Parameter(torch.zeros(3)), + '456': nn.Parameter(torch.zeros(3)), + } + conv = MixedConv2d.mutate(orig_conv, 'dummy', memo, {'mixed_op_sampling': MixedOpDifferentiablePolicy}) + assert conv(torch.zeros((1, 3, 7, 7))).size(2) == 7 + + assert set(conv.export({}).keys()) == {'123', '456'} + + +def test_differentiable_layerchoice_dedup(): + layerchoice1 = LayerChoice([MutableConv2d(3, 3, 3), MutableConv2d(3, 3, 3)], label='a') + layerchoice2 = LayerChoice([MutableConv2d(3, 3, 3), MutableConv2d(3, 3, 3)], label='a') + + memo = {'a': nn.Parameter(torch.zeros(2))} + DifferentiableMixedLayer.mutate(layerchoice1, 'x', memo, {}) + DifferentiableMixedLayer.mutate(layerchoice2, 'x', memo, {}) + assert len(memo) == 1 and 'a' in memo + + +def _mutate_op_path_sampling_policy(operation): + for native_op in NATIVE_MIXED_OPERATIONS: + if native_op.bound_type == type(operation): + mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy}) + break + return mutate_op + + +def _mixed_operation_sampling_sanity_check(operation, memo, *input): + mutate_op = _mutate_op_path_sampling_policy(operation) + mutate_op.resample(memo=memo) + return mutate_op(*input) + + +def _mixed_operation_state_dict_sanity_check(operation, model, memo, *input): + mutate_op = _mutate_op_path_sampling_policy(operation) + mutate_op.resample(memo=memo) + frozen_op = mutate_op.freeze(memo) + return frozen_op(*input), mutate_op(*input) + + +def _mixed_operation_differentiable_sanity_check(operation, *input): + memo = {k: nn.Parameter(torch.zeros(len(v))) for k, v in operation.simplify().items()} + for native_op in NATIVE_MIXED_OPERATIONS: + if native_op.bound_type == type(operation): + mutate_op = native_op.mutate(operation, 'dummy', memo, {'mixed_op_sampling': MixedOpDifferentiablePolicy}) + break + + mutate_op(*input) + mutate_op.export({}) + mutate_op.export_probs({}) + + +def test_mixed_linear(): + linear = MutableLinear(nni.choice('shared', [3, 6, 9]), nni.choice('xx', [2, 4, 8])) + _mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3)) + _mixed_operation_sampling_sanity_check(linear, {'shared': 9}, torch.randn(2, 9)) + _mixed_operation_differentiable_sanity_check(linear, torch.randn(2, 9)) + + linear = MutableLinear(nni.choice('shared', [3, 6, 9]), nni.choice('xx', [2, 4, 8]), bias=False) + _mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3)) + + with pytest.raises(TypeError): + linear = MutableLinear(nni.choice('shared', [3, 6, 9]), nni.choice('xx', [2, 4, 8]), bias=nni.choice('yy', [False, True])) + _mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3)) + + linear = MutableLinear(nni.choice('in_features', [3, 6, 9]), nni.choice('out_features', [2, 4, 8]), bias=True) + kwargs = {'in_features': 6, 'out_features': 4} + out1, out2 = _mixed_operation_state_dict_sanity_check(linear, MutableLinear(**kwargs), kwargs, torch.randn(2, 6)) + assert torch.allclose(out1, out2) + + +def test_mixed_conv2d(): + conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('out', [2, 4, 8]) * 2, 1) + assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'out': 4}, torch.randn(2, 3, 9, 9)).size(1) == 8 + _mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3)) + + # stride + conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('out', [2, 4, 8]), 1, stride=nni.choice('stride', [1, 2])) + assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 2}, torch.randn(2, 3, 10, 10)).size(2) == 5 + assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 1}, torch.randn(2, 3, 10, 10)).size(2) == 10 + with pytest.raises(ValueError, match='must not be mutable'): + _mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 10, 10)) + + # groups, dw conv + conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9])) + assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10]) + + # groups, invalid case + conv = MutableConv2d(nni.choice('in', [9, 6, 3]), nni.choice('in', [9, 6, 3]), 1, groups=9) + with pytest.raises(RuntimeError): + assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10)) + + # groups, differentiable + conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('out', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9])) + _mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3)) + + conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9])) + _mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3)) + + with pytest.raises(ValueError): + conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('groups', [3, 9])) + _mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3)) + + with pytest.raises(RuntimeError): + conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9]) // 3) + _mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 10, 3, 3)) + + # make sure kernel is sliced correctly + conv = MutableConv2d(1, 1, nni.choice('k', [1, 3]), bias=False) + conv = MixedConv2d.mutate(conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy}) + with torch.no_grad(): + conv.weight.zero_() + # only center is 1, must pick center to pass this test + conv.weight[0, 0, 1, 1] = 1 + conv.resample({'k': 1}) + assert conv(torch.ones((1, 1, 3, 3))).sum().item() == 9 + + # only `in_channels`, `out_channels`, `kernel_size`, and `groups` influence state_dict + conv = MutableConv2d( + nni.choice('in_channels', [2, 4, 8]), nni.choice('out_channels', [6, 12, 24]), + kernel_size=nni.choice('kernel_size', [3, 5, 7]), groups=nni.choice('groups', [1, 2]) + ) + kwargs = { + 'in_channels': 8, 'out_channels': 12, + 'kernel_size': 5, 'groups': 2 + } + out1, out2 = _mixed_operation_state_dict_sanity_check(conv, MutableConv2d(**kwargs), kwargs, torch.randn(2, 8, 16, 16)) + assert torch.allclose(out1, out2) + +def test_mixed_batchnorm2d(): + bn = MutableBatchNorm2d(nni.choice('dim', [32, 64])) + + assert _mixed_operation_sampling_sanity_check(bn, {'dim': 32}, torch.randn(2, 32, 3, 3)).size(1) == 32 + assert _mixed_operation_sampling_sanity_check(bn, {'dim': 64}, torch.randn(2, 64, 3, 3)).size(1) == 64 + + _mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3)) + + bn = MutableBatchNorm2d(nni.choice('num_features', [32, 48, 64])) + kwargs = {'num_features': 48} + out1, out2 = _mixed_operation_state_dict_sanity_check(bn, MutableBatchNorm2d(**kwargs), kwargs, torch.randn(2, 48, 3, 3)) + assert torch.allclose(out1, out2) + +def test_mixed_layernorm(): + ln = MutableLayerNorm(nni.choice('normalized_shape', [32, 64]), elementwise_affine=True) + + assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 32}, torch.randn(2, 16, 32)).size(-1) == 32 + assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 64}, torch.randn(2, 16, 64)).size(-1) == 64 + + _mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 16, 64)) + + import itertools + ln = MutableLayerNorm(nni.choice('normalized_shape', list(itertools.product([16, 32, 64], [8, 16])))) + + assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (16, 8)}, torch.randn(2, 16, 8)).shape[-2:]) == [16, 8] + assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (64, 16)}, torch.randn(2, 64, 16)).shape[-2:]) == [64, 16] + + _mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 64, 16)) + + ln = MutableLayerNorm(nni.choice('normalized_shape', [32, 48, 64])) + kwargs = {'normalized_shape': 48} + out1, out2 = _mixed_operation_state_dict_sanity_check(ln, MutableLayerNorm(**kwargs), kwargs, torch.randn(2, 8, 48)) + assert torch.allclose(out1, out2) + +def test_mixed_mhattn(): + mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 4) + + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4}, + torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4 + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8}, + torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))[0].size(-1) == 8 + + _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8)) + + mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), nni.choice('heads', [2, 3, 4])) + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 2}, + torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4 + with pytest.raises(AssertionError, match='divisible'): + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 3}, + torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4 + + mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 4, kdim=nni.choice('kdim', [5, 7])) + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7}, + torch.randn(7, 2, 4), torch.randn(7, 2, 7), torch.randn(7, 2, 4))[0].size(-1) == 4 + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 5}, + torch.randn(7, 2, 8), torch.randn(7, 2, 5), torch.randn(7, 2, 8))[0].size(-1) == 8 + + mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 4, vdim=nni.choice('vdim', [5, 8])) + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'vdim': 8}, + torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 8))[0].size(-1) == 4 + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'vdim': 5}, + torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 5))[0].size(-1) == 8 + + _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(5, 3, 8), torch.randn(5, 3, 8), torch.randn(5, 3, 8)) + + mhattn = MutableMultiheadAttention(embed_dim=nni.choice('embed_dim', [4, 8, 16]), num_heads=nni.choice('num_heads', [1, 2, 4]), + kdim=nni.choice('kdim', [4, 8, 16]), vdim=nni.choice('vdim', [4, 8, 16])) + kwargs = {'embed_dim': 16, 'num_heads': 2, 'kdim': 4, 'vdim': 8} + (out1, _), (out2, _) = _mixed_operation_state_dict_sanity_check(mhattn, MutableMultiheadAttention(**kwargs), kwargs, torch.randn(7, 2, 16), torch.randn(7, 2, 4), torch.randn(7, 2, 8)) + assert torch.allclose(out1, out2) + +@pytest.mark.skipif(torch.__version__.startswith('1.7'), reason='batch_first is not supported for legacy PyTorch') +def test_mixed_mhattn_batch_first(): + # batch_first is not supported for legacy pytorch versions + # mark 1.7 because 1.7 is used on legacy pipeline + + mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 2, kdim=(nni.choice('kdim', [3, 7])), vdim=nni.choice('vdim', [5, 8]), + bias=False, add_bias_kv=True, batch_first=True) + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7, 'vdim': 8}, + torch.randn(2, 7, 4), torch.randn(2, 7, 7), torch.randn(2, 7, 8))[0].size(-1) == 4 + assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 3, 'vdim': 5}, + torch.randn(2, 7, 8), torch.randn(2, 7, 3), torch.randn(2, 7, 5))[0].size(-1) == 8 + + _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(1, 7, 8), torch.randn(1, 7, 7), torch.randn(1, 7, 8)) + + +def test_pathsampling_layer_input(): + op = PathSamplingLayer({'a': MutableLinear(2, 3, bias=False), 'b': MutableLinear(2, 3, bias=True)}, label='ccc') + with pytest.raises(RuntimeError, match='sample'): + op(torch.randn(4, 2)) + + op.resample({}) + assert op(torch.randn(4, 2)).size(-1) == 3 + assert op.simplify()['ccc'].values == ['a', 'b'] + assert op.export({})['ccc'] in ['a', 'b'] + + input = PathSamplingInput(5, 2, 'concat', 'ddd') + sample = input.resample({}) + assert 'ddd' in sample + assert len(sample['ddd']) == 2 + assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 4 + assert len(input.export({})['ddd']) == 2 + + +def test_differentiable_layer_input(): + op = DifferentiableMixedLayer({'a': MutableLinear(2, 3, bias=False), 'b': MutableLinear(2, 3, bias=True)}, nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee') + assert op(torch.randn(4, 2)).size(-1) == 3 + assert op.export({})['eee'] in ['a', 'b'] + probs = op.export_probs({}) + assert len(probs) == 1 + assert len(probs['eee']) == 2 + assert abs(probs['eee']['a'] + probs['eee']['b'] - 1) < 1e-4 + assert len(list(op.parameters())) == 4 + assert len(list(op.arch_parameters())) == 1 + + with pytest.raises(ValueError): + op = DifferentiableMixedLayer({'a': MutableLinear(2, 3), 'b': MutableLinear(2, 4)}, nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee') + op(torch.randn(4, 2)) + + input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd') + assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2 + assert len(input.export({})['ddd']) == 2 + assert len(input.export_probs({})) == 1 + assert len(input.export_probs({})['ddd']) == 5 + assert 3 in input.export_probs({})['ddd'] + + +def test_proxyless_layer_input(): + op = ProxylessMixedLayer({'a': MutableLinear(2, 3, bias=False), 'b': MutableLinear(2, 3, bias=True)}, nn.Parameter(torch.randn(2)), + nn.Softmax(-1), 'eee') + assert op.resample({})['eee'] in ['a', 'b'] + assert op(torch.randn(4, 2)).size(-1) == 3 + assert op.export({})['eee'] in ['a', 'b'] + assert len(list(op.parameters())) == 4 + assert len(list(op.arch_parameters())) == 1 + + input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd') + assert all(x in list(range(5)) for x in input.resample({})['ddd']) + assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2]) + exported = input.export({})['ddd'] + assert len(exported) == 2 and all(e in list(range(5)) for e in exported) + + +def test_pathsampling_repeat(): + op = PathSamplingRepeat([MutableLinear(16, 16), MutableLinear(16, 8), MutableLinear(8, 4)], nni.choice('ccc', [1, 2, 3])) + sample = op.resample({}) + assert sample['ccc'] in [1, 2, 3] + for i in range(1, 4): + op.resample({'ccc': i}) + out = op(torch.randn(2, 16)) + assert out.shape[1] == [16, 8, 4][i - 1] + + op = PathSamplingRepeat([MutableLinear(i + 1, i + 2) for i in range(7)], 2 * nni.choice('ddd', [1, 2, 3]) + 1) + sample = op.resample({}) + assert sample['ddd'] in [1, 2, 3] + for i in range(1, 4): + op.resample({'ddd': i}) + out = op(torch.randn(2, 1)) + assert out.shape[1] == (2 * i + 1) + 1 + + +def test_differentiable_repeat(): + op = DifferentiableMixedRepeat( + [MutableLinear(8 if i == 0 else 16, 16) for i in range(4)], + nni.choice('ccc', [0, 1]) * 2 + 1, + GumbelSoftmax(-1), + {'ccc': nn.Parameter(torch.randn(2))}, + ) + op.resample({}) + assert op(torch.randn(2, 8)).size() == torch.Size([2, 16]) + sample = op.export({}) + assert 'ccc' in sample and sample['ccc'] in [0, 1] + assert sorted(op.export_probs({})['ccc'].keys()) == [0, 1] + + class TupleModule(nn.Module): + def __init__(self, num): + super().__init__() + self.num = num + + def forward(self, *args, **kwargs): + return torch.full((2, 3), self.num), torch.full((3, 5), self.num), {'a': 7, 'b': [self.num] * 11} + + class CustomSoftmax(nn.Softmax): + def forward(self, *args, **kwargs): + return [0.3, 0.3, 0.4] + + op = DifferentiableMixedRepeat( + [TupleModule(i + 1) for i in range(4)], + nni.choice('ccc', [1, 2, 4]), + CustomSoftmax(), + {'ccc': nn.Parameter(torch.randn(3))}, + ) + op.resample({}) + res = op(None) + assert len(res) == 3 + assert res[0].shape == (2, 3) and res[0][0][0].item() == 2.5 + assert res[2]['a'] == 7 + assert len(res[2]['b']) == 11 and res[2]['b'][-1] == 2.5 + + +def test_pathsampling_cell(): + for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]: + model = cell_cls() + strategy = RandomOneShot() + model = strategy.mutate_model(model) + nas_modules = [m for m in model.modules() if isinstance(m, BaseSuperNetModule)] + result = {} + for module in nas_modules: + result.update(module.resample(memo=result)) + assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2 + result = {} + for module in nas_modules: + result.update(module.export(memo=result)) + assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2 + + if cell_cls in [CellLooseEnd, CellOpFactory]: + assert isinstance(model.cell, PathSamplingCell) + else: + assert not isinstance(model.cell, PathSamplingCell) + + inputs = { + CellSimple: (torch.randn(2, 16), torch.randn(2, 16)), + CellDefaultArgs: (torch.randn(2, 16),), + CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)), + CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)), + CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)), + }[cell_cls] + + output = model(*inputs) + if cell_cls == CellCustomProcessor: + assert isinstance(output, tuple) and len(output) == 2 and \ + output[1].shape == torch.Size([2, 16 * model.cell.num_nodes]) + else: + # no loose-end support for now + assert output.shape == torch.Size([2, 16 * model.cell.num_nodes]) + + +def test_differentiable_cell(): + for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]: + model = cell_cls() + strategy = DARTS() + model = strategy.mutate_model(model) + nas_modules = [m for m in model.modules() if isinstance(m, BaseSuperNetModule)] + result = {} + for module in nas_modules: + result.update(module.export(memo=result)) + assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2 + for k, v in result.items(): + if 'input' in k: + assert isinstance(v, list) and len(v) == 1 + + result_prob = {} + for module in nas_modules: + result_prob.update(module.export_probs(memo=result_prob)) + + ctrl_params = [] + for m in nas_modules: + ctrl_params += list(m.arch_parameters()) + if cell_cls in [CellLooseEnd, CellOpFactory]: + assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2 + assert len(result_prob) == len(ctrl_params) # len(op_names) == 2 + for v in result_prob.values(): + assert len(v) == 2 + assert isinstance(model.cell, DifferentiableMixedCell) + else: + assert not isinstance(model.cell, DifferentiableMixedCell) + + inputs = { + CellSimple: (torch.randn(2, 16), torch.randn(2, 16)), + CellDefaultArgs: (torch.randn(2, 16),), + CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)), + CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)), + CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)), + }[cell_cls] + + output = model(*inputs) + if cell_cls == CellCustomProcessor: + assert isinstance(output, tuple) and len(output) == 2 and \ + output[1].shape == torch.Size([2, 16 * model.cell.num_nodes]) + else: + # no loose-end support for now + assert output.shape == torch.Size([2, 16 * model.cell.num_nodes]) + + +def test_memo_sharing(): + class TestModelSpace(ModelSpace): + def __init__(self): + super().__init__() + self.linear1 = Cell( + [nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], + num_nodes=3, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end', + label='cell' + ) + self.linear2 = Cell( + [nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], + num_nodes=3, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end', + label='cell' + ) + + strategy = DARTS() + model = strategy.mutate_model(TestModelSpace()) + assert model.linear1._arch_alpha['cell/2_0'] is model.linear2._arch_alpha['cell/2_0'] + + +def test_parameters(): + class Model(ModelSpace): + def __init__(self): + super().__init__() + self.op = DifferentiableMixedLayer( + { + 'a': MutableLinear(2, 3, bias=False), + 'b': MutableLinear(2, 3, bias=True) + }, + nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'abc' + ) + + def forward(self, x): + return self.op(x) + + model = Model() + assert len(list(model.parameters())) == 4 + assert len(list(model.op.arch_parameters())) == 1 + + optimizer = torch.optim.SGD(model.parameters(), 0.1) + assert len(DartsLightningModule(model).arch_parameters()) == 1 + optimizer = DartsLightningModule(model).postprocess_weight_optimizers(optimizer) + assert len(optimizer.param_groups[0]['params']) == 3