Add license header and typehints for NAS (#4774)

This commit is contained in:
Yuge Zhang 2022-04-25 15:33:04 +08:00 коммит произвёл GitHub
Родитель 8c2f717d83
Коммит 1896212902
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
96 изменённых файлов: 858 добавлений и 500 удалений

2
dependencies/required.txt поставляемый
Просмотреть файл

@ -19,5 +19,5 @@ scikit-learn >= 0.24.1
scipy < 1.8 ; python_version < "3.8"
scipy ; python_version >= "3.8"
typeguard
typing_extensions >= 4.0.0 ; python_version < "3.8"
typing_extensions >= 4.0.0
websockets >= 10.1

Просмотреть файл

@ -13,7 +13,7 @@ import sys
import types
import warnings
from io import IOBase
from typing import Any, Dict, List, Optional, TypeVar, Union
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast, Generic
import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend
@ -115,7 +115,7 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
)
class SerializableObject(Traceable):
class SerializableObject(Generic[T], Traceable):
"""
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
@ -147,7 +147,7 @@ class SerializableObject(Traceable):
# Reinitialize
return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs)
return self
return cast(T, self)
@property
def trace_symbol(self) -> Any:
@ -187,7 +187,7 @@ class SerializableObject(Traceable):
')'
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any:
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> T:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
# Make obj complying with the interface of traceable, though we cannot change its base class.
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
@ -233,11 +233,11 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
else:
# sometimes create_wrapper is mandatory, e.g., for built-in types like list/int.
# but I don't want to check here because it's unreliable.
wrapper = type('wrapper', (Traceable, cls), attributes)
return wrapper
wrapper = type('wrapper', (Traceable, cast(Type, cls)), attributes)
return cast(T, wrapper)
def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = False) -> Union[T, Traceable]:
def trace(cls_or_func: T = cast(T, None), *, kw_only: bool = True, inheritable: bool = False) -> T:
"""
Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios:
@ -283,7 +283,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# Might be changed in future.
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls_or_func
return cast(T, cls_or_func)
def wrap(cls_or_func):
# already annotated, do nothing
@ -301,20 +301,22 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
# if we're being called as @trace()
if cls_or_func is None:
return wrap
return wrap # type: ignore
# if we are called without parentheses
return wrap(cls_or_func)
return wrap(cls_or_func) # type: ignore
def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_size_limit: int = 4096,
allow_nan: bool = True, **json_tricks_kwargs) -> Union[str, bytes]:
allow_nan: bool = True, **json_tricks_kwargs) -> str:
"""
Convert a nested data structure to a json string. Save to file if fp is specified.
Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle.
The serializer is not designed for long-term storage use, but rather to copy data between processes.
The format is also subject to change between NNI releases.
To compress the payload, please use :func:`dump_bytes`.
Parameters
----------
obj : any
@ -334,6 +336,39 @@ def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_s
Normally str. Sometimes bytes (if compressed).
"""
if json_tricks_kwargs.get('compression') is not None:
raise ValueError('If you meant to compress the dumped payload, please use `dump_bytes`.')
result = _dump(
obj=obj,
fp=fp,
use_trace=use_trace,
pickle_size_limit=pickle_size_limit,
allow_nan=allow_nan,
**json_tricks_kwargs)
return cast(str, result)
def dump_bytes(obj: Any, fp: Optional[Any] = None, *, compression: int = cast(int, None),
use_trace: bool = True, pickle_size_limit: int = 4096,
allow_nan: bool = True, **json_tricks_kwargs) -> bytes:
"""
Same as :func:`dump`, but to comporess payload, with `compression <https://json-tricks.readthedocs.io/en/stable/#dump>`__.
"""
if compression is None:
raise ValueError('compression must be set.')
result = _dump(
obj=obj,
fp=fp,
compression=compression,
use_trace=use_trace,
pickle_size_limit=pickle_size_limit,
allow_nan=allow_nan,
**json_tricks_kwargs)
return cast(bytes, result)
def _dump(*, obj: Any, fp: Optional[Any], use_trace: bool, pickle_size_limit: int,
allow_nan: bool, **json_tricks_kwargs) -> Union[str, bytes]:
encoders = [
# we don't need to check for dependency as many of those have already been required by NNI
json_tricks.pathlib_encode, # pathlib is a required dependency for NNI
@ -456,7 +491,7 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False):
raise TypeError(f"{base} has a superclass already decorated with trace, and it's using a customized metaclass {type(base)}. "
"Please either use the default metaclass, or remove trace from the super-class.")
class wrapper(SerializableObject, base, metaclass=metaclass):
class wrapper(SerializableObject, base, metaclass=metaclass): # type: ignore
def __init__(self, *args, **kwargs):
# store a copy of initial parameters
args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
@ -528,7 +563,8 @@ def _trace_func(func, kw_only):
# and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation.
new_type = _make_class_traceable(type(res), True)
res = new_type(res) # re-creating the object
# re-creating the object
res = new_type(res) # type: ignore
res = inject_trace_info(res, func, args, kwargs)
else:
raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. '
@ -750,7 +786,7 @@ def import_cls_or_func_from_hybrid_name(s: str) -> Any:
return _import_cls_or_func_from_name(s)
def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> str:
def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> Dict[str, str]:
if not isinstance(cls_or_func, type) and not _is_function(cls_or_func):
# not a function or class, continue
return cls_or_func
@ -762,8 +798,7 @@ def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False,
def _json_tricks_func_or_cls_decode(s: Dict[str, Any]) -> Any:
if isinstance(s, dict) and '__nni_type__' in s:
s = s['__nni_type__']
return import_cls_or_func_from_hybrid_name(s)
return import_cls_or_func_from_hybrid_name(s['__nni_type__'])
return s
@ -815,8 +850,7 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any:
if isinstance(obj, dict) and '__nni_obj__' in obj:
obj = obj['__nni_obj__']
b = base64.b64decode(obj)
b = base64.b64decode(obj['__nni_obj__'])
return _wrapped_cloudpickle_loads(b)
return obj

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import load_benchmark, download_benchmark

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
if __name__ == '__main__':

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .constants import INPUT, OUTPUT, CONV3X3_BN_RELU, CONV1X1_BN_RELU, MAXPOOL3X3
from .model import Nb101TrialStats, Nb101IntermediateStats, Nb101TrialConfig
from .query import query_nb101_trial_stats

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
INPUT = 'input'
OUTPUT = 'output'
CONV3X3_BN_RELU = 'conv3x3-bn-relu'

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
from tqdm import tqdm

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import hashlib
import numpy as np
@ -10,7 +13,7 @@ def _labeling_from_architecture(architecture, vertices):
def _adjancency_matrix_from_architecture(architecture, vertices):
matrix = np.zeros((vertices, vertices), dtype=np.bool)
matrix = np.zeros((vertices, vertices), dtype=np.bool) # type: ignore
for i in range(1, vertices):
for k in architecture['input{}'.format(i)]:
matrix[k, i] = 1

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
from peewee import fn

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3
from .model import Nb201TrialStats, Nb201IntermediateStats, Nb201TrialConfig
from .query import query_nb201_trial_stats

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE = 'none'
SKIP_CONNECT = 'skip_connect'
CONV_1X1 = 'conv_1x1'

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import re
@ -17,7 +20,7 @@ def parse_arch_str(arch_str):
'nor_conv_3x3': CONV_3X3,
'avg_pool_3x3': AVG_POOL_3X3
}
m = re.match(r'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|', arch_str)
m: re.Match = re.match(r'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|', arch_str) # type: ignore
return {
'0_1': mp[m.group(1)],
'0_2': mp[m.group(2)],

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
from peewee import fn

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .constants import *
from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
from .query import query_nds_trial_stats

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
NONE = 'none'
SKIP_CONNECT = 'skip_connect'
AVG_POOL_3X3 = 'avg_pool_3x3'

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import argparse
import os

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
from peewee import fn

Просмотреть файл

@ -1,4 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .model import NlpTrialStats, NlpIntermediateStats, NlpTrialConfig
from .query import query_nlp_trial_stats

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os
import argparse

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
from peewee import fn

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
import hashlib
import json
@ -6,6 +9,7 @@ import os
import shutil
import tempfile
from pathlib import Path
from typing import Optional
import requests
import tqdm
@ -49,8 +53,9 @@ def load_or_download_file(local_path: str, download_url: str, download: bool = F
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
r = requests.get(download_url, stream=True)
total_length = int(r.headers.get('content-length'))
with tqdm.tqdm(total=total_length, disable=not progress,
total_length: Optional[str] = r.headers.get('content-length')
assert total_length is not None, f'Content length is not found in the response of {download_url}'
with tqdm.tqdm(total=int(total_length), disable=not progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for chunk in r.iter_content(8192):
f.write(chunk)

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base_mutator import BaseMutator
from .base_trainer import BaseTrainer
from .fixed import apply_fixed_architecture

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .nasbench201 import NASBench201Cell

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict
import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .darts_cell import DartsCell
from .enas_cell import ENASMicroLayer
from .enas_cell import ENASMacroLayer

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .pytorch import model_to_pytorch_script

Просмотреть файл

@ -3,8 +3,9 @@
import logging
import re
from typing import Dict, List, Tuple, Any
from typing import Dict, List, Tuple, Any, cast
from nni.retiarii.operation import PyTorchOperation
from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice
@ -34,7 +35,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_slot))
edges = sorted(edges, key=(lambda edge: cast(int, edge.tail_slot)))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
@ -98,7 +99,7 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
name = re.sub('\W|^(?=\d)','_', name)
name = re.sub(r'\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
@ -130,7 +131,7 @@ def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
return cuda_remapped_id
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> Tuple[set, str]:
nodes = graph.topo_sort()
# handle module node and function node differently
@ -144,11 +145,12 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
for node in nodes:
if node.operation:
if placement and isinstance(node.operation, ToDevice):
cuda_remapped_id = cast(dict, cuda_remapped_id)
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])
if node.operation.type == 'shared':
continue
pkg_name = node.operation.get_import_pkg()
pkg_name = cast(PyTorchOperation, node.operation).get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
@ -157,6 +159,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice):
assert cuda_remapped_id is not None
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
else:
device_repr = placement[node].device_repr()

Просмотреть файл

