зеркало из https://github.com/microsoft/nni.git
Add license header and typehints for NAS (#4774)
This commit is contained in:
Родитель
8c2f717d83
Коммит
1896212902
|
@ -19,5 +19,5 @@ scikit-learn >= 0.24.1
|
|||
scipy < 1.8 ; python_version < "3.8"
|
||||
scipy ; python_version >= "3.8"
|
||||
typeguard
|
||||
typing_extensions >= 4.0.0 ; python_version < "3.8"
|
||||
typing_extensions >= 4.0.0
|
||||
websockets >= 10.1
|
||||
|
|
|
@ -13,7 +13,7 @@ import sys
|
|||
import types
|
||||
import warnings
|
||||
from io import IOBase
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast, Generic
|
||||
|
||||
import cloudpickle # use cloudpickle as backend for unserializable types and instances
|
||||
import json_tricks # use json_tricks as serializer backend
|
||||
|
@ -115,7 +115,7 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
|
|||
)
|
||||
|
||||
|
||||
class SerializableObject(Traceable):
|
||||
class SerializableObject(Generic[T], Traceable):
|
||||
"""
|
||||
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
|
||||
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
|
||||
|
@ -147,7 +147,7 @@ class SerializableObject(Traceable):
|
|||
# Reinitialize
|
||||
return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs)
|
||||
|
||||
return self
|
||||
return cast(T, self)
|
||||
|
||||
@property
|
||||
def trace_symbol(self) -> Any:
|
||||
|
@ -187,7 +187,7 @@ class SerializableObject(Traceable):
|
|||
')'
|
||||
|
||||
|
||||
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any:
|
||||
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> T:
|
||||
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
|
||||
# Make obj complying with the interface of traceable, though we cannot change its base class.
|
||||
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
|
||||
|
@ -233,11 +233,11 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
|
|||
else:
|
||||
# sometimes create_wrapper is mandatory, e.g., for built-in types like list/int.
|
||||
# but I don't want to check here because it's unreliable.
|
||||
wrapper = type('wrapper', (Traceable, cls), attributes)
|
||||
return wrapper
|
||||
wrapper = type('wrapper', (Traceable, cast(Type, cls)), attributes)
|
||||
return cast(T, wrapper)
|
||||
|
||||
|
||||
def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = False) -> Union[T, Traceable]:
|
||||
def trace(cls_or_func: T = cast(T, None), *, kw_only: bool = True, inheritable: bool = False) -> T:
|
||||
"""
|
||||
Annotate a function or a class if you want to preserve where it comes from.
|
||||
This is usually used in the following scenarios:
|
||||
|
@ -283,7 +283,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
|
|||
# Might be changed in future.
|
||||
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
|
||||
if nni_trace_flag.lower() == 'disable':
|
||||
return cls_or_func
|
||||
return cast(T, cls_or_func)
|
||||
|
||||
def wrap(cls_or_func):
|
||||
# already annotated, do nothing
|
||||
|
@ -301,20 +301,22 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True, inheritable: bool = Fa
|
|||
|
||||
# if we're being called as @trace()
|
||||
if cls_or_func is None:
|
||||
return wrap
|
||||
return wrap # type: ignore
|
||||
|
||||
# if we are called without parentheses
|
||||
return wrap(cls_or_func)
|
||||
return wrap(cls_or_func) # type: ignore
|
||||
|
||||
|
||||
def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_size_limit: int = 4096,
|
||||
allow_nan: bool = True, **json_tricks_kwargs) -> Union[str, bytes]:
|
||||
allow_nan: bool = True, **json_tricks_kwargs) -> str:
|
||||
"""
|
||||
Convert a nested data structure to a json string. Save to file if fp is specified.
|
||||
Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle.
|
||||
The serializer is not designed for long-term storage use, but rather to copy data between processes.
|
||||
The format is also subject to change between NNI releases.
|
||||
|
||||
To compress the payload, please use :func:`dump_bytes`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : any
|
||||
|
@ -334,6 +336,39 @@ def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_s
|
|||
Normally str. Sometimes bytes (if compressed).
|
||||
"""
|
||||
|
||||
if json_tricks_kwargs.get('compression') is not None:
|
||||
raise ValueError('If you meant to compress the dumped payload, please use `dump_bytes`.')
|
||||
result = _dump(
|
||||
obj=obj,
|
||||
fp=fp,
|
||||
use_trace=use_trace,
|
||||
pickle_size_limit=pickle_size_limit,
|
||||
allow_nan=allow_nan,
|
||||
**json_tricks_kwargs)
|
||||
return cast(str, result)
|
||||
|
||||
|
||||
def dump_bytes(obj: Any, fp: Optional[Any] = None, *, compression: int = cast(int, None),
|
||||
use_trace: bool = True, pickle_size_limit: int = 4096,
|
||||
allow_nan: bool = True, **json_tricks_kwargs) -> bytes:
|
||||
"""
|
||||
Same as :func:`dump`, but to comporess payload, with `compression <https://json-tricks.readthedocs.io/en/stable/#dump>`__.
|
||||
"""
|
||||
if compression is None:
|
||||
raise ValueError('compression must be set.')
|
||||
result = _dump(
|
||||
obj=obj,
|
||||
fp=fp,
|
||||
compression=compression,
|
||||
use_trace=use_trace,
|
||||
pickle_size_limit=pickle_size_limit,
|
||||
allow_nan=allow_nan,
|
||||
**json_tricks_kwargs)
|
||||
return cast(bytes, result)
|
||||
|
||||
|
||||
def _dump(*, obj: Any, fp: Optional[Any], use_trace: bool, pickle_size_limit: int,
|
||||
allow_nan: bool, **json_tricks_kwargs) -> Union[str, bytes]:
|
||||
encoders = [
|
||||
# we don't need to check for dependency as many of those have already been required by NNI
|
||||
json_tricks.pathlib_encode, # pathlib is a required dependency for NNI
|
||||
|
@ -456,7 +491,7 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False):
|
|||
raise TypeError(f"{base} has a superclass already decorated with trace, and it's using a customized metaclass {type(base)}. "
|
||||
"Please either use the default metaclass, or remove trace from the super-class.")
|
||||
|
||||
class wrapper(SerializableObject, base, metaclass=metaclass):
|
||||
class wrapper(SerializableObject, base, metaclass=metaclass): # type: ignore
|
||||
def __init__(self, *args, **kwargs):
|
||||
# store a copy of initial parameters
|
||||
args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
|
||||
|
@ -528,7 +563,8 @@ def _trace_func(func, kw_only):
|
|||
# and thus not possible to restore the trace parameters after dump and reload.
|
||||
# this is a known limitation.
|
||||
new_type = _make_class_traceable(type(res), True)
|
||||
res = new_type(res) # re-creating the object
|
||||
# re-creating the object
|
||||
res = new_type(res) # type: ignore
|
||||
res = inject_trace_info(res, func, args, kwargs)
|
||||
else:
|
||||
raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. '
|
||||
|
@ -750,7 +786,7 @@ def import_cls_or_func_from_hybrid_name(s: str) -> Any:
|
|||
return _import_cls_or_func_from_name(s)
|
||||
|
||||
|
||||
def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> str:
|
||||
def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> Dict[str, str]:
|
||||
if not isinstance(cls_or_func, type) and not _is_function(cls_or_func):
|
||||
# not a function or class, continue
|
||||
return cls_or_func
|
||||
|
@ -762,8 +798,7 @@ def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False,
|
|||
|
||||
def _json_tricks_func_or_cls_decode(s: Dict[str, Any]) -> Any:
|
||||
if isinstance(s, dict) and '__nni_type__' in s:
|
||||
s = s['__nni_type__']
|
||||
return import_cls_or_func_from_hybrid_name(s)
|
||||
return import_cls_or_func_from_hybrid_name(s['__nni_type__'])
|
||||
return s
|
||||
|
||||
|
||||
|
@ -815,8 +850,7 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
|
|||
|
||||
def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any:
|
||||
if isinstance(obj, dict) and '__nni_obj__' in obj:
|
||||
obj = obj['__nni_obj__']
|
||||
b = base64.b64decode(obj)
|
||||
b = base64.b64decode(obj['__nni_obj__'])
|
||||
return _wrapped_cloudpickle_loads(b)
|
||||
return obj
|
||||
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .utils import load_benchmark, download_benchmark
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .constants import INPUT, OUTPUT, CONV3X3_BN_RELU, CONV1X1_BN_RELU, MAXPOOL3X3
|
||||
from .model import Nb101TrialStats, Nb101IntermediateStats, Nb101TrialConfig
|
||||
from .query import query_nb101_trial_stats
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
INPUT = 'input'
|
||||
OUTPUT = 'output'
|
||||
CONV3X3_BN_RELU = 'conv3x3-bn-relu'
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
|
||||
from tqdm import tqdm
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import hashlib
|
||||
|
||||
import numpy as np
|
||||
|
@ -10,7 +13,7 @@ def _labeling_from_architecture(architecture, vertices):
|
|||
|
||||
|
||||
def _adjancency_matrix_from_architecture(architecture, vertices):
|
||||
matrix = np.zeros((vertices, vertices), dtype=np.bool)
|
||||
matrix = np.zeros((vertices, vertices), dtype=np.bool) # type: ignore
|
||||
for i in range(1, vertices):
|
||||
for k in architecture['input{}'.format(i)]:
|
||||
matrix[k, i] = 1
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
|
||||
from playhouse.sqlite_ext import JSONField
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import functools
|
||||
|
||||
from peewee import fn
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3
|
||||
from .model import Nb201TrialStats, Nb201IntermediateStats, Nb201TrialConfig
|
||||
from .query import query_nb201_trial_stats
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
NONE = 'none'
|
||||
SKIP_CONNECT = 'skip_connect'
|
||||
CONV_1X1 = 'conv_1x1'
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
|
@ -17,7 +20,7 @@ def parse_arch_str(arch_str):
|
|||
'nor_conv_3x3': CONV_3X3,
|
||||
'avg_pool_3x3': AVG_POOL_3X3
|
||||
}
|
||||
m = re.match(r'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|', arch_str)
|
||||
m: re.Match = re.match(r'\|(.*)~0\|\+\|(.*)~0\|(.*)~1\|\+\|(.*)~0\|(.*)~1\|(.*)~2\|', arch_str) # type: ignore
|
||||
return {
|
||||
'0_1': mp[m.group(1)],
|
||||
'0_2': mp[m.group(2)],
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
|
||||
from playhouse.sqlite_ext import JSONField
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import functools
|
||||
|
||||
from peewee import fn
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .constants import *
|
||||
from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
|
||||
from .query import query_nds_trial_stats
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
NONE = 'none'
|
||||
SKIP_CONNECT = 'skip_connect'
|
||||
AVG_POOL_3X3 = 'avg_pool_3x3'
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
|
||||
from playhouse.sqlite_ext import JSONField
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import functools
|
||||
|
||||
from peewee import fn
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .model import NlpTrialStats, NlpIntermediateStats, NlpTrialConfig
|
||||
from .query import query_nlp_trial_stats
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
|
||||
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import functools
|
||||
|
||||
from peewee import fn
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
|
@ -6,6 +9,7 @@ import os
|
|||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import tqdm
|
||||
|
@ -49,8 +53,9 @@ def load_or_download_file(local_path: str, download_url: str, download: bool = F
|
|||
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
||||
r = requests.get(download_url, stream=True)
|
||||
total_length = int(r.headers.get('content-length'))
|
||||
with tqdm.tqdm(total=total_length, disable=not progress,
|
||||
total_length: Optional[str] = r.headers.get('content-length')
|
||||
assert total_length is not None, f'Content length is not found in the response of {download_url}'
|
||||
with tqdm.tqdm(total=int(total_length), disable=not progress,
|
||||
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
||||
for chunk in r.iter_content(8192):
|
||||
f.write(chunk)
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .base_mutator import BaseMutator
|
||||
from .base_trainer import BaseTrainer
|
||||
from .fixed import apply_fixed_architecture
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .nasbench201 import NASBench201Cell
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from collections import OrderedDict
|
||||
import torch.nn as nn
|
||||
from nni.nas.pytorch.mutables import LayerChoice
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .darts_cell import DartsCell
|
||||
from .enas_cell import ENASMicroLayer
|
||||
from .enas_cell import ENASMacroLayer
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .pytorch import model_to_pytorch_script
|
||||
|
|
|
@ -3,8 +3,9 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Any
|
||||
from typing import Dict, List, Tuple, Any, cast
|
||||
|
||||
from nni.retiarii.operation import PyTorchOperation
|
||||
from nni.retiarii.operation_def.torch_op_def import ToDevice
|
||||
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
|
||||
from nni.common.device import Device, GPUDevice
|
||||
|
@ -34,7 +35,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
|
|||
if all(edge.tail_slot is None for edge in edges):
|
||||
return edges
|
||||
if all(isinstance(edge.tail_slot, int) for edge in edges):
|
||||
edges = sorted(edges, key=(lambda edge: edge.tail_slot))
|
||||
edges = sorted(edges, key=(lambda edge: cast(int, edge.tail_slot)))
|
||||
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
|
||||
return edges
|
||||
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
|
||||
|
@ -98,7 +99,7 @@ def _format_variable_name(name: str, graph_name: str) -> str:
|
|||
name = name.replace('/', '__')
|
||||
|
||||
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
|
||||
name = re.sub('\W|^(?=\d)','_', name)
|
||||
name = re.sub(r'\W|^(?=\d)','_', name)
|
||||
|
||||
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
|
||||
# name can't start with double underscore
|
||||
|
@ -130,7 +131,7 @@ def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
|
|||
return cuda_remapped_id
|
||||
|
||||
|
||||
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
|
||||
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> Tuple[set, str]:
|
||||
nodes = graph.topo_sort()
|
||||
|
||||
# handle module node and function node differently
|
||||
|
@ -144,11 +145,12 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
|
|||
for node in nodes:
|
||||
if node.operation:
|
||||
if placement and isinstance(node.operation, ToDevice):
|
||||
cuda_remapped_id = cast(dict, cuda_remapped_id)
|
||||
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])
|
||||
|
||||
if node.operation.type == 'shared':
|
||||
continue
|
||||
pkg_name = node.operation.get_import_pkg()
|
||||
pkg_name = cast(PyTorchOperation, node.operation).get_import_pkg()
|
||||
if pkg_name is not None:
|
||||
import_pkgs.add(pkg_name)
|
||||
|
||||
|
@ -157,6 +159,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
|
|||
if node_code is not None:
|
||||
if placement and node in placement and len(node_code) > 0:
|
||||
if isinstance(placement[node], GPUDevice):
|
||||
assert cuda_remapped_id is not None
|
||||
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
|
||||
else:
|
||||
device_repr = placement[node].device_repr()
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
# pylint: skip-file
|
||||
# type: ignore
|
||||
|
||||
"""
|
||||
FIXME
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .graph_gen import convert_to_graph
|
||||
|
|
|
@ -411,6 +411,7 @@ class GraphConverter:
|
|||
edge.graph = ir_graph
|
||||
if edge.head == method_ir_graph.input_node:
|
||||
# this is a member method, 'self' is the first argument, thus +1
|
||||
assert edge.head_slot is not None
|
||||
_input = node.inputsAt(edge.head_slot + 1)
|
||||
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
|
||||
edge.head = src_node
|
||||
|
@ -745,6 +746,7 @@ class GraphConverterWithShape(GraphConverter):
|
|||
if isinstance(submodule, LayerChoice):
|
||||
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
|
||||
lc_node = ir_model.get_node_by_name(full_name)
|
||||
assert lc_node is not None, f'Cannot find a node with name {full_name}'
|
||||
|
||||
for cand_name in submodule.names:
|
||||
cand = submodule[cand_name]
|
||||
|
@ -761,12 +763,14 @@ class GraphConverterWithShape(GraphConverter):
|
|||
return
|
||||
|
||||
graph_node = ir_model.get_node_by_name(graph.name)
|
||||
assert graph_node is not None, f'Cannot find a node with name {graph.name}'
|
||||
if not _without_shape_info(graph_node):
|
||||
return
|
||||
|
||||
if is_layerchoice_node(graph_node):
|
||||
cand_name = graph_node.operation.parameters['candidates'][0]
|
||||
cand_node = ir_model.get_node_by_name(cand_name)
|
||||
assert cand_node is not None, f'Cannot find a node with name {cand_name}'
|
||||
if _without_shape_info(cand_node):
|
||||
propagate_shape_for_graph(ir_model.graphs[cand_name])
|
||||
graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from ..operation import Cell
|
||||
from ..graph import Model, Graph, Node, Edge
|
||||
|
||||
|
@ -77,7 +81,7 @@ def _extract_info_from_trace_node(trace_node):
|
|||
return shape_parameters, None
|
||||
|
||||
|
||||
def is_layerchoice_node(ir_node: Node):
|
||||
def is_layerchoice_node(ir_node: Optional[Node]) -> TypeGuard[Node]:
|
||||
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
|
||||
return True
|
||||
else:
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# we will support tensorflow in future release
|
||||
|
||||
framework = 'pytorch'
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .lightning import *
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union, Type
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
|
@ -18,7 +18,7 @@ from .trainer import Trainer
|
|||
|
||||
@nni.trace
|
||||
class _MultiModelSupervisedLearningModule(LightningModule):
|
||||
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
|
||||
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, torchmetrics.Metric],
|
||||
n_models: int = 0,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
|
@ -171,7 +171,7 @@ class Classification(Lightning):
|
|||
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam,
|
||||
|
@ -184,7 +184,7 @@ class Classification(Lightning):
|
|||
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
|
||||
|
||||
class _RegressionModule(_MultiModelSupervisedLearningModule):
|
||||
def __init__(self, criterion: nn.Module = nn.MSELoss,
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam):
|
||||
|
|
|
@ -4,10 +4,11 @@
|
|||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, Optional, List, Callable
|
||||
from typing import Dict, Union, Optional, List, Callable, Type
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as nn_functional
|
||||
import torch.optim as optim
|
||||
import torchmetrics
|
||||
import torch.utils.data as torch_data
|
||||
|
@ -124,12 +125,12 @@ class Lightning(Evaluator):
|
|||
if other is None:
|
||||
return False
|
||||
if hasattr(self, "function") and hasattr(other, "function"):
|
||||
eq_func = (self.function == other.function)
|
||||
eq_func = getattr(self, "function") == getattr(other, "function")
|
||||
elif not (hasattr(self, "function") or hasattr(other, "function")):
|
||||
eq_func = True
|
||||
|
||||
if hasattr(self, "arguments") and hasattr(other, "arguments"):
|
||||
eq_args = (self.arguments == other.arguments)
|
||||
eq_args = getattr(self, "arguments") == getattr(other, "arguments")
|
||||
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
|
||||
eq_args = True
|
||||
|
||||
|
@ -159,10 +160,13 @@ def _check_dataloader(dataloader):
|
|||
### The following are some commonly used Lightning modules ###
|
||||
|
||||
class _SupervisedLearningModule(LightningModule):
|
||||
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
|
||||
|
||||
trainer: pl.Trainer
|
||||
|
||||
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetrics.Metric]],
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
export_onnx: Union[Path, str, bool, None] = None):
|
||||
super().__init__()
|
||||
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
|
||||
|
@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule):
|
|||
self.log('test_' + name, metric(y_hat, y), prog_bar=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
|
||||
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
nni.report_intermediate_result(self._get_validation_metrics())
|
||||
|
@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule):
|
|||
|
||||
class _AccuracyWithLogits(torchmetrics.Accuracy):
|
||||
def update(self, pred, target):
|
||||
return super().update(nn.functional.softmax(pred), target)
|
||||
return super().update(nn_functional.softmax(pred), target)
|
||||
|
||||
|
||||
@nni.trace
|
||||
class _ClassificationModule(_SupervisedLearningModule):
|
||||
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
export_onnx: bool = True):
|
||||
super().__init__(criterion, {'acc': _AccuracyWithLogits},
|
||||
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
|
||||
|
@ -275,10 +279,10 @@ class Classification(Lightning):
|
|||
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
train_dataloader: Optional[DataLoader] = None,
|
||||
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
|
||||
export_onnx: bool = True,
|
||||
|
@ -291,10 +295,10 @@ class Classification(Lightning):
|
|||
|
||||
@nni.trace
|
||||
class _RegressionModule(_SupervisedLearningModule):
|
||||
def __init__(self, criterion: nn.Module = nn.MSELoss,
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
export_onnx: bool = True):
|
||||
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
|
||||
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
|
||||
|
@ -328,10 +332,10 @@ class Regression(Lightning):
|
|||
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: nn.Module = nn.MSELoss,
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
train_dataloader: Optional[DataLoader] = None,
|
||||
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
|
||||
export_onnx: bool = True,
|
||||
|
|
|
@ -1 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .api import *
|
||||
|
|
|
@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
|
|||
@classmethod
|
||||
def pack_model_data(cls, model: Model) -> Any:
|
||||
mutation_summary = get_mutation_summary(model)
|
||||
assert model.evaluator is not None, 'Model evaluator can not be None'
|
||||
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
|
||||
|
||||
from ..graph import Model
|
||||
from ..integration_api import receive_trial_parameters
|
||||
|
@ -39,6 +42,9 @@ class BenchmarkGraphData:
|
|||
def load(data) -> 'BenchmarkGraphData':
|
||||
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"BenchmarkGraphData({self.mutation}, {self.benchmark}, {self.db_path})"
|
||||
|
||||
|
||||
class BenchmarkExecutionEngine(BaseExecutionEngine):
|
||||
"""
|
||||
|
@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
|
|||
@classmethod
|
||||
def trial_execute_graph(cls) -> None:
|
||||
graph_data = BenchmarkGraphData.load(receive_trial_parameters())
|
||||
assert graph_data.db_path is not None, f'Invalid graph data because db_path is None: {graph_data}'
|
||||
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
|
||||
final, intermediates = cls.query_in_benchmark(graph_data)
|
||||
|
||||
|
@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
|
|||
arch = t
|
||||
if arch is None:
|
||||
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
|
||||
print(arch)
|
||||
return _convert_to_final_and_intermediates(
|
||||
query_nb101_trial_stats(arch, 108, include_intermediates=True),
|
||||
'valid_acc'
|
||||
|
@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_
|
|||
benchmark_result = random.choice(benchmark_result)
|
||||
else:
|
||||
benchmark_result = benchmark_result[0]
|
||||
benchmark_result = cast(dict, benchmark_result)
|
||||
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict, Any, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
|
|||
@classmethod
|
||||
def pack_model_data(cls, model: Model) -> Any:
|
||||
mutation = get_mutation_dict(model)
|
||||
graph_data = PythonGraphData(model.python_class, model.python_init_params, mutation, model.evaluator)
|
||||
assert model.evaluator is not None, 'Model evaluator is not available.'
|
||||
graph_data = PythonGraphData(model.python_class, model.python_init_params or {}, mutation, model.evaluator)
|
||||
return graph_data
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Any, List
|
||||
from ..graph import Model
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from dataclasses import dataclass
|
|||
from pathlib import Path
|
||||
from subprocess import Popen
|
||||
from threading import Thread
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, List, Optional, Union, cast
|
||||
|
||||
import colorama
|
||||
import psutil
|
||||
|
@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest
|
|||
from nni.experiment.config import utils
|
||||
from nni.experiment.config.base import ConfigBase
|
||||
from nni.experiment.config.training_service import TrainingServiceConfig
|
||||
from nni.experiment.config.training_services import RemoteConfig
|
||||
from nni.experiment.pipe import Pipe
|
||||
from nni.tools.nnictl.command_utils import kill_command
|
||||
|
||||
|
@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment):
|
|||
Examples
|
||||
--------
|
||||
Multi-trial NAS:
|
||||
|
||||
>>> base_model = Net()
|
||||
>>> search_strategy = strategy.Random()
|
||||
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
|
||||
|
@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment):
|
|||
>>> exp.run(exp_config, 8081)
|
||||
|
||||
One-shot NAS:
|
||||
|
||||
>>> base_model = Net()
|
||||
>>> search_strategy = strategy.DARTS()
|
||||
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
|
||||
|
@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment):
|
|||
>>> exp.run(exp_config)
|
||||
|
||||
Export top models:
|
||||
|
||||
>>> for model_dict in exp.export_top_models(formatter='dict'):
|
||||
... print(model_dict)
|
||||
>>> with nni.retarii.fixed_arch(model_dict):
|
||||
... final_model = Net()
|
||||
"""
|
||||
|
||||
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = None,
|
||||
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None,
|
||||
trainer: BaseOneShotTrainer = None):
|
||||
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
|
||||
applied_mutators: List[Mutator] = cast(List[Mutator], None), strategy: BaseStrategy = cast(BaseStrategy, None),
|
||||
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
|
||||
if trainer is not None:
|
||||
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
|
||||
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
|
||||
|
@ -260,18 +264,19 @@ class RetiariiExperiment(Experiment):
|
|||
raise ValueError('Evaluator should not be none.')
|
||||
|
||||
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
|
||||
self.config: RetiariiExeConfig = None
|
||||
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
|
||||
self.port: Optional[int] = None
|
||||
|
||||
self.base_model = base_model
|
||||
self.evaluator: Evaluator = evaluator
|
||||
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
|
||||
self.applied_mutators = applied_mutators
|
||||
self.strategy = strategy
|
||||
|
||||
# FIXME: this is only a workaround
|
||||
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
|
||||
if not isinstance(strategy, OneShotStrategy):
|
||||
self._dispatcher = RetiariiAdvisor()
|
||||
else:
|
||||
self._dispatcher = cast(RetiariiAdvisor, None)
|
||||
self._dispatcher_thread: Optional[Thread] = None
|
||||
self._proc: Optional[Popen] = None
|
||||
self._pipe: Optional[Pipe] = None
|
||||
|
@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment):
|
|||
|
||||
assert self.config.training_service.platform == 'remote', \
|
||||
"CGO execution engine currently only supports remote training service"
|
||||
assert self.config.batch_waiting_time is not None
|
||||
assert self.config.batch_waiting_time is not None and self.config.max_concurrency_cgo is not None
|
||||
devices = self._construct_devices()
|
||||
engine = CGOExecutionEngine(devices,
|
||||
max_concurrency=self.config.max_concurrency_cgo,
|
||||
|
@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment):
|
|||
engine = PurePythonExecutionEngine()
|
||||
elif self.config.execution_engine == 'benchmark':
|
||||
from ..execution.benchmark import BenchmarkExecutionEngine
|
||||
assert self.config.benchmark is not None, '"benchmark" must be set when benchmark execution engine is used.'
|
||||
engine = BenchmarkExecutionEngine(self.config.benchmark)
|
||||
else:
|
||||
raise ValueError(f'Unsupported engine type: {self.config.execution_engine}')
|
||||
set_execution_engine(engine)
|
||||
|
||||
self.id = management.generate_experiment_id()
|
||||
|
@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment):
|
|||
def _construct_devices(self):
|
||||
devices = []
|
||||
if hasattr(self.config.training_service, 'machine_list'):
|
||||
for machine in self.config.training_service.machine_list:
|
||||
for machine in cast(RemoteConfig, self.config.training_service).machine_list:
|
||||
assert machine.gpu_indices is not None, \
|
||||
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
|
||||
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
|
||||
for gpu_idx in machine.gpu_indices:
|
||||
devices.append(GPUDevice(machine.host, gpu_idx))
|
||||
return devices
|
||||
|
@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment):
|
|||
def _create_dispatcher(self):
|
||||
return self._dispatcher
|
||||
|
||||
def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
|
||||
def run(self, config: Optional[RetiariiExeConfig] = None, port: int = 8080, debug: bool = False) -> None:
|
||||
"""
|
||||
Run the experiment.
|
||||
This function will block until experiment finish or error.
|
||||
|
@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment):
|
|||
This function will block until experiment finish or error.
|
||||
Return `True` when experiment done; or return `False` when experiment failed.
|
||||
"""
|
||||
assert self._proc is not None
|
||||
try:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment):
|
|||
_logger.warning('KeyboardInterrupt detected')
|
||||
finally:
|
||||
self.stop()
|
||||
raise RuntimeError('Check experiment status failed.')
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
|
@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment):
|
|||
if self._pipe is not None:
|
||||
self._pipe.close()
|
||||
|
||||
self.id = None
|
||||
self.port = None
|
||||
self.id = cast(str, None)
|
||||
self.port = cast(int, None)
|
||||
self._proc = None
|
||||
self._pipe = None
|
||||
self._dispatcher = None
|
||||
self._dispatcher = cast(RetiariiAdvisor, None)
|
||||
self._dispatcher_thread = None
|
||||
_logger.info('Experiment stopped')
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
|
|
@ -5,10 +5,16 @@
|
|||
Model representation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
|
||||
Optional, Set, Tuple, Type, Union, cast, overload)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .mutator import Mutator
|
||||
|
||||
from .operation import Cell, Operation, _IOPseudoOperation
|
||||
from .utils import uid
|
||||
|
@ -63,7 +69,7 @@ class Evaluator(abc.ABC):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _execute(self, model_cls: type) -> Any:
|
||||
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -203,7 +209,7 @@ class Model:
|
|||
matched_nodes.extend(nodes)
|
||||
return matched_nodes
|
||||
|
||||
def get_node_by_name(self, node_name: str) -> 'Node':
|
||||
def get_node_by_name(self, node_name: str) -> 'Node' | None:
|
||||
"""
|
||||
Traverse all the nodes to find the matched node with the given name.
|
||||
"""
|
||||
|
@ -217,7 +223,7 @@ class Model:
|
|||
else:
|
||||
return None
|
||||
|
||||
def get_node_by_python_name(self, python_name: str) -> 'Node':
|
||||
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
|
||||
"""
|
||||
Traverse all the nodes to find the matched node with the given python_name.
|
||||
"""
|
||||
|
@ -297,7 +303,7 @@ class Graph:
|
|||
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
|
||||
def __init__(self, model: Model, graph_id: int, name: str = cast(str, None), _internal: bool = False):
|
||||
assert _internal, '`Graph()` is private'
|
||||
|
||||
self.model: Model = model
|
||||
|
@ -338,9 +344,9 @@ class Graph:
|
|||
@overload
|
||||
def add_node(self, name: str, operation: Operation) -> 'Node': ...
|
||||
@overload
|
||||
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
|
||||
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
|
||||
|
||||
def add_node(self, name, operation_or_type, parameters=None):
|
||||
def add_node(self, name, operation_or_type, parameters=None): # type: ignore
|
||||
if isinstance(operation_or_type, Operation):
|
||||
op = operation_or_type
|
||||
else:
|
||||
|
@ -350,9 +356,10 @@ class Graph:
|
|||
@overload
|
||||
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
|
||||
@overload
|
||||
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
|
||||
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
|
||||
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
|
||||
|
||||
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node':
|
||||
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': # type: ignore
|
||||
if isinstance(operation_or_type, Operation):
|
||||
op = operation_or_type
|
||||
else:
|
||||
|
@ -405,7 +412,7 @@ class Graph:
|
|||
def get_nodes_by_name(self, name: str) -> List['Node']:
|
||||
return [node for node in self.hidden_nodes if node.name == name]
|
||||
|
||||
def get_nodes_by_python_name(self, python_name: str) -> Optional['Node']:
|
||||
def get_nodes_by_python_name(self, python_name: str) -> List['Node']:
|
||||
return [node for node in self.nodes if node.python_name == python_name]
|
||||
|
||||
def topo_sort(self) -> List['Node']:
|
||||
|
@ -594,7 +601,7 @@ class Node:
|
|||
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
|
||||
|
||||
@property
|
||||
def successor_slots(self) -> List[Tuple['Node', Union[int, None]]]:
|
||||
def successor_slots(self) -> Set[Tuple['Node', Union[int, None]]]:
|
||||
return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges)
|
||||
|
||||
@property
|
||||
|
@ -610,19 +617,19 @@ class Node:
|
|||
assert isinstance(self.operation, Cell)
|
||||
return self.graph.model.graphs[self.operation.parameters['cell']]
|
||||
|
||||
def update_label(self, label: str) -> None:
|
||||
def update_label(self, label: Optional[str]) -> None:
|
||||
self.label = label
|
||||
|
||||
@overload
|
||||
def update_operation(self, operation: Operation) -> None: ...
|
||||
@overload
|
||||
def update_operation(self, type_name: str, parameters: Dict[str, Any] = None) -> None: ...
|
||||
def update_operation(self, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> None: ...
|
||||
|
||||
def update_operation(self, operation_or_type, parameters=None):
|
||||
def update_operation(self, operation_or_type, parameters=None): # type: ignore
|
||||
if isinstance(operation_or_type, Operation):
|
||||
self.operation = operation_or_type
|
||||
else:
|
||||
self.operation = Operation.new(operation_or_type, parameters)
|
||||
self.operation = Operation.new(operation_or_type, cast(dict, parameters))
|
||||
|
||||
# mutation
|
||||
def remove(self) -> None:
|
||||
|
@ -663,7 +670,13 @@ class Node:
|
|||
return node
|
||||
|
||||
def _dump(self) -> Any:
|
||||
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters, 'attributes': self.operation.attributes}}
|
||||
ret: Dict[str, Any] = {
|
||||
'operation': {
|
||||
'type': self.operation.type,
|
||||
'parameters': self.operation.parameters,
|
||||
'attributes': self.operation.attributes
|
||||
}
|
||||
}
|
||||
if isinstance(self.operation, Cell):
|
||||
ret['operation']['cell_name'] = self.operation.cell_name
|
||||
if self.label is not None:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Tuple, Optional, Callable
|
||||
from typing import Tuple, Optional, Callable, cast
|
||||
|
||||
import nni.retiarii.nn.pytorch as nn
|
||||
from nni.retiarii import model_wrapper
|
||||
|
@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module):
|
|||
bn_momentum: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
self.widths = [
|
||||
self.widths = cast(nn.ChoiceOf[int], [
|
||||
nn.ValueChoice([make_divisible(base_width * mult, 8) for mult in width_multipliers], label=f'width_{i}')
|
||||
for i, base_width in enumerate(base_widths)
|
||||
]
|
||||
])
|
||||
self.expand_ratios = expand_ratios
|
||||
|
||||
blocks = [
|
||||
|
@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module):
|
|||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(self.widths[7], num_labels),
|
||||
nn.Linear(cast(int, self.widths[7]), num_labels),
|
||||
)
|
||||
|
||||
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from nni.retiarii import model_wrapper
|
||||
from nni.retiarii.nn.pytorch import NasBench101Cell
|
||||
|
@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell
|
|||
__all__ = ['NasBench101']
|
||||
|
||||
|
||||
def truncated_normal_(tensor, mean=0, std=1):
|
||||
def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
|
||||
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
|
||||
size = tensor.shape
|
||||
tmp = tensor.new_empty(size + (4,)).normal_()
|
||||
|
@ -117,9 +118,3 @@ class NasBench101(nn.Module):
|
|||
out = self.gap(out).view(bs, -1)
|
||||
out = self.classifier(out)
|
||||
return out
|
||||
|
||||
def reset_parameters(self):
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.BatchNorm2d):
|
||||
module.eps = self.config.bn_eps
|
||||
module.momentum = self.config.bn_momentum
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -176,8 +178,10 @@ class NasBench201(nn.Module):
|
|||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = NasBench201Cell({prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES},
|
||||
C_prev, C_curr, label='cell')
|
||||
ops: Dict[str, Callable[[int, int], nn.Module]] = {
|
||||
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES
|
||||
}
|
||||
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
|
||||
self.cells.append(cell)
|
||||
C_prev = C_curr
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str
|
|||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, List, Union, Iterable, Dict, Callable
|
||||
from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
|
@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module):
|
|||
See :class:`CellBuilder` on how to calculate those channel numbers.
|
||||
"""
|
||||
|
||||
def __init__(self, C_pprev: int, C_prev: int, C: int, last_cell_reduce: bool) -> None:
|
||||
def __init__(self, C_pprev: nn.MaybeChoice[int], C_prev: nn.MaybeChoice[int], C: nn.MaybeChoice[int], last_cell_reduce: bool) -> None:
|
||||
super().__init__()
|
||||
|
||||
if last_cell_reduce:
|
||||
self.pre0 = FactorizedReduce(C_pprev, C)
|
||||
self.pre0 = FactorizedReduce(cast(int, C_pprev), cast(int, C))
|
||||
else:
|
||||
self.pre0 = ReLUConvBN(C_pprev, C, 1, 1, 0)
|
||||
self.pre1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
self.pre0 = ReLUConvBN(cast(int, C_pprev), cast(int, C), 1, 1, 0)
|
||||
self.pre1 = ReLUConvBN(cast(int, C_prev), cast(int, C), 1, 1, 0)
|
||||
|
||||
def forward(self, cells):
|
||||
assert len(cells) == 2
|
||||
|
@ -283,15 +283,19 @@ class CellBuilder:
|
|||
Note that the builder is ephemeral, it can only be called once for every index.
|
||||
"""
|
||||
|
||||
def __init__(self, op_candidates: List[str], C_prev_in: int, C_in: int, C: int,
|
||||
num_nodes: int, merge_op: Literal['all', 'loose_end'],
|
||||
def __init__(self, op_candidates: List[str],
|
||||
C_prev_in: nn.MaybeChoice[int],
|
||||
C_in: nn.MaybeChoice[int],
|
||||
C: nn.MaybeChoice[int],
|
||||
num_nodes: int,
|
||||
merge_op: Literal['all', 'loose_end'],
|
||||
first_cell_reduce: bool, last_cell_reduce: bool):
|
||||
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
|
||||
self.C_in = C_in # This is the out channesl of last cell.
|
||||
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
|
||||
self.op_candidates = op_candidates
|
||||
self.num_nodes = num_nodes
|
||||
self.merge_op = merge_op
|
||||
self.merge_op: Literal['all', 'loose_end'] = merge_op
|
||||
self.first_cell_reduce = first_cell_reduce
|
||||
self.last_cell_reduce = last_cell_reduce
|
||||
self._expect_idx = 0
|
||||
|
@ -312,7 +316,7 @@ class CellBuilder:
|
|||
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
|
||||
preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce)
|
||||
|
||||
ops_factory: Dict[str, Callable[[int, int, int], nn.Module]] = {
|
||||
ops_factory: Dict[str, Callable[[int, int, Optional[int]], nn.Module]] = {
|
||||
op: # make final chosen ops named with their aliases
|
||||
lambda node_index, op_index, input_index:
|
||||
OPS[op](self.C, 2 if is_reduction_cell and (
|
||||
|
@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """
|
|||
|
||||
|
||||
class NDS(nn.Module):
|
||||
"""
|
||||
__doc__ = """
|
||||
The unified version of NASNet search space.
|
||||
|
||||
We follow the implementation in
|
||||
|
@ -378,8 +382,8 @@ class NDS(nn.Module):
|
|||
op_candidates: List[str],
|
||||
merge_op: Literal['all', 'loose_end'] = 'all',
|
||||
num_nodes_per_cell: int = 4,
|
||||
width: Union[Tuple[int], int] = 16,
|
||||
num_cells: Union[Tuple[int], int] = 20,
|
||||
width: Union[Tuple[int, ...], int] = 16,
|
||||
num_cells: Union[Tuple[int, ...], int] = 20,
|
||||
dataset: Literal['cifar', 'imagenet'] = 'imagenet',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__()
|
||||
|
@ -394,30 +398,31 @@ class NDS(nn.Module):
|
|||
else:
|
||||
C = width
|
||||
|
||||
self.num_cells: nn.MaybeChoice[int] = cast(int, num_cells)
|
||||
if isinstance(num_cells, Iterable):
|
||||
num_cells = nn.ValueChoice(list(num_cells), label='depth')
|
||||
num_cells_per_stage = [i * num_cells // 3 - (i - 1) * num_cells // 3 for i in range(3)]
|
||||
self.num_cells = nn.ValueChoice(list(num_cells), label='depth')
|
||||
num_cells_per_stage = [i * self.num_cells // 3 - (i - 1) * self.num_cells // 3 for i in range(3)]
|
||||
|
||||
# auxiliary head is different for network targetted at different datasets
|
||||
if dataset == 'imagenet':
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.Conv2d(3, cast(int, C // 2), kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(cast(int, C // 2)),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.Conv2d(cast(int, C // 2), cast(int, C), 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.Conv2d(cast(int, C), cast(int, C), 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
C_pprev = C_prev = C_curr = C
|
||||
last_cell_reduce = True
|
||||
elif dataset == 'cifar':
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, 3 * C, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(3 * C)
|
||||
nn.Conv2d(3, cast(int, 3 * C), 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(cast(int, 3 * C))
|
||||
)
|
||||
C_pprev = C_prev = 3 * C
|
||||
C_curr = C
|
||||
|
@ -439,7 +444,7 @@ class NDS(nn.Module):
|
|||
# C_pprev is output channel number of last second cell among all the cells already built.
|
||||
if len(stage) > 1:
|
||||
# Contains more than one cell
|
||||
C_pprev = len(stage[-2].output_node_indices) * C_curr
|
||||
C_pprev = len(cast(nn.Cell, stage[-2]).output_node_indices) * C_curr
|
||||
else:
|
||||
# Look up in the out channels of last stage.
|
||||
C_pprev = C_prev
|
||||
|
@ -447,7 +452,7 @@ class NDS(nn.Module):
|
|||
# This was originally,
|
||||
# C_prev = num_nodes_per_cell * C_curr.
|
||||
# but due to loose end, it becomes,
|
||||
C_prev = len(stage[-1].output_node_indices) * C_curr
|
||||
C_prev = len(cast(nn.Cell, stage[-1]).output_node_indices) * C_curr
|
||||
|
||||
# Useful in aligning the pprev and prev cell.
|
||||
last_cell_reduce = cell_builder.last_cell_reduce
|
||||
|
@ -457,11 +462,11 @@ class NDS(nn.Module):
|
|||
|
||||
if auxiliary_loss:
|
||||
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
|
||||
self.stages[2] = SequentialBreakdown(self.stages[2])
|
||||
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset)
|
||||
self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
|
||||
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.classifier = nn.Linear(C_prev, self.num_labels)
|
||||
self.classifier = nn.Linear(cast(int, C_prev), self.num_labels)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.dataset == 'imagenet':
|
||||
|
@ -483,7 +488,7 @@ class NDS(nn.Module):
|
|||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
if self.training and self.auxiliary_loss:
|
||||
return logits, logits_aux
|
||||
return logits, logits_aux # type: ignore
|
||||
else:
|
||||
return logits
|
||||
|
||||
|
@ -524,8 +529,8 @@ class NASNet(NDS):
|
|||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__(self.NASNET_OPS,
|
||||
|
@ -555,8 +560,8 @@ class ENAS(NDS):
|
|||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__(self.ENAS_OPS,
|
||||
|
@ -590,8 +595,8 @@ class AmoebaNet(NDS):
|
|||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
|
||||
|
@ -626,8 +631,8 @@ class PNAS(NDS):
|
|||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__(self.PNAS_OPS,
|
||||
|
@ -660,8 +665,8 @@ class DARTS(NDS):
|
|||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__(self.DARTS_OPS,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
from typing import Optional, Callable, List, Tuple
|
||||
from typing import Optional, Callable, List, Tuple, cast
|
||||
|
||||
import torch
|
||||
import nni.retiarii.nn.pytorch as nn
|
||||
|
@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
in_channels: nn.MaybeChoice[int],
|
||||
out_channels: nn.MaybeChoice[int],
|
||||
kernel_size: nn.MaybeChoice[int] = 3,
|
||||
stride: int = 1,
|
||||
groups: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
groups: nn.MaybeChoice[int] = 1,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
activation_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
dilation: int = 1,
|
||||
) -> None:
|
||||
|
@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential):
|
|||
if activation_layer is None:
|
||||
activation_layer = nn.ReLU6
|
||||
super().__init__(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups,
|
||||
bias=False),
|
||||
norm_layer(out_channels),
|
||||
nn.Conv2d(
|
||||
cast(int, in_channels),
|
||||
cast(int, out_channels),
|
||||
cast(int, kernel_size),
|
||||
stride,
|
||||
cast(int, padding),
|
||||
dilation=dilation,
|
||||
groups=cast(int, groups),
|
||||
bias=False
|
||||
),
|
||||
norm_layer(cast(int, out_channels)),
|
||||
activation_layer(inplace=True)
|
||||
)
|
||||
self.out_channels = out_channels
|
||||
|
@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
in_channels: nn.MaybeChoice[int],
|
||||
out_channels: nn.MaybeChoice[int],
|
||||
kernel_size: nn.MaybeChoice[int] = 3,
|
||||
stride: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
activation_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
|
@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: int,
|
||||
kernel_size: int = 3,
|
||||
in_channels: nn.MaybeChoice[int],
|
||||
out_channels: nn.MaybeChoice[int],
|
||||
expand_ratio: nn.MaybeChoice[float],
|
||||
kernel_size: nn.MaybeChoice[int] = 3,
|
||||
stride: int = 1,
|
||||
squeeze_and_excite: Optional[Callable[[int], nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
squeeze_and_excite: Optional[Callable[[nn.MaybeChoice[int]], nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
activation_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential):
|
|||
self.out_channels = out_channels
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_ch = nn.ValueChoice.to_int(round(in_channels * expand_ratio))
|
||||
hidden_ch = nn.ValueChoice.to_int(round(cast(int, in_channels * expand_ratio)))
|
||||
|
||||
# FIXME: check whether this equal works
|
||||
# Residual connection is added here stride = 1 and input channels and output channels are the same.
|
||||
|
@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module):
|
|||
|
||||
self.first_conv = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d)
|
||||
|
||||
blocks = [
|
||||
blocks: List[nn.Module] = [
|
||||
# first stage is fixed
|
||||
SeparableConv(widths[0], widths[1], kernel_size=3, stride=1)
|
||||
]
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import nni.retiarii.nn.pytorch as nn
|
||||
from nni.retiarii import model_wrapper
|
||||
|
@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module):
|
|||
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *,
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *,
|
||||
kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True):
|
||||
super().__init__()
|
||||
assert stride in [1, 2]
|
||||
|
@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module):
|
|||
def _decode_point_depth_conv(self, sequence):
|
||||
result = []
|
||||
first_depth = first_point = True
|
||||
pc = c = self.channels
|
||||
pc: int = self.channels
|
||||
c: int = self.channels
|
||||
for i, token in enumerate(sequence):
|
||||
# compute output channels of this conv
|
||||
if i + 1 == len(sequence):
|
||||
assert token == "p", "Last conv must be point-wise conv."
|
||||
c = self.oup_main
|
||||
elif token == "p" and first_point:
|
||||
c = self.mid_channels
|
||||
c = cast(int, self.mid_channels)
|
||||
if token == "d":
|
||||
# depth-wise conv
|
||||
if isinstance(pc, int) and isinstance(c, int):
|
||||
|
@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock):
|
|||
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *, stride: int, affine: bool = True):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *, stride: int, affine: bool = True):
|
||||
super().__init__(in_channels, out_channels, mid_channels,
|
||||
kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine)
|
||||
|
||||
|
@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module):
|
|||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.features = []
|
||||
feature_blocks = []
|
||||
|
||||
global_block_idx = 0
|
||||
for stage_idx, num_repeat in enumerate(self.stage_repeats):
|
||||
|
@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module):
|
|||
else:
|
||||
mid_channels = int(base_mid_channels)
|
||||
|
||||
mid_channels = cast(nn.MaybeChoice[int], mid_channels)
|
||||
|
||||
choice_block = nn.LayerChoice([
|
||||
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine),
|
||||
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine),
|
||||
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine),
|
||||
ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine)
|
||||
], label=f'layer_{global_block_idx}')
|
||||
self.features.append(choice_block)
|
||||
feature_blocks.append(choice_block)
|
||||
|
||||
self.features = nn.Sequential(*self.features)
|
||||
self.features = nn.Sequential(*feature_blocks)
|
||||
|
||||
# final layers
|
||||
last_conv_channels = self.stage_out_channels[-1]
|
||||
|
@ -226,12 +231,14 @@ class ShuffleNetSpace(nn.Module):
|
|||
torch.nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
torch.nn.init.constant_(m.bias, 0.0001)
|
||||
if m.running_mean is not None:
|
||||
torch.nn.init.constant_(m.running_mean, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
if m.weight is not None:
|
||||
torch.nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
torch.nn.init.constant_(m.bias, 0.0001)
|
||||
if m.running_mean is not None:
|
||||
torch.nn.init.constant_(m.running_mean, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
torch.nn.init.normal_(m.weight, 0, 0.01)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Useful type hints
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import nni
|
||||
from nni.common.serializer import PayloadTooLarge
|
||||
|
@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
|
|||
register_advisor(self) # register the current advisor as the "global only" advisor
|
||||
self.search_space = None
|
||||
|
||||
self.send_trial_callback: Callable[[dict], None] = None
|
||||
self.request_trial_jobs_callback: Callable[[int], None] = None
|
||||
self.trial_end_callback: Callable[[int, bool], None] = None
|
||||
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
|
||||
self.final_metric_callback: Callable[[int, MetricData], None] = None
|
||||
self.send_trial_callback: Optional[Callable[[dict], None]] = None
|
||||
self.request_trial_jobs_callback: Optional[Callable[[int], None]] = None
|
||||
self.trial_end_callback: Optional[Callable[[int, bool], None]] = None
|
||||
self.intermediate_metric_callback: Optional[Callable[[int, MetricData], None]] = None
|
||||
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
|
||||
|
||||
self.parameters_count = 0
|
||||
|
||||
|
@ -158,6 +158,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
|
|||
|
||||
def handle_trial_end(self, data):
|
||||
_logger.debug('Trial end: %s', data)
|
||||
if self.trial_end_callback is not None:
|
||||
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
|
||||
data['event'] == 'SUCCEEDED')
|
||||
|
||||
|
@ -166,9 +167,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
|
|||
if data['type'] == MetricType.REQUEST_PARAMETER:
|
||||
raise ValueError('Request parameter not supported')
|
||||
elif data['type'] == MetricType.PERIODICAL:
|
||||
if self.intermediate_metric_callback is not None:
|
||||
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
|
||||
self._process_value(data['value']))
|
||||
elif data['type'] == MetricType.FINAL:
|
||||
if self.final_metric_callback is not None:
|
||||
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
|
||||
self._process_value(data['value']))
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import (Any, Iterable, List, Optional, Tuple)
|
||||
import warnings
|
||||
from typing import (Any, Iterable, List, Optional, Tuple, cast)
|
||||
|
||||
from .graph import Model, Mutation, ModelStatus
|
||||
|
||||
|
@ -44,9 +45,11 @@ class Mutator:
|
|||
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None):
|
||||
def __init__(self, sampler: Optional[Sampler] = None, label: str = cast(str, None)):
|
||||
self.sampler: Optional[Sampler] = sampler
|
||||
self.label: Optional[str] = label
|
||||
if label is None:
|
||||
warnings.warn('Each mutator should have an explicit label. Mutator without label is deprecated.', DeprecationWarning)
|
||||
self.label: str = label
|
||||
self._cur_model: Optional[Model] = None
|
||||
self._cur_choice_idx: Optional[int] = None
|
||||
|
||||
|
|
|
@ -1,23 +1,40 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
import warnings
|
||||
from typing import Any, List, Union, Dict, Optional, Callable, Iterable, NoReturn, TypeVar, Sequence
|
||||
from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
|
||||
NoReturn, Optional, Sequence, SupportsRound, TypeVar,
|
||||
Union, cast)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.common.hpo_utils import ParameterSpec
|
||||
from nni.common.serializer import Translatable
|
||||
from nni.retiarii.serializer import basic_unit
|
||||
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace, NoContextError
|
||||
from nni.retiarii.utils import (STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace,
|
||||
NoContextError)
|
||||
|
||||
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
|
||||
|
||||
__all__ = [
|
||||
# APIs
|
||||
'LayerChoice',
|
||||
'InputChoice',
|
||||
'ValueChoice',
|
||||
'ModelParameterChoice',
|
||||
'Placeholder',
|
||||
|
||||
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'ModelParameterChoice', 'Placeholder', 'ChosenInputs']
|
||||
# Fixed module
|
||||
'ChosenInputs',
|
||||
|
||||
# Type utils
|
||||
'ReductionType',
|
||||
'MaybeChoice',
|
||||
'ChoiceOf',
|
||||
]
|
||||
|
||||
|
||||
class LayerChoice(Mutable):
|
||||
|
@ -130,26 +147,16 @@ class LayerChoice(Mutable):
|
|||
self.names.append(str(i))
|
||||
else:
|
||||
raise TypeError("Unsupported candidates type: {}".format(type(candidates)))
|
||||
self._first_module = self._modules[self.names[0]] # to make the dummy forward meaningful
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key()
|
||||
|
||||
@torch.jit.ignore
|
||||
def _key(self):
|
||||
warnings.warn('Using key to access the identifier of LayerChoice is deprecated. Please use label instead.',
|
||||
category=DeprecationWarning)
|
||||
return self._label
|
||||
self._first_module = cast(nn.Module, self._modules[self.names[0]]) # to make the dummy forward meaningful
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return self._label
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def __getitem__(self, idx: Union[int, str]) -> nn.Module:
|
||||
if isinstance(idx, str):
|
||||
return self._modules[idx]
|
||||
return list(self)[idx]
|
||||
return cast(nn.Module, self._modules[idx])
|
||||
return cast(nn.Module, list(self)[idx])
|
||||
|
||||
def __setitem__(self, idx, module):
|
||||
key = idx if isinstance(idx, str) else self.names[idx]
|
||||
|
@ -173,15 +180,6 @@ class LayerChoice(Mutable):
|
|||
def __iter__(self):
|
||||
return map(lambda name: self._modules[name], self.names)
|
||||
|
||||
@property
|
||||
def choices(self):
|
||||
return self._choices()
|
||||
|
||||
@torch.jit.ignore
|
||||
def _choices(self):
|
||||
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", category=DeprecationWarning)
|
||||
return list(self)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
The forward of layer choice is simply running the first candidate module.
|
||||
|
@ -266,16 +264,6 @@ class InputChoice(Mutable):
|
|||
assert self.reduction in ['mean', 'concat', 'sum', 'none']
|
||||
self._label = generate_new_label(label)
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key()
|
||||
|
||||
@torch.jit.ignore
|
||||
def _key(self):
|
||||
warnings.warn('Using key to access the identifier of InputChoice is deprecated. Please use label instead.',
|
||||
category=DeprecationWarning)
|
||||
return self._label
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return self._label
|
||||
|
@ -350,7 +338,7 @@ def _valuechoice_codegen(*, _internal: bool = False):
|
|||
'truediv': '//', 'floordiv': '/', 'mod': '%',
|
||||
'lshift': '<<', 'rshift': '>>',
|
||||
'and': '&', 'xor': '^', 'or': '|',
|
||||
# no reflection
|
||||
# no reverse
|
||||
'lt': '<', 'le': '<=', 'eq': '==',
|
||||
'ne': '!=', 'ge': '>=', 'gt': '>',
|
||||
# NOTE
|
||||
|
@ -358,14 +346,14 @@ def _valuechoice_codegen(*, _internal: bool = False):
|
|||
# Might support them in future when we actually need them.
|
||||
}
|
||||
|
||||
binary_template = """ def __{op}__(self, other: Any) -> 'ValueChoiceX':
|
||||
binary_template = """ def __{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [self, other])"""
|
||||
|
||||
binary_r_template = """ def __r{op}__(self, other: Any) -> 'ValueChoiceX':
|
||||
binary_r_template = """ def __r{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [other, self])"""
|
||||
|
||||
unary_template = """ def __{op}__(self) -> 'ValueChoiceX':
|
||||
return ValueChoiceX(operator.{op}, '{sym}{{}}', [self])"""
|
||||
unary_template = """ def __{op}__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
|
||||
return cast(ChoiceOf[_value], ValueChoiceX(operator.{op}, '{sym}{{}}', [self]))"""
|
||||
|
||||
for op, sym in MAPPING.items():
|
||||
if op in ['neg', 'pos', 'invert']:
|
||||
|
@ -377,7 +365,13 @@ def _valuechoice_codegen(*, _internal: bool = False):
|
|||
print(binary_r_template.format(op=op, opt=opt, sym=sym) + '\n')
|
||||
|
||||
|
||||
def _valuechoice_staticmethod_helper(orig_func):
|
||||
_func = TypeVar('_func')
|
||||
_cand = TypeVar('_cand')
|
||||
_value = TypeVar('_value')
|
||||
|
||||
|
||||
def _valuechoice_staticmethod_helper(orig_func: _func) -> _func:
|
||||
if orig_func.__doc__ is not None:
|
||||
orig_func.__doc__ += """
|
||||
Notes
|
||||
-----
|
||||
|
@ -388,7 +382,7 @@ def _valuechoice_staticmethod_helper(orig_func):
|
|||
return orig_func
|
||||
|
||||
|
||||
class ValueChoiceX(Translatable, nn.Module):
|
||||
class ValueChoiceX(Generic[_cand], Translatable, nn.Module):
|
||||
"""Internal API. Implementation note:
|
||||
|
||||
The transformed (X) version of value choice.
|
||||
|
@ -408,7 +402,10 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
This class is implemented as a ``nn.Module`` so that it can be scanned by python engine / torchscript.
|
||||
"""
|
||||
|
||||
def __init__(self, function: Callable[..., Any], repr_template: str, arguments: List[Any], dry_run: bool = True):
|
||||
def __init__(self, function: Callable[..., _cand] = cast(Callable[..., _cand], None),
|
||||
repr_template: str = cast(str, None),
|
||||
arguments: List[Any] = cast('List[MaybeChoice[_cand]]', None),
|
||||
dry_run: bool = True):
|
||||
super().__init__()
|
||||
|
||||
if function is None:
|
||||
|
@ -431,7 +428,7 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
|
||||
def inner_choices(self) -> Iterable['ValueChoice']:
|
||||
"""
|
||||
Return an iterable of all leaf value choices.
|
||||
Return a generator of all leaf value choices.
|
||||
Useful for composition of value choices.
|
||||
No deduplication on labels. Mutators should take care.
|
||||
"""
|
||||
|
@ -439,18 +436,18 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
if isinstance(arg, ValueChoiceX):
|
||||
yield from arg.inner_choices()
|
||||
|
||||
def dry_run(self) -> Any:
|
||||
def dry_run(self) -> _cand:
|
||||
"""
|
||||
Dry run the value choice to get one of its possible evaluation results.
|
||||
"""
|
||||
# values are not used
|
||||
return self._evaluate(iter([]), True)
|
||||
|
||||
def all_options(self) -> Iterable[Any]:
|
||||
def all_options(self) -> Iterable[_cand]:
|
||||
"""Explore all possibilities of a value choice.
|
||||
"""
|
||||
# Record all inner choices: label -> candidates, no duplicates.
|
||||
dedup_inner_choices: Dict[str, List[Any]] = {}
|
||||
dedup_inner_choices: Dict[str, List[_cand]] = {}
|
||||
# All labels of leaf nodes on tree, possibly duplicates.
|
||||
all_labels: List[str] = []
|
||||
|
||||
|
@ -470,14 +467,14 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
chosen = dict(zip(dedup_labels, chosen))
|
||||
yield self.evaluate([chosen[label] for label in all_labels])
|
||||
|
||||
def evaluate(self, values: Iterable[Any]) -> Any:
|
||||
def evaluate(self, values: Iterable[_cand]) -> _cand:
|
||||
"""
|
||||
Evaluate the result of this group.
|
||||
``values`` should in the same order of ``inner_choices()``.
|
||||
"""
|
||||
return self._evaluate(iter(values), False)
|
||||
|
||||
def _evaluate(self, values: Iterable[Any], dry_run: bool = False) -> Any:
|
||||
def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
|
||||
# "values" iterates in the recursion
|
||||
eval_args = []
|
||||
for arg in self.arguments:
|
||||
|
@ -497,7 +494,7 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
"""
|
||||
return self.dry_run()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
reprs = []
|
||||
for arg in self.arguments:
|
||||
if isinstance(arg, ValueChoiceX) and not isinstance(arg, ValueChoice):
|
||||
|
@ -513,7 +510,7 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
# Special operators that can be useful in place of built-in conditional operators.
|
||||
@staticmethod
|
||||
@_valuechoice_staticmethod_helper
|
||||
def to_int(obj: 'ValueChoiceOrAny') -> Union['ValueChoiceX', int]:
|
||||
def to_int(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[int]':
|
||||
"""
|
||||
Convert a ``ValueChoice`` to an integer.
|
||||
"""
|
||||
|
@ -523,7 +520,7 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
|
||||
@staticmethod
|
||||
@_valuechoice_staticmethod_helper
|
||||
def to_float(obj: 'ValueChoiceOrAny') -> Union['ValueChoiceX', float]:
|
||||
def to_float(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[float]':
|
||||
"""
|
||||
Convert a ``ValueChoice`` to a float.
|
||||
"""
|
||||
|
@ -533,9 +530,9 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
|
||||
@staticmethod
|
||||
@_valuechoice_staticmethod_helper
|
||||
def condition(pred: 'ValueChoiceOrAny',
|
||||
true: 'ValueChoiceOrAny',
|
||||
false: 'ValueChoiceOrAny') -> 'ValueChoiceOrAny':
|
||||
def condition(pred: 'MaybeChoice[bool]',
|
||||
true: 'MaybeChoice[_value]',
|
||||
false: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
|
||||
"""
|
||||
Return ``true`` if the predicate ``pred`` is true else ``false``.
|
||||
|
||||
|
@ -549,35 +546,39 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
|
||||
@staticmethod
|
||||
@_valuechoice_staticmethod_helper
|
||||
def max(arg0: Union[Iterable['ValueChoiceOrAny'], 'ValueChoiceOrAny'],
|
||||
*args: List['ValueChoiceOrAny']) -> 'ValueChoiceOrAny':
|
||||
def max(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
|
||||
*args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
|
||||
"""
|
||||
Returns the maximum value from a list of value choices.
|
||||
The usage should be similar to Python's built-in value choices,
|
||||
where the parameters could be an iterable, or at least two arguments.
|
||||
"""
|
||||
if not args:
|
||||
return ValueChoiceX.max(*list(arg0))
|
||||
lst = [arg0] + list(args)
|
||||
if not isinstance(arg0, Iterable):
|
||||
raise TypeError('Expect more than one items to compare max')
|
||||
return cast(MaybeChoice[_value], ValueChoiceX.max(*list(arg0)))
|
||||
lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
|
||||
if any(isinstance(obj, ValueChoiceX) for obj in lst):
|
||||
return ValueChoiceX(max, 'max({})', lst)
|
||||
return max(lst)
|
||||
return max(cast(Any, lst))
|
||||
|
||||
@staticmethod
|
||||
@_valuechoice_staticmethod_helper
|
||||
def min(arg0: Union[Iterable['ValueChoiceOrAny'], 'ValueChoiceOrAny'],
|
||||
*args: List['ValueChoiceOrAny']) -> 'ValueChoiceOrAny':
|
||||
def min(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
|
||||
*args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
|
||||
"""
|
||||
Returns the minunum value from a list of value choices.
|
||||
The usage should be similar to Python's built-in value choices,
|
||||
where the parameters could be an iterable, or at least two arguments.
|
||||
"""
|
||||
if not args:
|
||||
return ValueChoiceX.min(*list(arg0))
|
||||
lst = [arg0] + list(args)
|
||||
if not isinstance(arg0, Iterable):
|
||||
raise TypeError('Expect more than one items to compare min')
|
||||
return cast(MaybeChoice[_value], ValueChoiceX.min(*list(arg0)))
|
||||
lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
|
||||
if any(isinstance(obj, ValueChoiceX) for obj in lst):
|
||||
return ValueChoiceX(min, 'min({})', lst)
|
||||
return min(lst)
|
||||
return min(cast(Any, lst))
|
||||
|
||||
def __hash__(self):
|
||||
# this is required because we have implemented ``__eq__``
|
||||
|
@ -589,24 +590,25 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
# - Implementation effort is too huge.
|
||||
# As a result, inplace operators like +=, *=, magic methods like `__getattr__` are not included in this list.
|
||||
|
||||
def __getitem__(self, key: Any) -> 'ValueChoiceX':
|
||||
def __getitem__(self: 'ChoiceOf[Any]', key: Any) -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(lambda x, y: x[y], '{}[{}]', [self, key])
|
||||
|
||||
# region implement int, float, round, trunc, floor, ceil
|
||||
# because I believe sometimes we need them to calculate #channels
|
||||
# `__int__` and `__float__` are not supported because `__int__` is required to return int.
|
||||
def __round__(self, ndigits: Optional[Any] = None) -> 'ValueChoiceX':
|
||||
def __round__(self: 'ChoiceOf[SupportsRound[_value]]',
|
||||
ndigits: Optional['MaybeChoice[int]'] = None) -> 'ChoiceOf[Union[int, SupportsRound[_value]]]':
|
||||
if ndigits is not None:
|
||||
return ValueChoiceX(round, 'round({}, {})', [self, ndigits])
|
||||
return ValueChoiceX(round, 'round({})', [self])
|
||||
return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({}, {})', [self, ndigits]))
|
||||
return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({})', [self]))
|
||||
|
||||
def __trunc__(self) -> 'ValueChoiceX':
|
||||
def __trunc__(self) -> NoReturn:
|
||||
raise RuntimeError("Try to use `ValueChoice.to_int()` instead of `math.trunc()` on value choices.")
|
||||
|
||||
def __floor__(self) -> 'ValueChoiceX':
|
||||
def __floor__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
|
||||
return ValueChoiceX(math.floor, 'math.floor({})', [self])
|
||||
|
||||
def __ceil__(self) -> 'ValueChoiceX':
|
||||
def __ceil__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
|
||||
return ValueChoiceX(math.ceil, 'math.ceil({})', [self])
|
||||
|
||||
def __index__(self) -> NoReturn:
|
||||
|
@ -622,132 +624,133 @@ class ValueChoiceX(Translatable, nn.Module):
|
|||
|
||||
# region the following code is generated with codegen (see above)
|
||||
# Annotated with "region" because I want to collapse them in vscode
|
||||
def __neg__(self) -> 'ValueChoiceX':
|
||||
return ValueChoiceX(operator.neg, '-{}', [self])
|
||||
def __neg__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
|
||||
return cast(ChoiceOf[_value], ValueChoiceX(operator.neg, '-{}', [self]))
|
||||
|
||||
def __pos__(self) -> 'ValueChoiceX':
|
||||
return ValueChoiceX(operator.pos, '+{}', [self])
|
||||
def __pos__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
|
||||
return cast(ChoiceOf[_value], ValueChoiceX(operator.pos, '+{}', [self]))
|
||||
|
||||
def __invert__(self) -> 'ValueChoiceX':
|
||||
return ValueChoiceX(operator.invert, '~{}', [self])
|
||||
def __invert__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
|
||||
return cast(ChoiceOf[_value], ValueChoiceX(operator.invert, '~{}', [self]))
|
||||
|
||||
def __add__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __add__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.add, '{} + {}', [self, other])
|
||||
|
||||
def __radd__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __radd__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.add, '{} + {}', [other, self])
|
||||
|
||||
def __sub__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __sub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.sub, '{} - {}', [self, other])
|
||||
|
||||
def __rsub__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rsub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.sub, '{} - {}', [other, self])
|
||||
|
||||
def __mul__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __mul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.mul, '{} * {}', [self, other])
|
||||
|
||||
def __rmul__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.mul, '{} * {}', [other, self])
|
||||
|
||||
def __matmul__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __matmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.matmul, '{} @ {}', [self, other])
|
||||
|
||||
def __rmatmul__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rmatmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.matmul, '{} @ {}', [other, self])
|
||||
|
||||
def __truediv__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __truediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.truediv, '{} // {}', [self, other])
|
||||
|
||||
def __rtruediv__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rtruediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.truediv, '{} // {}', [other, self])
|
||||
|
||||
def __floordiv__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __floordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.floordiv, '{} / {}', [self, other])
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rfloordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.floordiv, '{} / {}', [other, self])
|
||||
|
||||
def __mod__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __mod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.mod, '{} % {}', [self, other])
|
||||
|
||||
def __rmod__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.mod, '{} % {}', [other, self])
|
||||
|
||||
def __lshift__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __lshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.lshift, '{} << {}', [self, other])
|
||||
|
||||
def __rlshift__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rlshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.lshift, '{} << {}', [other, self])
|
||||
|
||||
def __rshift__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.rshift, '{} >> {}', [self, other])
|
||||
|
||||
def __rrshift__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rrshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.rshift, '{} >> {}', [other, self])
|
||||
|
||||
def __and__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __and__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.and_, '{} & {}', [self, other])
|
||||
|
||||
def __rand__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rand__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.and_, '{} & {}', [other, self])
|
||||
|
||||
def __xor__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __xor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.xor, '{} ^ {}', [self, other])
|
||||
|
||||
def __rxor__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rxor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.xor, '{} ^ {}', [other, self])
|
||||
|
||||
def __or__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __or__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.or_, '{} | {}', [self, other])
|
||||
|
||||
def __ror__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __ror__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.or_, '{} | {}', [other, self])
|
||||
|
||||
def __lt__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __lt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.lt, '{} < {}', [self, other])
|
||||
|
||||
def __le__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __le__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.le, '{} <= {}', [self, other])
|
||||
|
||||
def __eq__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __eq__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.eq, '{} == {}', [self, other])
|
||||
|
||||
def __ne__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __ne__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.ne, '{} != {}', [self, other])
|
||||
|
||||
def __ge__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __ge__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.ge, '{} >= {}', [self, other])
|
||||
|
||||
def __gt__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __gt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(operator.gt, '{} > {}', [self, other])
|
||||
# endregion
|
||||
|
||||
# __pow__, __divmod__, __abs__ are special ones.
|
||||
# Not easy to cover those cases with codegen.
|
||||
def __pow__(self, other: Any, modulo: Optional[Any] = None) -> 'ValueChoiceX':
|
||||
def __pow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
|
||||
if modulo is not None:
|
||||
return ValueChoiceX(pow, 'pow({}, {}, {})', [self, other, modulo])
|
||||
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [self, other])
|
||||
|
||||
def __rpow__(self, other: Any, modulo: Optional[Any] = None) -> 'ValueChoiceX':
|
||||
def __rpow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
|
||||
if modulo is not None:
|
||||
return ValueChoiceX(pow, 'pow({}, {}, {})', [other, self, modulo])
|
||||
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [other, self])
|
||||
|
||||
def __divmod__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __divmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(divmod, 'divmod({}, {})', [self, other])
|
||||
|
||||
def __rdivmod__(self, other: Any) -> 'ValueChoiceX':
|
||||
def __rdivmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(divmod, 'divmod({}, {})', [other, self])
|
||||
|
||||
def __abs__(self) -> 'ValueChoiceX':
|
||||
def __abs__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[Any]':
|
||||
return ValueChoiceX(abs, 'abs({})', [self])
|
||||
|
||||
|
||||
ValueChoiceOrAny = TypeVar('ValueChoiceOrAny', ValueChoiceX, Any)
|
||||
ChoiceOf = ValueChoiceX
|
||||
MaybeChoice = Union[ValueChoiceX[_cand], _cand]
|
||||
|
||||
|
||||
class ValueChoice(ValueChoiceX, Mutable):
|
||||
class ValueChoice(ValueChoiceX[_cand], Mutable):
|
||||
"""
|
||||
ValueChoice is to choose one from ``candidates``. The most common use cases are:
|
||||
|
||||
|
@ -865,14 +868,14 @@ class ValueChoice(ValueChoiceX, Mutable):
|
|||
# FIXME: prior is designed but not supported yet
|
||||
|
||||
@classmethod
|
||||
def create_fixed_module(cls, candidates: List[Any], *, label: Optional[str] = None, **kwargs):
|
||||
def create_fixed_module(cls, candidates: List[_cand], *, label: Optional[str] = None, **kwargs):
|
||||
value = get_fixed_value(label)
|
||||
if value not in candidates:
|
||||
raise ValueError(f'Value {value} does not belong to the candidates: {candidates}.')
|
||||
return value
|
||||
|
||||
def __init__(self, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
|
||||
super().__init__(None, None, None)
|
||||
def __init__(self, candidates: List[_cand], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.candidates = candidates
|
||||
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
|
||||
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
|
||||
|
@ -894,10 +897,10 @@ class ValueChoice(ValueChoiceX, Mutable):
|
|||
# yield self because self is the only value choice here
|
||||
yield self
|
||||
|
||||
def dry_run(self) -> Any:
|
||||
def dry_run(self) -> _cand:
|
||||
return self.candidates[0]
|
||||
|
||||
def _evaluate(self, values: Iterable[Any], dry_run: bool = False) -> Any:
|
||||
def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
|
||||
if dry_run:
|
||||
return self.candidates[0]
|
||||
try:
|
||||
|
@ -986,6 +989,7 @@ class ModelParameterChoice:
|
|||
Examples
|
||||
--------
|
||||
Get a dynamic-shaped parameter. Because ``torch.zeros`` is not a basic unit, we can't use :class:`ValueChoice` on it.
|
||||
|
||||
>>> parameter_dim = nn.ModelParameterChoice([64, 128, 256])
|
||||
>>> self.token = nn.Parameter(torch.zeros(1, parameter_dim, 32, 32))
|
||||
"""
|
||||
|
@ -1016,12 +1020,14 @@ class ModelParameterChoice:
|
|||
if default not in candidates:
|
||||
# could be callable
|
||||
try:
|
||||
default = default(candidates)
|
||||
default = cast(Callable[[List[ValueType]], ValueType], default)(candidates)
|
||||
except TypeError as e:
|
||||
if 'not callable' in str(e):
|
||||
raise TypeError("`default` is not in `candidates`, and it's also not callable.")
|
||||
raise
|
||||
|
||||
default = cast(ValueType, default)
|
||||
|
||||
label = generate_new_label(label)
|
||||
parameter_spec = ParameterSpec(
|
||||
label, # name
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from typing import Callable, Dict, List, Union, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Union, Optional, Tuple, Sequence, cast
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
|
@ -193,8 +196,10 @@ class Cell(nn.Module):
|
|||
def __init__(self,
|
||||
op_candidates: Union[
|
||||
Callable[[], List[nn.Module]],
|
||||
List[Union[nn.Module, _cell_op_factory_type]],
|
||||
Dict[str, Union[nn.Module, _cell_op_factory_type]]
|
||||
List[nn.Module],
|
||||
List[_cell_op_factory_type],
|
||||
Dict[str, nn.Module],
|
||||
Dict[str, _cell_op_factory_type]
|
||||
],
|
||||
num_nodes: int,
|
||||
num_ops_per_node: int = 1,
|
||||
|
@ -251,8 +256,8 @@ class Cell(nn.Module):
|
|||
ops = self._convert_op_candidates(op_candidates, i, k, chosen)
|
||||
|
||||
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
|
||||
self.ops[-1].append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
|
||||
self.inputs[-1].append(inp)
|
||||
cast(ModuleList, self.ops[-1]).append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
|
||||
cast(ModuleList, self.inputs[-1]).append(inp)
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
|
@ -274,13 +279,17 @@ class Cell(nn.Module):
|
|||
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
|
||||
of some of (possibly all) the nodes' outputs in the cell.
|
||||
"""
|
||||
processed_inputs: List[torch.Tensor]
|
||||
if len(inputs) == 1 and isinstance(inputs[0], list):
|
||||
inputs = inputs[0]
|
||||
processed_inputs = list(inputs[0]) # shallow copy
|
||||
else:
|
||||
inputs = list(inputs)
|
||||
assert len(inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
|
||||
states = self.preprocessor(inputs)
|
||||
for ops, inps in zip(self.ops, self.inputs):
|
||||
processed_inputs = cast(List[torch.Tensor], list(inputs))
|
||||
assert len(processed_inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
|
||||
states: List[torch.Tensor] = self.preprocessor(processed_inputs)
|
||||
for ops, inps in zip(
|
||||
cast(Sequence[Sequence[LayerChoice]], self.ops),
|
||||
cast(Sequence[Sequence[InputChoice]], self.inputs)
|
||||
):
|
||||
current_state = []
|
||||
for op, inp in zip(ops, inps):
|
||||
current_state.append(op(inp(states)))
|
||||
|
@ -291,7 +300,7 @@ class Cell(nn.Module):
|
|||
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
|
||||
else:
|
||||
this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim)
|
||||
return self.postprocessor(this_cell, inputs)
|
||||
return self.postprocessor(this_cell, processed_inputs)
|
||||
|
||||
@staticmethod
|
||||
def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]:
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, List, Union, Tuple, Optional
|
||||
from typing import Callable, List, Dict, Union, Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
|
||||
|
||||
from .api import LayerChoice, ValueChoice, ValueChoiceX
|
||||
from .api import LayerChoice, ValueChoice, ValueChoiceX, ChoiceOf
|
||||
from .cell import Cell
|
||||
from .nasbench101 import NasBench101Cell, NasBench101Mutator
|
||||
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
|
||||
|
@ -64,7 +67,7 @@ class Repeat(Mutable):
|
|||
List[Callable[[int], nn.Module]],
|
||||
nn.Module,
|
||||
List[nn.Module]],
|
||||
depth: Union[int, Tuple[int, int], ValueChoice], *, label: Optional[str] = None):
|
||||
depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
|
||||
if isinstance(depth, tuple):
|
||||
# we can't create a value choice here,
|
||||
# otherwise we will have two value choices, one created here, another in init.
|
||||
|
@ -90,7 +93,7 @@ class Repeat(Mutable):
|
|||
List[Callable[[int], nn.Module]],
|
||||
nn.Module,
|
||||
List[nn.Module]],
|
||||
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
|
||||
depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
|
||||
super().__init__()
|
||||
|
||||
self._label = None # by default, no label
|
||||
|
@ -192,7 +195,7 @@ class NasBench201Cell(nn.Module):
|
|||
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
|
||||
return OrderedDict(x)
|
||||
|
||||
def __init__(self, op_candidates: List[Callable[[int, int], nn.Module]],
|
||||
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
|
||||
in_features: int, out_features: int, num_tensors: int = 4,
|
||||
label: Optional[str] = None):
|
||||
super().__init__()
|
||||
|
@ -214,16 +217,15 @@ class NasBench201Cell(nn.Module):
|
|||
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
|
||||
self.layers.append(node_ops)
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The forward of input choice is simply selecting first on all choices.
|
||||
It shouldn't be called directly by users in most cases.
|
||||
"""
|
||||
tensors = [inputs]
|
||||
tensors: List[torch.Tensor] = [inputs]
|
||||
for layer in self.layers:
|
||||
current_tensor = []
|
||||
for i, op in enumerate(layer):
|
||||
current_tensor.append(op(tensors[i]))
|
||||
current_tensor = torch.sum(torch.stack(current_tensor), 0)
|
||||
tensors.append(current_tensor)
|
||||
current_tensor: List[torch.Tensor] = []
|
||||
for i, op in enumerate(layer): # type: ignore
|
||||
current_tensor.append(op(tensors[i])) # type: ignore
|
||||
tensors.append(torch.sum(torch.stack(current_tensor), 0))
|
||||
return tensors[-1]
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -233,7 +235,7 @@ class AutoActivation(nn.Module):
|
|||
-----
|
||||
Current `beta` is not per-channel parameter.
|
||||
"""
|
||||
def __init__(self, unit_num: int = 1, label: str = None):
|
||||
def __init__(self, unit_num: int = 1, label: str | None = None):
|
||||
super().__init__()
|
||||
self._label = generate_new_label(label)
|
||||
self.unaries = nn.ModuleList()
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
@ -41,7 +44,7 @@ def generate_new_label(label: Optional[str]):
|
|||
return label
|
||||
|
||||
|
||||
def get_fixed_value(label: str) -> Any:
|
||||
def get_fixed_value(label: Optional[str]) -> Any:
|
||||
ret = get_current_context('fixed')
|
||||
try:
|
||||
return ret[generate_new_label(label)]
|
||||
|
@ -49,7 +52,7 @@ def get_fixed_value(label: str) -> Any:
|
|||
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
|
||||
|
||||
|
||||
def get_fixed_dict(label_prefix: str) -> Tuple[str, Any]:
|
||||
def get_fixed_dict(label_prefix: Optional[str]) -> Tuple[str, Any]:
|
||||
ret = get_current_context('fixed')
|
||||
try:
|
||||
label_prefix = generate_new_label(label_prefix)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import inspect
|
||||
from typing import Any, List, Optional, Tuple, Dict, Iterator
|
||||
from typing import Any, List, Optional, Tuple, Dict, Iterator, Iterable, cast
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -28,12 +28,14 @@ class LayerChoiceMutator(Mutator):
|
|||
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
|
||||
# We add the connections here in the mutation logic.
|
||||
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
|
||||
target = model.graphs[node.operation.cell_name]
|
||||
target = model.graphs[cast(Cell, node.operation).cell_name]
|
||||
chosen_node = target.get_node_by_name(chosen)
|
||||
assert chosen_node is not None
|
||||
target.add_edge((target.input_node, 0), (chosen_node, None))
|
||||
target.add_edge((chosen_node, None), (target.output_node, None))
|
||||
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
|
||||
operation = cast(Cell, node.operation)
|
||||
target_node = cast(Node, model.get_node_by_name(node.name))
|
||||
target_node.update_operation(Cell(operation.cell_name))
|
||||
|
||||
# remove redundant nodes
|
||||
for rm_node in list(target.hidden_nodes): # remove from a list on the fly will cause issues
|
||||
|
@ -57,7 +59,7 @@ class InputChoiceMutator(Mutator):
|
|||
else:
|
||||
chosen = [self.choice(candidates) for _ in range(n_chosen)]
|
||||
for node in self.nodes:
|
||||
target = model.get_node_by_name(node.name)
|
||||
target = cast(Node, model.get_node_by_name(node.name))
|
||||
target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs',
|
||||
{'chosen': chosen, 'reduction': node.operation.parameters['reduction']})
|
||||
|
||||
|
@ -74,7 +76,7 @@ class ValueChoiceMutator(Mutator):
|
|||
# no need to support transformation here,
|
||||
# because it is naturally done in forward loop
|
||||
for node in self.nodes:
|
||||
target = model.get_node_by_name(node.name)
|
||||
target = cast(Node, model.get_node_by_name(node.name))
|
||||
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
|
||||
|
||||
|
||||
|
@ -86,7 +88,7 @@ class ParameterChoiceLeafMutator(Mutator):
|
|||
super().__init__(label=label)
|
||||
self.candidates = candidates
|
||||
|
||||
def mutate(self, model: Model) -> Model:
|
||||
def mutate(self, model: Model) -> None:
|
||||
# leave a record here
|
||||
# real mutations will be done in ParameterChoiceMutator
|
||||
self.choice(self.candidates)
|
||||
|
@ -103,7 +105,7 @@ class ParameterChoiceMutator(Mutator):
|
|||
|
||||
self.nodes = nodes
|
||||
|
||||
def mutate(self, model: Model) -> Model:
|
||||
def mutate(self, model: Model) -> None:
|
||||
# looks like {"label1": "cat", "label2": 123}
|
||||
value_choice_decisions = {}
|
||||
for mutation in model.history:
|
||||
|
@ -122,7 +124,7 @@ class ParameterChoiceMutator(Mutator):
|
|||
result_value = value_choice.evaluate(leaf_node_values)
|
||||
|
||||
# update model with graph mutation primitives
|
||||
target = model.get_node_by_name(node.name)
|
||||
target = cast(Node, model.get_node_by_name(node.name))
|
||||
target.update_operation(target.operation.type, {**target.operation.parameters, argname: result_value})
|
||||
|
||||
|
||||
|
@ -138,20 +140,20 @@ class RepeatMutator(Mutator):
|
|||
while u != graph.output_node:
|
||||
if u != graph.input_node:
|
||||
chain.append(u)
|
||||
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.'
|
||||
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successors}.'
|
||||
u = u.successors[0]
|
||||
return chain
|
||||
|
||||
def mutate(self, model):
|
||||
for node in self.nodes:
|
||||
# the logic here is similar to layer choice. We find cell attached to each node.
|
||||
target: Graph = model.graphs[node.operation.cell_name]
|
||||
target: Graph = model.graphs[cast(Cell, node.operation).cell_name]
|
||||
chain = self._retrieve_chain_from_graph(target)
|
||||
# and we get the chosen depth (by value choice)
|
||||
node_in_model = model.get_node_by_name(node.name)
|
||||
node_in_model = cast(Node, model.get_node_by_name(node.name))
|
||||
# depth is a value choice in base model
|
||||
# but it's already mutated by a ParameterChoiceMutator here
|
||||
chosen_depth = node_in_model.operation.parameters['depth']
|
||||
chosen_depth: int = node_in_model.operation.parameters['depth']
|
||||
for edge in chain[chosen_depth - 1].outgoing_edges:
|
||||
edge.remove()
|
||||
target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
|
||||
|
@ -159,8 +161,11 @@ class RepeatMutator(Mutator):
|
|||
for edge in rm_node.outgoing_edges:
|
||||
edge.remove()
|
||||
rm_node.remove()
|
||||
|
||||
# to delete the unused parameters.
|
||||
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
|
||||
target_node = cast(Node, model.get_node_by_name(node.name))
|
||||
cell_operation = cast(Cell, node.operation)
|
||||
target_node.update_operation(Cell(cell_operation.cell_name))
|
||||
|
||||
|
||||
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
|
||||
|
@ -241,7 +246,7 @@ class ManyChooseManyMutator(Mutator):
|
|||
Choose based on labels. Will not affect the model itself.
|
||||
"""
|
||||
|
||||
def __init__(self, label: Optional[str]):
|
||||
def __init__(self, label: str):
|
||||
super().__init__(label=label)
|
||||
|
||||
@staticmethod
|
||||
|
@ -257,7 +262,7 @@ class ManyChooseManyMutator(Mutator):
|
|||
return node.operation.parameters['n_chosen']
|
||||
return 1
|
||||
|
||||
def mutate(self, model: Model):
|
||||
def mutate(self, model: Model) -> None:
|
||||
# this mutate does not have any effect, but it is recorded in the mutation history
|
||||
for node in model.get_nodes_by_label(self.label):
|
||||
n_chosen = self.number_of_chosen(node)
|
||||
|
@ -280,12 +285,12 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
|
|||
if not is_model_wrapped(pytorch_model):
|
||||
raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
|
||||
'if your model has init parameters.')
|
||||
model.python_init_params = pytorch_model.trace_kwargs
|
||||
model.python_init_params = cast(dict, pytorch_model.trace_kwargs)
|
||||
else:
|
||||
model.python_init_params = {}
|
||||
|
||||
# hyper-parameter choice
|
||||
namespace: ModelNamespace = pytorch_model._model_namespace
|
||||
namespace: ModelNamespace = cast(ModelNamespace, pytorch_model._model_namespace)
|
||||
for param_spec in namespace.parameter_specs:
|
||||
assert param_spec.categorical and param_spec.type == 'choice'
|
||||
node = graph.add_node(f'param_spec_{param_spec.name}', 'ModelParameterChoice', {'candidates': param_spec.values})
|
||||
|
@ -294,7 +299,8 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
|
|||
for name, module in pytorch_model.named_modules():
|
||||
# tricky case: value choice that serves as parameters are stored in traced arguments
|
||||
if is_basic_unit(module):
|
||||
for key, value in module.trace_kwargs.items():
|
||||
trace_kwargs = cast(Dict[str, Any], module.trace_kwargs)
|
||||
for key, value in trace_kwargs.items():
|
||||
if isinstance(value, ValueChoiceX):
|
||||
for i, choice in enumerate(value.inner_choices()):
|
||||
node = graph.add_node(f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates})
|
||||
|
@ -329,14 +335,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
|
|||
mutators = []
|
||||
mutators_final = []
|
||||
for nodes in _group_by_label_and_type(graph.hidden_nodes):
|
||||
label = nodes[0].label
|
||||
assert label is not None, f'label of {nodes[0]} can not be None.'
|
||||
assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
|
||||
f'Node with label "{nodes[0].label}" does not all have the same type.'
|
||||
f'Node with label "{label}" does not all have the same type.'
|
||||
assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
|
||||
f'Node with label "{nodes[0].label}" does not agree on parameters.'
|
||||
f'Node with label "{label}" does not agree on parameters.'
|
||||
if nodes[0].operation.type == 'NasBench101Cell':
|
||||
mutators_final.append(NasBench101Mutator(nodes[0].label))
|
||||
# The mutation of Nas-bench-101 is special, and has to be done lastly.
|
||||
mutators_final.append(NasBench101Mutator(label))
|
||||
else:
|
||||
mutators.append(ManyChooseManyMutator(nodes[0].label))
|
||||
mutators.append(ManyChooseManyMutator(label))
|
||||
return model, mutators + mutators_final
|
||||
|
||||
|
||||
|
@ -350,7 +359,7 @@ class EvaluatorValueChoiceLeafMutator(Mutator):
|
|||
super().__init__(label=label)
|
||||
self.candidates = candidates
|
||||
|
||||
def mutate(self, model: Model) -> Model:
|
||||
def mutate(self, model: Model) -> None:
|
||||
# leave a record here
|
||||
# real mutations will be done in ParameterChoiceMutator
|
||||
self.choice(self.candidates)
|
||||
|
@ -388,7 +397,7 @@ class EvaluatorValueChoiceMutator(Mutator):
|
|||
|
||||
return obj
|
||||
|
||||
def mutate(self, model: Model):
|
||||
def mutate(self, model: Model) -> None:
|
||||
value_choice_decisions = {}
|
||||
for mutation in model.history:
|
||||
if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator):
|
||||
|
@ -454,7 +463,7 @@ def _is_all_equal(lst):
|
|||
return True
|
||||
|
||||
|
||||
def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
|
||||
def _group_by_label_and_type(nodes: Iterable[Node]) -> List[List[Node]]:
|
||||
result = {}
|
||||
for node in nodes:
|
||||
key = (node.label, node.operation.type)
|
||||
|
@ -464,7 +473,7 @@ def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
|
|||
return list(result.values())
|
||||
|
||||
|
||||
def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
|
||||
def _group_by_label(nodes: Iterable[Node]) -> List[List[Node]]:
|
||||
result = {}
|
||||
for node in nodes:
|
||||
label = node.operation.parameters['label']
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, List, Optional, Union, Dict
|
||||
from typing import Callable, List, Optional, Union, Dict, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -89,7 +92,7 @@ def compute_vertex_channels(input_channels, output_channels, matrix):
|
|||
return vertex_channels
|
||||
|
||||
|
||||
def prune(matrix, ops):
|
||||
def prune(matrix, ops) -> Tuple[np.ndarray, List[Union[str, Callable[[int], nn.Module]]]]:
|
||||
"""
|
||||
Prune the extraneous parts of the graph.
|
||||
|
||||
|
@ -152,11 +155,17 @@ class _NasBench101CellFixed(nn.Module):
|
|||
|
||||
assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1
|
||||
|
||||
self.operations = ['IN'] + operations + ['OUT'] # add psuedo nodes
|
||||
raw_operations: List[Union[str, Callable[[int], nn.Module]]] = list(operations)
|
||||
del operations # operations is no longer needed. Delete it to avoid misuse
|
||||
|
||||
# add psuedo nodes
|
||||
raw_operations.insert(0, 'IN')
|
||||
raw_operations.append('OUT')
|
||||
|
||||
self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes)
|
||||
del num_nodes # raw number of nodes is no longer used
|
||||
|
||||
self.connection_matrix, self.operations = prune(self.connection_matrix, self.operations)
|
||||
self.connection_matrix, self.operations = prune(self.connection_matrix, raw_operations)
|
||||
|
||||
self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix)
|
||||
|
||||
|
@ -172,7 +181,8 @@ class _NasBench101CellFixed(nn.Module):
|
|||
self.projections.append(projection(in_features, self.hidden_features[i]))
|
||||
|
||||
for i in range(1, self.num_nodes - 1):
|
||||
self.ops.append(operations[i - 1](self.hidden_features[i]))
|
||||
operation = cast(Callable[[int], nn.Module], self.operations[i])
|
||||
self.ops.append(operation(self.hidden_features[i]))
|
||||
|
||||
@staticmethod
|
||||
def build_connection_matrix(adjacency_list, num_nodes):
|
||||
|
@ -361,7 +371,7 @@ class NasBench101Mutator(Mutator):
|
|||
# for validation purposes
|
||||
# for python execution engine
|
||||
|
||||
def __init__(self, label: Optional[str]):
|
||||
def __init__(self, label: str):
|
||||
super().__init__(label=label)
|
||||
|
||||
@staticmethod
|
||||
|
@ -378,9 +388,11 @@ class NasBench101Mutator(Mutator):
|
|||
return 1
|
||||
|
||||
def mutate(self, model: Model):
|
||||
max_num_edges = cast(int, None)
|
||||
for node in model.get_nodes_by_label(self.label):
|
||||
max_num_edges = node.operation.parameters['max_num_edges']
|
||||
break
|
||||
assert max_num_edges is not None
|
||||
mutation_dict = {mut.mutator.label: mut.samples for mut in model.history}
|
||||
num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
|
||||
adjacency_list = [mutation_dict[f'{self.label}/input{i}'] for i in range(1, num_nodes)]
|
||||
|
|
|
@ -1,26 +1,42 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
|
||||
nn_cache_file_path = Path(__file__).parent / '_nn.py'
|
||||
|
||||
# Update this when cache format changes, to enforce an update.
|
||||
cache_version = 2
|
||||
|
||||
|
||||
def validate_cache() -> bool:
|
||||
import torch
|
||||
|
||||
cache_valid = []
|
||||
|
||||
if nn_cache_file_path.exists():
|
||||
lines = nn_cache_file_path.read_text().splitlines()
|
||||
for line in lines:
|
||||
if line.startswith('# _torch_version'):
|
||||
_cached_torch_version = line[line.find('=') + 1:].strip()
|
||||
if _cached_torch_version == torch.__version__:
|
||||
cache_valid.append(True)
|
||||
if line.startswith('# _torch_nn_cache_version'):
|
||||
_cached_cache_version = int(line[line.find('=') + 1:].strip())
|
||||
if _cached_cache_version == cache_version:
|
||||
cache_valid.append(True)
|
||||
|
||||
return len(cache_valid) >= 2 and all(cache_valid)
|
||||
|
||||
|
||||
def generate_stub_file() -> str:
|
||||
import inspect
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
|
||||
|
||||
nn_cache_file_path = Path(__file__).parent / '_nn.py'
|
||||
|
||||
cache_valid = False
|
||||
|
||||
if nn_cache_file_path.exists():
|
||||
from . import _nn # pylint: disable=no-name-in-module
|
||||
# valid only when torch version match
|
||||
if _nn._torch_version == torch.__version__:
|
||||
cache_valid = True
|
||||
|
||||
if not cache_valid:
|
||||
_NO_WRAP_CLASSES = [
|
||||
# not an nn.Module
|
||||
'Parameter',
|
||||
|
@ -47,7 +63,10 @@ if not cache_valid:
|
|||
'# This file is auto-generated to make auto-completion work.',
|
||||
'# When pytorch version does not match, it will get automatically updated.',
|
||||
'# pylint: skip-file',
|
||||
f'_torch_version = "{torch.__version__}"',
|
||||
'# pyright: reportGeneralTypeIssues=false',
|
||||
f'# _torch_version = {torch.__version__}',
|
||||
f'# _torch_nn_cache_version = {cache_version}',
|
||||
'import typing',
|
||||
'import torch.nn as nn',
|
||||
'from nni.retiarii.serializer import basic_unit',
|
||||
]
|
||||
|
@ -66,10 +85,9 @@ if not cache_valid:
|
|||
'It means your PyTorch version might not be supported.', RuntimeWarning)
|
||||
code.append(f'{name} = nn.{name}')
|
||||
elif name in _WRAP_WITHOUT_TAG_CLASSES:
|
||||
code.append(f'{name} = basic_unit(nn.{name}, basic_unit_tag=False)')
|
||||
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}, basic_unit_tag=False))')
|
||||
else:
|
||||
code.append(f'{name} = basic_unit(nn.{name})')
|
||||
|
||||
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}))')
|
||||
all_names.append(name)
|
||||
|
||||
elif inspect.isfunction(obj) or inspect.ismodule(obj):
|
||||
|
@ -78,12 +96,19 @@ if not cache_valid:
|
|||
|
||||
code.append(f'__all__ = {all_names}')
|
||||
|
||||
return '\n'.join(code)
|
||||
|
||||
|
||||
def write_cache(code: str) -> None:
|
||||
with nn_cache_file_path.open('w') as fp:
|
||||
fp.write('\n'.join(code))
|
||||
fp.write(code)
|
||||
|
||||
|
||||
# Import all modules from generated _nn.py
|
||||
code = generate_stub_file()
|
||||
|
||||
from . import _nn # pylint: disable=no-name-in-module
|
||||
__all__ = _nn.__all__
|
||||
from ._nn import * # pylint: disable=import-error, wildcard-import
|
||||
if not validate_cache():
|
||||
write_cache(code)
|
||||
|
||||
del Path, validate_cache, write_cache, cache_version, nn_cache_file_path, code
|
||||
|
||||
from ._nn import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import
|
||||
|
|
|
@ -20,7 +20,7 @@ from .supermodule.base import BaseSuperNetModule
|
|||
__all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules']
|
||||
|
||||
|
||||
MutationHook = Callable[[nn.Module, str, Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
|
||||
MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
|
||||
|
||||
|
||||
def traverse_and_mutate_submodules(
|
||||
|
@ -149,11 +149,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
|
|||
|
||||
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
|
||||
|
||||
To be more specific, the input arguments are three arguments:
|
||||
To be more specific, the input arguments are four arguments:
|
||||
|
||||
#. a module that might be processed,
|
||||
#. name of the module in its parent module,
|
||||
#. a memo dict whose usage depends on the particular algorithm.
|
||||
#. keyword arguments (configurations).
|
||||
|
||||
Note that the memo should be read/written by hooks.
|
||||
There won't be any hooks called on root module.
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# type: ignore
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# type: ignore
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
|
|
@ -27,7 +27,7 @@ which interprets the slice and apply it on a tensor.
|
|||
"""
|
||||
|
||||
import operator
|
||||
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic
|
||||
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -128,9 +128,10 @@ class Slicable(Generic[T]):
|
|||
raise TypeError(f'Unsuppoted weight type: {type(weight)}')
|
||||
self.weight = weight
|
||||
|
||||
def __getitem__(self, index: multidim_slice) -> T:
|
||||
def __getitem__(self, index: Union[slice_type, multidim_slice]) -> T:
|
||||
if not isinstance(index, tuple):
|
||||
index = (index, )
|
||||
index = cast(multidim_slice, index)
|
||||
|
||||
# Get the dict value in index's leafs
|
||||
# There can be at most one dict
|
||||
|
|
|
@ -24,7 +24,7 @@ class BaseSuperNetModule(nn.Module):
|
|||
rather than their compositions.
|
||||
"""
|
||||
|
||||
def resample(self, memo: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def resample(self, memo: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Resample the super-net module.
|
||||
|
||||
|
@ -40,7 +40,7 @@ class BaseSuperNetModule(nn.Module):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def export(self, memo: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def export(self, memo: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Export the final architecture within this module.
|
||||
It should have the same keys as ``search_space_spec()``.
|
||||
|
|
|
@ -275,11 +275,11 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
|
|||
if not arch:
|
||||
yield name, p
|
||||
|
||||
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Differentiable. Do nothing in resample."""
|
||||
return {}
|
||||
|
||||
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Export is also random for each leaf value choice."""
|
||||
result = {}
|
||||
for name, spec in operation.search_space_spec().items():
|
||||
|
|
|
@ -8,11 +8,12 @@ which is commonly known as super-kernel (as in channel search), or weight entang
|
|||
|
||||
import inspect
|
||||
import itertools
|
||||
from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar
|
||||
from typing import Union, Tuple, Dict, List, Any, Type, Optional, TypeVar, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
import nni.retiarii.nn.pytorch as retiarii_nn
|
||||
from nni.common.hpo_utils import ParameterSpec
|
||||
|
@ -46,11 +47,11 @@ class MixedOperationSamplingPolicy:
|
|||
"""
|
||||
pass
|
||||
|
||||
def resample(self, operation: 'MixedOperation', memo: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def resample(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""The handler of :meth:`MixedOperation.resample`."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def export(self, operation: 'MixedOperation', memo: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def export(self, operation: 'MixedOperation', memo: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""The handler of :meth:`MixedOperation.export`."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -513,43 +514,42 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
|
|||
embed_dim = _W(embed_dim)
|
||||
|
||||
# in projection weights & biases has q, k, v weights concatenated together
|
||||
in_proj_bias = in_proj_weight = None
|
||||
in_proj_bias: Optional[Tensor] = None
|
||||
in_proj_weight: Optional[Tensor] = None
|
||||
if self.in_proj_bias is not None:
|
||||
in_proj_bias = _S(self.in_proj_bias)[self._to_proj_slice(embed_dim)]
|
||||
in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim)]
|
||||
if self.in_proj_weight is not None:
|
||||
in_proj_weight = _S(self.in_proj_weight)[self._to_proj_slice(embed_dim), :embed_dim]
|
||||
in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim), :embed_dim]
|
||||
|
||||
bias_k = _S(self.bias_k)[:, :, :embed_dim] if self.bias_k is not None else None
|
||||
bias_v = _S(self.bias_v)[:, :, :embed_dim] if self.bias_v is not None else None
|
||||
out_proj_weight = _S(self.out_proj.weight)[:embed_dim, :embed_dim]
|
||||
out_proj_bias = _S(self.out_proj.bias)[:embed_dim] if self.out_proj.bias is not None else None
|
||||
bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim] if self.bias_k is not None else None
|
||||
bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim] if self.bias_v is not None else None
|
||||
out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim, :embed_dim]
|
||||
out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim] if self.out_proj.bias is not None else None
|
||||
|
||||
if not qkv_same_embed_dim:
|
||||
kdim = _W(kdim)
|
||||
vdim = _W(vdim)
|
||||
|
||||
q_proj = _S(self.q_proj_weight)[:embed_dim, :embed_dim]
|
||||
k_proj = _S(self.k_proj_weight)[:embed_dim]
|
||||
k_proj = _S(k_proj)[:, :kdim]
|
||||
v_proj = _S(self.v_proj_weight)[:embed_dim]
|
||||
v_proj = _S(v_proj)[:, :vdim]
|
||||
q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim, :embed_dim]
|
||||
k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim]
|
||||
k_proj = _S(k_proj)[:, :_W(kdim)]
|
||||
v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim]
|
||||
v_proj = _S(v_proj)[:, :_W(vdim)]
|
||||
|
||||
# The rest part is basically same as pytorch
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query, key, value, used_embed_dim, num_heads,
|
||||
in_proj_weight, in_proj_bias,
|
||||
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
|
||||
bias_k, bias_v, self.add_zero_attn,
|
||||
dropout, out_proj_weight, out_proj_bias,
|
||||
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
||||
attn_mask=attn_mask, use_separate_proj_weight=True,
|
||||
q_proj_weight=q_proj, k_proj_weight=k_proj, v_proj_weight=v_proj)
|
||||
else:
|
||||
# Cast tensor here because of a bug in pytorch stub
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query, key, value, used_embed_dim, num_heads,
|
||||
in_proj_weight, in_proj_bias,
|
||||
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
|
||||
bias_k, bias_v, self.add_zero_attn,
|
||||
dropout, out_proj_weight, out_proj_bias,
|
||||
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
||||
attn_mask=attn_mask)
|
||||
|
|
|
@ -9,7 +9,7 @@ The support remains limited. Known limitations include:
|
|||
- The code contains duplicates. Needs refactor.
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Tuple, Optional, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -94,7 +94,7 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
|
|||
self._sample_idx = self.op_names.index(self._sampled)
|
||||
else:
|
||||
probs = self._softmax(self._arch_alpha)
|
||||
self._sample_idx = torch.multinomial(probs, 1)[0].item()
|
||||
self._sample_idx = int(torch.multinomial(probs, 1)[0].item())
|
||||
self._sampled = self.op_names[self._sample_idx]
|
||||
|
||||
# set binary gates
|
||||
|
@ -109,10 +109,11 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
|
|||
"""Chose the argmax if label isn't found in memo."""
|
||||
if self.label in memo:
|
||||
return {} # nothing new to export
|
||||
return {self.label: self.op_names[torch.argmax(self._arch_alpha).item()]}
|
||||
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
|
||||
|
||||
def finalize_grad(self):
|
||||
binary_grads = self._binary_gates.grad
|
||||
assert binary_grads is not None
|
||||
with torch.no_grad():
|
||||
if self._arch_alpha.grad is None:
|
||||
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
|
||||
|
@ -164,13 +165,13 @@ class ProxylessMixedInput(DifferentiableMixedInput):
|
|||
else:
|
||||
probs = self._softmax(self._arch_alpha)
|
||||
sample = torch.multinomial(probs, 1)[0].item()
|
||||
self._sampled = sample
|
||||
self._sampled = int(sample)
|
||||
|
||||
# set binary gates
|
||||
with torch.no_grad():
|
||||
self._binary_gates.zero_()
|
||||
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
|
||||
self._binary_gates.data[sample] = 1.0
|
||||
self._binary_gates.data[cast(int, self._sampled)] = 1.0
|
||||
|
||||
return {self.label: self._sampled}
|
||||
|
||||
|
@ -182,6 +183,7 @@ class ProxylessMixedInput(DifferentiableMixedInput):
|
|||
|
||||
def finalize_grad(self):
|
||||
binary_grads = self._binary_gates.grad
|
||||
assert binary_grads is not None
|
||||
with torch.no_grad():
|
||||
if self._arch_alpha.grad is None:
|
||||
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
|
||||
|
|
|
@ -129,6 +129,8 @@ class PathSamplingInput(BaseSuperNetModule):
|
|||
if isinstance(module, InputChoice):
|
||||
if module.reduction not in ['sum', 'mean', 'concat']:
|
||||
raise ValueError('Only input choice of sum/mean/concat reduction is supported.')
|
||||
if module.n_chosen is None:
|
||||
raise ValueError('n_chosen is None is not supported yet.')
|
||||
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
|
||||
|
||||
def forward(self, input_tensors):
|
||||
|
@ -161,7 +163,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
|
|||
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
|
||||
self._sampled: Optional[Dict[str, Any]] = None
|
||||
|
||||
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def resample(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Random sample for each leaf value choice."""
|
||||
result = {}
|
||||
space_spec = operation.search_space_spec()
|
||||
|
@ -179,7 +181,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
|
|||
|
||||
return result
|
||||
|
||||
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def export(self, operation: MixedOperation, memo: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Export is also random for each leaf value choice."""
|
||||
result = {}
|
||||
space_spec = operation.search_space_spec()
|
||||
|
|
|
@ -132,7 +132,7 @@ def _replace_module_with_type(root_module, init_fn, type_name, modules):
|
|||
for name, child in m.named_children():
|
||||
if isinstance(child, type_name):
|
||||
setattr(m, name, init_fn(child))
|
||||
modules.append((child.key, getattr(m, name)))
|
||||
modules.append((child.label, getattr(m, name)))
|
||||
else:
|
||||
apply(child)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import (Any, Dict, List)
|
||||
from typing import (Any, Dict, List, Optional, cast)
|
||||
|
||||
from . import debug_configs
|
||||
|
||||
|
@ -34,6 +34,8 @@ class Operation:
|
|||
Arbitrary key-value parameters (e.g. kernel_size).
|
||||
"""
|
||||
|
||||
io_names: List[str] = []
|
||||
|
||||
def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
|
||||
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
|
||||
self.type: str = type_name
|
||||
|
@ -43,7 +45,7 @@ class Operation:
|
|||
def to_init_code(self, field: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _to_class_name(self) -> str:
|
||||
|
@ -53,8 +55,8 @@ class Operation:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None,
|
||||
attributes: Dict[str, Any] = None) -> 'Operation':
|
||||
def new(type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None), cell_name: str = cast(str, None),
|
||||
attributes: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Operation':
|
||||
parameters = parameters or {}
|
||||
attributes = attributes or {}
|
||||
if type_name == '_cell':
|
||||
|
@ -98,16 +100,16 @@ class PyTorchOperation(Operation):
|
|||
subclass_name = 'FunctionalOperator'
|
||||
for subclass in cls.__subclasses__():
|
||||
if hasattr(subclass, '_ori_type_name') and \
|
||||
subclass_name in subclass._ori_type_name:
|
||||
subclass_name in cast(Any, subclass)._ori_type_name:
|
||||
return subclass
|
||||
for subclass in cls.__subclasses__():
|
||||
if hasattr(subclass, '_artificial_op_name') and \
|
||||
subclass_name in subclass._artificial_op_name:
|
||||
subclass_name in cast(Any, subclass)._artificial_op_name:
|
||||
return subclass
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def to_class_name(cls, type_name) -> str:
|
||||
def to_class_name(cls, type_name) -> Optional[str]:
|
||||
if type_name.startswith('__torch__.'):
|
||||
return type_name[len('__torch__.'):]
|
||||
elif type_name.startswith('__mutated__.'):
|
||||
|
@ -119,7 +121,7 @@ class PyTorchOperation(Operation):
|
|||
def is_functional(cls, type_name) -> bool:
|
||||
return type_name.startswith('Function.')
|
||||
|
||||
def _to_class_name(self) -> str:
|
||||
def _to_class_name(self) -> Optional[str]:
|
||||
if self.type.startswith('__torch__.'):
|
||||
return self.type[len('__torch__.'):]
|
||||
elif self.type.startswith('__mutated__.'):
|
||||
|
@ -127,7 +129,7 @@ class PyTorchOperation(Operation):
|
|||
else:
|
||||
return None
|
||||
|
||||
def get_import_pkg(self) -> str:
|
||||
def get_import_pkg(self) -> Optional[str]:
|
||||
if self.type.startswith('__torch__.'):
|
||||
return self.type[len('__torch__.'):].split('.')[0]
|
||||
elif self.type.startswith('__mutated__.'):
|
||||
|
@ -135,14 +137,14 @@ class PyTorchOperation(Operation):
|
|||
else:
|
||||
return None
|
||||
|
||||
def to_init_code(self, field: str) -> str:
|
||||
def to_init_code(self, field: str) -> Optional[str]:
|
||||
if self._to_class_name() is not None:
|
||||
assert 'positional_args' not in self.parameters
|
||||
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
|
||||
return f'self.{field} = {self._to_class_name()}({kw_params})'
|
||||
return None
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -207,7 +209,9 @@ class Cell(PyTorchOperation):
|
|||
No real usage. Exists for compatibility with base class.
|
||||
"""
|
||||
|
||||
def __init__(self, cell_name: str, parameters: Dict[str, Any] = None, attributes: Dict[str, Any] = None):
|
||||
def __init__(self, cell_name: str,
|
||||
parameters: Dict[str, Any] = cast(Dict[str, Any], None),
|
||||
attributes: Dict[str, Any] = cast(Dict[str, Any], None)):
|
||||
self.type = '_cell'
|
||||
self.cell_name = cell_name
|
||||
self.parameters = parameters or {}
|
||||
|
@ -217,7 +221,7 @@ class Cell(PyTorchOperation):
|
|||
# TODO: ugly, think about how to refactor this part
|
||||
return _convert_name(self.cell_name)
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = self.{field}({", ".join(inputs)})'
|
||||
|
||||
class _IOPseudoOperation(Operation):
|
||||
|
@ -227,7 +231,7 @@ class _IOPseudoOperation(Operation):
|
|||
especially in static type checking.
|
||||
"""
|
||||
|
||||
def __init__(self, type_name: str, io_names: List = None):
|
||||
def __init__(self, type_name: str, io_names: List[str] = cast(List[str], None)):
|
||||
assert type_name.startswith('_')
|
||||
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
|
||||
self.io_names = io_names
|
||||
|
@ -235,7 +239,7 @@ class _IOPseudoOperation(Operation):
|
|||
def to_init_code(self, field: str) -> str:
|
||||
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Definition of operation types.
|
||||
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (Any, Dict, List)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as nn_functional
|
||||
|
||||
from ..operation import PyTorchOperation
|
||||
|
||||
|
@ -39,23 +42,23 @@ class NoOpIdentity(PyTorchOperation):
|
|||
"""
|
||||
_ori_type_name = ['noop_identity']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {", ".join(inputs)}'
|
||||
|
||||
|
||||
class ModuleOperator(PyTorchOperation):
|
||||
_ori_type_name = ['ModuleOperator', 'shared']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = self.{field}({", ".join(inputs)})'
|
||||
|
||||
|
||||
class FunctionalOperator(PyTorchOperation):
|
||||
_ori_type_name = ['FunctionalOperator']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
func_name = self.type[len('Function.'):]
|
||||
if not hasattr(torch.nn.functional, func_name):
|
||||
if not hasattr(nn_functional, func_name):
|
||||
raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, '
|
||||
f'{func_name} is not in it.')
|
||||
return f'{output} = F.{func_name}({", ".join(inputs)})'
|
||||
|
@ -64,7 +67,7 @@ class FunctionalOperator(PyTorchOperation):
|
|||
class PrimConstant(PyTorchOperation):
|
||||
_ori_type_name = ['prim::Constant']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
|
||||
# TODO: deal with all the types
|
||||
if self.parameters['type'] in ['None', 'NoneType']:
|
||||
|
@ -87,28 +90,28 @@ class PrimConstant(PyTorchOperation):
|
|||
class PrimListConstruct(PyTorchOperation):
|
||||
_ori_type_name = ['prim::ListConstruct']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = [{", ".join(inputs)}]'
|
||||
|
||||
|
||||
class PrimListUnpack(PyTorchOperation):
|
||||
_ori_type_name = ['prim::ListUnpack']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]}'
|
||||
|
||||
|
||||
class PrimTupleConstruct(PyTorchOperation):
|
||||
_ori_type_name = ['prim::TupleConstruct']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = ({", ".join(inputs)})'
|
||||
|
||||
|
||||
class PrimTupleUnpack(PyTorchOperation):
|
||||
_ori_type_name = ['prim::TupleUnpack']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
# have single output here, because the following code uses index to access the unpacked values
|
||||
assert len(inputs) == 1
|
||||
return f'{output} = {inputs[0]}'
|
||||
|
@ -117,7 +120,7 @@ class PrimTupleUnpack(PyTorchOperation):
|
|||
class PrimGetAttr(PyTorchOperation):
|
||||
_ori_type_name = ['prim::GetAttr']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
if self.parameters['value'] is not None:
|
||||
return f"{output} = {self.parameters['value']}"
|
||||
else:
|
||||
|
@ -127,14 +130,14 @@ class PrimGetAttr(PyTorchOperation):
|
|||
class PrimUncheckedCast(PyTorchOperation):
|
||||
_ori_type_name = ['prim::unchecked_cast']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]}'
|
||||
|
||||
|
||||
class SimpleMember(PyTorchOperation):
|
||||
_ori_type_name = ['prim::is_cuda', 'prim::data']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
member_name = self.type.split('::')[-1]
|
||||
return f'{output} = {inputs[0]}.{member_name}'
|
||||
|
||||
|
@ -142,16 +145,16 @@ class SimpleMember(PyTorchOperation):
|
|||
class AtenContiguous(PyTorchOperation):
|
||||
_ori_type_name = ['aten::contiguous']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
# defined in pytorch/c10/core/MemoryFormat.h
|
||||
assert inputs_value[1] in [0, 1, 2]
|
||||
assert inputs_value is not None and inputs_value[1] in [0, 1, 2]
|
||||
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
|
||||
|
||||
|
||||
class AtenGetitem(PyTorchOperation):
|
||||
_ori_type_name = ['aten::__getitem__']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
assert len(inputs) == 2
|
||||
return f'{output} = {inputs[0]}[{inputs[1]}]'
|
||||
|
||||
|
@ -159,7 +162,7 @@ class AtenGetitem(PyTorchOperation):
|
|||
class AtenAppend(PyTorchOperation):
|
||||
_ori_type_name = ['aten::append']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
assert len(inputs) == 2
|
||||
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
|
||||
|
||||
|
@ -167,7 +170,7 @@ class AtenAppend(PyTorchOperation):
|
|||
class MergedSlice(PyTorchOperation):
|
||||
_ori_type_name = ['MergedSlice']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
if (len(inputs) - 1) % 4 == 0:
|
||||
slices = []
|
||||
dim = int((len(inputs) - 1) / 4)
|
||||
|
@ -187,21 +190,21 @@ class MergedSlice(PyTorchOperation):
|
|||
class AtenBool(PyTorchOperation):
|
||||
_ori_type_name = ['aten::Bool']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = bool({inputs[0]})'
|
||||
|
||||
|
||||
class AtenNot(PyTorchOperation):
|
||||
_ori_type_name = ['aten::__not__']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = not {inputs[0]}'
|
||||
|
||||
|
||||
class AtenCat(PyTorchOperation):
|
||||
_ori_type_name = ['aten::cat']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
assert len(inputs) == 2
|
||||
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
|
||||
|
||||
|
@ -215,7 +218,7 @@ class AtenTensors(PyTorchOperation):
|
|||
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
|
||||
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
schemas = torch._C._jit_get_schemas_for_operator(self.type)
|
||||
# match number of inputs
|
||||
overloaded_defs = [len(s.arguments) for s in schemas]
|
||||
|
@ -257,40 +260,41 @@ class AtenTensors(PyTorchOperation):
|
|||
class AtenFloordiv(PyTorchOperation):
|
||||
_ori_type_name = ['aten::floordiv']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]} // {inputs[1]}'
|
||||
|
||||
|
||||
class AtenMul(PyTorchOperation):
|
||||
_ori_type_name = ['aten::mul']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]} * {inputs[1]}'
|
||||
|
||||
|
||||
class AtenLen(PyTorchOperation):
|
||||
_ori_type_name = ['aten::len']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = len({inputs[0]})'
|
||||
|
||||
|
||||
class AtenIntImplicit(PyTorchOperation):
|
||||
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
if self.type.endswith('Implicit'):
|
||||
return f'{output} = {inputs[0]}'
|
||||
elif self.type == 'aten::Int':
|
||||
return f'{output} = int({inputs[0]})'
|
||||
elif self.type == 'aten::Float':
|
||||
return f'{output} = float({inputs[0]})'
|
||||
raise TypeError(f'Unexpected type: {self.type}')
|
||||
|
||||
|
||||
class AtenIndex(PyTorchOperation):
|
||||
_ori_type_name = ['aten::index']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]}[{inputs[1]}]'
|
||||
|
||||
|
||||
|
@ -355,13 +359,13 @@ def _get_tensor_ops():
|
|||
|
||||
def _get_torch_ops():
|
||||
torch_op_args = {}
|
||||
for mod in torch.jit._builtins._modules_containing_builtins:
|
||||
for mod in torch.jit._builtins._modules_containing_builtins: # type: ignore
|
||||
name = mod.__name__
|
||||
if name == 'torch._C._nn':
|
||||
continue
|
||||
# only process 'torch.XXX'
|
||||
for elem in dir(mod):
|
||||
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem))
|
||||
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem)) # type: ignore
|
||||
if builtin is not None:
|
||||
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
||||
for schema in schemas:
|
||||
|
@ -436,7 +440,7 @@ class TensorOps(PyTorchOperation):
|
|||
return None
|
||||
raise RuntimeError(f'tensor op type {_type} has no matched')
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
# TODO: deal with conditional ops
|
||||
if self.type in TensorOps.comparison_ops:
|
||||
return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})'
|
||||
|
@ -486,7 +490,7 @@ class TorchOps(PyTorchOperation):
|
|||
else:
|
||||
raise RuntimeError(f'torch op type {_type} has no matched')
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
matched_args = TorchOps._get_matched_args(self.type, inputs)
|
||||
op_name = self.type.split('::')[-1]
|
||||
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
|
||||
|
@ -498,7 +502,7 @@ class AtenAvgpool2d(PyTorchOperation):
|
|||
# NOTE: it is not included in the above aten ops for unkown reason
|
||||
_ori_type_name = ['aten::avg_pool2d']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
|
||||
|
||||
|
||||
|
@ -506,7 +510,7 @@ class ToDevice(PyTorchOperation):
|
|||
_artificial_op_name = "ToDevice"
|
||||
|
||||
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False,
|
||||
attributes: Dict[str, Any] = None):
|
||||
attributes: Dict[str, Any] = {}):
|
||||
self.type = "ToDevice"
|
||||
self.device = parameters['device']
|
||||
self.overridden_device_repr = None
|
||||
|
@ -540,5 +544,5 @@ class AtenDet(PyTorchOperation):
|
|||
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
|
||||
_ori_type_name = ['aten::linalg_det']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = torch.det({inputs[0]})'
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, TypeVar, Union
|
||||
from typing import Any, TypeVar, Type
|
||||
|
||||
from nni.common.serializer import Traceable, is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
|
||||
from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
|
||||
from .utils import ModelNamespace
|
||||
|
||||
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
|
||||
|
@ -48,7 +48,7 @@ def serialize_cls(cls):
|
|||
return trace(cls)
|
||||
|
||||
|
||||
def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
|
||||
def basic_unit(cls: T, basic_unit_tag: bool = True) -> T:
|
||||
"""
|
||||
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
|
||||
|
||||
|
@ -75,17 +75,17 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
|
|||
return cls
|
||||
|
||||
import torch.nn as nn
|
||||
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
|
||||
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' # type: ignore
|
||||
|
||||
cls = trace(cls)
|
||||
cls._nni_basic_unit = basic_unit_tag
|
||||
cls._nni_basic_unit = basic_unit_tag # type: ignore
|
||||
|
||||
_torchscript_patch(cls)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
def model_wrapper(cls: T) -> Union[T, Traceable]:
|
||||
def model_wrapper(cls: T) -> T:
|
||||
"""
|
||||
Wrap the base model (search space). For example,
|
||||
|
||||
|
@ -113,7 +113,7 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
|
|||
return cls
|
||||
|
||||
import torch.nn as nn
|
||||
assert issubclass(cls, nn.Module)
|
||||
assert issubclass(cls, nn.Module) # type: ignore
|
||||
|
||||
# subclass can still use trace info
|
||||
wrapper = trace(cls, inheritable=True)
|
||||
|
@ -146,7 +146,7 @@ def is_model_wrapped(cls_or_instance) -> bool:
|
|||
return getattr(cls_or_instance, '_nni_model_wrapper', False)
|
||||
|
||||
|
||||
def _check_wrapped(cls: T, rewrap: str) -> bool:
|
||||
def _check_wrapped(cls: Type, rewrap: str) -> bool:
|
||||
wrapped = None
|
||||
if is_model_wrapped(cls):
|
||||
wrapped = 'model_wrapper'
|
||||
|
|
|
@ -1,19 +1,26 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# This file might cause import error for those who didn't install RL-related dependencies
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import tianshou
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from gym import spaces
|
||||
from tianshou.data import to_torch
|
||||
from tianshou.env.worker import EnvWorker
|
||||
|
||||
from nni.typehint import TypedDict
|
||||
|
||||
from .utils import get_targeted_model
|
||||
from ..graph import ModelStatus
|
||||
from ..execution import submit_models, wait_models
|
||||
|
@ -76,8 +83,13 @@ class MultiThreadEnvWorker(EnvWorker):
|
|||
self.pool.terminate()
|
||||
return self.env.close()
|
||||
|
||||
class ObservationType(TypedDict):
|
||||
action_history: np.ndarray
|
||||
cur_step: int
|
||||
action_dim: int
|
||||
|
||||
class ModelEvaluationEnv(gym.Env):
|
||||
|
||||
class ModelEvaluationEnv(gym.Env[ObservationType, int]):
|
||||
def __init__(self, base_model, mutators, search_space):
|
||||
self.base_model = base_model
|
||||
self.mutators = mutators
|
||||
|
@ -98,7 +110,7 @@ class ModelEvaluationEnv(gym.Env):
|
|||
def action_space(self):
|
||||
return spaces.Discrete(self.action_dim)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> ObservationType:
|
||||
self.action_history = np.zeros(self.num_steps, dtype=np.int32)
|
||||
self.cur_step = 0
|
||||
self.sample = {}
|
||||
|
@ -108,14 +120,14 @@ class ModelEvaluationEnv(gym.Env):
|
|||
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]])
|
||||
}
|
||||
|
||||
def step(self, action):
|
||||
def step(self, action: int) -> Tuple[ObservationType, float, bool, dict]:
|
||||
cur_key = self.ss_keys[self.cur_step]
|
||||
assert action < len(self.search_space[cur_key]), \
|
||||
f'Current action {action} out of range {self.search_space[cur_key]}.'
|
||||
self.action_history[self.cur_step] = action
|
||||
self.sample[cur_key] = self.search_space[cur_key][action]
|
||||
self.cur_step += 1
|
||||
obs = {
|
||||
obs: ObservationType = {
|
||||
'action_history': self.action_history,
|
||||
'cur_step': self.cur_step,
|
||||
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) \
|
||||
|
@ -129,7 +141,7 @@ class ModelEvaluationEnv(gym.Env):
|
|||
wait_models(model)
|
||||
if model.status == ModelStatus.Failed:
|
||||
return self.reset(), 0., False, {}
|
||||
rew = model.metric
|
||||
rew = float(model.metric)
|
||||
_logger.info(f'Model metric received as reward: {rew}')
|
||||
return obs, rew, True, {}
|
||||
else:
|
||||
|
@ -147,7 +159,7 @@ class Preprocessor(nn.Module):
|
|||
self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
|
||||
|
||||
def forward(self, obs):
|
||||
seq = nn.functional.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token
|
||||
seq = F.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token
|
||||
# end token is used to avoid out-of-range of v_s_. Will not actually affect BP.
|
||||
seq = self.embedding(seq.long())
|
||||
feature, _ = self.rnn(seq)
|
||||
|
@ -167,7 +179,7 @@ class Actor(nn.Module):
|
|||
# to take care of choices with different number of options
|
||||
mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
|
||||
out[mask.to(out.device)] = float('-inf')
|
||||
return nn.functional.softmax(out, dim=-1), kwargs.get('state', None)
|
||||
return F.softmax(out, dim=-1), kwargs.get('state', None)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
|
|
|
@ -14,5 +14,5 @@ class BaseStrategy(abc.ABC):
|
|||
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
|
||||
pass
|
||||
|
||||
def export_top_models(self) -> List[Any]:
|
||||
def export_top_models(self, top_k: int) -> List[Any]:
|
||||
raise NotImplementedError('"export_top_models" is not implemented.')
|
||||
|
|
|
@ -6,7 +6,7 @@ import itertools
|
|||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Sequence, Optional
|
||||
|
||||
from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted
|
||||
from .base import BaseStrategy
|
||||
|
@ -30,6 +30,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
|
|||
history = set()
|
||||
search_space_values = copy.deepcopy(list(search_space.values()))
|
||||
while True:
|
||||
selected: Optional[Sequence[int]] = None
|
||||
for retry_count in range(retries):
|
||||
selected = [random.choice(v) for v in search_space_values]
|
||||
if not dedup:
|
||||
|
@ -41,6 +42,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
|
|||
if retry_count + 1 == retries:
|
||||
_logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
|
||||
return
|
||||
assert selected is not None, 'Retry attempts exhausted.'
|
||||
yield {key: value for key, value in zip(keys, selected)}
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from typing import Optional, Callable
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
|
||||
|
||||
|
@ -15,8 +16,8 @@ _logger = logging.getLogger(__name__)
|
|||
class TPESampler(Sampler):
|
||||
def __init__(self, optimize_mode='minimize'):
|
||||
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
|
||||
self.cur_sample = None
|
||||
self.index = None
|
||||
self.cur_sample: Optional[dict] = None
|
||||
self.index: Optional[int] = None
|
||||
self.total_parameters = {}
|
||||
|
||||
def update_sample_space(self, sample_space):
|
||||
|
@ -34,6 +35,7 @@ class TPESampler(Sampler):
|
|||
self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)
|
||||
|
||||
def choice(self, candidates, mutator, model, index):
|
||||
assert isinstance(self.index, int) and isinstance(self.cur_sample, dict)
|
||||
chosen = self.cur_sample[str(self.index)]
|
||||
self.index += 1
|
||||
return chosen
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import collections
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
|
|
|
@ -25,4 +25,6 @@ if __name__ == '__main__':
|
|||
elif args.exec == 'benchmark':
|
||||
from .execution.benchmark import BenchmarkExecutionEngine
|
||||
engine = BenchmarkExecutionEngine
|
||||
else:
|
||||
raise ValueError(f'Unrecognized benchmark name: {args.exec}')
|
||||
engine.trial_execute_graph()
|
||||
|
|
|
@ -6,7 +6,7 @@ import itertools
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, List, Dict
|
||||
from typing import Any, List, Dict, cast
|
||||
from pathlib import Path
|
||||
|
||||
from nni.common.hpo_utils import ParameterSpec
|
||||
|
@ -41,9 +41,10 @@ def get_module_name(cls_or_func):
|
|||
if module_name == '__main__':
|
||||
# infer the module name with inspect
|
||||
for frm in inspect.stack():
|
||||
if inspect.getmodule(frm[0]).__name__ == '__main__':
|
||||
module = inspect.getmodule(frm[0])
|
||||
if module is not None and module.__name__ == '__main__':
|
||||
# main module found
|
||||
main_file_path = Path(inspect.getsourcefile(frm[0]))
|
||||
main_file_path = Path(cast(str, inspect.getsourcefile(frm[0])))
|
||||
if not Path().samefile(main_file_path.parent):
|
||||
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
|
||||
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
|
||||
|
@ -227,6 +228,7 @@ def original_state_dict_hooks(model: Any):
|
|||
supernet_style_state_dict = model.state_dict()
|
||||
"""
|
||||
|
||||
import torch.utils.hooks
|
||||
import torch.nn as nn
|
||||
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
|
||||
|
||||
|
@ -297,8 +299,8 @@ def original_state_dict_hooks(model: Any):
|
|||
raise KeyError(f'"{src}" not in state dict, but found in mapping.')
|
||||
destination.update(result)
|
||||
|
||||
hooks: List[torch.utils.hooks.RemovableHandle] = []
|
||||
try:
|
||||
hooks = []
|
||||
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
|
||||
hooks.append(model._register_state_dict_hook(state_dict_hook))
|
||||
yield
|
||||
|
|
|
@ -6,7 +6,7 @@ Types for static checking.
|
|||
"""
|
||||
|
||||
__all__ = [
|
||||
'Literal',
|
||||
'Literal', 'TypedDict',
|
||||
'Parameters', 'SearchSpace', 'TrialMetric', 'TrialRecord',
|
||||
]
|
||||
|
||||
|
|
|
@ -64,6 +64,9 @@ stages:
|
|||
python -m pip install "typing-extensions>=3.10"
|
||||
displayName: Resolve dependency version
|
||||
|
||||
- script: python test/vso_tools/trigger_import.py
|
||||
displayName: Trigger import
|
||||
|
||||
- script: |
|
||||
python -m pylint --rcfile pylintrc nni
|
||||
displayName: pylint
|
||||
|
|
|
@ -3,10 +3,13 @@
|
|||
"nni/algorithms",
|
||||
"nni/common/device.py",
|
||||
"nni/common/graph_utils.py",
|
||||
"nni/common/serializer.py",
|
||||
"nni/compression",
|
||||
"nni/nas",
|
||||
"nni/retiarii",
|
||||
"nni/nas/tensorflow",
|
||||
"nni/nas/pytorch",
|
||||
"nni/retiarii/execution/cgo_engine.py",
|
||||
"nni/retiarii/execution/logical_optimizer",
|
||||
"nni/retiarii/evaluator/pytorch/cgo",
|
||||
"nni/retiarii/oneshot",
|
||||
"nni/smartparam.py",
|
||||
"nni/tools/annotation",
|
||||
"nni/tools/gpu_tool",
|
||||
|
@ -14,5 +17,6 @@
|
|||
"nni/tools/nnictl",
|
||||
"nni/tools/trial_tool"
|
||||
],
|
||||
"reportMissingImports": false
|
||||
"reportMissingImports": false,
|
||||
"reportPrivateImportUsage": false
|
||||
}
|
||||
|
|
|
@ -4,4 +4,5 @@ filterwarnings =
|
|||
ignore:Using key to access the identifier of:DeprecationWarning
|
||||
ignore:layer_choice.choices is deprecated.:DeprecationWarning
|
||||
ignore:The truth value of an empty array is ambiguous.:DeprecationWarning
|
||||
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
|
||||
ignore:nni.retiarii.serialize is deprecated and will be removed in future release.:DeprecationWarning
|
||||
|
|
|
@ -36,5 +36,9 @@
|
|||
{"head": ["conv2", null], "tail": ["pool2", null]},
|
||||
{"head": ["pool2", null], "tail": ["_outputs", 0]}
|
||||
]
|
||||
},
|
||||
|
||||
"_evaluator": {
|
||||
"type": "DebugEvaluator"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ from nni.retiarii.codegen import model_to_pytorch_script
|
|||
from nni.retiarii.execution import set_execution_engine
|
||||
from nni.retiarii.execution.base import BaseExecutionEngine
|
||||
from nni.retiarii.execution.python import PurePythonExecutionEngine
|
||||
from nni.retiarii.graph import DebugEvaluator
|
||||
from nni.retiarii.integration import RetiariiAdvisor
|
||||
|
||||
|
||||
|
@ -51,6 +52,7 @@ class EngineTest(unittest.TestCase):
|
|||
'edges': []
|
||||
}
|
||||
})
|
||||
model.evaluator = DebugEvaluator()
|
||||
model.python_class = object
|
||||
submit_models(model, model)
|
||||
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
"""Trigger import of some modules to write some caches,
|
||||
so that static analysis (e.g., pyright) can know the type."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../'))
|
||||
|
||||
import nni
|
||||
import nni.retiarii.nn.pytorch
|
Загрузка…
Ссылка в новой задаче