зеркало из https://github.com/microsoft/nni.git
NAS oneshot (stage 2) - Supernet modules (#5372)
This commit is contained in:
Родитель
4bd3f33a3a
Коммит
d40f408a0a
|
@ -9,59 +9,12 @@ from typing import Any, Dict
|
|||
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.common.hpo_utils import ParameterSpec
|
||||
from nni.nas.nn.pytorch import MutableModule
|
||||
|
||||
__all__ = ['BaseSuperNetModule', 'sub_state_dict']
|
||||
__all__ = ['BaseSuperNetModule']
|
||||
|
||||
|
||||
def sub_state_dict(module: Any, destination: Any=None, prefix: str='', keep_vars: bool=False) -> Dict[str, Any]:
|
||||
"""Returns a dictionary containing a whole state of the BaseSuperNetModule.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are
|
||||
included. Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arch : dict[str, Any]
|
||||
subnet architecture dict.
|
||||
destination (dict, optional):
|
||||
If provided, the state of module will be updated into the dict
|
||||
and the same object is returned. Otherwise, an ``OrderedDict``
|
||||
will be created and returned. Default: ``None``.
|
||||
prefix (str, optional):
|
||||
a prefix added to parameter and buffer names to compose the keys in state_dict.
|
||||
Default: ``''``.
|
||||
keep_vars (bool, optional):
|
||||
by default the :class:`~torch.Tensor` s returned in the state dict are
|
||||
detached from autograd. If it's set to ``True``, detaching will not be performed.
|
||||
Default: ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Subnet state dictionary.
|
||||
"""
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
|
||||
local_metadata = dict(version=module._version)
|
||||
if hasattr(destination, "_metadata"):
|
||||
destination._metadata[prefix[:-1]] = local_metadata
|
||||
|
||||
if isinstance(module, BaseSuperNetModule):
|
||||
module._save_to_sub_state_dict(destination, prefix, keep_vars)
|
||||
else:
|
||||
module._save_to_state_dict(destination, prefix, keep_vars)
|
||||
for name, m in module._modules.items():
|
||||
if m is not None:
|
||||
sub_state_dict(m, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
|
||||
|
||||
return destination
|
||||
|
||||
|
||||
class BaseSuperNetModule(nn.Module):
|
||||
class BaseSuperNetModule(MutableModule):
|
||||
"""
|
||||
Mutated module in super-net.
|
||||
Usually, the feed-forward of the module itself is undefined.
|
||||
|
@ -116,17 +69,6 @@ class BaseSuperNetModule(nn.Module):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def search_space_spec(self) -> dict[str, ParameterSpec]:
|
||||
"""
|
||||
Space specification (sample points).
|
||||
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
|
||||
|
||||
For example: ::
|
||||
|
||||
{"layer1": ParameterSpec(values=["conv", "pool"])}
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def mutate(cls, module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> \
|
||||
'BaseSuperNetModule' | bool | tuple['BaseSuperNetModule', bool]:
|
||||
|
@ -153,22 +95,3 @@ class BaseSuperNetModule(nn.Module):
|
|||
See :class:`BaseOneShotLightningModule <nni.retiarii.oneshot.pytorch.base_lightning.BaseOneShotLightningModule>` for details.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars):
|
||||
"""Save the params and buffers of the current module to state dict."""
|
||||
for name, value in itertools.chain(self._parameters.items(), self._buffers.items()): # direct children
|
||||
if value is None or name in self._non_persistent_buffers_set:
|
||||
# it won't appear in state dict
|
||||
continue
|
||||
destination[prefix + name] = value if keep_vars else value.detach()
|
||||
|
||||
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
|
||||
"""Save the sub-module to state dict."""
|
||||
for name, module in self._modules.items():
|
||||
if module is not None:
|
||||
sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
|
||||
|
||||
def _save_to_sub_state_dict(self, destination, prefix, keep_vars):
|
||||
"""Save to state dict."""
|
||||
self._save_param_buff_to_state_dict(destination, prefix, keep_vars)
|
||||
self._save_module_to_state_dict(destination, prefix, keep_vars)
|
||||
|
|
|
@ -13,15 +13,14 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nni.common.hpo_utils import ParameterSpec
|
||||
from nni.nas.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
|
||||
from nni.nas.nn.pytorch.choice import ValueChoiceX
|
||||
from nni.mutable import MutableExpression, Mutable, Categorical
|
||||
from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat
|
||||
from nni.nas.nn.pytorch.cell import preprocess_cell_inputs
|
||||
|
||||
from .base import BaseSuperNetModule
|
||||
from .operation import MixedOperation, MixedOperationSamplingPolicy
|
||||
from .sampling import PathSamplingCell
|
||||
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, weighted_sum
|
||||
from ._expression_utils import traverse_all_options, weighted_sum
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -46,7 +45,7 @@ class GumbelSoftmax(nn.Softmax):
|
|||
return F.gumbel_softmax(inputs, tau=self.tau, hard=self.hard, dim=self.dim)
|
||||
|
||||
|
||||
class DifferentiableMixedLayer(BaseSuperNetModule):
|
||||
class DifferentiableMixedLayer(LayerChoice, BaseSuperNetModule):
|
||||
"""
|
||||
Mixed layer, in which fprop is decided by a weighted sum of several layers.
|
||||
Proposed in `DARTS: Differentiable Architecture Search <https://arxiv.org/abs/1806.09055>`__.
|
||||
|
@ -66,31 +65,18 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
|
|||
Customizable softmax function. Usually ``nn.Softmax(-1)``.
|
||||
label : str
|
||||
Name of the choice.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
op_names : str
|
||||
Operator names.
|
||||
label : str
|
||||
Name of the choice.
|
||||
"""
|
||||
|
||||
_arch_parameter_names: list[str] = ['_arch_alpha']
|
||||
|
||||
def __init__(self,
|
||||
paths: list[tuple[str, nn.Module]],
|
||||
paths: list[nn.Module] | dict[str, nn.Module],
|
||||
alpha: torch.Tensor,
|
||||
softmax: nn.Module,
|
||||
label: str):
|
||||
super().__init__()
|
||||
self.op_names = []
|
||||
super().__init__(paths, label=label)
|
||||
if len(alpha) != len(paths):
|
||||
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({len(paths)}).')
|
||||
for name, module in paths:
|
||||
self.add_module(name, module)
|
||||
self.op_names.append(name)
|
||||
assert self.op_names, 'There has to be at least one op to choose from.'
|
||||
self.label = label
|
||||
self._arch_alpha = alpha
|
||||
self._softmax = softmax
|
||||
|
||||
|
@ -102,62 +88,41 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
|
|||
"""Choose the operator with the maximum logit."""
|
||||
if self.label in memo:
|
||||
return {} # nothing new to export
|
||||
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
|
||||
return {self.label: self.names[int(torch.argmax(self._arch_alpha).item())]}
|
||||
|
||||
def export_probs(self, memo):
|
||||
if any(k.startswith(self.label + '/') for k in memo):
|
||||
return {} # nothing new
|
||||
if self.label in memo:
|
||||
return {}
|
||||
weights = self._softmax(self._arch_alpha).cpu().tolist()
|
||||
ret = {f'{self.label}/{name}': value for name, value in zip(self.op_names, weights)}
|
||||
return ret
|
||||
|
||||
def search_space_spec(self):
|
||||
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
|
||||
True, size=len(self.op_names))}
|
||||
return {self.label: dict(zip(self.names, weights))}
|
||||
|
||||
@classmethod
|
||||
def mutate(cls, module, name, memo, mutate_kwargs):
|
||||
if isinstance(module, LayerChoice):
|
||||
size = len(module)
|
||||
if module.label in memo:
|
||||
if type(module) is LayerChoice: # must be exactly LayerChoice
|
||||
if module.label not in memo:
|
||||
raise KeyError(f'LayerChoice {module.label} not found in memo.')
|
||||
alpha = memo[module.label]
|
||||
if len(alpha) != size:
|
||||
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
|
||||
else:
|
||||
alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
|
||||
memo[module.label] = alpha
|
||||
|
||||
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
|
||||
return cls(list(module.named_children()), alpha, softmax, module.label)
|
||||
return cls(module.candidates, alpha, softmax, module.label)
|
||||
|
||||
def reduction(self, items: list[Any], weights: list[float]) -> Any:
|
||||
def _reduction(self, items: list[Any], weights: list[float]) -> Any:
|
||||
"""Override this for customized reduction."""
|
||||
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
|
||||
return weighted_sum(items, weights)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""The forward of mixed layer accepts same arguments as its sub-layer."""
|
||||
all_op_results = [getattr(self, op)(*args, **kwargs) for op in self.op_names]
|
||||
return self.reduction(all_op_results, self._softmax(self._arch_alpha))
|
||||
all_op_results = [self[op](*args, **kwargs) for op in self.names]
|
||||
return self._reduction(all_op_results, self._softmax(self._arch_alpha))
|
||||
|
||||
def parameters(self, *args, **kwargs):
|
||||
"""Parameters excluding architecture parameters."""
|
||||
for _, p in self.named_parameters(*args, **kwargs):
|
||||
def arch_parameters(self):
|
||||
"""Iterate over architecture parameters. Not recursive."""
|
||||
for name, p in self.named_parameters():
|
||||
if any(name == par_name for par_name in self._arch_parameter_names):
|
||||
yield p
|
||||
|
||||
def named_parameters(self, *args, **kwargs):
|
||||
"""Named parameters excluding architecture parameters."""
|
||||
arch = kwargs.pop('arch', False)
|
||||
for name, p in super().named_parameters(*args, **kwargs):
|
||||
if any(name == par_name for par_name in self._arch_parameter_names):
|
||||
if arch:
|
||||
yield name, p
|
||||
else:
|
||||
if not arch:
|
||||
yield name, p
|
||||
|
||||
|
||||
class DifferentiableMixedInput(BaseSuperNetModule):
|
||||
class DifferentiableMixedInput(InputChoice, BaseSuperNetModule):
|
||||
"""
|
||||
Mixed input. Forward returns a weighted sum of candidates.
|
||||
Implementation is very similar to :class:`DifferentiableMixedLayer`.
|
||||
|
@ -174,11 +139,6 @@ class DifferentiableMixedInput(BaseSuperNetModule):
|
|||
Customizable softmax function. Usually ``nn.Softmax(-1)``.
|
||||
label : str
|
||||
Name of the choice.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
label : str
|
||||
Name of the choice.
|
||||
"""
|
||||
|
||||
_arch_parameter_names: list[str] = ['_arch_alpha']
|
||||
|
@ -189,18 +149,14 @@ class DifferentiableMixedInput(BaseSuperNetModule):
|
|||
alpha: torch.Tensor,
|
||||
softmax: nn.Module,
|
||||
label: str):
|
||||
super().__init__()
|
||||
self.n_candidates = n_candidates
|
||||
if len(alpha) != n_candidates:
|
||||
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({n_candidates}).')
|
||||
if n_chosen is None:
|
||||
warnings.warn('Differentiable architecture search does not support choosing multiple inputs. Assuming one.',
|
||||
RuntimeWarning)
|
||||
self.n_chosen = 1
|
||||
self.n_chosen = n_chosen
|
||||
self.label = label
|
||||
n_chosen = 1
|
||||
super().__init__(n_candidates, n_chosen=n_chosen, label=label)
|
||||
if len(alpha) != n_candidates:
|
||||
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({n_candidates}).')
|
||||
self._softmax = softmax
|
||||
|
||||
self._arch_alpha = alpha
|
||||
|
||||
def resample(self, memo):
|
||||
|
@ -212,68 +168,44 @@ class DifferentiableMixedInput(BaseSuperNetModule):
|
|||
if self.label in memo:
|
||||
return {} # nothing new to export
|
||||
chosen = sorted(torch.argsort(-self._arch_alpha).cpu().numpy().tolist()[:self.n_chosen])
|
||||
if len(chosen) == 1:
|
||||
chosen = chosen[0]
|
||||
return {self.label: chosen}
|
||||
|
||||
def export_probs(self, memo):
|
||||
if any(k.startswith(self.label + '/') for k in memo):
|
||||
return {} # nothing new
|
||||
if self.label in memo:
|
||||
return {}
|
||||
weights = self._softmax(self._arch_alpha).cpu().tolist()
|
||||
ret = {f'{self.label}/{index}': value for index, value in enumerate(weights)}
|
||||
return ret
|
||||
|
||||
def search_space_spec(self):
|
||||
return {
|
||||
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
|
||||
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
|
||||
}
|
||||
return {self.label: dict(enumerate(weights))}
|
||||
|
||||
@classmethod
|
||||
def mutate(cls, module, name, memo, mutate_kwargs):
|
||||
if isinstance(module, InputChoice):
|
||||
if type(module) == InputChoice: # must be exactly InputChoice
|
||||
module = cast(InputChoice, module)
|
||||
if module.reduction not in ['sum', 'mean']:
|
||||
raise ValueError('Only input choice of sum/mean reduction is supported.')
|
||||
size = module.n_candidates
|
||||
if module.label in memo:
|
||||
if module.label not in memo:
|
||||
raise KeyError(f'InputChoice {module.label} not found in memo.')
|
||||
alpha = memo[module.label]
|
||||
if len(alpha) != size:
|
||||
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
|
||||
else:
|
||||
alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
|
||||
memo[module.label] = alpha
|
||||
|
||||
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
|
||||
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
|
||||
|
||||
def reduction(self, items: list[Any], weights: list[float]) -> Any:
|
||||
def _reduction(self, items: list[Any], weights: list[float]) -> Any:
|
||||
"""Override this for customized reduction."""
|
||||
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
|
||||
return weighted_sum(items, weights)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward takes a list of input candidates."""
|
||||
return self.reduction(inputs, self._softmax(self._arch_alpha))
|
||||
return self._reduction(inputs, self._softmax(self._arch_alpha))
|
||||
|
||||
def parameters(self, *args, **kwargs):
|
||||
"""Parameters excluding architecture parameters."""
|
||||
for _, p in self.named_parameters(*args, **kwargs):
|
||||
yield p
|
||||
|
||||
def named_parameters(self, *args, **kwargs):
|
||||
"""Named parameters excluding architecture parameters."""
|
||||
arch = kwargs.pop('arch', False)
|
||||
for name, p in super().named_parameters(*args, **kwargs):
|
||||
def arch_parameters(self):
|
||||
"""Iterate over architecture parameters. Not recursive."""
|
||||
for name, p in self.named_parameters():
|
||||
if any(name == par_name for par_name in self._arch_parameter_names):
|
||||
if arch:
|
||||
yield name, p
|
||||
else:
|
||||
if not arch:
|
||||
yield name, p
|
||||
yield p
|
||||
|
||||
|
||||
class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
|
||||
"""Implementes the differentiable sampling in mixed operation.
|
||||
"""Implements the differentiable sampling in mixed operation.
|
||||
|
||||
One mixed operation can have multiple value choices in its arguments.
|
||||
Thus the ``_arch_alpha`` here is a parameter dict, and ``named_parameters``
|
||||
|
@ -293,36 +225,21 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
|
|||
def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
|
||||
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
|
||||
operation._arch_alpha = nn.ParameterDict()
|
||||
for name, spec in operation.search_space_spec().items():
|
||||
if name in memo:
|
||||
alpha = memo[name]
|
||||
if len(alpha) != spec.size:
|
||||
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
|
||||
else:
|
||||
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
|
||||
memo[name] = alpha
|
||||
operation._arch_alpha[name] = alpha
|
||||
for name in operation.simplify():
|
||||
if name not in memo:
|
||||
raise KeyError(f'Argument {name} not found in memo.')
|
||||
operation._arch_alpha[str(name)] = memo[name]
|
||||
|
||||
operation.parameters = functools.partial(self.parameters, module=operation) # bind self
|
||||
operation.named_parameters = functools.partial(self.named_parameters, module=operation)
|
||||
operation.arch_parameters = functools.partial(self.arch_parameters, module=operation)
|
||||
|
||||
operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
|
||||
|
||||
@staticmethod
|
||||
def parameters(module, *args, **kwargs):
|
||||
for _, p in module.named_parameters(*args, **kwargs):
|
||||
yield p
|
||||
|
||||
@staticmethod
|
||||
def named_parameters(module, *args, **kwargs):
|
||||
arch = kwargs.pop('arch', False)
|
||||
for name, p in super(module.__class__, module).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call
|
||||
def arch_parameters(module):
|
||||
"""Iterate over architecture parameters. Not recursive."""
|
||||
for name, p in module.named_parameters():
|
||||
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
|
||||
if arch:
|
||||
yield name, p
|
||||
else:
|
||||
if not arch:
|
||||
yield name, p
|
||||
yield p
|
||||
|
||||
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Differentiable. Do nothing in resample."""
|
||||
|
@ -331,21 +248,21 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
|
|||
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Export is argmax for each leaf value choice."""
|
||||
result = {}
|
||||
for name, spec in operation.search_space_spec().items():
|
||||
for name, spec in operation.simplify().items():
|
||||
if name in memo:
|
||||
continue
|
||||
chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item())
|
||||
result[name] = spec.values[chosen_index]
|
||||
result[name] = cast(Categorical, spec).values[chosen_index]
|
||||
return result
|
||||
|
||||
def export_probs(self, operation: MixedOperation, memo: dict[str, Any]):
|
||||
"""Export the weight for every leaf value choice."""
|
||||
ret = {}
|
||||
for name, spec in operation.search_space_spec().items():
|
||||
if any(k.startswith(name + '/') for k in memo):
|
||||
for name, spec in operation.simplify().items():
|
||||
if name in memo:
|
||||
continue
|
||||
weights = operation._softmax(operation._arch_alpha[name]).cpu().tolist() # type: ignore
|
||||
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
|
||||
ret.update({name: dict(zip(cast(Categorical, spec).values, weights))})
|
||||
return ret
|
||||
|
||||
def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
|
||||
|
@ -357,9 +274,9 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
|
|||
return operation.init_arguments[name]
|
||||
|
||||
|
||||
class DifferentiableMixedRepeat(BaseSuperNetModule):
|
||||
class DifferentiableMixedRepeat(Repeat, BaseSuperNetModule):
|
||||
"""
|
||||
Implementaion of Repeat in a differentiable supernet.
|
||||
Implementation of Repeat in a differentiable supernet.
|
||||
Result is a weighted sum of possible prefixes, sliced by possible depths.
|
||||
|
||||
If the output is not a single tensor, it will be summed at every independant dimension.
|
||||
|
@ -368,27 +285,17 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
|
|||
|
||||
_arch_parameter_names: list[str] = ['_arch_alpha']
|
||||
|
||||
depth_choice: MutableExpression[int]
|
||||
|
||||
def __init__(self,
|
||||
blocks: list[nn.Module],
|
||||
depth: ChoiceOf[int],
|
||||
depth: MutableExpression[int],
|
||||
softmax: nn.Module,
|
||||
memo: dict[str, Any]):
|
||||
super().__init__()
|
||||
self.blocks = blocks
|
||||
self.depth = depth
|
||||
alphas: dict[str, Any]):
|
||||
assert isinstance(depth, Mutable)
|
||||
super().__init__(blocks, depth)
|
||||
self._softmax = softmax
|
||||
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
|
||||
self._arch_alpha = nn.ParameterDict()
|
||||
|
||||
for name, spec in self._space_spec.items():
|
||||
if name in memo:
|
||||
alpha = memo[name]
|
||||
if len(alpha) != spec.size:
|
||||
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
|
||||
else:
|
||||
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
|
||||
memo[name] = alpha
|
||||
self._arch_alpha[name] = alpha
|
||||
self._arch_alpha = nn.ParameterDict(alphas)
|
||||
|
||||
def resample(self, memo):
|
||||
"""Do nothing."""
|
||||
|
@ -397,48 +304,43 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
|
|||
def export(self, memo):
|
||||
"""Choose argmax for each leaf value choice."""
|
||||
result = {}
|
||||
for name, spec in self._space_spec.items():
|
||||
for name, spec in self.depth_choice.simplify().items():
|
||||
if name in memo:
|
||||
continue
|
||||
chosen_index = int(torch.argmax(self._arch_alpha[name]).item())
|
||||
result[name] = spec.values[chosen_index]
|
||||
result[name] = cast(Categorical, spec).values[chosen_index]
|
||||
return result
|
||||
|
||||
def export_probs(self, memo):
|
||||
"""Export the weight for every leaf value choice."""
|
||||
ret = {}
|
||||
for name, spec in self.search_space_spec().items():
|
||||
if any(k.startswith(name + '/') for k in memo):
|
||||
for name, spec in self.depth_choice.simplify().items():
|
||||
if name in memo:
|
||||
continue
|
||||
weights = self._softmax(self._arch_alpha[name]).cpu().tolist()
|
||||
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
|
||||
ret.update({name: dict(zip(cast(Categorical, spec).values, weights))})
|
||||
return ret
|
||||
|
||||
def search_space_spec(self):
|
||||
return self._space_spec
|
||||
|
||||
@classmethod
|
||||
def mutate(cls, module, name, memo, mutate_kwargs):
|
||||
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
|
||||
if type(module) == Repeat and isinstance(module.depth_choice, Mutable): # Repeat and depth is mutable
|
||||
# Only interesting when depth is mutable
|
||||
module = cast(Repeat, module)
|
||||
alphas = {}
|
||||
for name in cast(Mutable, module.depth_choice).simplify():
|
||||
if name not in memo:
|
||||
raise KeyError(f'Mutable depth "{name}" not found in memo')
|
||||
alphas[name] = memo[name]
|
||||
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
|
||||
return cls(cast(List[nn.Module], module.blocks), module.depth_choice, softmax, memo)
|
||||
return cls(list(module.blocks), cast(MutableExpression[int], module.depth_choice), softmax, alphas)
|
||||
|
||||
def parameters(self, *args, **kwargs):
|
||||
for _, p in self.named_parameters(*args, **kwargs):
|
||||
def arch_parameters(self):
|
||||
"""Iterate over architecture parameters. Not recursive."""
|
||||
for name, p in self.named_parameters():
|
||||
if any(name.startswith(par_name) for par_name in self._arch_parameter_names):
|
||||
yield p
|
||||
|
||||
def named_parameters(self, *args, **kwargs):
|
||||
arch = kwargs.pop('arch', False)
|
||||
for name, p in super().named_parameters(*args, **kwargs):
|
||||
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
|
||||
if arch:
|
||||
yield name, p
|
||||
else:
|
||||
if not arch:
|
||||
yield name, p
|
||||
|
||||
def reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any:
|
||||
def _reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any:
|
||||
"""Override this for customized reduction."""
|
||||
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
|
||||
return weighted_sum(items, weights)
|
||||
|
@ -447,7 +349,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
|
|||
weights: dict[str, torch.Tensor] = {
|
||||
label: self._softmax(alpha) for label, alpha in self._arch_alpha.items()
|
||||
}
|
||||
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights)))
|
||||
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth_choice, weights=weights)))
|
||||
|
||||
res: list[torch.Tensor] = []
|
||||
weight_list: list[float] = []
|
||||
|
@ -459,7 +361,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
|
|||
res.append(x)
|
||||
depths.append(i)
|
||||
|
||||
return self.reduction(res, weight_list, depths)
|
||||
return self._reduction(res, weight_list, depths)
|
||||
|
||||
|
||||
class DifferentiableMixedCell(PathSamplingCell):
|
||||
|
@ -473,6 +375,8 @@ class DifferentiableMixedCell(PathSamplingCell):
|
|||
# TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code.
|
||||
# Possibly need another refactor here.
|
||||
|
||||
_arch_parameter_names: list[str] = ['_arch_alpha']
|
||||
|
||||
def __init__(
|
||||
self, op_factory, num_nodes, num_ops_per_node,
|
||||
num_predecessors, preprocessor, postprocessor, concat_dim,
|
||||
|
@ -486,10 +390,13 @@ class DifferentiableMixedCell(PathSamplingCell):
|
|||
self._arch_alpha = nn.ParameterDict()
|
||||
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
|
||||
for j in range(i):
|
||||
edge_label = f'{label}/{i}_{j}'
|
||||
edge_label = f'{self.label}/{i}_{j}'
|
||||
# Some parameters still need to be created here inside.
|
||||
# We should avoid conflict with "outside parameters".
|
||||
memo_label = edge_label + '/in_cell'
|
||||
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
|
||||
if edge_label in memo:
|
||||
alpha = memo[edge_label]
|
||||
if memo_label in memo:
|
||||
alpha = memo[memo_label]
|
||||
if len(alpha) != len(op) + 1:
|
||||
if len(alpha) != len(op):
|
||||
raise ValueError(
|
||||
|
@ -504,7 +411,7 @@ class DifferentiableMixedCell(PathSamplingCell):
|
|||
else:
|
||||
# +1 to emulate the input choice.
|
||||
alpha = nn.Parameter(torch.randn(len(op) + 1) * 1E-3)
|
||||
memo[edge_label] = alpha
|
||||
memo[memo_label] = alpha
|
||||
self._arch_alpha[edge_label] = alpha
|
||||
|
||||
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
|
||||
|
@ -517,10 +424,10 @@ class DifferentiableMixedCell(PathSamplingCell):
|
|||
"""When export probability, we follow the structure in arch alpha."""
|
||||
ret = {}
|
||||
for name, parameter in self._arch_alpha.items():
|
||||
if any(k.startswith(name + '/') for k in memo):
|
||||
if name in memo:
|
||||
continue
|
||||
weights = self._softmax(parameter).cpu().tolist()
|
||||
ret.update({f'{name}/{value}': weight for value, weight in zip(self.op_names, weights)})
|
||||
ret.update({name: dict(zip(self.op_names, weights))})
|
||||
return ret
|
||||
|
||||
def export(self, memo):
|
||||
|
@ -565,7 +472,7 @@ class DifferentiableMixedCell(PathSamplingCell):
|
|||
# all_weights could be too short in case ``num_ops_per_node`` is too large.
|
||||
_, j, op_name = all_weights[k % len(all_weights)]
|
||||
exported[f'{self.label}/op_{i}_{k}'] = op_name
|
||||
exported[f'{self.label}/input_{i}_{k}'] = j
|
||||
exported[f'{self.label}/input_{i}_{k}'] = [j]
|
||||
|
||||
return exported
|
||||
|
||||
|
@ -579,10 +486,10 @@ class DifferentiableMixedCell(PathSamplingCell):
|
|||
op_results = torch.stack([op(states[j]) for op in ops[j].values()])
|
||||
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) # (-1, 1, 1, 1, 1, ...)
|
||||
op_weights = self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}'])
|
||||
if len(op_weights) == len(op_results) + 1:
|
||||
if op_weights.size(0) == op_results.size(0) + 1:
|
||||
# concatenate with a zero operation, indicating this path is not chosen at all.
|
||||
op_results = torch.cat((op_results, torch.zeros_like(op_results[:1])), 0)
|
||||
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
|
||||
edge_sum = torch.sum(op_results * op_weights.view(*alpha_shape), 0)
|
||||
current_state.append(edge_sum)
|
||||
|
||||
states.append(sum(current_state)) # type: ignore
|
||||
|
@ -591,16 +498,8 @@ class DifferentiableMixedCell(PathSamplingCell):
|
|||
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
|
||||
return self.postprocessor(this_cell, processed_inputs)
|
||||
|
||||
def parameters(self, *args, **kwargs):
|
||||
for _, p in self.named_parameters(*args, **kwargs):
|
||||
def arch_parameters(self):
|
||||
"""Iterate over architecture parameters. Not recursive."""
|
||||
for name, p in self.named_parameters():
|
||||
if any(name.startswith(par_name) for par_name in self._arch_parameter_names):
|
||||
yield p
|
||||
|
||||
def named_parameters(self, *args, **kwargs):
|
||||
arch = kwargs.pop('arch', False)
|
||||
for name, p in super().named_parameters(*args, **kwargs):
|
||||
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
|
||||
if arch:
|
||||
yield name, p
|
||||
else:
|
||||
if not arch:
|
||||
yield name, p
|
||||
|
|
|
@ -18,13 +18,15 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
import nni.nas.nn.pytorch as nas_nn
|
||||
from nni.common.hpo_utils import ParameterSpec
|
||||
from nni.common.serializer import is_traceable
|
||||
from nni.nas.nn.pytorch.choice import ValueChoiceX
|
||||
from nni.mutable import MutableExpression
|
||||
from nni.nas.nn.pytorch import (
|
||||
ParametrizedModule,
|
||||
MutableConv2d, MutableLinear, MutableBatchNorm2d, MutableLayerNorm, MutableMultiheadAttention
|
||||
)
|
||||
|
||||
from .base import BaseSuperNetModule, sub_state_dict
|
||||
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant
|
||||
from .base import BaseSuperNetModule
|
||||
from ._expression_utils import traverse_all_options, evaluate_constant
|
||||
from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict
|
||||
|
||||
T = TypeVar('T')
|
||||
|
@ -40,7 +42,7 @@ __all__ = [
|
|||
'NATIVE_MIXED_OPERATIONS',
|
||||
]
|
||||
|
||||
_diff_not_compatible_error = 'To be compatible with differentiable one-shot strategy, {} in {} must not be ValueChoice.'
|
||||
_diff_not_compatible_error = 'To be compatible with differentiable one-shot strategy, {} in {} must not be mutable.'
|
||||
|
||||
|
||||
class MixedOperationSamplingPolicy:
|
||||
|
@ -84,10 +86,10 @@ class MixedOperationSamplingPolicy:
|
|||
|
||||
class MixedOperation(BaseSuperNetModule):
|
||||
"""This is the base class for all mixed operations.
|
||||
It's what you should inherit to support a new operation with ValueChoice.
|
||||
It's what you should inherit to support a new operation with mutable.
|
||||
|
||||
It contains commonly used utilities that will ease the effort to write customized mixed oeprations,
|
||||
i.e., operations with ValueChoice in its arguments.
|
||||
It contains commonly used utilities that will ease the effort to write customized mixed operations,
|
||||
i.e., operations with mutable in its arguments.
|
||||
To customize, please write your own mixed operation, and add the hook into ``mutation_hooks`` parameter when using the strategy.
|
||||
|
||||
By design, for a mixed operation to work in a specific algorithm,
|
||||
|
@ -116,14 +118,14 @@ class MixedOperation(BaseSuperNetModule):
|
|||
|
||||
sampling_policy: MixedOperationSamplingPolicy
|
||||
|
||||
def super_init_argument(self, name: str, value_choice: ValueChoiceX) -> Any:
|
||||
def super_init_argument(self, name: str, value_choice: MutableExpression) -> Any:
|
||||
"""Get the initialization argument when constructing super-kernel, i.e., calling ``super().__init__()``.
|
||||
This is often related to specific operator, rather than algo.
|
||||
|
||||
For example::
|
||||
|
||||
def super_init_argument(self, name, value_choice):
|
||||
return max(value_choice.candidates)
|
||||
return max(value_choice.grid())
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -138,8 +140,8 @@ class MixedOperation(BaseSuperNetModule):
|
|||
|
||||
def __init__(self, module_kwargs: dict[str, Any]) -> None:
|
||||
# Concerned arguments
|
||||
self.mutable_arguments: dict[str, ValueChoiceX] = {}
|
||||
# Useful when retrieving arguments without ValueChoice
|
||||
self.mutable_arguments: dict[str, MutableExpression] = {}
|
||||
# Useful when retrieving arguments without mutable
|
||||
self.init_arguments: dict[str, Any] = {**module_kwargs}
|
||||
self._fill_missing_init_arguments()
|
||||
|
||||
|
@ -147,21 +149,37 @@ class MixedOperation(BaseSuperNetModule):
|
|||
super_init_kwargs = {}
|
||||
|
||||
for key, value in module_kwargs.items():
|
||||
if isinstance(value, ValueChoiceX):
|
||||
if isinstance(value, MutableExpression):
|
||||
if key not in self.argument_list:
|
||||
raise TypeError(f'Unsupported value choice on argument of {self.bound_type}: {key}')
|
||||
raise TypeError(f'Unsupported mutable argument of "{self.bound_type}": {key}')
|
||||
super_init_kwargs[key] = self.super_init_argument(key, value)
|
||||
self.mutable_arguments[key] = value
|
||||
else:
|
||||
super_init_kwargs[key] = value
|
||||
|
||||
# get all inner leaf value choices
|
||||
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices(list(self.mutable_arguments.values()))
|
||||
|
||||
super().__init__(**super_init_kwargs)
|
||||
|
||||
for mutable in self.mutable_arguments.values():
|
||||
self.add_mutable(mutable)
|
||||
|
||||
self.__post_init__()
|
||||
|
||||
def freeze(self, sample) -> Any:
|
||||
"""Freeze the mixed operation to a specific operation.
|
||||
Weights will be copied from the mixed operation to the frozen operation.
|
||||
|
||||
The returned operation will be of the ``bound_type``.
|
||||
"""
|
||||
arguments = {**self.init_arguments}
|
||||
for name, mutable in self.mutable_arguments.items():
|
||||
arguments[name] = mutable.freeze(sample)
|
||||
operation = self.bound_type(**arguments)
|
||||
|
||||
# copy weights
|
||||
state_dict = self.freeze_weight(**arguments)
|
||||
operation.load_state_dict(state_dict)
|
||||
return operation
|
||||
|
||||
def resample(self, memo):
|
||||
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
|
||||
return self.sampling_policy.resample(self, memo)
|
||||
|
@ -174,21 +192,12 @@ class MixedOperation(BaseSuperNetModule):
|
|||
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
|
||||
return self.sampling_policy.export(self, memo)
|
||||
|
||||
def search_space_spec(self):
|
||||
return self._space_spec
|
||||
|
||||
@classmethod
|
||||
def mutate(cls, module, name, memo, mutate_kwargs):
|
||||
"""Find value choice in module's arguments and replace the whole module"""
|
||||
has_valuechoice = False
|
||||
if isinstance(module, cls.bound_type) and is_traceable(module):
|
||||
for arg in itertools.chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
|
||||
if isinstance(arg, ValueChoiceX):
|
||||
has_valuechoice = True
|
||||
|
||||
if has_valuechoice:
|
||||
if isinstance(module, cls.bound_type) and isinstance(module, ParametrizedModule):
|
||||
if module.trace_args:
|
||||
raise ValueError('ValueChoice on class arguments cannot appear together with ``trace_args``. '
|
||||
raise ValueError('Mutable on class arguments cannot appear together with ``trace_args``. '
|
||||
'Please enable ``kw_only`` on nni.trace.')
|
||||
|
||||
# save type and kwargs
|
||||
|
@ -232,21 +241,12 @@ class MixedOperation(BaseSuperNetModule):
|
|||
if param.default is not param.empty and param.name not in self.init_arguments:
|
||||
self.init_arguments[param.name] = param.default
|
||||
|
||||
def slice_param(self, **kwargs):
|
||||
def freeze_weight(self, **kwargs):
|
||||
"""Slice the params and buffers for subnet forward and state dict.
|
||||
When there is a `mapping=True` in kwargs, the return result will be wrapped in dict.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars):
|
||||
kwargs = {name: self.forward_argument(name) for name in self.argument_list}
|
||||
params_mapping: dict[str, Any] = self.slice_param(**kwargs)
|
||||
for name, value in itertools.chain(self._parameters.items(), self._buffers.items()): # direct children
|
||||
if value is None or name in self._non_persistent_buffers_set:
|
||||
# it won't appear in state dict
|
||||
continue
|
||||
value = params_mapping.get(name, value)
|
||||
destination[prefix + name] = value if keep_vars else value.detach()
|
||||
The arguments are same as the arguments passed to ``__init__``.
|
||||
"""
|
||||
raise NotImplementedError('freeze_weight is not implemented.')
|
||||
|
||||
|
||||
class MixedLinear(MixedOperation, nn.Linear):
|
||||
|
@ -260,13 +260,13 @@ class MixedLinear(MixedOperation, nn.Linear):
|
|||
Prefix of weight and bias will be sliced.
|
||||
"""
|
||||
|
||||
bound_type = nas_nn.Linear
|
||||
bound_type = MutableLinear
|
||||
argument_list = ['in_features', 'out_features']
|
||||
|
||||
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
|
||||
return max(traverse_all_options(value_choice))
|
||||
def super_init_argument(self, name: str, value_choice: MutableExpression):
|
||||
return max(value_choice.grid())
|
||||
|
||||
def slice_param(self, in_features: int_or_int_dict, out_features: int_or_int_dict, **kwargs) -> Any:
|
||||
def freeze_weight(self, in_features: int_or_int_dict, out_features: int_or_int_dict, **kwargs) -> Any:
|
||||
in_features_ = _W(in_features)
|
||||
out_features_ = _W(out_features)
|
||||
|
||||
|
@ -281,7 +281,7 @@ class MixedLinear(MixedOperation, nn.Linear):
|
|||
out_features: int_or_int_dict,
|
||||
inputs: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
params_mapping = self.slice_param(in_features, out_features)
|
||||
params_mapping = self.freeze_weight(in_features, out_features)
|
||||
weight, bias = [params_mapping.get(name) for name in ['weight', 'bias']]
|
||||
|
||||
return F.linear(inputs, weight, bias)
|
||||
|
@ -321,7 +321,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
|
|||
□ □ □ □ □ □ □ □ □ □
|
||||
"""
|
||||
|
||||
bound_type = nas_nn.Conv2d
|
||||
bound_type = MutableConv2d
|
||||
argument_list = [
|
||||
'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'dilation', 'groups'
|
||||
]
|
||||
|
@ -332,12 +332,12 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
|
|||
return (value, value)
|
||||
return value
|
||||
|
||||
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
|
||||
def super_init_argument(self, name: str, mutable_expr: MutableExpression):
|
||||
if name not in ['in_channels', 'out_channels', 'groups', 'stride', 'kernel_size', 'padding', 'dilation']:
|
||||
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
|
||||
|
||||
if name == ['kernel_size', 'padding']:
|
||||
all_sizes = set(traverse_all_options(value_choice))
|
||||
all_sizes = set(traverse_all_options(mutable_expr))
|
||||
if any(isinstance(sz, tuple) for sz in all_sizes):
|
||||
# maximum kernel should be calculated on every dimension
|
||||
return (
|
||||
|
@ -351,28 +351,37 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
|
|||
if 'in_channels' in self.mutable_arguments:
|
||||
# If the ratio is constant, we don't need to try the maximum groups.
|
||||
try:
|
||||
constant = evaluate_constant(self.mutable_arguments['in_channels'] / value_choice)
|
||||
return max(cast(List[float], traverse_all_options(value_choice))) // int(constant)
|
||||
constant = evaluate_constant(self.mutable_arguments['in_channels'] / mutable_expr)
|
||||
return max(cast(List[float], traverse_all_options(mutable_expr))) // int(constant)
|
||||
except ValueError:
|
||||
warnings.warn(
|
||||
'Both input channels and groups are ValueChoice in a convolution, and their relative ratio is not a constant. '
|
||||
'Both input channels and groups are mutable in a convolution, and their relative ratio is not a constant. '
|
||||
'This can be problematic for most one-shot algorithms. Please check whether this is your intention.',
|
||||
RuntimeWarning
|
||||
)
|
||||
|
||||
# minimum groups, maximum kernel
|
||||
return min(traverse_all_options(value_choice))
|
||||
return min(traverse_all_options(mutable_expr))
|
||||
|
||||
else:
|
||||
return max(traverse_all_options(value_choice))
|
||||
return max(traverse_all_options(mutable_expr))
|
||||
|
||||
def slice_param(self,
|
||||
def freeze_weight(self,
|
||||
in_channels: int_or_int_dict,
|
||||
out_channels: int_or_int_dict,
|
||||
kernel_size: scalar_or_scalar_dict[_int_or_tuple],
|
||||
groups: int_or_int_dict,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
**kwargs) -> Any:
|
||||
rv = self._freeze_weight_impl(in_channels, out_channels, kernel_size, groups)
|
||||
rv.pop('in_channels_per_group', None)
|
||||
return rv
|
||||
|
||||
def _freeze_weight_impl(self,
|
||||
in_channels: int_or_int_dict,
|
||||
out_channels: int_or_int_dict,
|
||||
kernel_size: scalar_or_scalar_dict[_int_or_tuple],
|
||||
groups: int_or_int_dict,
|
||||
**kwargs) -> Any:
|
||||
in_channels_ = _W(in_channels)
|
||||
out_channels_ = _W(out_channels)
|
||||
|
||||
|
@ -386,8 +395,8 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
|
|||
in_channels_per_group = None
|
||||
else:
|
||||
assert 'groups' in self.mutable_arguments
|
||||
err_message = 'For differentiable one-shot strategy, when groups is a ValueChoice, ' \
|
||||
'in_channels and out_channels should also be a ValueChoice. ' \
|
||||
err_message = 'For differentiable one-shot strategy, when groups is a mutable, ' \
|
||||
'in_channels and out_channels should also be a mutable. ' \
|
||||
'Also, the ratios of in_channels divided by groups, and out_channels divided by groups ' \
|
||||
'should be constants.'
|
||||
if 'in_channels' not in self.mutable_arguments or 'out_channels' not in self.mutable_arguments:
|
||||
|
@ -412,7 +421,6 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
|
|||
|
||||
return {'weight': weight, 'bias': bias, 'in_channels_per_group': in_channels_per_group}
|
||||
|
||||
|
||||
def forward_with_args(self,
|
||||
in_channels: int_or_int_dict,
|
||||
out_channels: int_or_int_dict,
|
||||
|
@ -426,7 +434,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
|
|||
if any(isinstance(arg, dict) for arg in [stride, dilation]):
|
||||
raise ValueError(_diff_not_compatible_error.format('stride, dilation', 'Conv2d'))
|
||||
|
||||
params_mapping = self.slice_param(in_channels, out_channels, kernel_size, groups)
|
||||
params_mapping = self._freeze_weight_impl(in_channels, out_channels, kernel_size, groups)
|
||||
weight, bias, in_channels_per_group = [
|
||||
params_mapping.get(name)
|
||||
for name in ['weight', 'bias', 'in_channels_per_group']
|
||||
|
@ -478,13 +486,13 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
|
|||
PyTorch BatchNorm supports a case where momentum can be none, which is not supported here.
|
||||
"""
|
||||
|
||||
bound_type = nas_nn.BatchNorm2d
|
||||
bound_type = MutableBatchNorm2d
|
||||
argument_list = ['num_features', 'eps', 'momentum']
|
||||
|
||||
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
|
||||
return max(traverse_all_options(value_choice))
|
||||
def super_init_argument(self, name: str, mutable_expr: MutableExpression):
|
||||
return max(traverse_all_options(mutable_expr))
|
||||
|
||||
def slice_param(self, num_features: int_or_int_dict, **kwargs) -> Any:
|
||||
def freeze_weight(self, num_features: int_or_int_dict, **kwargs) -> Any:
|
||||
if isinstance(num_features, dict):
|
||||
num_features = self.num_features
|
||||
weight, bias = self.weight, self.bias
|
||||
|
@ -508,7 +516,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
|
|||
if any(isinstance(arg, dict) for arg in [eps, momentum]):
|
||||
raise ValueError(_diff_not_compatible_error.format('eps and momentum', 'BatchNorm2d'))
|
||||
|
||||
params_mapping = self.slice_param(num_features)
|
||||
params_mapping = self.freeze_weight(num_features)
|
||||
weight, bias, running_mean, running_var = [
|
||||
params_mapping.get(name)
|
||||
for name in ['weight', 'bias', 'running_mean', 'running_var']
|
||||
|
@ -547,7 +555,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
|
|||
eps is required to be float.
|
||||
"""
|
||||
|
||||
bound_type = nas_nn.LayerNorm
|
||||
bound_type = MutableLayerNorm
|
||||
argument_list = ['normalized_shape', 'eps']
|
||||
|
||||
@staticmethod
|
||||
|
@ -556,10 +564,10 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
|
|||
return (value, value)
|
||||
return value
|
||||
|
||||
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
|
||||
def super_init_argument(self, name: str, mutable_expr: MutableExpression):
|
||||
if name not in ['normalized_shape']:
|
||||
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
|
||||
all_sizes = set(traverse_all_options(value_choice))
|
||||
all_sizes = set(traverse_all_options(mutable_expr))
|
||||
if any(isinstance(sz, (tuple, list)) for sz in all_sizes):
|
||||
# transpose
|
||||
all_sizes = list(zip(*all_sizes))
|
||||
|
@ -568,7 +576,12 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
|
|||
else:
|
||||
return max(all_sizes)
|
||||
|
||||
def slice_param(self, normalized_shape, **kwargs) -> Any:
|
||||
def freeze_weight(self, normalized_shape, **kwargs) -> Any:
|
||||
rv = self._freeze_weight_impl(normalized_shape)
|
||||
rv.pop('normalized_shape')
|
||||
return rv
|
||||
|
||||
def _freeze_weight_impl(self, normalized_shape, **kwargs) -> Any:
|
||||
if isinstance(normalized_shape, dict):
|
||||
normalized_shape = self.normalized_shape
|
||||
|
||||
|
@ -595,7 +608,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
|
|||
if any(isinstance(arg, dict) for arg in [eps]):
|
||||
raise ValueError(_diff_not_compatible_error.format('eps', 'LayerNorm'))
|
||||
|
||||
params_mapping = self.slice_param(normalized_shape)
|
||||
params_mapping = self._freeze_weight_impl(normalized_shape)
|
||||
weight, bias, normalized_shape = [
|
||||
params_mapping.get(name)
|
||||
for name in ['weight', 'bias', 'normalized_shape']
|
||||
|
@ -633,7 +646,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
|
|||
All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``.
|
||||
"""
|
||||
|
||||
bound_type = nas_nn.MultiheadAttention
|
||||
bound_type = MutableMultiheadAttention
|
||||
argument_list = ['embed_dim', 'num_heads', 'kdim', 'vdim', 'dropout']
|
||||
|
||||
def __post_init__(self):
|
||||
|
@ -671,8 +684,17 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
|
|||
nn.init.xavier_uniform_(self.k_proj_weight)
|
||||
nn.init.xavier_uniform_(self.v_proj_weight)
|
||||
|
||||
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
|
||||
return max(traverse_all_options(value_choice))
|
||||
def super_init_argument(self, name: str, mutable_expr: MutableExpression):
|
||||
return max(traverse_all_options(mutable_expr))
|
||||
|
||||
def freeze_weight(self, embed_dim, kdim, vdim, **kwargs):
|
||||
rv = self._freeze_weight_impl(embed_dim, kdim, vdim, **kwargs)
|
||||
# pop flags and nones, as they won't show in state dict
|
||||
rv.pop('qkv_same_embed_dim')
|
||||
for k in list(rv):
|
||||
if rv[k] is None:
|
||||
rv.pop(k)
|
||||
return rv
|
||||
|
||||
def _to_proj_slice(self, embed_dim: _W) -> list[slice]:
|
||||
# slice three parts, corresponding to q, k, v respectively
|
||||
|
@ -682,7 +704,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
|
|||
slice(self.embed_dim * 2, self.embed_dim * 2 + embed_dim)
|
||||
]
|
||||
|
||||
def slice_param(self, embed_dim, kdim, vdim, **kwargs):
|
||||
def _freeze_weight_impl(self, embed_dim, kdim, vdim, **kwargs):
|
||||
# by default, kdim, vdim can be none
|
||||
if kdim is None:
|
||||
kdim = embed_dim
|
||||
|
@ -723,31 +745,6 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
|
|||
'qkv_same_embed_dim': qkv_same_embed_dim
|
||||
}
|
||||
|
||||
def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars):
|
||||
kwargs = {name: self.forward_argument(name) for name in self.argument_list}
|
||||
params_mapping = self.slice_param(**kwargs, mapping=True)
|
||||
for name, value in itertools.chain(self._parameters.items(), self._buffers.items()):
|
||||
if value is None or name in self._non_persistent_buffers_set:
|
||||
continue
|
||||
value = params_mapping.get(name, value)
|
||||
destination[prefix + name] = value if keep_vars else value.detach()
|
||||
|
||||
# params of out_proj is handled in ``MixedMultiHeadAttention`` rather than
|
||||
# ``NonDynamicallyQuantizableLinear`` sub-module. We also convert it to state dict here.
|
||||
for name in ["out_proj.weight", "out_proj.bias"]:
|
||||
value = params_mapping.get(name, None)
|
||||
if value is None:
|
||||
continue
|
||||
destination[prefix + name] = value if keep_vars else value.detach()
|
||||
|
||||
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
|
||||
for name, module in self._modules.items():
|
||||
# the weights of ``NonDynamicallyQuantizableLinear`` has been handled in `_save_param_buff_to_state_dict`.
|
||||
if isinstance(module, nn.modules.linear.NonDynamicallyQuantizableLinear):
|
||||
continue
|
||||
if module is not None:
|
||||
sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
|
||||
|
||||
def forward_with_args(
|
||||
self,
|
||||
embed_dim: int_or_int_dict, num_heads: int,
|
||||
|
@ -770,7 +767,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
|
|||
else:
|
||||
used_embed_dim = embed_dim
|
||||
|
||||
params_mapping = self.slice_param(embed_dim, kdim, vdim)
|
||||
params_mapping = self._freeze_weight_impl(embed_dim, kdim, vdim)
|
||||
in_proj_bias, in_proj_weight, bias_k, bias_v, \
|
||||
out_proj_weight, out_proj_bias, q_proj, k_proj, v_proj, qkv_same_embed_dim = [
|
||||
params_mapping.get(name)
|
||||
|
|
|
@ -6,19 +6,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import random
|
||||
from typing import Any, List, Dict, Sequence, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.common.hpo_utils import ParameterSpec
|
||||
from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
|
||||
from nni.nas.nn.pytorch.choice import ValueChoiceX
|
||||
from nni.mutable import MutableExpression, label_scope, Mutable, Categorical, CategoricalMultiple
|
||||
from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, Cell
|
||||
from nni.nas.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
|
||||
|
||||
from .base import BaseSuperNetModule, sub_state_dict
|
||||
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum
|
||||
from .base import BaseSuperNetModule
|
||||
from ._expression_utils import weighted_sum
|
||||
from .operation import MixedOperationSamplingPolicy, MixedOperation
|
||||
|
||||
__all__ = [
|
||||
|
@ -28,7 +28,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
class PathSamplingLayer(BaseSuperNetModule):
|
||||
class PathSamplingLayer(LayerChoice, BaseSuperNetModule):
|
||||
"""
|
||||
Mixed layer, in which fprop is decided by exactly one inner layer or sum of multiple (sampled) layers.
|
||||
If multiple modules are selected, the result will be summed and returned.
|
||||
|
@ -41,62 +41,43 @@ class PathSamplingLayer(BaseSuperNetModule):
|
|||
Name of the choice.
|
||||
"""
|
||||
|
||||
def __init__(self, paths: list[tuple[str, nn.Module]], label: str):
|
||||
super().__init__()
|
||||
self.op_names = []
|
||||
for name, module in paths:
|
||||
self.add_module(name, module)
|
||||
self.op_names.append(name)
|
||||
assert self.op_names, 'There has to be at least one op to choose from.'
|
||||
def __init__(self, paths: dict[str, nn.Module] | list[nn.Module], label: str):
|
||||
super().__init__(paths, label=label)
|
||||
self._sampled: list[str] | str | None = None # sampled can be either a list of indices or an index
|
||||
self.label = label
|
||||
|
||||
def resample(self, memo):
|
||||
"""Random choose one path if label is not found in memo."""
|
||||
if self.label in memo:
|
||||
self._sampled = memo[self.label]
|
||||
else:
|
||||
self._sampled = random.choice(self.op_names)
|
||||
self._sampled = self.choice.random()
|
||||
return {self.label: self._sampled}
|
||||
|
||||
def export(self, memo):
|
||||
"""Random choose one name if label isn't found in memo."""
|
||||
if self.label in memo:
|
||||
return {} # nothing new to export
|
||||
return {self.label: random.choice(self.op_names)}
|
||||
|
||||
def search_space_spec(self):
|
||||
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
|
||||
True, size=len(self.op_names))}
|
||||
return {self.label: self.choice.random()}
|
||||
|
||||
@classmethod
|
||||
def mutate(cls, module, name, memo, mutate_kwargs):
|
||||
if isinstance(module, LayerChoice):
|
||||
return cls(list(module.named_children()), module.label)
|
||||
if type(module) is LayerChoice:
|
||||
return cls(module.candidates, module.label)
|
||||
|
||||
def reduction(self, items: list[Any], sampled: list[Any]):
|
||||
def _reduction(self, items: list[Any], sampled: list[Any]):
|
||||
"""Override this to implement customized reduction."""
|
||||
return weighted_sum(items)
|
||||
|
||||
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
|
||||
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
|
||||
|
||||
for samp in sampled:
|
||||
module = getattr(self, str(samp))
|
||||
if module is not None:
|
||||
sub_state_dict(module, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._sampled is None:
|
||||
raise RuntimeError('At least one path needs to be sampled before fprop.')
|
||||
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
|
||||
|
||||
# str(samp) is needed here because samp can sometimes be integers, but attr are always str
|
||||
res = [getattr(self, str(samp))(*args, **kwargs) for samp in sampled]
|
||||
return self.reduction(res, sampled)
|
||||
res = [self[samp](*args, **kwargs) for samp in sampled]
|
||||
return self._reduction(res, sampled)
|
||||
|
||||
|
||||
class PathSamplingInput(BaseSuperNetModule):
|
||||
class PathSamplingInput(InputChoice, BaseSuperNetModule):
|
||||
"""
|
||||
Mixed input. Take a list of tensor as input, select some of them and return the sum.
|
||||
|
||||
|
@ -106,22 +87,9 @@ class PathSamplingInput(BaseSuperNetModule):
|
|||
Sampled input indices.
|
||||
"""
|
||||
|
||||
def __init__(self, n_candidates: int, n_chosen: int, reduction_type: str, label: str):
|
||||
super().__init__()
|
||||
self.n_candidates = n_candidates
|
||||
self.n_chosen = n_chosen
|
||||
self.reduction_type = reduction_type
|
||||
def __init__(self, n_candidates: int, n_chosen: int, reduction: str, label: str):
|
||||
super().__init__(n_candidates, n_chosen=n_chosen, reduction=reduction, label=label)
|
||||
self._sampled: list[int] | int | None = None
|
||||
self.label = label
|
||||
|
||||
def _random_choose_n(self):
|
||||
sampling = list(range(self.n_candidates))
|
||||
random.shuffle(sampling)
|
||||
sampling = sorted(sampling[:self.n_chosen])
|
||||
if len(sampling) == 1:
|
||||
return sampling[0]
|
||||
else:
|
||||
return sampling
|
||||
|
||||
def resample(self, memo):
|
||||
"""Random choose one path / multiple paths if label is not found in memo.
|
||||
|
@ -131,42 +99,36 @@ class PathSamplingInput(BaseSuperNetModule):
|
|||
if self.label in memo:
|
||||
self._sampled = memo[self.label]
|
||||
else:
|
||||
self._sampled = self._random_choose_n()
|
||||
self._sampled = self.choice.random()
|
||||
return {self.label: self._sampled}
|
||||
|
||||
def export(self, memo):
|
||||
"""Random choose one name if label isn't found in memo."""
|
||||
if self.label in memo:
|
||||
return {} # nothing new to export
|
||||
return {self.label: self._random_choose_n()}
|
||||
|
||||
def search_space_spec(self):
|
||||
return {
|
||||
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
|
||||
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
|
||||
}
|
||||
return {self.label: self.choice.random()}
|
||||
|
||||
@classmethod
|
||||
def mutate(cls, module, name, memo, mutate_kwargs):
|
||||
if isinstance(module, InputChoice):
|
||||
if type(module) is InputChoice:
|
||||
if module.reduction not in ['sum', 'mean', 'concat']:
|
||||
raise ValueError('Only input choice of sum/mean/concat reduction is supported.')
|
||||
if module.n_chosen is None:
|
||||
raise ValueError('n_chosen is None is not supported yet.')
|
||||
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
|
||||
|
||||
def reduction(self, items: list[Any], sampled: list[Any]) -> Any:
|
||||
def _reduction(self, items: list[Any], sampled: list[Any]) -> Any:
|
||||
"""Override this to implement customized reduction."""
|
||||
if len(items) == 1:
|
||||
return items[0]
|
||||
else:
|
||||
if self.reduction_type == 'sum':
|
||||
if self.reduction == 'sum':
|
||||
return sum(items)
|
||||
elif self.reduction_type == 'mean':
|
||||
elif self.reduction == 'mean':
|
||||
return sum(items) / len(items)
|
||||
elif self.reduction_type == 'concat':
|
||||
elif self.reduction == 'concat':
|
||||
return torch.cat(items, 1)
|
||||
raise ValueError(f'Unsupported reduction type: {self.reduction_type}')
|
||||
raise ValueError(f'Unsupported reduction type: {self.reduction}')
|
||||
|
||||
def forward(self, input_tensors):
|
||||
if self._sampled is None:
|
||||
|
@ -175,7 +137,7 @@ class PathSamplingInput(BaseSuperNetModule):
|
|||
raise ValueError(f'Expect {self.n_candidates} input tensors, found {len(input_tensors)}.')
|
||||
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
|
||||
res = [input_tensors[samp] for samp in sampled]
|
||||
return self.reduction(res, sampled)
|
||||
return self._reduction(res, sampled)
|
||||
|
||||
|
||||
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
|
||||
|
@ -193,28 +155,28 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
|
|||
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Random sample for each leaf value choice."""
|
||||
result = {}
|
||||
space_spec = operation.search_space_spec()
|
||||
for label in space_spec:
|
||||
space_spec = operation.simplify()
|
||||
for label, mutable in space_spec.items():
|
||||
if label in memo:
|
||||
result[label] = memo[label]
|
||||
else:
|
||||
result[label] = random.choice(space_spec[label].values)
|
||||
result[label] = mutable.random()
|
||||
|
||||
# composits to kwargs
|
||||
# composites to kwargs
|
||||
# example: result = {"exp_ratio": 3}, self._sampled = {"in_channels": 48, "out_channels": 96}
|
||||
self._sampled = {}
|
||||
for key, value in operation.mutable_arguments.items():
|
||||
self._sampled[key] = evaluate_value_choice_with_dict(value, result)
|
||||
self._sampled[key] = value.freeze(result)
|
||||
|
||||
return result
|
||||
|
||||
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Export is also random for each leaf value choice."""
|
||||
result = {}
|
||||
space_spec = operation.search_space_spec()
|
||||
for label in space_spec:
|
||||
space_spec = operation.simplify()
|
||||
for label, mutable in space_spec.items():
|
||||
if label not in memo:
|
||||
result[label] = random.choice(space_spec[label].values)
|
||||
result[label] = mutable.random()
|
||||
return result
|
||||
|
||||
def forward_argument(self, operation: MixedOperation, name: str) -> Any:
|
||||
|
@ -226,9 +188,9 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
|
|||
return operation.init_arguments[name]
|
||||
|
||||
|
||||
class PathSamplingRepeat(BaseSuperNetModule):
|
||||
class PathSamplingRepeat(Repeat, BaseSuperNetModule):
|
||||
"""
|
||||
Implementaion of Repeat in a path-sampling supernet.
|
||||
Implementation of Repeat in a path-sampling supernet.
|
||||
Samples one / some of the prefixes of the repeated blocks.
|
||||
|
||||
Attributes
|
||||
|
@ -237,56 +199,43 @@ class PathSamplingRepeat(BaseSuperNetModule):
|
|||
Sampled depth.
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int]):
|
||||
super().__init__()
|
||||
self.blocks: Any = blocks
|
||||
self.depth = depth
|
||||
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
|
||||
def __init__(self, blocks: list[nn.Module], depth: MutableExpression[int]):
|
||||
super().__init__(blocks, depth)
|
||||
self._sampled: list[int] | int | None = None
|
||||
|
||||
def resample(self, memo):
|
||||
"""Since depth is based on ValueChoice, we only need to randomly sample every leaf value choices."""
|
||||
result = {}
|
||||
for label in self._space_spec:
|
||||
assert isinstance(self.depth_choice, Mutable)
|
||||
for label, mutable in self.depth_choice.simplify().items():
|
||||
if label in memo:
|
||||
result[label] = memo[label]
|
||||
else:
|
||||
result[label] = random.choice(self._space_spec[label].values)
|
||||
result[label] = mutable.random()
|
||||
|
||||
self._sampled = evaluate_value_choice_with_dict(self.depth, result)
|
||||
self._sampled = self.depth_choice.freeze(result)
|
||||
|
||||
return result
|
||||
|
||||
def export(self, memo):
|
||||
"""Random choose one if every choice not in memo."""
|
||||
result = {}
|
||||
for label in self._space_spec:
|
||||
assert isinstance(self.depth_choice, Mutable)
|
||||
for label, mutable in self.depth_choice.simplify().items():
|
||||
if label not in memo:
|
||||
result[label] = random.choice(self._space_spec[label].values)
|
||||
result[label] = mutable.random()
|
||||
return result
|
||||
|
||||
def search_space_spec(self):
|
||||
return self._space_spec
|
||||
|
||||
@classmethod
|
||||
def mutate(cls, module, name, memo, mutate_kwargs):
|
||||
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
|
||||
if type(module) == Repeat and isinstance(module.depth_choice, MutableExpression):
|
||||
# Only interesting when depth is mutable
|
||||
return cls(cast(List[nn.Module], module.blocks), module.depth_choice)
|
||||
return cls(list(module.blocks), module.depth_choice)
|
||||
|
||||
def reduction(self, items: list[Any], sampled: list[Any]):
|
||||
def _reduction(self, items: list[Any], sampled: list[Any]):
|
||||
"""Override this to implement customized reduction."""
|
||||
return weighted_sum(items)
|
||||
|
||||
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
|
||||
sampled: Any = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
|
||||
|
||||
for cur_depth, (name, module) in enumerate(self.blocks.named_children(), start=1):
|
||||
if module is not None:
|
||||
sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
|
||||
if not any(d > cur_depth for d in sampled):
|
||||
break
|
||||
|
||||
def forward(self, x):
|
||||
if self._sampled is None:
|
||||
raise RuntimeError('At least one depth needs to be sampled before fprop.')
|
||||
|
@ -299,7 +248,7 @@ class PathSamplingRepeat(BaseSuperNetModule):
|
|||
res.append(x)
|
||||
if not any(d > cur_depth for d in sampled):
|
||||
break
|
||||
return self.reduction(res, sampled)
|
||||
return self._reduction(res, sampled)
|
||||
|
||||
|
||||
class PathSamplingCell(BaseSuperNetModule):
|
||||
|
@ -345,39 +294,46 @@ class PathSamplingCell(BaseSuperNetModule):
|
|||
# InputChoice is implicit in this graph.
|
||||
for i in self.output_node_indices:
|
||||
self.ops.append(nn.ModuleList())
|
||||
for k in range(i + self.num_predecessors):
|
||||
for k in range(i):
|
||||
# Second argument in (i, **0**, k) is always 0.
|
||||
# One-shot strategy can't handle the cases where op spec is dependent on `op_index`.
|
||||
ops, _ = create_cell_op_candidates(op_factory, i, 0, k)
|
||||
self.op_names = list(ops.keys())
|
||||
cast(nn.ModuleList, self.ops[-1]).append(nn.ModuleDict(ops))
|
||||
|
||||
self.label = label
|
||||
with label_scope(label) as self.label_scope:
|
||||
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
|
||||
for k in range(self.num_ops_per_node):
|
||||
op_label = f'op_{i}_{k}'
|
||||
input_label = f'input_{i}_{k}'
|
||||
self.add_mutable(Categorical(self.op_names, label=op_label))
|
||||
# Need multiple here to align with the original cell.
|
||||
self.add_mutable(CategoricalMultiple(range(i), n_chosen=1, label=input_label))
|
||||
|
||||
self._sampled: dict[str, str | int] = {}
|
||||
|
||||
def search_space_spec(self) -> dict[str, ParameterSpec]:
|
||||
# TODO: Recreating the space here.
|
||||
# The spec should be moved to definition of Cell itself.
|
||||
space_spec = {}
|
||||
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
|
||||
for k in range(self.num_ops_per_node):
|
||||
op_label = f'{self.label}/op_{i}_{k}'
|
||||
input_label = f'{self.label}/input_{i}_{k}'
|
||||
space_spec[op_label] = ParameterSpec(op_label, 'choice', self.op_names, (op_label,), True, size=len(self.op_names))
|
||||
space_spec[input_label] = ParameterSpec(input_label, 'choice', list(range(i)), (input_label, ), True, size=i)
|
||||
return space_spec
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return self.label_scope.name
|
||||
|
||||
def freeze(self, sample):
|
||||
raise NotImplementedError('PathSamplingCell does not support freeze.')
|
||||
|
||||
def resample(self, memo):
|
||||
"""Random choose one path if label is not found in memo."""
|
||||
self._sampled = {}
|
||||
new_sampled = {}
|
||||
for label, param_spec in self.search_space_spec().items():
|
||||
for label, param_spec in self.simplify().items():
|
||||
if label in memo:
|
||||
assert not isinstance(memo[label], list), 'Multi-path sampling is currently unsupported on cell.'
|
||||
if isinstance(memo[label], list) and len(memo[label]) > 1:
|
||||
raise ValueError(f'Multi-path sampling is currently unsupported on cell: {memo[label]}')
|
||||
self._sampled[label] = memo[label]
|
||||
else:
|
||||
if isinstance(param_spec, Categorical):
|
||||
self._sampled[label] = new_sampled[label] = random.choice(param_spec.values)
|
||||
elif isinstance(param_spec, CategoricalMultiple):
|
||||
assert param_spec.n_chosen == 1
|
||||
self._sampled[label] = new_sampled[label] = [random.choice(param_spec.values)]
|
||||
return new_sampled
|
||||
|
||||
def export(self, memo):
|
||||
|
@ -392,7 +348,7 @@ class PathSamplingCell(BaseSuperNetModule):
|
|||
|
||||
for k in range(self.num_ops_per_node):
|
||||
# Select op list based on the input chosen
|
||||
input_index = self._sampled[f'{self.label}/input_{i}_{k}']
|
||||
input_index = self._sampled[f'{self.label}/input_{i}_{k}'][0] # [0] because it's a list and n_chosen=1
|
||||
op_candidates = ops[cast(int, input_index)]
|
||||
# Select op from op list based on the op chosen
|
||||
op_index = self._sampled[f'{self.label}/op_{i}_{k}']
|
||||
|
@ -411,7 +367,7 @@ class PathSamplingCell(BaseSuperNetModule):
|
|||
Mutate only handles cells of specific configurations (e.g., with loose end).
|
||||
Fallback to the default mutate if the cell is not handled here.
|
||||
"""
|
||||
if isinstance(module, Cell):
|
||||
if type(module) is Cell:
|
||||
op_factory = None # not all the cells need to be replaced
|
||||
if module.op_candidates_factory is not None:
|
||||
op_factory = module.op_candidates_factory
|
||||
|
@ -420,10 +376,16 @@ class PathSamplingCell(BaseSuperNetModule):
|
|||
elif module.merge_op == 'loose_end':
|
||||
op_candidates_lc = module.ops[-1][-1] # type: ignore
|
||||
assert isinstance(op_candidates_lc, LayerChoice)
|
||||
op_factory = { # create a factory
|
||||
name: lambda _, __, ___: copy.deepcopy(op_candidates_lc[name])
|
||||
for name in op_candidates_lc.names
|
||||
}
|
||||
candidates = op_candidates_lc.candidates
|
||||
def _copy(_, __, ___, op):
|
||||
return copy.deepcopy(op)
|
||||
|
||||
if isinstance(candidates, list):
|
||||
op_factory = [functools.partial(_copy, op=op) for op in candidates]
|
||||
elif isinstance(candidates, dict):
|
||||
op_factory = {name: functools.partial(_copy, op=op) for name, op in candidates.items()}
|
||||
else:
|
||||
raise ValueError(f'Unsupported type of candidates: {type(candidates)}')
|
||||
if op_factory is not None:
|
||||
return cls(
|
||||
op_factory,
|
||||
|
|
|
@ -0,0 +1,633 @@
|
|||
import pytest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import nni
|
||||
from nni.mutable import MutableList, frozen
|
||||
from nni.nas.nn.pytorch import LayerChoice, ModelSpace, Cell, MutableConv2d, MutableBatchNorm2d, MutableLayerNorm, MutableLinear, MutableMultiheadAttention
|
||||
from nni.nas.oneshot.pytorch.differentiable import DartsLightningModule
|
||||
from nni.nas.oneshot.pytorch.strategy import RandomOneShot, DARTS
|
||||
from nni.nas.oneshot.pytorch.supermodule.base import BaseSuperNetModule
|
||||
from nni.nas.oneshot.pytorch.supermodule.differentiable import (
|
||||
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax,
|
||||
DifferentiableMixedRepeat, DifferentiableMixedCell
|
||||
)
|
||||
from nni.nas.oneshot.pytorch.supermodule.sampling import (
|
||||
MixedOpPathSamplingPolicy, PathSamplingLayer, PathSamplingInput, PathSamplingRepeat, PathSamplingCell
|
||||
)
|
||||
from nni.nas.oneshot.pytorch.supermodule.operation import MixedConv2d, NATIVE_MIXED_OPERATIONS
|
||||
from nni.nas.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput
|
||||
from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W
|
||||
from nni.nas.oneshot.pytorch.supermodule._expression_utils import *
|
||||
|
||||
from ut.nas.nn.models import (
|
||||
CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory
|
||||
)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def context():
|
||||
frozen._ENSURE_FROZEN_STRICT = False
|
||||
yield
|
||||
frozen._ENSURE_FROZEN_STRICT = True
|
||||
|
||||
|
||||
def test_slice():
|
||||
weight = np.ones((3, 7, 24, 23))
|
||||
assert S(weight)[:, 1:3, :, 9:13].shape == (3, 2, 24, 4)
|
||||
assert S(weight)[:, 1:W(3)*2+1, :, 9:13].shape == (3, 6, 24, 4)
|
||||
assert S(weight)[:, 1:W(3)*2+1].shape == (3, 6, 24, 23)
|
||||
|
||||
# Ellipsis
|
||||
assert S(weight)[..., 9:13].shape == (3, 7, 24, 4)
|
||||
assert S(weight)[:2, ..., 1:W(3)+1].shape == (2, 7, 24, 3)
|
||||
assert S(weight)[..., 1:W(3)*2+1].shape == (3, 7, 24, 6)
|
||||
assert S(weight)[..., :10, 1:W(3)*2+1].shape == (3, 7, 10, 6)
|
||||
|
||||
# no effect
|
||||
assert S(weight)[:] is weight
|
||||
|
||||
# list
|
||||
assert S(weight)[[slice(1), slice(2, 3)]].shape == (2, 7, 24, 23)
|
||||
assert S(weight)[[slice(1), slice(2, W(2) + 1)], W(2):].shape == (2, 5, 24, 23)
|
||||
|
||||
# weighted
|
||||
weight = S(weight)[:W({1: 0.5, 2: 0.3, 3: 0.2})]
|
||||
weight = weight[:, 0, 0, 0]
|
||||
assert weight[0] == 1 and weight[1] == 0.5 and weight[2] == 0.2
|
||||
|
||||
weight = np.ones((3, 6, 6))
|
||||
value = W({1: 0.5, 3: 0.5})
|
||||
weight = S(weight)[:, 3 - value:3 + value, 3 - value:3 + value]
|
||||
for i in range(0, 6):
|
||||
for j in range(0, 6):
|
||||
if 2 <= i <= 3 and 2 <= j <= 3:
|
||||
assert weight[0, i, j] == 1
|
||||
else:
|
||||
assert weight[1, i, j] == 0.5
|
||||
|
||||
# weighted + list
|
||||
value = W({1: 0.5, 3: 0.5})
|
||||
weight = np.ones((8, 4))
|
||||
weight = S(weight)[[slice(value), slice(4, value + 4)]]
|
||||
assert weight.sum(1).tolist() == [4, 2, 2, 0, 4, 2, 2, 0]
|
||||
|
||||
with pytest.raises(ValueError, match='one distinct'):
|
||||
# has to be exactly the same instance, equal is not enough
|
||||
weight = S(weight)[:W({1: 0.5}), : W({1: 0.5})]
|
||||
|
||||
|
||||
def test_valuechoice_utils():
|
||||
chosen = {"exp": 3, "add": 1}
|
||||
vc0 = nni.choice('exp', [3, 4, 6]) * 2 + nni.choice('add', [0, 1])
|
||||
|
||||
assert vc0.freeze(chosen) == 7
|
||||
vc = vc0 + nni.choice('exp', [3, 4, 6])
|
||||
assert vc.freeze(chosen) == 10
|
||||
|
||||
assert list(MutableList([vc0, vc]).simplify().keys()) == ['exp', 'add']
|
||||
|
||||
assert traverse_all_options(vc) == [9, 10, 12, 13, 18, 19]
|
||||
weights = dict(traverse_all_options(vc, weights={'exp': [0.5, 0.3, 0.2], 'add': [0.4, 0.6]}))
|
||||
ans = dict([(9, 0.2), (10, 0.3), (12, 0.12), (13, 0.18), (18, 0.08), (19, 0.12)])
|
||||
assert len(weights) == len(ans)
|
||||
for value, weight in ans.items():
|
||||
assert abs(weight - weights[value]) < 1e-6
|
||||
|
||||
assert evaluate_constant(nni.choice('x', [3, 4, 6]) - nni.choice('x', [3, 4, 6])) == 0
|
||||
with pytest.raises(ValueError):
|
||||
evaluate_constant(nni.choice('x', [3, 4, 6]) - nni.choice('y', [3, 4, 6]))
|
||||
|
||||
assert evaluate_constant(nni.choice('x', [3, 4, 6]) * 2 / nni.choice('x', [3, 4, 6])) == 2
|
||||
|
||||
|
||||
def test_expectation():
|
||||
vc = nni.choice('exp', [3, 4, 6]) * 2 + nni.choice('add', [0, 1])
|
||||
assert expression_expectation(vc, {'exp': [0.5, 0.3, 0.2], 'add': [0.4, 0.6]}) == 8.4
|
||||
|
||||
vc = sum([nni.choice(f'e{i}', [0, 1]) for i in range(100)])
|
||||
assert expression_expectation(vc, {f'e{i}': [0.5] * 2 for i in range(100)}) == 50
|
||||
|
||||
vc = nni.choice('a', [1, 2, 3]) * nni.choice('b', [1, 2, 3]) - nni.choice('c', [1, 2, 3])
|
||||
probs1 = [0.2, 0.3, 0.5]
|
||||
probs2 = [0.1, 0.2, 0.7]
|
||||
probs3 = [0.3, 0.4, 0.3]
|
||||
expect = sum(
|
||||
(i * j - k) * p1 * p2 * p3
|
||||
for i, p1 in enumerate(probs1, 1)
|
||||
for j, p2 in enumerate(probs2, 1)
|
||||
for k, p3 in enumerate(probs3, 1)
|
||||
)
|
||||
assert abs(expression_expectation(vc, {'a': probs1, 'b': probs2, 'c': probs3}) - expect) < 1e-12
|
||||
|
||||
vc = nni.choice('a', [1, 2, 3]) + 1
|
||||
assert expression_expectation(vc, {'a': [0.2, 0.3, 0.5]}) == 3.3
|
||||
|
||||
|
||||
def test_weighted_sum():
|
||||
weights = [0.1, 0.2, 0.7]
|
||||
items = [1, 2, 3]
|
||||
assert abs(weighted_sum(items, weights) - 2.6) < 1e-6
|
||||
|
||||
assert weighted_sum(items) == 6
|
||||
|
||||
with pytest.raises(TypeError, match='Unsupported'):
|
||||
weighted_sum(['a', 'b', 'c'], weights)
|
||||
|
||||
assert abs(weighted_sum(np.arange(3), weights).item() - 1.6) < 1e-6
|
||||
|
||||
items = [torch.full((2, 3, 5), i) for i in items]
|
||||
assert abs(weighted_sum(items, weights).flatten()[0].item() - 2.6) < 1e-6
|
||||
|
||||
items = [torch.randn(2, 3, i) for i in [1, 2, 3]]
|
||||
with pytest.raises(ValueError, match=r'does not match.*\n.*torch\.Tensor\(2, 3, 1\)'):
|
||||
weighted_sum(items, weights)
|
||||
|
||||
items = [(1, 2), (3, 4), (5, 6)]
|
||||
res = weighted_sum(items, weights)
|
||||
assert len(res) == 2 and abs(res[0] - 4.2) < 1e-6 and abs(res[1] - 5.2) < 1e-6
|
||||
|
||||
items = [(1, 2), (3, 4), (5, 6, 7)]
|
||||
with pytest.raises(ValueError):
|
||||
weighted_sum(items, weights)
|
||||
|
||||
items = [{"a": i, "b": np.full((2, 3, 5), i)} for i in [1, 2, 3]]
|
||||
res = weighted_sum(items, weights)
|
||||
assert res['b'].shape == (2, 3, 5)
|
||||
assert abs(res['b'][0][0][0] - res['a']) < 1e-6
|
||||
assert abs(res['a'] - 2.6) < 1e-6
|
||||
|
||||
|
||||
def test_pathsampling_valuechoice():
|
||||
orig_conv = MutableConv2d(3, nni.choice('123', [3, 5, 7]), kernel_size=3)
|
||||
conv = MixedConv2d.mutate(orig_conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
|
||||
conv.resample(memo={'123': 5})
|
||||
assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 5
|
||||
conv.resample(memo={'123': 7})
|
||||
assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 7
|
||||
assert conv.export({})['123'] in [3, 5, 7]
|
||||
|
||||
|
||||
def test_differentiable_valuechoice():
|
||||
orig_conv = MutableConv2d(3, nni.choice('456', [3, 5, 7]),
|
||||
kernel_size=nni.choice('123', [3, 5, 7]),
|
||||
padding=nni.choice('123', [3, 5, 7]) // 2
|
||||
)
|
||||
memo = {
|
||||
'123': nn.Parameter(torch.zeros(3)),
|
||||
'456': nn.Parameter(torch.zeros(3)),
|
||||
}
|
||||
conv = MixedConv2d.mutate(orig_conv, 'dummy', memo, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
|
||||
assert conv(torch.zeros((1, 3, 7, 7))).size(2) == 7
|
||||
|
||||
assert set(conv.export({}).keys()) == {'123', '456'}
|
||||
|
||||
|
||||
def test_differentiable_layerchoice_dedup():
|
||||
layerchoice1 = LayerChoice([MutableConv2d(3, 3, 3), MutableConv2d(3, 3, 3)], label='a')
|
||||
layerchoice2 = LayerChoice([MutableConv2d(3, 3, 3), MutableConv2d(3, 3, 3)], label='a')
|
||||
|
||||
memo = {'a': nn.Parameter(torch.zeros(2))}
|
||||
DifferentiableMixedLayer.mutate(layerchoice1, 'x', memo, {})
|
||||
DifferentiableMixedLayer.mutate(layerchoice2, 'x', memo, {})
|
||||
assert len(memo) == 1 and 'a' in memo
|
||||
|
||||
|
||||
def _mutate_op_path_sampling_policy(operation):
|
||||
for native_op in NATIVE_MIXED_OPERATIONS:
|
||||
if native_op.bound_type == type(operation):
|
||||
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
|
||||
break
|
||||
return mutate_op
|
||||
|
||||
|
||||
def _mixed_operation_sampling_sanity_check(operation, memo, *input):
|
||||
mutate_op = _mutate_op_path_sampling_policy(operation)
|
||||
mutate_op.resample(memo=memo)
|
||||
return mutate_op(*input)
|
||||
|
||||
|
||||
def _mixed_operation_state_dict_sanity_check(operation, model, memo, *input):
|
||||
mutate_op = _mutate_op_path_sampling_policy(operation)
|
||||
mutate_op.resample(memo=memo)
|
||||
frozen_op = mutate_op.freeze(memo)
|
||||
return frozen_op(*input), mutate_op(*input)
|
||||
|
||||
|
||||
def _mixed_operation_differentiable_sanity_check(operation, *input):
|
||||
memo = {k: nn.Parameter(torch.zeros(len(v))) for k, v in operation.simplify().items()}
|
||||
for native_op in NATIVE_MIXED_OPERATIONS:
|
||||
if native_op.bound_type == type(operation):
|
||||
mutate_op = native_op.mutate(operation, 'dummy', memo, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
|
||||
break
|
||||
|
||||
mutate_op(*input)
|
||||
mutate_op.export({})
|
||||
mutate_op.export_probs({})
|
||||
|
||||
|
||||
def test_mixed_linear():
|
||||
linear = MutableLinear(nni.choice('shared', [3, 6, 9]), nni.choice('xx', [2, 4, 8]))
|
||||
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
|
||||
_mixed_operation_sampling_sanity_check(linear, {'shared': 9}, torch.randn(2, 9))
|
||||
_mixed_operation_differentiable_sanity_check(linear, torch.randn(2, 9))
|
||||
|
||||
linear = MutableLinear(nni.choice('shared', [3, 6, 9]), nni.choice('xx', [2, 4, 8]), bias=False)
|
||||
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
linear = MutableLinear(nni.choice('shared', [3, 6, 9]), nni.choice('xx', [2, 4, 8]), bias=nni.choice('yy', [False, True]))
|
||||
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
|
||||
|
||||
linear = MutableLinear(nni.choice('in_features', [3, 6, 9]), nni.choice('out_features', [2, 4, 8]), bias=True)
|
||||
kwargs = {'in_features': 6, 'out_features': 4}
|
||||
out1, out2 = _mixed_operation_state_dict_sanity_check(linear, MutableLinear(**kwargs), kwargs, torch.randn(2, 6))
|
||||
assert torch.allclose(out1, out2)
|
||||
|
||||
|
||||
def test_mixed_conv2d():
|
||||
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('out', [2, 4, 8]) * 2, 1)
|
||||
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'out': 4}, torch.randn(2, 3, 9, 9)).size(1) == 8
|
||||
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
|
||||
|
||||
# stride
|
||||
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('out', [2, 4, 8]), 1, stride=nni.choice('stride', [1, 2]))
|
||||
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 2}, torch.randn(2, 3, 10, 10)).size(2) == 5
|
||||
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 1}, torch.randn(2, 3, 10, 10)).size(2) == 10
|
||||
with pytest.raises(ValueError, match='must not be mutable'):
|
||||
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 10, 10))
|
||||
|
||||
# groups, dw conv
|
||||
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9]))
|
||||
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10])
|
||||
|
||||
# groups, invalid case
|
||||
conv = MutableConv2d(nni.choice('in', [9, 6, 3]), nni.choice('in', [9, 6, 3]), 1, groups=9)
|
||||
with pytest.raises(RuntimeError):
|
||||
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10))
|
||||
|
||||
# groups, differentiable
|
||||
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('out', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9]))
|
||||
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
|
||||
|
||||
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9]))
|
||||
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('groups', [3, 9]))
|
||||
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
conv = MutableConv2d(nni.choice('in', [3, 6, 9]), nni.choice('in', [3, 6, 9]), 1, groups=nni.choice('in', [3, 6, 9]) // 3)
|
||||
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 10, 3, 3))
|
||||
|
||||
# make sure kernel is sliced correctly
|
||||
conv = MutableConv2d(1, 1, nni.choice('k', [1, 3]), bias=False)
|
||||
conv = MixedConv2d.mutate(conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
|
||||
with torch.no_grad():
|
||||
conv.weight.zero_()
|
||||
# only center is 1, must pick center to pass this test
|
||||
conv.weight[0, 0, 1, 1] = 1
|
||||
conv.resample({'k': 1})
|
||||
assert conv(torch.ones((1, 1, 3, 3))).sum().item() == 9
|
||||
|
||||
# only `in_channels`, `out_channels`, `kernel_size`, and `groups` influence state_dict
|
||||
conv = MutableConv2d(
|
||||
nni.choice('in_channels', [2, 4, 8]), nni.choice('out_channels', [6, 12, 24]),
|
||||
kernel_size=nni.choice('kernel_size', [3, 5, 7]), groups=nni.choice('groups', [1, 2])
|
||||
)
|
||||
kwargs = {
|
||||
'in_channels': 8, 'out_channels': 12,
|
||||
'kernel_size': 5, 'groups': 2
|
||||
}
|
||||
out1, out2 = _mixed_operation_state_dict_sanity_check(conv, MutableConv2d(**kwargs), kwargs, torch.randn(2, 8, 16, 16))
|
||||
assert torch.allclose(out1, out2)
|
||||
|
||||
def test_mixed_batchnorm2d():
|
||||
bn = MutableBatchNorm2d(nni.choice('dim', [32, 64]))
|
||||
|
||||
assert _mixed_operation_sampling_sanity_check(bn, {'dim': 32}, torch.randn(2, 32, 3, 3)).size(1) == 32
|
||||
assert _mixed_operation_sampling_sanity_check(bn, {'dim': 64}, torch.randn(2, 64, 3, 3)).size(1) == 64
|
||||
|
||||
_mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3))
|
||||
|
||||
bn = MutableBatchNorm2d(nni.choice('num_features', [32, 48, 64]))
|
||||
kwargs = {'num_features': 48}
|
||||
out1, out2 = _mixed_operation_state_dict_sanity_check(bn, MutableBatchNorm2d(**kwargs), kwargs, torch.randn(2, 48, 3, 3))
|
||||
assert torch.allclose(out1, out2)
|
||||
|
||||
def test_mixed_layernorm():
|
||||
ln = MutableLayerNorm(nni.choice('normalized_shape', [32, 64]), elementwise_affine=True)
|
||||
|
||||
assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 32}, torch.randn(2, 16, 32)).size(-1) == 32
|
||||
assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 64}, torch.randn(2, 16, 64)).size(-1) == 64
|
||||
|
||||
_mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 16, 64))
|
||||
|
||||
import itertools
|
||||
ln = MutableLayerNorm(nni.choice('normalized_shape', list(itertools.product([16, 32, 64], [8, 16]))))
|
||||
|
||||
assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (16, 8)}, torch.randn(2, 16, 8)).shape[-2:]) == [16, 8]
|
||||
assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (64, 16)}, torch.randn(2, 64, 16)).shape[-2:]) == [64, 16]
|
||||
|
||||
_mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 64, 16))
|
||||
|
||||
ln = MutableLayerNorm(nni.choice('normalized_shape', [32, 48, 64]))
|
||||
kwargs = {'normalized_shape': 48}
|
||||
out1, out2 = _mixed_operation_state_dict_sanity_check(ln, MutableLayerNorm(**kwargs), kwargs, torch.randn(2, 8, 48))
|
||||
assert torch.allclose(out1, out2)
|
||||
|
||||
def test_mixed_mhattn():
|
||||
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 4)
|
||||
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4},
|
||||
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8},
|
||||
torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))[0].size(-1) == 8
|
||||
|
||||
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))
|
||||
|
||||
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), nni.choice('heads', [2, 3, 4]))
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 2},
|
||||
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
|
||||
with pytest.raises(AssertionError, match='divisible'):
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 3},
|
||||
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
|
||||
|
||||
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 4, kdim=nni.choice('kdim', [5, 7]))
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7},
|
||||
torch.randn(7, 2, 4), torch.randn(7, 2, 7), torch.randn(7, 2, 4))[0].size(-1) == 4
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 5},
|
||||
torch.randn(7, 2, 8), torch.randn(7, 2, 5), torch.randn(7, 2, 8))[0].size(-1) == 8
|
||||
|
||||
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 4, vdim=nni.choice('vdim', [5, 8]))
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'vdim': 8},
|
||||
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 8))[0].size(-1) == 4
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'vdim': 5},
|
||||
torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 5))[0].size(-1) == 8
|
||||
|
||||
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(5, 3, 8), torch.randn(5, 3, 8), torch.randn(5, 3, 8))
|
||||
|
||||
mhattn = MutableMultiheadAttention(embed_dim=nni.choice('embed_dim', [4, 8, 16]), num_heads=nni.choice('num_heads', [1, 2, 4]),
|
||||
kdim=nni.choice('kdim', [4, 8, 16]), vdim=nni.choice('vdim', [4, 8, 16]))
|
||||
kwargs = {'embed_dim': 16, 'num_heads': 2, 'kdim': 4, 'vdim': 8}
|
||||
(out1, _), (out2, _) = _mixed_operation_state_dict_sanity_check(mhattn, MutableMultiheadAttention(**kwargs), kwargs, torch.randn(7, 2, 16), torch.randn(7, 2, 4), torch.randn(7, 2, 8))
|
||||
assert torch.allclose(out1, out2)
|
||||
|
||||
@pytest.mark.skipif(torch.__version__.startswith('1.7'), reason='batch_first is not supported for legacy PyTorch')
|
||||
def test_mixed_mhattn_batch_first():
|
||||
# batch_first is not supported for legacy pytorch versions
|
||||
# mark 1.7 because 1.7 is used on legacy pipeline
|
||||
|
||||
mhattn = MutableMultiheadAttention(nni.choice('emb', [4, 8]), 2, kdim=(nni.choice('kdim', [3, 7])), vdim=nni.choice('vdim', [5, 8]),
|
||||
bias=False, add_bias_kv=True, batch_first=True)
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7, 'vdim': 8},
|
||||
torch.randn(2, 7, 4), torch.randn(2, 7, 7), torch.randn(2, 7, 8))[0].size(-1) == 4
|
||||
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 3, 'vdim': 5},
|
||||
torch.randn(2, 7, 8), torch.randn(2, 7, 3), torch.randn(2, 7, 5))[0].size(-1) == 8
|
||||
|
||||
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(1, 7, 8), torch.randn(1, 7, 7), torch.randn(1, 7, 8))
|
||||
|
||||
|
||||
def test_pathsampling_layer_input():
|
||||
op = PathSamplingLayer({'a': MutableLinear(2, 3, bias=False), 'b': MutableLinear(2, 3, bias=True)}, label='ccc')
|
||||
with pytest.raises(RuntimeError, match='sample'):
|
||||
op(torch.randn(4, 2))
|
||||
|
||||
op.resample({})
|
||||
assert op(torch.randn(4, 2)).size(-1) == 3
|
||||
assert op.simplify()['ccc'].values == ['a', 'b']
|
||||
assert op.export({})['ccc'] in ['a', 'b']
|
||||
|
||||
input = PathSamplingInput(5, 2, 'concat', 'ddd')
|
||||
sample = input.resample({})
|
||||
assert 'ddd' in sample
|
||||
assert len(sample['ddd']) == 2
|
||||
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 4
|
||||
assert len(input.export({})['ddd']) == 2
|
||||
|
||||
|
||||
def test_differentiable_layer_input():
|
||||
op = DifferentiableMixedLayer({'a': MutableLinear(2, 3, bias=False), 'b': MutableLinear(2, 3, bias=True)}, nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
|
||||
assert op(torch.randn(4, 2)).size(-1) == 3
|
||||
assert op.export({})['eee'] in ['a', 'b']
|
||||
probs = op.export_probs({})
|
||||
assert len(probs) == 1
|
||||
assert len(probs['eee']) == 2
|
||||
assert abs(probs['eee']['a'] + probs['eee']['b'] - 1) < 1e-4
|
||||
assert len(list(op.parameters())) == 4
|
||||
assert len(list(op.arch_parameters())) == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
op = DifferentiableMixedLayer({'a': MutableLinear(2, 3), 'b': MutableLinear(2, 4)}, nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
|
||||
op(torch.randn(4, 2))
|
||||
|
||||
input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
|
||||
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
|
||||
assert len(input.export({})['ddd']) == 2
|
||||
assert len(input.export_probs({})) == 1
|
||||
assert len(input.export_probs({})['ddd']) == 5
|
||||
assert 3 in input.export_probs({})['ddd']
|
||||
|
||||
|
||||
def test_proxyless_layer_input():
|
||||
op = ProxylessMixedLayer({'a': MutableLinear(2, 3, bias=False), 'b': MutableLinear(2, 3, bias=True)}, nn.Parameter(torch.randn(2)),
|
||||
nn.Softmax(-1), 'eee')
|
||||
assert op.resample({})['eee'] in ['a', 'b']
|
||||
assert op(torch.randn(4, 2)).size(-1) == 3
|
||||
assert op.export({})['eee'] in ['a', 'b']
|
||||
assert len(list(op.parameters())) == 4
|
||||
assert len(list(op.arch_parameters())) == 1
|
||||
|
||||
input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
|
||||
assert all(x in list(range(5)) for x in input.resample({})['ddd'])
|
||||
assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2])
|
||||
exported = input.export({})['ddd']
|
||||
assert len(exported) == 2 and all(e in list(range(5)) for e in exported)
|
||||
|
||||
|
||||
def test_pathsampling_repeat():
|
||||
op = PathSamplingRepeat([MutableLinear(16, 16), MutableLinear(16, 8), MutableLinear(8, 4)], nni.choice('ccc', [1, 2, 3]))
|
||||
sample = op.resample({})
|
||||
assert sample['ccc'] in [1, 2, 3]
|
||||
for i in range(1, 4):
|
||||
op.resample({'ccc': i})
|
||||
out = op(torch.randn(2, 16))
|
||||
assert out.shape[1] == [16, 8, 4][i - 1]
|
||||
|
||||
op = PathSamplingRepeat([MutableLinear(i + 1, i + 2) for i in range(7)], 2 * nni.choice('ddd', [1, 2, 3]) + 1)
|
||||
sample = op.resample({})
|
||||
assert sample['ddd'] in [1, 2, 3]
|
||||
for i in range(1, 4):
|
||||
op.resample({'ddd': i})
|
||||
out = op(torch.randn(2, 1))
|
||||
assert out.shape[1] == (2 * i + 1) + 1
|
||||
|
||||
|
||||
def test_differentiable_repeat():
|
||||
op = DifferentiableMixedRepeat(
|
||||
[MutableLinear(8 if i == 0 else 16, 16) for i in range(4)],
|
||||
nni.choice('ccc', [0, 1]) * 2 + 1,
|
||||
GumbelSoftmax(-1),
|
||||
{'ccc': nn.Parameter(torch.randn(2))},
|
||||
)
|
||||
op.resample({})
|
||||
assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
|
||||
sample = op.export({})
|
||||
assert 'ccc' in sample and sample['ccc'] in [0, 1]
|
||||
assert sorted(op.export_probs({})['ccc'].keys()) == [0, 1]
|
||||
|
||||
class TupleModule(nn.Module):
|
||||
def __init__(self, num):
|
||||
super().__init__()
|
||||
self.num = num
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return torch.full((2, 3), self.num), torch.full((3, 5), self.num), {'a': 7, 'b': [self.num] * 11}
|
||||
|
||||
class CustomSoftmax(nn.Softmax):
|
||||
def forward(self, *args, **kwargs):
|
||||
return [0.3, 0.3, 0.4]
|
||||
|
||||
op = DifferentiableMixedRepeat(
|
||||
[TupleModule(i + 1) for i in range(4)],
|
||||
nni.choice('ccc', [1, 2, 4]),
|
||||
CustomSoftmax(),
|
||||
{'ccc': nn.Parameter(torch.randn(3))},
|
||||
)
|
||||
op.resample({})
|
||||
res = op(None)
|
||||
assert len(res) == 3
|
||||
assert res[0].shape == (2, 3) and res[0][0][0].item() == 2.5
|
||||
assert res[2]['a'] == 7
|
||||
assert len(res[2]['b']) == 11 and res[2]['b'][-1] == 2.5
|
||||
|
||||
|
||||
def test_pathsampling_cell():
|
||||
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
|
||||
model = cell_cls()
|
||||
strategy = RandomOneShot()
|
||||
model = strategy.mutate_model(model)
|
||||
nas_modules = [m for m in model.modules() if isinstance(m, BaseSuperNetModule)]
|
||||
result = {}
|
||||
for module in nas_modules:
|
||||
result.update(module.resample(memo=result))
|
||||
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
|
||||
result = {}
|
||||
for module in nas_modules:
|
||||
result.update(module.export(memo=result))
|
||||
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
|
||||
|
||||
if cell_cls in [CellLooseEnd, CellOpFactory]:
|
||||
assert isinstance(model.cell, PathSamplingCell)
|
||||
else:
|
||||
assert not isinstance(model.cell, PathSamplingCell)
|
||||
|
||||
inputs = {
|
||||
CellSimple: (torch.randn(2, 16), torch.randn(2, 16)),
|
||||
CellDefaultArgs: (torch.randn(2, 16),),
|
||||
CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)),
|
||||
CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)),
|
||||
CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)),
|
||||
}[cell_cls]
|
||||
|
||||
output = model(*inputs)
|
||||
if cell_cls == CellCustomProcessor:
|
||||
assert isinstance(output, tuple) and len(output) == 2 and \
|
||||
output[1].shape == torch.Size([2, 16 * model.cell.num_nodes])
|
||||
else:
|
||||
# no loose-end support for now
|
||||
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
|
||||
|
||||
|
||||
def test_differentiable_cell():
|
||||
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
|
||||
model = cell_cls()
|
||||
strategy = DARTS()
|
||||
model = strategy.mutate_model(model)
|
||||
nas_modules = [m for m in model.modules() if isinstance(m, BaseSuperNetModule)]
|
||||
result = {}
|
||||
for module in nas_modules:
|
||||
result.update(module.export(memo=result))
|
||||
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
|
||||
for k, v in result.items():
|
||||
if 'input' in k:
|
||||
assert isinstance(v, list) and len(v) == 1
|
||||
|
||||
result_prob = {}
|
||||
for module in nas_modules:
|
||||
result_prob.update(module.export_probs(memo=result_prob))
|
||||
|
||||
ctrl_params = []
|
||||
for m in nas_modules:
|
||||
ctrl_params += list(m.arch_parameters())
|
||||
if cell_cls in [CellLooseEnd, CellOpFactory]:
|
||||
assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2
|
||||
assert len(result_prob) == len(ctrl_params) # len(op_names) == 2
|
||||
for v in result_prob.values():
|
||||
assert len(v) == 2
|
||||
assert isinstance(model.cell, DifferentiableMixedCell)
|
||||
else:
|
||||
assert not isinstance(model.cell, DifferentiableMixedCell)
|
||||
|
||||
inputs = {
|
||||
CellSimple: (torch.randn(2, 16), torch.randn(2, 16)),
|
||||
CellDefaultArgs: (torch.randn(2, 16),),
|
||||
CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)),
|
||||
CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)),
|
||||
CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)),
|
||||
}[cell_cls]
|
||||
|
||||
output = model(*inputs)
|
||||
if cell_cls == CellCustomProcessor:
|
||||
assert isinstance(output, tuple) and len(output) == 2 and \
|
||||
output[1].shape == torch.Size([2, 16 * model.cell.num_nodes])
|
||||
else:
|
||||
# no loose-end support for now
|
||||
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
|
||||
|
||||
|
||||
def test_memo_sharing():
|
||||
class TestModelSpace(ModelSpace):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = Cell(
|
||||
[nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
|
||||
num_nodes=3, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end',
|
||||
label='cell'
|
||||
)
|
||||
self.linear2 = Cell(
|
||||
[nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
|
||||
num_nodes=3, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end',
|
||||
label='cell'
|
||||
)
|
||||
|
||||
strategy = DARTS()
|
||||
model = strategy.mutate_model(TestModelSpace())
|
||||
assert model.linear1._arch_alpha['cell/2_0'] is model.linear2._arch_alpha['cell/2_0']
|
||||
|
||||
|
||||
def test_parameters():
|
||||
class Model(ModelSpace):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = DifferentiableMixedLayer(
|
||||
{
|
||||
'a': MutableLinear(2, 3, bias=False),
|
||||
'b': MutableLinear(2, 3, bias=True)
|
||||
},
|
||||
nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'abc'
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
model = Model()
|
||||
assert len(list(model.parameters())) == 4
|
||||
assert len(list(model.op.arch_parameters())) == 1
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), 0.1)
|
||||
assert len(DartsLightningModule(model).arch_parameters()) == 1
|
||||
optimizer = DartsLightningModule(model).postprocess_weight_optimizers(optimizer)
|
||||
assert len(optimizer.param_groups[0]['params']) == 3
|
Загрузка…
Ссылка в новой задаче