@ -2,6 +2,7 @@
# Licensed under the MIT license.
# pylint: skip-file
# type: ignore
"""
FIXME

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .graph_gen import convert_to_graph

Просмотреть файл

@ -411,6 +411,7 @@ class GraphConverter:
edge.graph = ir_graph
if edge.head == method_ir_graph.input_node:
# this is a member method, 'self' is the first argument, thus +1
assert edge.head_slot is not None
_input = node.inputsAt(edge.head_slot + 1)
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
edge.head = src_node
@ -745,6 +746,7 @@ class GraphConverterWithShape(GraphConverter):
if isinstance(submodule, LayerChoice):
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
lc_node = ir_model.get_node_by_name(full_name)
assert lc_node is not None, f'Cannot find a node with name {full_name}'
for cand_name in submodule.names:
cand = submodule[cand_name]
@ -761,12 +763,14 @@ class GraphConverterWithShape(GraphConverter):
return
graph_node = ir_model.get_node_by_name(graph.name)
assert graph_node is not None, f'Cannot find a node with name {graph.name}'
if not _without_shape_info(graph_node):
return
if is_layerchoice_node(graph_node):
cand_name = graph_node.operation.parameters['candidates'][0]
cand_node = ir_model.get_node_by_name(cand_name)
assert cand_node is not None, f'Cannot find a node with name {cand_name}'
if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']

Просмотреть файл

@ -1,6 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
from typing_extensions import TypeGuard
from ..operation import Cell
from ..graph import Model, Graph, Node, Edge
@ -77,7 +81,7 @@ def _extract_info_from_trace_node(trace_node):
return shape_parameters, None
def is_layerchoice_node(ir_node: Node):
def is_layerchoice_node(ir_node: Optional[Node]) -> TypeGuard[Node]:
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
return True
else:

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# we will support tensorflow in future release
framework = 'pytorch'

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .lightning import *

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT license.
import warnings
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Type
import torch.nn as nn
@ -18,7 +18,7 @@ from .trainer import Trainer
@nni.trace
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0,
learning_rate: float = 0.001,
weight_decay: float = 0.,
@ -171,7 +171,7 @@ class Classification(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
@ -184,7 +184,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
class _RegressionModule(_MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):

Просмотреть файл

@ -4,10 +4,11 @@
import os
import warnings
from pathlib import Path
from typing import Dict, Union, Optional, List, Callable
from typing import Dict, Union, Optional, List, Callable, Type
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as nn_functional
import torch.optim as optim
import torchmetrics
import torch.utils.data as torch_data
@ -124,12 +125,12 @@ class Lightning(Evaluator):
if other is None:
return False
if hasattr(self, "function") and hasattr(other, "function"):
eq_func = (self.function == other.function)
eq_func = getattr(self, "function") == getattr(other, "function")
elif not (hasattr(self, "function") or hasattr(other, "function")):
eq_func = True
if hasattr(self, "arguments") and hasattr(other, "arguments"):
eq_args = (self.arguments == other.arguments)
eq_args = getattr(self, "arguments") == getattr(other, "arguments")
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
eq_args = True
@ -159,10 +160,13 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
trainer: pl.Trainer
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetrics.Metric]],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule):
self.log('test_' + name, metric(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule):
class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)
return super().update(nn_functional.softmax(pred), target)
@nni.trace
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
@ -275,10 +279,10 @@ class Classification(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
@ -291,10 +295,10 @@ class Classification(Lightning):
@nni.trace
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
@ -328,10 +332,10 @@ class Regression(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,

Просмотреть файл

@ -1 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .api import *

Просмотреть файл

@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
@classmethod

Просмотреть файл

@ -1,6 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
from ..graph import Model
from ..integration_api import receive_trial_parameters
@ -39,6 +42,9 @@ class BenchmarkGraphData:
def load(data) -> 'BenchmarkGraphData':
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
def __repr__(self) -> str:
return f"BenchmarkGraphData({self.mutation}, {self.benchmark}, {self.db_path})"
class BenchmarkExecutionEngine(BaseExecutionEngine):
"""
@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = BenchmarkGraphData.load(receive_trial_parameters())
assert graph_data.db_path is not None, f'Invalid graph data because db_path is None: {graph_data}'
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
final, intermediates = cls.query_in_benchmark(graph_data)
@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
arch = t
if arch is None:
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
print(arch)
return _convert_to_final_and_intermediates(
query_nb101_trial_stats(arch, 108, include_intermediates=True),
'valid_acc'
@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_
benchmark_result = random.choice(benchmark_result)
else:
benchmark_result = benchmark_result[0]
benchmark_result = cast(dict, benchmark_result)
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]

Просмотреть файл

@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import random

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Any, Type
import torch.nn as nn
@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model)
graph_data = PythonGraphData(model.python_class, model.python_init_params, mutation, model.evaluator)
assert model.evaluator is not None, 'Model evaluator is not available.'
graph_data = PythonGraphData(model.python_class, model.python_init_params or {}, mutation, model.evaluator)
return graph_data
@classmethod

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, List
from ..graph import Model

Просмотреть файл

@ -11,7 +11,7 @@ from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from threading import Thread
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Union, cast
import colorama
import psutil
@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest
from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment.config.training_services import RemoteConfig
from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command
@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment):
Examples
--------
Multi-trial NAS:
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config, 8081)
One-shot NAS:
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config)
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
"""
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = None,
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None,
trainer: BaseOneShotTrainer = None):
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None), strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
@ -260,21 +264,22 @@ class RetiariiExperiment(Experiment):
raise ValueError('Evaluator should not be none.')
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
self.port: Optional[int] = None
self.base_model = base_model
self.evaluator: Evaluator = evaluator
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
self.applied_mutators = applied_mutators
self.strategy = strategy
# FIXME: this is only a workaround
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
if not isinstance(strategy, OneShotStrategy):
self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
else:
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self.url_prefix = None
@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment):
assert self.config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert self.config.batch_waiting_time is not None
assert self.config.batch_waiting_time is not None and self.config.max_concurrency_cgo is not None
devices = self._construct_devices()
engine = CGOExecutionEngine(devices,
max_concurrency=self.config.max_concurrency_cgo,
@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment):
engine = PurePythonExecutionEngine()
elif self.config.execution_engine == 'benchmark':
from ..execution.benchmark import BenchmarkExecutionEngine
assert self.config.benchmark is not None, '"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(self.config.benchmark)
else:
raise ValueError(f'Unsupported engine type: {self.config.execution_engine}')
set_execution_engine(engine)
self.id = management.generate_experiment_id()
@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment):
def _construct_devices(self):
devices = []
if hasattr(self.config.training_service, 'machine_list'):
for machine in self.config.training_service.machine_list:
for machine in cast(RemoteConfig, self.config.training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment):
def _create_dispatcher(self):
return self._dispatcher
def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
def run(self, config: Optional[RetiariiExeConfig] = None, port: int = 8080, debug: bool = False) -> None:
"""
Run the experiment.
This function will block until experiment finish or error.
@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment):
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
assert self._proc is not None
try:
while True:
time.sleep(10)
@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment):
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
raise RuntimeError('Check experiment status failed.')
def stop(self) -> None:
"""
@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment):
if self._pipe is not None:
self._pipe.close()
self.id = None
self.port = None
self.id = cast(str, None)
self.port = cast(int, None)
self._proc = None
self._pipe = None
self._dispatcher = None
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None
_logger.info('Experiment stopped')

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from pathlib import Path

Просмотреть файл

@ -5,10 +5,16 @@
Model representation.
"""
from __future__ import annotations
import abc
import json
from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union, cast, overload)
if TYPE_CHECKING:
from .mutator import Mutator
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid
@ -63,7 +69,7 @@ class Evaluator(abc.ABC):
pass
@abc.abstractmethod
def _execute(self, model_cls: type) -> Any:
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass
@abc.abstractmethod
@ -203,7 +209,7 @@ class Model:
matched_nodes.extend(nodes)
return matched_nodes
def get_node_by_name(self, node_name: str) -> 'Node':
def get_node_by_name(self, node_name: str) -> 'Node' | None:
"""
Traverse all the nodes to find the matched node with the given name.
"""
@ -217,7 +223,7 @@ class Model:
else:
return None
def get_node_by_python_name(self, python_name: str) -> 'Node':
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
"""
Traverse all the nodes to find the matched node with the given python_name.
"""
@ -297,7 +303,7 @@ class Graph:
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
def __init__(self, model: Model, graph_id: int, name: str = cast(str, None), _internal: bool = False):
assert _internal, '`Graph()` is private'
self.model: Model = model
@ -338,9 +344,9 @@ class Graph:
@overload
def add_node(self, name: str, operation: Operation) -> 'Node': ...
@overload
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def add_node(self, name, operation_or_type, parameters=None):
def add_node(self, name, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
@ -350,9 +356,10 @@ class Graph:
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node':
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
@ -405,7 +412,7 @@ class Graph:
def get_nodes_by_name(self, name: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.name == name]
def get_nodes_by_python_name(self, python_name: str) -> Optional['Node']:
def get_nodes_by_python_name(self, python_name: str) -> List['Node']:
return [node for node in self.nodes if node.python_name == python_name]
def topo_sort(self) -> List['Node']:
@ -594,7 +601,7 @@ class Node:
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
@property
def successor_slots(self) -> List[Tuple['Node', Union[int, None]]]:
def successor_slots(self) -> Set[Tuple['Node', Union[int, None]]]:
return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges)
@property
@ -610,19 +617,19 @@ class Node:
assert isinstance(self.operation, Cell)
return self.graph.model.graphs[self.operation.parameters['cell']]
def update_label(self, label: str) -> None:
def update_label(self, label: Optional[str]) -> None:
self.label = label
@overload
def update_operation(self, operation: Operation) -> None: ...
@overload
def update_operation(self, type_name: str, parameters: Dict[str, Any] = None) -> None: ...
def update_operation(self, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> None: ...
def update_operation(self, operation_or_type, parameters=None):
def update_operation(self, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
self.operation = operation_or_type
else:
self.operation = Operation.new(operation_or_type, parameters)
self.operation = Operation.new(operation_or_type, cast(dict, parameters))
# mutation
def remove(self) -> None:
@ -663,7 +670,13 @@ class Node:
return node
def _dump(self) -> Any:
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters, 'attributes': self.operation.attributes}}
ret: Dict[str, Any] = {
'operation': {
'type': self.operation.type,
'parameters': self.operation.parameters,
'attributes': self.operation.attributes
}
}
if isinstance(self.operation, Cell):
ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None:

Просмотреть файл

@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Tuple, Optional, Callable
from typing import Tuple, Optional, Callable, cast
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module):
bn_momentum: float = 0.1):
super().__init__()
self.widths = [
self.widths = cast(nn.ChoiceOf[int], [
nn.ValueChoice([make_divisible(base_width * mult, 8) for mult in width_multipliers], label=f'width_{i}')
for i, base_width in enumerate(base_widths)
]
])
self.expand_ratios = expand_ratios
blocks = [
@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module):
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(self.widths[7], num_labels),
nn.Linear(cast(int, self.widths[7]), num_labels),
)
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)

Просмотреть файл

@ -3,6 +3,7 @@
import math
import torch
import torch.nn as nn
from nni.retiarii import model_wrapper
from nni.retiarii.nn.pytorch import NasBench101Cell
@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell
__all__ = ['NasBench101']
def truncated_normal_(tensor, mean=0, std=1):
def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
@ -117,9 +118,3 @@ class NasBench101(nn.Module):
out = self.gap(out).view(bs, -1)
out = self.classifier(out)
return out
def reset_parameters(self):
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = self.config.bn_eps
module.momentum = self.config.bn_momentum

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Callable, Dict
import torch
import torch.nn as nn
@ -176,8 +178,10 @@ class NasBench201(nn.Module):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = NasBench201Cell({prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES},
C_prev, C_curr, label='cell')
ops: Dict[str, Callable[[int, int], nn.Module]] = {
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES
}
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell)
C_prev = C_curr

Просмотреть файл

