зеркало из https://github.com/microsoft/nni.git
One-shot documentation (more imprv.) (#4924)
This commit is contained in:
Родитель
2bc984412c
Коммит
993109bb3f
|
@ -72,7 +72,7 @@ For example,
|
|||
test_dataset = nni.trace(MNIST, root='data/mnist', train=False, download=True, transform=transform)
|
||||
|
||||
# pl.DataLoader and pl.Classification is already traced and supports serialization.
|
||||
evaluator = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
|
||||
evaluator = pl.Classification(train_dataloaders=pl.DataLoader(train_dataset, batch_size=100),
|
||||
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
|
||||
max_epochs=10)
|
||||
|
||||
|
@ -143,6 +143,6 @@ Then, users need to wrap everything (including LightningModule, trainer and data
|
|||
|
||||
lightning = pl.Lightning(AutoEncoder(),
|
||||
pl.Trainer(max_epochs=10),
|
||||
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
|
||||
train_dataloaders=pl.DataLoader(train_dataset, batch_size=100),
|
||||
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
|
||||
experiment = RetiariiExperiment(base_model, lightning, mutators, strategy)
|
||||
|
|
|
@ -100,7 +100,7 @@ We have already implemented two trainers: :class:`nni.retiarii.evaluator.pytorch
|
|||
|
||||
from nni.retiarii.evaluator.pytorch.cgo.evaluator import Classification
|
||||
|
||||
trainer = Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
|
||||
trainer = Classification(train_dataloaders=pl.DataLoader(train_dataset, batch_size=100),
|
||||
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
|
||||
max_epochs=1, limit_train_batches=0.2)
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ One-shot strategy
|
|||
|
||||
One-shot NAS algorithms leverage weight sharing among models in neural architecture search space to train a supernet, and use this supernet to guide the selection of better models. This type of algorihtms greatly reduces computational resource compared to independently training each model from scratch (which we call "Multi-trial NAS").
|
||||
|
||||
Starting from v2.8, the usage of one-shot strategies are much alike to multi-trial strategies. Users simply need to create a strategy and run :class:`~nni.retiarii.experiment.pytorch.RetiariiExperiment`. Since one-shot strategies will manipulate the training recipe, to use a one-shot strategy, the evaluator needs to be one of the :ref:`PyTorch-Lightning evaluators <lightning-evaluator>`, either built-in or customized. Example follows:
|
||||
Starting from v2.8, the usage of one-shot strategies are much alike to multi-trial strategies. Users simply need to create a strategy and run :class:`~nni.retiarii.experiment.pytorch.RetiariiExperiment`. Since one-shot strategies will manipulate the training recipe, to use a one-shot strategy, the evaluator needs to be one of the :ref:`PyTorch-Lightning evaluators <lightning-evaluator>`, either built-in or customized. Last but not least, don't forget to set execution engine to ``oneshot``. Example follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -78,14 +78,22 @@ Starting from v2.8, the usage of one-shot strategies are much alike to multi-tri
|
|||
evaluator = pl.Classification(...)
|
||||
exploration_strategy = strategy.DARTS()
|
||||
|
||||
exp_config.execution_engine = 'oneshot'
|
||||
|
||||
One-shot strategies only support a limited set of :ref:`mutation-primitives`, and does not support :doc:`customizing mutators <mutator>` at all. See the :ref:`reference <one-shot-strategy-reference>` for the detailed support list of each algorithm.
|
||||
|
||||
*New in v2.8*: One-shot strategy is now compatible with `Lightning accelerators <https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu.html>`__. It means that, you can accelerate one-shot strategies on hardwares like multiple GPUs. To enable this feature, you only need to pass the keyword arguments which used to be set in ``pytorch_lightning.Trainer``, to your evaluator. See :doc:`this reference </reference/nas/evaluator>` for more details.
|
||||
.. versionadded:: 2.8
|
||||
|
||||
One-shot strategy is now compatible with `Lightning accelerators <https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu.html>`__. It means that, you can accelerate one-shot strategies on hardwares like multiple GPUs. To enable this feature, you only need to pass the keyword arguments which used to be set in ``pytorch_lightning.Trainer``, to your evaluator. See :doc:`this reference </reference/nas/evaluator>` for more details.
|
||||
|
||||
One-shot strategy (legacy)
|
||||
--------------------------
|
||||
|
||||
.. warning:: The following usages are deprecated and will be removed in future releases. If you intend to use them, the references can be found :doc:`here </deprecated/oneshot_legacy>`.
|
||||
.. warning::
|
||||
|
||||
.. deprecated:: 2.8
|
||||
|
||||
The following usages are deprecated and will be removed in future releases. If you intend to use them, the references can be found :doc:`here </deprecated/oneshot_legacy>`.
|
||||
|
||||
The usage of one-shot NAS strategy is a little different from multi-trial strategy. One-shot strategy is implemented with a special type of objects named *Trainer*. Following the common practice of one-shot NAS, *Trainer* trains the super-net and searches for the optimal architecture in a single run. For example,
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ Below is an example, ``transforms.Compose``, ``transforms.Normalize``, and ``MNI
|
|||
|
||||
train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
|
||||
test_dataset = nni.trace(create_mnist_dataset)('data/mnist', transform=transform) # factory is also acceptable
|
||||
evaluator = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
|
||||
evaluator = pl.Classification(train_dataloaders=pl.DataLoader(train_dataset, batch_size=100),
|
||||
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
|
||||
max_epochs=10)
|
||||
|
||||
|
|
|
@ -60,6 +60,11 @@ dl.field-list > dd {
|
|||
margin-left: 1.5em;
|
||||
}
|
||||
|
||||
/* Version-related */
|
||||
span.versionmodified {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
/* show headerlink when hover/focus */
|
||||
dt.sig-object:focus .headerlink, dt.sig-object:hover .headerlink {
|
||||
-webkit-transform: translate(0);
|
||||
|
|
|
@ -36,7 +36,9 @@ class Repeat(Mutable):
|
|||
meaning that the block will be repeated at least ``min`` times and at most ``max`` times.
|
||||
If a ValueChoice, it should choose from a series of positive integers.
|
||||
|
||||
*New in v2.8*: Minimum depth can be 0. But this feature is NOT supported on graph engine.
|
||||
.. versionadded:: 2.8
|
||||
|
||||
Minimum depth can be 0. But this feature is NOT supported on graph engine.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
|
|
@ -31,9 +31,11 @@ class DartsLightningModule(BaseOneShotLightningModule):
|
|||
|
||||
The current implementation is for DARTS in first order. Second order (unrolled) is not supported yet.
|
||||
|
||||
*New in v2.8*: Supports searching for ValueChoices on operations, with the technique described in
|
||||
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
|
||||
One difference is that, in DARTS, we are using Softmax instead of GumbelSoftmax.
|
||||
.. versionadded:: 2.8
|
||||
|
||||
Supports searching for ValueChoices on operations, with the technique described in
|
||||
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
|
||||
One difference is that, in DARTS, we are using Softmax instead of GumbelSoftmax.
|
||||
|
||||
The supported mutation primitives of DARTS are:
|
||||
|
||||
|
@ -187,8 +189,10 @@ class GumbelDartsLightningModule(DartsLightningModule):
|
|||
Essentially, it samples one path on forward,
|
||||
and implements its own backward to update the architecture parameters based on only one path.
|
||||
|
||||
*New in v2.8*: Supports searching for ValueChoices on operations, with the technique described in
|
||||
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
|
||||
.. versionadded:: 2.8
|
||||
|
||||
Supports searching for ValueChoices on operations, with the technique described in
|
||||
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
|
||||
|
||||
The supported mutation primitives of GumbelDARTS are:
|
||||
|
||||
|
|
|
@ -38,6 +38,9 @@ __all__ = [
|
|||
'NATIVE_MIXED_OPERATIONS',
|
||||
]
|
||||
|
||||
_diff_not_compatible_error = 'To be compatible with differentiable one-shot strategy, {} in {} must not be ValueChoice.'
|
||||
|
||||
|
||||
class MixedOperationSamplingPolicy:
|
||||
"""
|
||||
Algo-related part for mixed Operation.
|
||||
|
@ -182,7 +185,7 @@ class MixedOperation(BaseSuperNetModule):
|
|||
mixed_op = cls(cast(dict, module.trace_kwargs))
|
||||
|
||||
if 'mixed_op_sampling' not in mutate_kwargs:
|
||||
raise ValueError('Need to sampling policy of mixed op, but not found in `mutate_kwargs`.')
|
||||
raise ValueError("Need a sampling policy for mixed op, but it's not found in `mutate_kwargs`.")
|
||||
policy_cls: Type[MixedOperationSamplingPolicy] = mutate_kwargs['mixed_op_sampling']
|
||||
# initialize policy class
|
||||
# this is put in mutate because we need to access memo
|
||||
|
@ -329,7 +332,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
|
|||
inputs: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
if any(isinstance(arg, dict) for arg in [stride, dilation, groups]):
|
||||
raise ValueError('stride, dilation, groups does not support weighted sampling.')
|
||||
raise ValueError(_diff_not_compatible_error.format('stride, dilation and groups', 'Conv2d'))
|
||||
|
||||
in_channels_ = _W(in_channels)
|
||||
out_channels_ = _W(out_channels)
|
||||
|
@ -394,7 +397,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
|
|||
inputs: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
if any(isinstance(arg, dict) for arg in [eps, momentum]):
|
||||
raise ValueError('eps, momentum do not support weighted sampling')
|
||||
raise ValueError(_diff_not_compatible_error.format('eps and momentum', 'BatchNorm2d'))
|
||||
|
||||
if isinstance(num_features, dict):
|
||||
num_features = self.num_features
|
||||
|
@ -511,7 +514,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
|
|||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
|
||||
if any(isinstance(arg, dict) for arg in [num_heads, dropout]):
|
||||
raise ValueError('num_heads, dropout do not support weighted sampling.')
|
||||
raise ValueError(_diff_not_compatible_error.format('num_heads and dropout', 'MultiHeadAttention'))
|
||||
|
||||
# by default, kdim, vdim can be none
|
||||
if kdim is None:
|
||||
|
|
|
@ -140,6 +140,8 @@ def test_mixed_conv2d():
|
|||
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([2, 4, 8], label='out'), 1, stride=ValueChoice([1, 2], label='stride'))
|
||||
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 2}, torch.randn(2, 3, 10, 10)).size(2) == 5
|
||||
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 1}, torch.randn(2, 3, 10, 10)).size(2) == 10
|
||||
with pytest.raises(ValueError, match='must not be ValueChoice'):
|
||||
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 10, 10))
|
||||
|
||||
# groups, dw conv
|
||||
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in'))
|
||||
|
|
Загрузка…
Ссылка в новой задаче