@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str
"""
from collections import OrderedDict
from typing import Tuple, List, Union, Iterable, Dict, Callable
from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
try:
from typing import Literal
@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module):
See :class:`CellBuilder` on how to calculate those channel numbers.
"""
def __init__(self, C_pprev: int, C_prev: int, C: int, last_cell_reduce: bool) -> None:
def __init__(self, C_pprev: nn.MaybeChoice[int], C_prev: nn.MaybeChoice[int], C: nn.MaybeChoice[int], last_cell_reduce: bool) -> None:
super().__init__()
if last_cell_reduce:
self.pre0 = FactorizedReduce(C_pprev, C)
self.pre0 = FactorizedReduce(cast(int, C_pprev), cast(int, C))
else:
self.pre0 = ReLUConvBN(C_pprev, C, 1, 1, 0)
self.pre1 = ReLUConvBN(C_prev, C, 1, 1, 0)
self.pre0 = ReLUConvBN(cast(int, C_pprev), cast(int, C), 1, 1, 0)
self.pre1 = ReLUConvBN(cast(int, C_prev), cast(int, C), 1, 1, 0)
def forward(self, cells):
assert len(cells) == 2
@ -283,15 +283,19 @@ class CellBuilder:
Note that the builder is ephemeral, it can only be called once for every index.
"""
def __init__(self, op_candidates: List[str], C_prev_in: int, C_in: int, C: int,
num_nodes: int, merge_op: Literal['all', 'loose_end'],
def __init__(self, op_candidates: List[str],
C_prev_in: nn.MaybeChoice[int],
C_in: nn.MaybeChoice[int],
C: nn.MaybeChoice[int],
num_nodes: int,
merge_op: Literal['all', 'loose_end'],
first_cell_reduce: bool, last_cell_reduce: bool):
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
self.C_in = C_in # This is the out channesl of last cell.
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self.op_candidates = op_candidates
self.num_nodes = num_nodes
self.merge_op = merge_op
self.merge_op: Literal['all', 'loose_end'] = merge_op
self.first_cell_reduce = first_cell_reduce
self.last_cell_reduce = last_cell_reduce
self._expect_idx = 0
@ -312,7 +316,7 @@ class CellBuilder:
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce)
ops_factory: Dict[str, Callable[[int, int, int], nn.Module]] = {
ops_factory: Dict[str, Callable[[int, int, Optional[int]], nn.Module]] = {
op: # make final chosen ops named with their aliases
lambda node_index, op_index, input_index:
OPS[op](self.C, 2 if is_reduction_cell and (
@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """
class NDS(nn.Module):
"""
__doc__ = """
The unified version of NASNet search space.
We follow the implementation in
@ -378,8 +382,8 @@ class NDS(nn.Module):
op_candidates: List[str],
merge_op: Literal['all', 'loose_end'] = 'all',
num_nodes_per_cell: int = 4,
width: Union[Tuple[int], int] = 16,
num_cells: Union[Tuple[int], int] = 20,
width: Union[Tuple[int, ...], int] = 16,
num_cells: Union[Tuple[int, ...], int] = 20,
dataset: Literal['cifar', 'imagenet'] = 'imagenet',
auxiliary_loss: bool = False):
super().__init__()
@ -394,30 +398,31 @@ class NDS(nn.Module):
else:
C = width
self.num_cells: nn.MaybeChoice[int] = cast(int, num_cells)
if isinstance(num_cells, Iterable):
num_cells = nn.ValueChoice(list(num_cells), label='depth')
num_cells_per_stage = [i * num_cells // 3 - (i - 1) * num_cells // 3 for i in range(3)]
self.num_cells = nn.ValueChoice(list(num_cells), label='depth')
num_cells_per_stage = [i * self.num_cells // 3 - (i - 1) * self.num_cells // 3 for i in range(3)]
# auxiliary head is different for network targetted at different datasets
if dataset == 'imagenet':
self.stem0 = nn.Sequential(
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.Conv2d(3, cast(int, C // 2), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(cast(int, C // 2)),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.Conv2d(cast(int, C // 2), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.Conv2d(cast(int, C), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_pprev = C_prev = C_curr = C
last_cell_reduce = True
elif dataset == 'cifar':
self.stem = nn.Sequential(
nn.Conv2d(3, 3 * C, 3, padding=1, bias=False),
nn.BatchNorm2d(3 * C)
nn.Conv2d(3, cast(int, 3 * C), 3, padding=1, bias=False),
nn.BatchNorm2d(cast(int, 3 * C))
)
C_pprev = C_prev = 3 * C
C_curr = C
@ -439,7 +444,7 @@ class NDS(nn.Module):
# C_pprev is output channel number of last second cell among all the cells already built.
if len(stage) > 1:
# Contains more than one cell
C_pprev = len(stage[-2].output_node_indices) * C_curr
C_pprev = len(cast(nn.Cell, stage[-2]).output_node_indices) * C_curr
else:
# Look up in the out channels of last stage.
C_pprev = C_prev
@ -447,7 +452,7 @@ class NDS(nn.Module):
# This was originally,
# C_prev = num_nodes_per_cell * C_curr.
# but due to loose end, it becomes,
C_prev = len(stage[-1].output_node_indices) * C_curr
C_prev = len(cast(nn.Cell, stage[-1]).output_node_indices) * C_curr
# Useful in aligning the pprev and prev cell.
last_cell_reduce = cell_builder.last_cell_reduce
@ -457,11 +462,11 @@ class NDS(nn.Module):
if auxiliary_loss:
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
self.stages[2] = SequentialBreakdown(self.stages[2])
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset)
self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(C_prev, self.num_labels)
self.classifier = nn.Linear(cast(int, C_prev), self.num_labels)
def forward(self, inputs):
if self.dataset == 'imagenet':
@ -483,7 +488,7 @@ class NDS(nn.Module):
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
if self.training and self.auxiliary_loss:
return logits, logits_aux
return logits, logits_aux # type: ignore
else:
return logits
@ -524,8 +529,8 @@ class NASNet(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.NASNET_OPS,
@ -555,8 +560,8 @@ class ENAS(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.ENAS_OPS,
@ -590,8 +595,8 @@ class AmoebaNet(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
@ -626,8 +631,8 @@ class PNAS(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.PNAS_OPS,
@ -660,8 +665,8 @@ class DARTS(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.DARTS_OPS,

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT license.
import math
from typing import Optional, Callable, List, Tuple
from typing import Optional, Callable, List, Tuple, cast
import torch
import nni.retiarii.nn.pytorch as nn
@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
groups: nn.MaybeChoice[int] = 1,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential):
if activation_layer is None:
activation_layer = nn.ReLU6
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups,
bias=False),
norm_layer(out_channels),
nn.Conv2d(
cast(int, in_channels),
cast(int, out_channels),
cast(int, kernel_size),
stride,
cast(int, padding),
dilation=dilation,
groups=cast(int, groups),
bias=False
),
norm_layer(cast(int, out_channels)),
activation_layer(inplace=True)
)
self.out_channels = out_channels
@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__(
@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: int,
kernel_size: int = 3,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
expand_ratio: nn.MaybeChoice[float],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
squeeze_and_excite: Optional[Callable[[int], nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
squeeze_and_excite: Optional[Callable[[nn.MaybeChoice[int]], nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential):
self.out_channels = out_channels
assert stride in [1, 2]
hidden_ch = nn.ValueChoice.to_int(round(in_channels * expand_ratio))
hidden_ch = nn.ValueChoice.to_int(round(cast(int, in_channels * expand_ratio)))
# FIXME: check whether this equal works
# Residual connection is added here stride = 1 and input channels and output channels are the same.
@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module):
self.first_conv = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d)
blocks = [
blocks: List[nn.Module] = [
# first stage is fixed
SeparableConv(widths[0], widths[1], kernel_size=3, stride=1)
]

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module):
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *,
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *,
kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True):
super().__init__()
assert stride in [1, 2]
@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module):
def _decode_point_depth_conv(self, sequence):
result = []
first_depth = first_point = True
pc = c = self.channels
pc: int = self.channels
c: int = self.channels
for i, token in enumerate(sequence):
# compute output channels of this conv
if i + 1 == len(sequence):
assert token == "p", "Last conv must be point-wise conv."
c = self.oup_main
elif token == "p" and first_point:
c = self.mid_channels
c = cast(int, self.mid_channels)
if token == "d":
# depth-wise conv
if isinstance(pc, int) and isinstance(c, int):
@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock):
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *, stride: int, affine: bool = True):
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *, stride: int, affine: bool = True):
super().__init__(in_channels, out_channels, mid_channels,
kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine)
@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module):
nn.ReLU(inplace=True),
)
self.features = []
feature_blocks = []
global_block_idx = 0
for stage_idx, num_repeat in enumerate(self.stage_repeats):
@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module):
else:
mid_channels = int(base_mid_channels)
mid_channels = cast(nn.MaybeChoice[int], mid_channels)
choice_block = nn.LayerChoice([
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine),
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine),
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine),
ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine)
], label=f'layer_{global_block_idx}')
self.features.append(choice_block)
feature_blocks.append(choice_block)
self.features = nn.Sequential(*self.features)
self.features = nn.Sequential(*feature_blocks)
# final layers
last_conv_channels = self.stage_out_channels[-1]
@ -226,13 +231,15 @@ class ShuffleNetSpace(nn.Module):
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.BatchNorm1d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:

Просмотреть файл

@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Useful type hints

Просмотреть файл

@ -3,7 +3,7 @@
import logging
import os
from typing import Any, Callable
from typing import Any, Callable, Optional
import nni
from nni.common.serializer import PayloadTooLarge
@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None
self.send_trial_callback: Callable[[dict], None] = None
self.request_trial_jobs_callback: Callable[[int], None] = None
self.trial_end_callback: Callable[[int, bool], None] = None
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
self.final_metric_callback: Callable[[int, MetricData], None] = None
self.send_trial_callback: Optional[Callable[[dict], None]] = None
self.request_trial_jobs_callback: Optional[Callable[[int], None]] = None
self.trial_end_callback: Optional[Callable[[int, bool], None]] = None
self.intermediate_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.parameters_count = 0
@ -158,19 +158,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data)
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
if self.trial_end_callback is not None:
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
if self.intermediate_metric_callback is not None:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
elif data['type'] == MetricType.FINAL:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
if self.final_metric_callback is not None:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
@staticmethod
def _process_value(value) -> Any: # hopefully a float

Просмотреть файл

@ -1,7 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Iterable, List, Optional, Tuple)
import warnings
from typing import (Any, Iterable, List, Optional, Tuple, cast)
from .graph import Model, Mutation, ModelStatus
@ -44,9 +45,11 @@ class Mutator:
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
"""
def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None):
def __init__(self, sampler: Optional[Sampler] = None, label: str = cast(str, None)):
self.sampler: Optional[Sampler] = sampler
self.label: Optional[str] = label
if label is None:
warnings.warn('Each mutator should have an explicit label. Mutator without label is deprecated.', DeprecationWarning)
self.label: str = label
self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None

Просмотреть файл

@ -1,23 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import itertools
import math
import operator
import warnings
from typing import Any, List, Union, Dict, Optional, Callable, Iterable, NoReturn, TypeVar, Sequence
from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
NoReturn, Optional, Sequence, SupportsRound, TypeVar,
Union, cast)
import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace, NoContextError
from nni.retiarii.utils import (STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace,
NoContextError)
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
__all__ = [
# APIs
'LayerChoice',
'InputChoice',
'ValueChoice',
'ModelParameterChoice',
'Placeholder',
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'ModelParameterChoice', 'Placeholder', 'ChosenInputs']
# Fixed module
'ChosenInputs',
# Type utils
'ReductionType',
'MaybeChoice',
'ChoiceOf',
]
class LayerChoice(Mutable):
@ -130,26 +147,16 @@ class LayerChoice(Mutable):
self.names.append(str(i))
else:
raise TypeError("Unsupported candidates type: {}".format(type(candidates)))
self._first_module = self._modules[self.names[0]] # to make the dummy forward meaningful
@property
def key(self):
return self._key()
@torch.jit.ignore
def _key(self):
warnings.warn('Using key to access the identifier of LayerChoice is deprecated. Please use label instead.',
category=DeprecationWarning)
return self._label
self._first_module = cast(nn.Module, self._modules[self.names[0]]) # to make the dummy forward meaningful
@property
def label(self):
return self._label
def __getitem__(self, idx):
def __getitem__(self, idx: Union[int, str]) -> nn.Module:
if isinstance(idx, str):
return self._modules[idx]
return list(self)[idx]
return cast(nn.Module, self._modules[idx])
return cast(nn.Module, list(self)[idx])
def __setitem__(self, idx, module):
key = idx if isinstance(idx, str) else self.names[idx]
@ -173,15 +180,6 @@ class LayerChoice(Mutable):
def __iter__(self):
return map(lambda name: self._modules[name], self.names)
@property
def choices(self):
return self._choices()
@torch.jit.ignore
def _choices(self):
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", category=DeprecationWarning)
return list(self)
def forward(self, x):
"""
The forward of layer choice is simply running the first candidate module.
@ -266,16 +264,6 @@ class InputChoice(Mutable):
assert self.reduction in ['mean', 'concat', 'sum', 'none']
self._label = generate_new_label(label)
@property
def key(self):
return self._key()
@torch.jit.ignore
def _key(self):
warnings.warn('Using key to access the identifier of InputChoice is deprecated. Please use label instead.',
category=DeprecationWarning)
return self._label
@property
def label(self):
return self._label
@ -350,7 +338,7 @@ def _valuechoice_codegen(*, _internal: bool = False):
'truediv': '//', 'floordiv': '/', 'mod': '%',
'lshift': '<<', 'rshift': '>>',
'and': '&', 'xor': '^', 'or': '|',
# no reflection
# no reverse
'lt': '<', 'le': '<=', 'eq': '==',
'ne': '!=', 'ge': '>=', 'gt': '>',
# NOTE
@ -358,14 +346,14 @@ def _valuechoice_codegen(*, _internal: bool = False):
# Might support them in future when we actually need them.
}
binary_template = """ def __{op}__(self, other: Any) -> 'ValueChoiceX':
binary_template = """ def __{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [self, other])"""
binary_r_template = """ def __r{op}__(self, other: Any) -> 'ValueChoiceX':
binary_r_template = """ def __r{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [other, self])"""
unary_template = """ def __{op}__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.{op}, '{sym}{{}}', [self])"""
unary_template = """ def __{op}__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.{op}, '{sym}{{}}', [self]))"""
for op, sym in MAPPING.items():
if op in ['neg', 'pos', 'invert']:
@ -377,8 +365,14 @@ def _valuechoice_codegen(*, _internal: bool = False):
print(binary_r_template.format(op=op, opt=opt, sym=sym) + '\n')
def _valuechoice_staticmethod_helper(orig_func):
orig_func.__doc__ += """
_func = TypeVar('_func')
_cand = TypeVar('_cand')
_value = TypeVar('_value')
def _valuechoice_staticmethod_helper(orig_func: _func) -> _func:
if orig_func.__doc__ is not None:
orig_func.__doc__ += """
Notes
-----
This function performs lazy evaluation.
@ -388,7 +382,7 @@ def _valuechoice_staticmethod_helper(orig_func):
return orig_func
class ValueChoiceX(Translatable, nn.Module):
class ValueChoiceX(Generic[_cand], Translatable, nn.Module):
"""Internal API. Implementation note:
The transformed (X) version of value choice.
@ -408,7 +402,10 @@ class ValueChoiceX(Translatable, nn.Module):
This class is implemented as a ``nn.Module`` so that it can be scanned by python engine / torchscript.
"""
def __init__(self, function: Callable[..., Any], repr_template: str, arguments: List[Any], dry_run: bool = True):
def __init__(self, function: Callable[..., _cand] = cast(Callable[..., _cand], None),
repr_template: str = cast(str, None),
arguments: List[Any] = cast('List[MaybeChoice[_cand]]', None),
dry_run: bool = True):
super().__init__()
if function is None:
@ -431,7 +428,7 @@ class ValueChoiceX(Translatable, nn.Module):
def inner_choices(self) -> Iterable['ValueChoice']:
"""
Return an iterable of all leaf value choices.
Return a generator of all leaf value choices.
Useful for composition of value choices.
No deduplication on labels. Mutators should take care.
"""
@ -439,18 +436,18 @@ class ValueChoiceX(Translatable, nn.Module):
if isinstance(arg, ValueChoiceX):
yield from arg.inner_choices()
def dry_run(self) -> Any:
def dry_run(self) -> _cand:
"""
Dry run the value choice to get one of its possible evaluation results.
"""
# values are not used
return self._evaluate(iter([]), True)
def all_options(self) -> Iterable[Any]:
def all_options(self) -> Iterable[_cand]:
"""Explore all possibilities of a value choice.
"""
# Record all inner choices: label -> candidates, no duplicates.
dedup_inner_choices: Dict[str, List[Any]] = {}
dedup_inner_choices: Dict[str, List[_cand]] = {}
# All labels of leaf nodes on tree, possibly duplicates.
all_labels: List[str] = []
@ -470,14 +467,14 @@ class ValueChoiceX(Translatable, nn.Module):
chosen = dict(zip(dedup_labels, chosen))
yield self.evaluate([chosen[label] for label in all_labels])
def evaluate(self, values: Iterable[Any]) -> Any:
def evaluate(self, values: Iterable[_cand]) -> _cand:
"""
Evaluate the result of this group.
``values`` should in the same order of ``inner_choices()``.
"""
return self._evaluate(iter(values), False)
def _evaluate(self, values: Iterable[Any], dry_run: bool = False) -> Any:
def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
# "values" iterates in the recursion
eval_args = []
for arg in self.arguments:
@ -497,7 +494,7 @@ class ValueChoiceX(Translatable, nn.Module):
"""
return self.dry_run()
def __repr__(self):
def __repr__(self) -> str:
reprs = []
for arg in self.arguments:
if isinstance(arg, ValueChoiceX) and not isinstance(arg, ValueChoice):
@ -513,7 +510,7 @@ class ValueChoiceX(Translatable, nn.Module):
# Special operators that can be useful in place of built-in conditional operators.
@staticmethod
@_valuechoice_staticmethod_helper
def to_int(obj: 'ValueChoiceOrAny') -> Union['ValueChoiceX', int]:
def to_int(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[int]':
"""
Convert a ``ValueChoice`` to an integer.
"""
@ -523,7 +520,7 @@ class ValueChoiceX(Translatable, nn.Module):
@staticmethod
@_valuechoice_staticmethod_helper
def to_float(obj: 'ValueChoiceOrAny') -> Union['ValueChoiceX', float]:
def to_float(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[float]':
"""
Convert a ``ValueChoice`` to a float.
"""
@ -533,9 +530,9 @@ class ValueChoiceX(Translatable, nn.Module):
@staticmethod
@_valuechoice_staticmethod_helper
def condition(pred: 'ValueChoiceOrAny',
true: 'ValueChoiceOrAny',
false: 'ValueChoiceOrAny') -> 'ValueChoiceOrAny':
def condition(pred: 'MaybeChoice[bool]',
true: 'MaybeChoice[_value]',
false: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
"""
Return ``true`` if the predicate ``pred`` is true else ``false``.
@ -549,35 +546,39 @@ class ValueChoiceX(Translatable, nn.Module):
@staticmethod
@_valuechoice_staticmethod_helper
def max(arg0: Union[Iterable['ValueChoiceOrAny'], 'ValueChoiceOrAny'],
*args: List['ValueChoiceOrAny']) -> 'ValueChoiceOrAny':
def max(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
*args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
"""
Returns the maximum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
"""
if not args:
return ValueChoiceX.max(*list(arg0))
lst = [arg0] + list(args)
if not isinstance(arg0, Iterable):
raise TypeError('Expect more than one items to compare max')
return cast(MaybeChoice[_value], ValueChoiceX.max(*list(arg0)))
lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
if any(isinstance(obj, ValueChoiceX) for obj in lst):
return ValueChoiceX(max, 'max({})', lst)
return max(lst)
return max(cast(Any, lst))
@staticmethod
@_valuechoice_staticmethod_helper
def min(arg0: Union[Iterable['ValueChoiceOrAny'], 'ValueChoiceOrAny'],
*args: List['ValueChoiceOrAny']) -> 'ValueChoiceOrAny':
def min(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
*args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
"""
Returns the minunum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
"""
if not args:
return ValueChoiceX.min(*list(arg0))
lst = [arg0] + list(args)
if not isinstance(arg0, Iterable):
raise TypeError('Expect more than one items to compare min')
return cast(MaybeChoice[_value], ValueChoiceX.min(*list(arg0)))
lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
if any(isinstance(obj, ValueChoiceX) for obj in lst):
return ValueChoiceX(min, 'min({})', lst)
return min(lst)
return min(cast(Any, lst))
def __hash__(self):
# this is required because we have implemented ``__eq__``
@ -589,24 +590,25 @@ class ValueChoiceX(Translatable, nn.Module):
# - Implementation effort is too huge.
# As a result, inplace operators like +=, *=, magic methods like `__getattr__` are not included in this list.
def __getitem__(self, key: Any) -> 'ValueChoiceX':
def __getitem__(self: 'ChoiceOf[Any]', key: Any) -> 'ChoiceOf[Any]':
return ValueChoiceX(lambda x, y: x[y], '{}[{}]', [self, key])
# region implement int, float, round, trunc, floor, ceil
# because I believe sometimes we need them to calculate #channels
# `__int__` and `__float__` are not supported because `__int__` is required to return int.
def __round__(self, ndigits: Optional[Any] = None) -> 'ValueChoiceX':
def __round__(self: 'ChoiceOf[SupportsRound[_value]]',
ndigits: Optional['MaybeChoice[int]'] = None) -> 'ChoiceOf[Union[int, SupportsRound[_value]]]':
if ndigits is not None:
return ValueChoiceX(round, 'round({}, {})', [self, ndigits])
return ValueChoiceX(round, 'round({})', [self])
return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({}, {})', [self, ndigits]))
return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({})', [self]))
def __trunc__(self) -> 'ValueChoiceX':
def __trunc__(self) -> NoReturn:
raise RuntimeError("Try to use `ValueChoice.to_int()` instead of `math.trunc()` on value choices.")
def __floor__(self) -> 'ValueChoiceX':
def __floor__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
return ValueChoiceX(math.floor, 'math.floor({})', [self])
def __ceil__(self) -> 'ValueChoiceX':
def __ceil__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
return ValueChoiceX(math.ceil, 'math.ceil({})', [self])
def __index__(self) -> NoReturn:
@ -622,132 +624,133 @@ class ValueChoiceX(Translatable, nn.Module):
# region the following code is generated with codegen (see above)
# Annotated with "region" because I want to collapse them in vscode
def __neg__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.neg, '-{}', [self])
def __neg__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.neg, '-{}', [self]))
def __pos__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.pos, '+{}', [self])
def __pos__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.pos, '+{}', [self]))
def __invert__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.invert, '~{}', [self])
def __invert__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.invert, '~{}', [self]))
def __add__(self, other: Any) -> 'ValueChoiceX':
def __add__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.add, '{} + {}', [self, other])
def __radd__(self, other: Any) -> 'ValueChoiceX':
def __radd__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.add, '{} + {}', [other, self])
def __sub__(self, other: Any) -> 'ValueChoiceX':
def __sub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.sub, '{} - {}', [self, other])
def __rsub__(self, other: Any) -> 'ValueChoiceX':
def __rsub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.sub, '{} - {}', [other, self])
def __mul__(self, other: Any) -> 'ValueChoiceX':
def __mul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mul, '{} * {}', [self, other])
def __rmul__(self, other: Any) -> 'ValueChoiceX':
def __rmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mul, '{} * {}', [other, self])
def __matmul__(self, other: Any) -> 'ValueChoiceX':
def __matmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.matmul, '{} @ {}', [self, other])
def __rmatmul__(self, other: Any) -> 'ValueChoiceX':
def __rmatmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.matmul, '{} @ {}', [other, self])
def __truediv__(self, other: Any) -> 'ValueChoiceX':
def __truediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.truediv, '{} // {}', [self, other])
def __rtruediv__(self, other: Any) -> 'ValueChoiceX':
def __rtruediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.truediv, '{} // {}', [other, self])
def __floordiv__(self, other: Any) -> 'ValueChoiceX':
def __floordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.floordiv, '{} / {}', [self, other])
def __rfloordiv__(self, other: Any) -> 'ValueChoiceX':
def __rfloordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.floordiv, '{} / {}', [other, self])
def __mod__(self, other: Any) -> 'ValueChoiceX':
def __mod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mod, '{} % {}', [self, other])
def __rmod__(self, other: Any) -> 'ValueChoiceX':
def __rmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mod, '{} % {}', [other, self])
def __lshift__(self, other: Any) -> 'ValueChoiceX':
def __lshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.lshift, '{} << {}', [self, other])
def __rlshift__(self, other: Any) -> 'ValueChoiceX':
def __rlshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.lshift, '{} << {}', [other, self])
def __rshift__(self, other: Any) -> 'ValueChoiceX':
def __rshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.rshift, '{} >> {}', [self, other])
def __rrshift__(self, other: Any) -> 'ValueChoiceX':
def __rrshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.rshift, '{} >> {}', [other, self])
def __and__(self, other: Any) -> 'ValueChoiceX':
def __and__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.and_, '{} & {}', [self, other])
def __rand__(self, other: Any) -> 'ValueChoiceX':
def __rand__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.and_, '{} & {}', [other, self])
def __xor__(self, other: Any) -> 'ValueChoiceX':
def __xor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.xor, '{} ^ {}', [self, other])
def __rxor__(self, other: Any) -> 'ValueChoiceX':
def __rxor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.xor, '{} ^ {}', [other, self])
def __or__(self, other: Any) -> 'ValueChoiceX':
def __or__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.or_, '{} | {}', [self, other])
def __ror__(self, other: Any) -> 'ValueChoiceX':
def __ror__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.or_, '{} | {}', [other, self])
def __lt__(self, other: Any) -> 'ValueChoiceX':
def __lt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.lt, '{} < {}', [self, other])
def __le__(self, other: Any) -> 'ValueChoiceX':
def __le__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.le, '{} <= {}', [self, other])
def __eq__(self, other: Any) -> 'ValueChoiceX':
def __eq__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.eq, '{} == {}', [self, other])
def __ne__(self, other: Any) -> 'ValueChoiceX':
def __ne__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.ne, '{} != {}', [self, other])
def __ge__(self, other: Any) -> 'ValueChoiceX':
def __ge__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.ge, '{} >= {}', [self, other])
def __gt__(self, other: Any) -> 'ValueChoiceX':
def __gt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.gt, '{} > {}', [self, other])
# endregion
# __pow__, __divmod__, __abs__ are special ones.
# Not easy to cover those cases with codegen.
def __pow__(self, other: Any, modulo: Optional[Any] = None) -> 'ValueChoiceX':
def __pow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
if modulo is not None:
return ValueChoiceX(pow, 'pow({}, {}, {})', [self, other, modulo])
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [self, other])
def __rpow__(self, other: Any, modulo: Optional[Any] = None) -> 'ValueChoiceX':
def __rpow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
if modulo is not None:
return ValueChoiceX(pow, 'pow({}, {}, {})', [other, self, modulo])
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [other, self])
def __divmod__(self, other: Any) -> 'ValueChoiceX':
def __divmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(divmod, 'divmod({}, {})', [self, other])
def __rdivmod__(self, other: Any) -> 'ValueChoiceX':
def __rdivmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(divmod, 'divmod({}, {})', [other, self])
def __abs__(self) -> 'ValueChoiceX':
def __abs__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(abs, 'abs({})', [self])
ValueChoiceOrAny = TypeVar('ValueChoiceOrAny', ValueChoiceX, Any)
ChoiceOf = ValueChoiceX
MaybeChoice = Union[ValueChoiceX[_cand], _cand]
class ValueChoice(ValueChoiceX, Mutable):
class ValueChoice(ValueChoiceX[_cand], Mutable):
"""
ValueChoice is to choose one from ``candidates``. The most common use cases are:
@ -865,14 +868,14 @@ class ValueChoice(ValueChoiceX, Mutable):
# FIXME: prior is designed but not supported yet
@classmethod
def create_fixed_module(cls, candidates: List[Any], *, label: Optional[str] = None, **kwargs):
def create_fixed_module(cls, candidates: List[_cand], *, label: Optional[str] = None, **kwargs):
value = get_fixed_value(label)
if value not in candidates:
raise ValueError(f'Value {value} does not belong to the candidates: {candidates}.')
return value
def __init__(self, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__(None, None, None)
def __init__(self, candidates: List[_cand], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__()
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
@ -894,10 +897,10 @@ class ValueChoice(ValueChoiceX, Mutable):
# yield self because self is the only value choice here
yield self
def dry_run(self) -> Any:
def dry_run(self) -> _cand:
return self.candidates[0]
def _evaluate(self, values: Iterable[Any], dry_run: bool = False) -> Any:
def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
if dry_run:
return self.candidates[0]
try:
@ -986,6 +989,7 @@ class ModelParameterChoice:
Examples
--------
Get a dynamic-shaped parameter. Because ``torch.zeros`` is not a basic unit, we can't use :class:`ValueChoice` on it.
>>> parameter_dim = nn.ModelParameterChoice([64, 128, 256])
>>> self.token = nn.Parameter(torch.zeros(1, parameter_dim, 32, 32))
"""
@ -1016,12 +1020,14 @@ class ModelParameterChoice:
if default not in candidates:
# could be callable
try:
default = default(candidates)
default = cast(Callable[[List[ValueType]], ValueType], default)(candidates)
except TypeError as e:
if 'not callable' in str(e):
raise TypeError("`default` is not in `candidates`, and it's also not callable.")
raise
default = cast(ValueType, default)
label = generate_new_label(label)
parameter_spec = ParameterSpec(
label, # name

Просмотреть файл

@ -1,6 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import warnings
from typing import Callable, Dict, List, Union, Optional, Tuple
from typing import Callable, Dict, List, Union, Optional, Tuple, Sequence, cast
try:
from typing import Literal
except ImportError:
@ -193,8 +196,10 @@ class Cell(nn.Module):
def __init__(self,
op_candidates: Union[
Callable[[], List[nn.Module]],
List[Union[nn.Module, _cell_op_factory_type]],
Dict[str, Union[nn.Module, _cell_op_factory_type]]
List[nn.Module],
List[_cell_op_factory_type],
Dict[str, nn.Module],
Dict[str, _cell_op_factory_type]
],
num_nodes: int,
num_ops_per_node: int = 1,
@ -251,8 +256,8 @@ class Cell(nn.Module):
ops = self._convert_op_candidates(op_candidates, i, k, chosen)
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
self.ops[-1].append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
self.inputs[-1].append(inp)
cast(ModuleList, self.ops[-1]).append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
cast(ModuleList, self.inputs[-1]).append(inp)
@property
def label(self):
@ -274,13 +279,17 @@ class Cell(nn.Module):
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell.
"""
processed_inputs: List[torch.Tensor]
if len(inputs) == 1 and isinstance(inputs[0], list):
inputs = inputs[0]
processed_inputs = list(inputs[0]) # shallow copy
else:
inputs = list(inputs)
assert len(inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
states = self.preprocessor(inputs)
for ops, inps in zip(self.ops, self.inputs):
processed_inputs = cast(List[torch.Tensor], list(inputs))
assert len(processed_inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for ops, inps in zip(
cast(Sequence[Sequence[LayerChoice]], self.ops),
cast(Sequence[Sequence[InputChoice]], self.inputs)
):
current_state = []
for op, inp in zip(ops, inps):
current_state.append(op(inp(states)))
@ -291,7 +300,7 @@ class Cell(nn.Module):
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
else:
this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim)
return self.postprocessor(this_cell, inputs)
return self.postprocessor(this_cell, processed_inputs)
@staticmethod
def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]:

Просмотреть файл

@ -1,14 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import warnings
from collections import OrderedDict
from typing import Callable, List, Union, Tuple, Optional
from typing import Callable, List, Dict, Union, Tuple, Optional
import torch
import torch.nn as nn
from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice, ValueChoice, ValueChoiceX
from .api import LayerChoice, ValueChoice, ValueChoiceX, ChoiceOf
from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
@ -64,7 +67,7 @@ class Repeat(Mutable):
List[Callable[[int], nn.Module]],
nn.Module,
List[nn.Module]],
depth: Union[int, Tuple[int, int], ValueChoice], *, label: Optional[str] = None):
depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
if isinstance(depth, tuple):
# we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init.
@ -90,7 +93,7 @@ class Repeat(Mutable):
List[Callable[[int], nn.Module]],
nn.Module,
List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
super().__init__()
self._label = None # by default, no label
@ -192,7 +195,7 @@ class NasBench201Cell(nn.Module):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __init__(self, op_candidates: List[Callable[[int, int], nn.Module]],
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
@ -214,16 +217,15 @@ class NasBench201Cell(nn.Module):
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops)
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors = [inputs]
tensors: List[torch.Tensor] = [inputs]
for layer in self.layers:
current_tensor = []
for i, op in enumerate(layer):
current_tensor.append(op(tensors[i]))
current_tensor = torch.sum(torch.stack(current_tensor), 0)
tensors.append(current_tensor)
current_tensor: List[torch.Tensor] = []
for i, op in enumerate(layer): # type: ignore
current_tensor.append(op(tensors[i])) # type: ignore
tensors.append(torch.sum(torch.stack(current_tensor), 0))
return tensors[-1]

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from packaging.version import Version
import torch
import torch.nn as nn
@ -233,7 +235,7 @@ class AutoActivation(nn.Module):
-----
Current `beta` is not per-channel parameter.
"""
def __init__(self, unit_num: int = 1, label: str = None):
def __init__(self, unit_num: int = 1, label: str | None = None):
super().__init__()
self._label = generate_new_label(label)
self.unaries = nn.ModuleList()

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Optional, Tuple, Union
import torch.nn as nn
@ -41,7 +44,7 @@ def generate_new_label(label: Optional[str]):
return label
def get_fixed_value(label: str) -> Any:
def get_fixed_value(label: Optional[str]) -> Any:
ret = get_current_context('fixed')
try:
return ret[generate_new_label(label)]
@ -49,7 +52,7 @@ def get_fixed_value(label: str) -> Any:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
def get_fixed_dict(label_prefix: str) -> Tuple[str, Any]:
def get_fixed_dict(label_prefix: Optional[str]) -> Tuple[str, Any]:
ret = get_current_context('fixed')
try:
label_prefix = generate_new_label(label_prefix)

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT license.
import inspect
from typing import Any, List, Optional, Tuple, Dict, Iterator
from typing import Any, List, Optional, Tuple, Dict, Iterator, Iterable, cast
import torch.nn as nn
@ -28,12 +28,14 @@ class LayerChoiceMutator(Mutator):
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
# We add the connections here in the mutation logic.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target = model.graphs[node.operation.cell_name]
target = model.graphs[cast(Cell, node.operation).cell_name]
chosen_node = target.get_node_by_name(chosen)
assert chosen_node is not None
target.add_edge((target.input_node, 0), (chosen_node, None))
target.add_edge((chosen_node, None), (target.output_node, None))
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
operation = cast(Cell, node.operation)
target_node = cast(Node, model.get_node_by_name(node.name))
target_node.update_operation(Cell(operation.cell_name))
# remove redundant nodes
for rm_node in list(target.hidden_nodes): # remove from a list on the fly will cause issues
@ -57,7 +59,7 @@ class InputChoiceMutator(Mutator):
else:
chosen = [self.choice(candidates) for _ in range(n_chosen)]
for node in self.nodes:
target = model.get_node_by_name(node.name)
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs',
{'chosen': chosen, 'reduction': node.operation.parameters['reduction']})
@ -74,7 +76,7 @@ class ValueChoiceMutator(Mutator):
# no need to support transformation here,
# because it is naturally done in forward loop
for node in self.nodes:
target = model.get_node_by_name(node.name)
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
@ -86,7 +88,7 @@ class ParameterChoiceLeafMutator(Mutator):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
def mutate(self, model: Model) -> None:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
@ -103,7 +105,7 @@ class ParameterChoiceMutator(Mutator):
self.nodes = nodes
def mutate(self, model: Model) -> Model:
def mutate(self, model: Model) -> None:
# looks like {"label1": "cat", "label2": 123}
value_choice_decisions = {}
for mutation in model.history:
@ -122,7 +124,7 @@ class ParameterChoiceMutator(Mutator):
result_value = value_choice.evaluate(leaf_node_values)
# update model with graph mutation primitives
target = model.get_node_by_name(node.name)
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation(target.operation.type, {**target.operation.parameters, argname: result_value})
@ -138,20 +140,20 @@ class RepeatMutator(Mutator):
while u != graph.output_node:
if u != graph.input_node:
chain.append(u)
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.'
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successors}.'
u = u.successors[0]
return chain
def mutate(self, model):
for node in self.nodes:
# the logic here is similar to layer choice. We find cell attached to each node.
target: Graph = model.graphs[node.operation.cell_name]
target: Graph = model.graphs[cast(Cell, node.operation).cell_name]
chain = self._retrieve_chain_from_graph(target)
# and we get the chosen depth (by value choice)
node_in_model = model.get_node_by_name(node.name)
node_in_model = cast(Node, model.get_node_by_name(node.name))
# depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here
chosen_depth = node_in_model.operation.parameters['depth']
chosen_depth: int = node_in_model.operation.parameters['depth']
for edge in chain[chosen_depth - 1].outgoing_edges:
edge.remove()
target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
@ -159,8 +161,11 @@ class RepeatMutator(Mutator):
for edge in rm_node.outgoing_edges:
edge.remove()
rm_node.remove()
# to delete the unused parameters.
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
target_node = cast(Node, model.get_node_by_name(node.name))
cell_operation = cast(Cell, node.operation)
target_node.update_operation(Cell(cell_operation.cell_name))
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
@ -241,7 +246,7 @@ class ManyChooseManyMutator(Mutator):
Choose based on labels. Will not affect the model itself.
"""
def __init__(self, label: Optional[str]):
def __init__(self, label: str):
super().__init__(label=label)
@staticmethod
@ -257,7 +262,7 @@ class ManyChooseManyMutator(Mutator):
return node.operation.parameters['n_chosen']
return 1
def mutate(self, model: Model):
def mutate(self, model: Model) -> None:
# this mutate does not have any effect, but it is recorded in the mutation history
for node in model.get_nodes_by_label(self.label):
n_chosen = self.number_of_chosen(node)
@ -280,12 +285,12 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if not is_model_wrapped(pytorch_model):
raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.')
model.python_init_params = pytorch_model.trace_kwargs
model.python_init_params = cast(dict, pytorch_model.trace_kwargs)
else:
model.python_init_params = {}
# hyper-parameter choice
namespace: ModelNamespace = pytorch_model._model_namespace
namespace: ModelNamespace = cast(ModelNamespace, pytorch_model._model_namespace)
for param_spec in namespace.parameter_specs:
assert param_spec.categorical and param_spec.type == 'choice'
node = graph.add_node(f'param_spec_{param_spec.name}', 'ModelParameterChoice', {'candidates': param_spec.values})
@ -294,7 +299,8 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
for name, module in pytorch_model.named_modules():
# tricky case: value choice that serves as parameters are stored in traced arguments
if is_basic_unit(module):
for key, value in module.trace_kwargs.items():
trace_kwargs = cast(Dict[str, Any], module.trace_kwargs)
for key, value in trace_kwargs.items():
if isinstance(value, ValueChoiceX):
for i, choice in enumerate(value.inner_choices()):
node = graph.add_node(f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates})
@ -329,14 +335,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
mutators = []
mutators_final = []
for nodes in _group_by_label_and_type(graph.hidden_nodes):
label = nodes[0].label
assert label is not None, f'label of {nodes[0]} can not be None.'
assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
f'Node with label "{nodes[0].label}" does not all have the same type.'
f'Node with label "{label}" does not all have the same type.'
assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
f'Node with label "{nodes[0].label}" does not agree on parameters.'
f'Node with label "{label}" does not agree on parameters.'
if nodes[0].operation.type == 'NasBench101Cell':
mutators_final.append(NasBench101Mutator(nodes[0].label))
# The mutation of Nas-bench-101 is special, and has to be done lastly.
mutators_final.append(NasBench101Mutator(label))
else:
mutators.append(ManyChooseManyMutator(nodes[0].label))
mutators.append(ManyChooseManyMutator(label))
return model, mutators + mutators_final
@ -350,7 +359,7 @@ class EvaluatorValueChoiceLeafMutator(Mutator):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
def mutate(self, model: Model) -> None:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
@ -388,7 +397,7 @@ class EvaluatorValueChoiceMutator(Mutator):
return obj
def mutate(self, model: Model):
def mutate(self, model: Model) -> None:
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator):
@ -454,7 +463,7 @@ def _is_all_equal(lst):
return True
def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
def _group_by_label_and_type(nodes: Iterable[Node]) -> List[List[Node]]:
result = {}
for node in nodes:
key = (node.label, node.operation.type)
@ -464,7 +473,7 @@ def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
return list(result.values())
def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
def _group_by_label(nodes: Iterable[Node]) -> List[List[Node]]:
result = {}
for node in nodes:
label = node.operation.parameters['label']

Просмотреть файл

@ -1,6 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
from typing import Callable, List, Optional, Union, Dict
from typing import Callable, List, Optional, Union, Dict, Tuple, cast
import numpy as np
import torch
@ -89,7 +92,7 @@ def compute_vertex_channels(input_channels, output_channels, matrix):
return vertex_channels
def prune(matrix, ops):
def prune(matrix, ops) -> Tuple[np.ndarray, List[Union[str, Callable[[int], nn.Module]]]]:
"""
Prune the extraneous parts of the graph.
@ -152,11 +155,17 @@ class _NasBench101CellFixed(nn.Module):
assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1
self.operations = ['IN'] + operations + ['OUT'] # add psuedo nodes
raw_operations: List[Union[str, Callable[[int], nn.Module]]] = list(operations)
del operations # operations is no longer needed. Delete it to avoid misuse
# add psuedo nodes
raw_operations.insert(0, 'IN')
raw_operations.append('OUT')
self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes)
del num_nodes # raw number of nodes is no longer used
self.connection_matrix, self.operations = prune(self.connection_matrix, self.operations)
self.connection_matrix, self.operations = prune(self.connection_matrix, raw_operations)
self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix)
@ -172,7 +181,8 @@ class _NasBench101CellFixed(nn.Module):
self.projections.append(projection(in_features, self.hidden_features[i]))
for i in range(1, self.num_nodes - 1):
self.ops.append(operations[i - 1](self.hidden_features[i]))
operation = cast(Callable[[int], nn.Module], self.operations[i])
self.ops.append(operation(self.hidden_features[i]))
@staticmethod
def build_connection_matrix(adjacency_list, num_nodes):
@ -361,7 +371,7 @@ class NasBench101Mutator(Mutator):
# for validation purposes
# for python execution engine
def __init__(self, label: Optional[str]):
def __init__(self, label: str):
super().__init__(label=label)
@staticmethod
@ -378,9 +388,11 @@ class NasBench101Mutator(Mutator):
return 1
def mutate(self, model: Model):
max_num_edges = cast(int, None)
for node in model.get_nodes_by_label(self.label):
max_num_edges = node.operation.parameters['max_num_edges']
break
assert max_num_edges is not None
mutation_dict = {mut.mutator.label: mut.samples for mut in model.history}
num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
adjacency_list = [mutation_dict[f'{self.label}/input{i}'] for i in range(1, num_nodes)]

Просмотреть файл

@ -1,26 +1,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
import warnings
from pathlib import Path
import torch
import torch.nn as nn
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path = Path(__file__).parent / '_nn.py'
cache_valid = False
# Update this when cache format changes, to enforce an update.
cache_version = 2
if nn_cache_file_path.exists():
from . import _nn # pylint: disable=no-name-in-module
# valid only when torch version match
if _nn._torch_version == torch.__version__:
cache_valid = True
if not cache_valid:
def validate_cache() -> bool:
import torch
cache_valid = []
if nn_cache_file_path.exists():
lines = nn_cache_file_path.read_text().splitlines()
for line in lines:
if line.startswith('# _torch_version'):
_cached_torch_version = line[line.find('=') + 1:].strip()
if _cached_torch_version == torch.__version__:
cache_valid.append(True)
if line.startswith('# _torch_nn_cache_version'):
_cached_cache_version = int(line[line.find('=') + 1:].strip())
if _cached_cache_version == cache_version:
cache_valid.append(True)
return len(cache_valid) >= 2 and all(cache_valid)
def generate_stub_file() -> str:
import inspect
import warnings
import torch
import torch.nn as nn
_NO_WRAP_CLASSES = [
# not an nn.Module
'Parameter',
@ -47,7 +63,10 @@ if not cache_valid:
'# This file is auto-generated to make auto-completion work.',
'# When pytorch version does not match, it will get automatically updated.',
'# pylint: skip-file',
f'_torch_version = "{torch.__version__}"',
'# pyright: reportGeneralTypeIssues=false',
f'# _torch_version = {torch.__version__}',
f'# _torch_nn_cache_version = {cache_version}',
'import typing',
'import torch.nn as nn',
'from nni.retiarii.serializer import basic_unit',
]
@ -66,10 +85,9 @@ if not cache_valid:
'It means your PyTorch version might not be supported.', RuntimeWarning)
code.append(f'{name} = nn.{name}')
elif name in _WRAP_WITHOUT_TAG_CLASSES:
code.append(f'{name} = basic_unit(nn.{name}, basic_unit_tag=False)')
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}, basic_unit_tag=False))')
else:
code.append(f'{name} = basic_unit(nn.{name})')
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}))')
all_names.append(name)
elif inspect.isfunction(obj) or inspect.ismodule(obj):
@ -78,12 +96,19 @@ if not cache_valid:
code.append(f'__all__ = {all_names}')
return '\n'.join(code)
def write_cache(code: str) -> None:
with nn_cache_file_path.open('w') as fp:
fp.write('\n'.join(code))
fp.write(code)
# Import all modules from generated _nn.py
code = generate_stub_file()
from . import _nn # pylint: disable=no-name-in-module
__all__ = _nn.__all__
from ._nn import * # pylint: disable=import-error, wildcard-import
if not validate_cache():
write_cache(code)
del Path, validate_cache, write_cache, cache_version, nn_cache_file_path, code
from ._nn import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import

Просмотреть файл

@ -20,7 +20,7 @@ from .supermodule.base import BaseSuperNetModule
__all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules']
MutationHook = Callable[[nn.Module, str, Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
def traverse_and_mutate_submodules(
@ -149,11 +149,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are three arguments:
To be more specific, the input arguments are four arguments:
#. a module that might be processed,
#. name of the module in its parent module,
#. a memo dict whose usage depends on the particular algorithm.
#. keyword arguments (configurations).
Note that the memo should be read/written by hooks.
There won't be any hooks called on root module.

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# type: ignore
import copy
import logging
from collections import OrderedDict

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# type: ignore
import logging
import torch

Просмотреть файл

@ -27,7 +27,7 @@ which interprets the slice and apply it on a tensor.
"""
import operator
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic, cast
import numpy as np
import torch
@ -128,9 +128,10 @@ class Slicable(Generic[T]):
raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight
def __getitem__(self, index: multidim_slice) -> T:
def __getitem__(self, index: Union[slice_type, multidim_slice]) -> T:
if not isinstance(index, tuple):
index = (index, )
index = cast(multidim_slice, index)
# Get the dict value in index's leafs
# There can be at most one dict

Просмотреть файл

@ -24,7 +24,7 @@ class BaseSuperNetModule(nn.Module):
rather than their compositions.
"""
def resample(self, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def resample(self, memo: Dict[str, Any]) -> Dict[str, Any]:
"""
Resample the super-net module.
@ -40,7 +40,7 @@ class BaseSuperNetModule(nn.Module):
"""
raise NotImplementedError()
def export(self, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def export(self, memo: Dict[str, Any]) -> Dict[str, Any]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.

Просмотреть файл

@ -275,11 +275,11 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
if not arch:
yield name, p
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Differentiable. Do nothing in resample."""
return {}
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
for name, spec in operation.search_space_spec().items():

Просмотреть файл

@ -8,11 +8,12 @@ which is commonly known as super-kernel (as in channel search), or weight entang
import inspect
import itertools
from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar
from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import nni.retiarii.nn.pytorch as retiarii_nn
from nni.common.hpo_utils import ParameterSpec
@ -46,11 +47,11 @@ class MixedOperationSamplingPolicy:
"""
pass
def resample(self, operation: 'MixedOperation', memo: Dict[str, Any] = None) -> Dict[str, Any]:
def resample(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
"""The handler of :meth:`MixedOperation.resample`."""
raise NotImplementedError()
def export(self, operation: 'MixedOperation', memo: Dict[str, Any] = None) -> Dict[str, Any]:
def export(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
"""The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError()
@ -513,43 +514,42 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
embed_dim = _W(embed_dim)
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias = in_proj_weight = None
in_proj_bias: Optional[Tensor] = None
in_proj_weight: Optional[Tensor] = None
if self.in_proj_bias is not None:
in_proj_bias = _S(self.in_proj_bias)[self._to_proj_slice(embed_dim)]
in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim)]
if self.in_proj_weight is not None:
in_proj_weight = _S(self.in_proj_weight)[self._to_proj_slice(embed_dim), :embed_dim]
in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim), :embed_dim]
bias_k = _S(self.bias_k)[:, :, :embed_dim] if self.bias_k is not None else None
bias_v = _S(self.bias_v)[:, :, :embed_dim] if self.bias_v is not None else None
out_proj_weight = _S(self.out_proj.weight)[:embed_dim, :embed_dim]
out_proj_bias = _S(self.out_proj.bias)[:embed_dim] if self.out_proj.bias is not None else None
bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim] if self.bias_k is not None else None
bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim] if self.bias_v is not None else None
out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim, :embed_dim]
out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim] if self.out_proj.bias is not None else None
if not qkv_same_embed_dim:
kdim = _W(kdim)
vdim = _W(vdim)
q_proj = _S(self.q_proj_weight)[:embed_dim, :embed_dim]
k_proj = _S(self.k_proj_weight)[:embed_dim]
k_proj = _S(k_proj)[:, :kdim]
v_proj = _S(self.v_proj_weight)[:embed_dim]
v_proj = _S(v_proj)[:, :vdim]
q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim, :embed_dim]
k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim]
k_proj = _S(k_proj)[:, :_W(kdim)]
v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim]
v_proj = _S(v_proj)[:, :_W(vdim)]
# The rest part is basically same as pytorch
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads,
in_proj_weight, in_proj_bias,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, out_proj_bias,
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=q_proj, k_proj_weight=k_proj, v_proj_weight=v_proj)
else:
# Cast tensor here because of a bug in pytorch stub
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads,
in_proj_weight, in_proj_bias,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, out_proj_bias,
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask)

Просмотреть файл

@ -9,7 +9,7 @@ The support remains limited. Known limitations include:
- The code contains duplicates. Needs refactor.
"""
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, cast
import torch
import torch.nn as nn
@ -94,7 +94,7 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self._sample_idx = self.op_names.index(self._sampled)
else:
probs = self._softmax(self._arch_alpha)
self._sample_idx = torch.multinomial(probs, 1)[0].item()
self._sample_idx = int(torch.multinomial(probs, 1)[0].item())
self._sampled = self.op_names[self._sample_idx]
# set binary gates
@ -109,10 +109,11 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[torch.argmax(self._arch_alpha).item()]}
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
@ -164,13 +165,13 @@ class ProxylessMixedInput(DifferentiableMixedInput):
else:
probs = self._softmax(self._arch_alpha)
sample = torch.multinomial(probs, 1)[0].item()
self._sampled = sample
self._sampled = int(sample)
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
self._binary_gates.data[cast(int, self._sampled)] = 1.0
return {self.label: self._sampled}
@ -182,6 +183,7 @@ class ProxylessMixedInput(DifferentiableMixedInput):
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)

Просмотреть файл

@ -129,6 +129,8 @@ class PathSamplingInput(BaseSuperNetModule):
if isinstance(module, InputChoice):
if module.reduction not in ['sum', 'mean', 'concat']:
raise ValueError('Only input choice of sum/mean/concat reduction is supported.')
if module.n_chosen is None:
raise ValueError('n_chosen is None is not supported yet.')
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
def forward(self, input_tensors):
@ -161,7 +163,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self._sampled: Optional[Dict[str, Any]] = None
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Random sample for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
@ -179,7 +181,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return result
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()

Просмотреть файл

@ -132,7 +132,7 @@ def _replace_module_with_type(root_module, init_fn, type_name, modules):
for name, child in m.named_children():
if isinstance(child, type_name):
setattr(m, name, init_fn(child))
modules.append((child.key, getattr(m, name)))
modules.append((child.label, getattr(m, name)))
else:
apply(child)

Просмотреть файл

@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Dict, List)
from typing import (Any, Dict, List, Optional, cast)
from . import debug_configs
@ -34,6 +34,8 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
"""
io_names: List[str] = []
def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name
@ -43,7 +45,7 @@ class Operation:
def to_init_code(self, field: str) -> str:
raise NotImplementedError()
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise NotImplementedError()
def _to_class_name(self) -> str:
@ -53,8 +55,8 @@ class Operation:
return True
@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None,
attributes: Dict[str, Any] = None) -> 'Operation':
def new(type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None), cell_name: str = cast(str, None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Operation':
parameters = parameters or {}
attributes = attributes or {}
if type_name == '_cell':
@ -98,16 +100,16 @@ class PyTorchOperation(Operation):
subclass_name = 'FunctionalOperator'
for subclass in cls.__subclasses__():
if hasattr(subclass, '_ori_type_name') and \
subclass_name in subclass._ori_type_name:
subclass_name in cast(Any, subclass)._ori_type_name:
return subclass
for subclass in cls.__subclasses__():
if hasattr(subclass, '_artificial_op_name') and \
subclass_name in subclass._artificial_op_name:
subclass_name in cast(Any, subclass)._artificial_op_name:
return subclass
return cls
@classmethod
def to_class_name(cls, type_name) -> str:
def to_class_name(cls, type_name) -> Optional[str]:
if type_name.startswith('__torch__.'):
return type_name[len('__torch__.'):]
elif type_name.startswith('__mutated__.'):
@ -119,7 +121,7 @@ class PyTorchOperation(Operation):
def is_functional(cls, type_name) -> bool:
return type_name.startswith('Function.')
def _to_class_name(self) -> str:
def _to_class_name(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):]
elif self.type.startswith('__mutated__.'):
@ -127,7 +129,7 @@ class PyTorchOperation(Operation):
else:
return None
def get_import_pkg(self) -> str:
def get_import_pkg(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):].split('.')[0]
elif self.type.startswith('__mutated__.'):
@ -135,14 +137,14 @@ class PyTorchOperation(Operation):
else:
return None
def to_init_code(self, field: str) -> str:
def to_init_code(self, field: str) -> Optional[str]:
if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
return f'self.{field} = {self._to_class_name()}({kw_params})'
return None
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
"""
Parameters
----------
@ -207,7 +209,9 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str, parameters: Dict[str, Any] = None, attributes: Dict[str, Any] = None):
def __init__(self, cell_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)):
self.type = '_cell'
self.cell_name = cell_name
self.parameters = parameters or {}
@ -217,7 +221,7 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name)
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class _IOPseudoOperation(Operation):
@ -227,7 +231,7 @@ class _IOPseudoOperation(Operation):
especially in static type checking.
"""
def __init__(self, type_name: str, io_names: List = None):
def __init__(self, type_name: str, io_names: List[str] = cast(List[str], None)):
assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
self.io_names = io_names
@ -235,7 +239,7 @@ class _IOPseudoOperation(Operation):
def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def __bool__(self) -> bool:

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Definition of operation types.

Просмотреть файл

@ -1,9 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import (Any, Dict, List)
import torch
import torch.nn.functional as nn_functional
from ..operation import PyTorchOperation
@ -39,23 +42,23 @@ class NoOpIdentity(PyTorchOperation):
"""
_ori_type_name = ['noop_identity']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {", ".join(inputs)}'
class ModuleOperator(PyTorchOperation):
_ori_type_name = ['ModuleOperator', 'shared']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class FunctionalOperator(PyTorchOperation):
_ori_type_name = ['FunctionalOperator']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
func_name = self.type[len('Function.'):]
if not hasattr(torch.nn.functional, func_name):
if not hasattr(nn_functional, func_name):
raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, '
f'{func_name} is not in it.')
return f'{output} = F.{func_name}({", ".join(inputs)})'
@ -64,7 +67,7 @@ class FunctionalOperator(PyTorchOperation):
class PrimConstant(PyTorchOperation):
_ori_type_name = ['prim::Constant']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if self.parameters['type'] in ['None', 'NoneType']:
@ -87,28 +90,28 @@ class PrimConstant(PyTorchOperation):
class PrimListConstruct(PyTorchOperation):
_ori_type_name = ['prim::ListConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = [{", ".join(inputs)}]'
class PrimListUnpack(PyTorchOperation):
_ori_type_name = ['prim::ListUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class PrimTupleConstruct(PyTorchOperation):
_ori_type_name = ['prim::TupleConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = ({", ".join(inputs)})'
class PrimTupleUnpack(PyTorchOperation):
_ori_type_name = ['prim::TupleUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1
return f'{output} = {inputs[0]}'
@ -117,7 +120,7 @@ class PrimTupleUnpack(PyTorchOperation):
class PrimGetAttr(PyTorchOperation):
_ori_type_name = ['prim::GetAttr']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}"
else:
@ -127,14 +130,14 @@ class PrimGetAttr(PyTorchOperation):
class PrimUncheckedCast(PyTorchOperation):
_ori_type_name = ['prim::unchecked_cast']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
member_name = self.type.split('::')[-1]
return f'{output} = {inputs[0]}.{member_name}'
@ -142,16 +145,16 @@ class SimpleMember(PyTorchOperation):
class AtenContiguous(PyTorchOperation):
_ori_type_name = ['aten::contiguous']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value[1] in [0, 1, 2]
assert inputs_value is not None and inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
class AtenGetitem(PyTorchOperation):
_ori_type_name = ['aten::__getitem__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
@ -159,7 +162,7 @@ class AtenGetitem(PyTorchOperation):
class AtenAppend(PyTorchOperation):
_ori_type_name = ['aten::append']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
@ -167,7 +170,7 @@ class AtenAppend(PyTorchOperation):
class MergedSlice(PyTorchOperation):
_ori_type_name = ['MergedSlice']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if (len(inputs) - 1) % 4 == 0:
slices = []
dim = int((len(inputs) - 1) / 4)
@ -187,21 +190,21 @@ class MergedSlice(PyTorchOperation):
class AtenBool(PyTorchOperation):
_ori_type_name = ['aten::Bool']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = bool({inputs[0]})'
class AtenNot(PyTorchOperation):
_ori_type_name = ['aten::__not__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = not {inputs[0]}'
class AtenCat(PyTorchOperation):
_ori_type_name = ['aten::cat']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
@ -215,7 +218,7 @@ class AtenTensors(PyTorchOperation):
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
schemas = torch._C._jit_get_schemas_for_operator(self.type)
# match number of inputs
overloaded_defs = [len(s.arguments) for s in schemas]
@ -257,40 +260,41 @@ class AtenTensors(PyTorchOperation):
class AtenFloordiv(PyTorchOperation):
_ori_type_name = ['aten::floordiv']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenMul(PyTorchOperation):
_ori_type_name = ['aten::mul']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} * {inputs[1]}'
class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = len({inputs[0]})'
class AtenIntImplicit(PyTorchOperation):
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.type.endswith('Implicit'):
return f'{output} = {inputs[0]}'
elif self.type == 'aten::Int':
return f'{output} = int({inputs[0]})'
elif self.type == 'aten::Float':
return f'{output} = float({inputs[0]})'
raise TypeError(f'Unexpected type: {self.type}')
class AtenIndex(PyTorchOperation):
_ori_type_name = ['aten::index']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}[{inputs[1]}]'
@ -355,13 +359,13 @@ def _get_tensor_ops():
def _get_torch_ops():
torch_op_args = {}
for mod in torch.jit._builtins._modules_containing_builtins:
for mod in torch.jit._builtins._modules_containing_builtins: # type: ignore
name = mod.__name__
if name == 'torch._C._nn':
continue
# only process 'torch.XXX'
for elem in dir(mod):
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem))
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem)) # type: ignore
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
@ -436,7 +440,7 @@ class TensorOps(PyTorchOperation):
return None
raise RuntimeError(f'tensor op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: deal with conditional ops
if self.type in TensorOps.comparison_ops:
return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})'
@ -486,7 +490,7 @@ class TorchOps(PyTorchOperation):
else:
raise RuntimeError(f'torch op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
matched_args = TorchOps._get_matched_args(self.type, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
@ -498,7 +502,7 @@ class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
@ -506,7 +510,7 @@ class ToDevice(PyTorchOperation):
_artificial_op_name = "ToDevice"
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False,
attributes: Dict[str, Any] = None):
attributes: Dict[str, Any] = {}):
self.type = "ToDevice"
self.device = parameters['device']
self.overridden_device_repr = None
@ -540,5 +544,5 @@ class AtenDet(PyTorchOperation):
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = torch.det({inputs[0]})'

Просмотреть файл

@ -4,9 +4,9 @@
import inspect
import os
import warnings
from typing import Any, TypeVar, Union
from typing import Any, TypeVar, Type
from nni.common.serializer import Traceable, is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
@ -48,7 +48,7 @@ def serialize_cls(cls):
return trace(cls)
def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
def basic_unit(cls: T, basic_unit_tag: bool = True) -> T:
"""
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
@ -75,17 +75,17 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' # type: ignore
cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag
cls._nni_basic_unit = basic_unit_tag # type: ignore
_torchscript_patch(cls)
return cls
def model_wrapper(cls: T) -> Union[T, Traceable]:
def model_wrapper(cls: T) -> T:
"""
Wrap the base model (search space). For example,
@ -113,7 +113,7 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module)
assert issubclass(cls, nn.Module) # type: ignore
# subclass can still use trace info
wrapper = trace(cls, inheritable=True)
@ -146,7 +146,7 @@ def is_model_wrapped(cls_or_instance) -> bool:
return getattr(cls_or_instance, '_nni_model_wrapper', False)
def _check_wrapped(cls: T, rewrap: str) -> bool:
def _check_wrapped(cls: Type, rewrap: str) -> bool:
wrapped = None
if is_model_wrapped(cls):
wrapped = 'model_wrapper'

Просмотреть файл

@ -1,19 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# This file might cause import error for those who didn't install RL-related dependencies
import logging
import threading
from multiprocessing.pool import ThreadPool
from typing import Tuple
import gym
import numpy as np
import tianshou
import torch
import torch.nn as nn
import torch.nn.functional as F
from gym import spaces
from tianshou.data import to_torch
from tianshou.env.worker import EnvWorker
from nni.typehint import TypedDict
from .utils import get_targeted_model
from ..graph import ModelStatus
from ..execution import submit_models, wait_models
@ -76,8 +83,13 @@ class MultiThreadEnvWorker(EnvWorker):
self.pool.terminate()
return self.env.close()
class ObservationType(TypedDict):
action_history: np.ndarray
cur_step: int
action_dim: int
class ModelEvaluationEnv(gym.Env):
class ModelEvaluationEnv(gym.Env[ObservationType, int]):
def __init__(self, base_model, mutators, search_space):
self.base_model = base_model
self.mutators = mutators
@ -98,7 +110,7 @@ class ModelEvaluationEnv(gym.Env):
def action_space(self):
return spaces.Discrete(self.action_dim)
def reset(self):
def reset(self) -> ObservationType:
self.action_history = np.zeros(self.num_steps, dtype=np.int32)
self.cur_step = 0
self.sample = {}
@ -108,14 +120,14 @@ class ModelEvaluationEnv(gym.Env):
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]])
}
def step(self, action):
def step(self, action: int) -> Tuple[ObservationType, float, bool, dict]:
cur_key = self.ss_keys[self.cur_step]
assert action < len(self.search_space[cur_key]), \
f'Current action {action} out of range {self.search_space[cur_key]}.'
self.action_history[self.cur_step] = action
self.sample[cur_key] = self.search_space[cur_key][action]
self.cur_step += 1
obs = {
obs: ObservationType = {
'action_history': self.action_history,
'cur_step': self.cur_step,
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) \
@ -129,7 +141,7 @@ class ModelEvaluationEnv(gym.Env):
wait_models(model)
if model.status == ModelStatus.Failed:
return self.reset(), 0., False, {}
rew = model.metric
rew = float(model.metric)
_logger.info(f'Model metric received as reward: {rew}')
return obs, rew, True, {}
else:
@ -147,7 +159,7 @@ class Preprocessor(nn.Module):
self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
def forward(self, obs):
seq = nn.functional.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token
seq = F.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token
# end token is used to avoid out-of-range of v_s_. Will not actually affect BP.
seq = self.embedding(seq.long())
feature, _ = self.rnn(seq)
@ -167,7 +179,7 @@ class Actor(nn.Module):
# to take care of choices with different number of options
mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
out[mask.to(out.device)] = float('-inf')
return nn.functional.softmax(out, dim=-1), kwargs.get('state', None)
return F.softmax(out, dim=-1), kwargs.get('state', None)
class Critic(nn.Module):

Просмотреть файл

@ -14,5 +14,5 @@ class BaseStrategy(abc.ABC):
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass
def export_top_models(self) -> List[Any]:
def export_top_models(self, top_k: int) -> List[Any]:
raise NotImplementedError('"export_top_models" is not implemented.')

Просмотреть файл

@ -6,7 +6,7 @@ import itertools
import logging
import random
import time
from typing import Any, Dict, List
from typing import Any, Dict, List, Sequence, Optional
from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted
from .base import BaseStrategy
@ -30,6 +30,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
history = set()
search_space_values = copy.deepcopy(list(search_space.values()))
while True:
selected: Optional[Sequence[int]] = None
for retry_count in range(retries):
selected = [random.choice(v) for v in search_space_values]
if not dedup:
@ -41,6 +42,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
if retry_count + 1 == retries:
_logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
return
assert selected is not None, 'Retry attempts exhausted.'
yield {key: value for key, value in zip(keys, selected)}

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Optional, Callable

Просмотреть файл

@ -3,6 +3,7 @@
import logging
import time
from typing import Optional
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
@ -15,8 +16,8 @@ _logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample = None
self.index = None
self.cur_sample: Optional[dict] = None
self.index: Optional[int] = None
self.total_parameters = {}
def update_sample_space(self, sample_space):
@ -34,6 +35,7 @@ class TPESampler(Sampler):
self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)
def choice(self, candidates, mutator, model, index):
assert isinstance(self.index, int) and isinstance(self.cur_sample, dict)
chosen = self.cur_sample[str(self.index)]
self.index += 1
return chosen

Просмотреть файл

@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import logging
from typing import Dict, Any, List

Просмотреть файл

@ -25,4 +25,6 @@ if __name__ == '__main__':
elif args.exec == 'benchmark':
from .execution.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine
else:
raise ValueError(f'Unrecognized benchmark name: {args.exec}')
engine.trial_execute_graph()

Просмотреть файл

@ -6,7 +6,7 @@ import itertools
import warnings
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, List, Dict
from typing import Any, List, Dict, cast
from pathlib import Path
from nni.common.hpo_utils import ParameterSpec
@ -41,9 +41,10 @@ def get_module_name(cls_or_func):
if module_name == '__main__':
# infer the module name with inspect
for frm in inspect.stack():
if inspect.getmodule(frm[0]).__name__ == '__main__':
module = inspect.getmodule(frm[0])
if module is not None and module.__name__ == '__main__':
# main module found
main_file_path = Path(inspect.getsourcefile(frm[0]))
main_file_path = Path(cast(str, inspect.getsourcefile(frm[0])))
if not Path().samefile(main_file_path.parent):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
@ -227,6 +228,7 @@ def original_state_dict_hooks(model: Any):
supernet_style_state_dict = model.state_dict()
"""
import torch.utils.hooks
import torch.nn as nn
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
@ -297,8 +299,8 @@ def original_state_dict_hooks(model: Any):
raise KeyError(f'"{src}" not in state dict, but found in mapping.')
destination.update(result)
hooks: List[torch.utils.hooks.RemovableHandle] = []
try:
hooks = []
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
hooks.append(model._register_state_dict_hook(state_dict_hook))
yield

Просмотреть файл

@ -6,7 +6,7 @@ Types for static checking.
"""
__all__ = [
'Literal',
'Literal', 'TypedDict',
'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord',
]

Просмотреть файл

@ -64,6 +64,9 @@ stages:
python -m pip install "typing-extensions>=3.10"
displayName: Resolve dependency version
- script: python test/vso_tools/trigger_import.py
displayName: Trigger import
- script: |
python -m pylint --rcfile pylintrc nni
displayName: pylint

Просмотреть файл

@ -3,10 +3,13 @@
"nni/algorithms",
"nni/common/device.py",
"nni/common/graph_utils.py",
"nni/common/serializer.py",
"nni/compression",
"nni/nas",
"nni/retiarii",
"nni/nas/tensorflow",
"nni/nas/pytorch",
"nni/retiarii/execution/cgo_engine.py",
"nni/retiarii/execution/logical_optimizer",
"nni/retiarii/evaluator/pytorch/cgo",
"nni/retiarii/oneshot",
"nni/smartparam.py",
"nni/tools/annotation",
"nni/tools/gpu_tool",
@ -14,5 +17,6 @@
"nni/tools/nnictl",
"nni/tools/trial_tool"
],
"reportMissingImports": false
"reportMissingImports": false,
"reportPrivateImportUsage": false
}

Просмотреть файл

@ -4,4 +4,5 @@ filterwarnings =
ignore:Using key to access the identifier of:DeprecationWarning
ignore:layer_choice.choices is deprecated.:DeprecationWarning
ignore:The truth value of an empty array is ambiguous.:DeprecationWarning
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
ignore:nni.retiarii.serialize is deprecated and will be removed in future release.:DeprecationWarning

Просмотреть файл

@ -36,5 +36,9 @@
{"head": ["conv2", null], "tail": ["pool2", null]},
{"head": ["pool2", null], "tail": ["_outputs", 0]}
]
},
"_evaluator": {
"type": "DebugEvaluator"
}
}

Просмотреть файл

@ -9,6 +9,7 @@ from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine
from nni.retiarii.graph import DebugEvaluator
from nni.retiarii.integration import RetiariiAdvisor
@ -51,6 +52,7 @@ class EngineTest(unittest.TestCase):
'edges': []
}
})
model.evaluator = DebugEvaluator()
model.python_class = object
submit_models(model, model)

Просмотреть файл

@ -0,0 +1,10 @@
"""Trigger import of some modules to write some caches,
so that static analysis (e.g., pyright) can know the type."""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../'))
import nni
import nni.retiarii.nn.pytorch