зеркало из https://github.com/microsoft/nni.git
Use "tuner" to config advisor (#4773)
This commit is contained in:
Родитель
c8cc5c625c
Коммит
f24c8380cf
|
@ -38,7 +38,7 @@ As a general example, random tuner can be configured as follow:
|
|||
'x': {'_type': 'uniform', '_value': [0, 1]},
|
||||
'y': {'_type': 'choice', '_value': ['a', 'b', 'c']}
|
||||
}
|
||||
config.tuner.name = 'Random'
|
||||
config.tuner.name = 'random'
|
||||
config.tuner.class_args = {'seed': 0}
|
||||
|
||||
Built-in Tuners
|
||||
|
|
|
@ -203,9 +203,9 @@ ExperimentConfig
|
|||
|
||||
* - tunerGpuIndices
|
||||
- ``list[int]`` or ``str`` or ``int``, optional
|
||||
- Limit the GPUs visible to tuner, assessor, and advisor.
|
||||
- Limit the GPUs visible to tuner and assessor.
|
||||
This will be the ``CUDA_VISIBLE_DEVICES`` environment variable of tuner process.
|
||||
Because tuner, assessor, and advisor run in the same process, this option will affect them all.
|
||||
Because tuner and assessor run in the same process, this option will affect both of them.
|
||||
|
||||
* - tuner
|
||||
- ``AlgorithmConfig``, optional
|
||||
|
@ -219,8 +219,7 @@ ExperimentConfig
|
|||
|
||||
* - advisor
|
||||
- ``AlgorithmConfig``, optional
|
||||
- Specify the advisor.
|
||||
NNI provides two built-in advisors: :class:`BOHB <nni.algorithms.hpo.bohb_advisor.BOHB>` and :class:`Hyperband <nni.algorithms.hpo.hyperband_advisor.Hyperband>`.
|
||||
- Deprecated, use ``tuner`` instead.
|
||||
|
||||
* - trainingService
|
||||
- ``TrainingServiceConfig``
|
||||
|
@ -251,7 +250,7 @@ For customized algorithms, there are two ways to describe them:
|
|||
|
||||
* - name
|
||||
- ``str`` or ``None``, optional
|
||||
- Default: None. Name of the built-in or registered algorithm.
|
||||
- Default: None. Name of the built-in or registered algorithm, case insensitive.
|
||||
``str`` for the built-in and registered algorithm, ``None`` for other customized algorithms.
|
||||
|
||||
* - className
|
||||
|
|
|
@ -7,8 +7,8 @@ import logging
|
|||
import json
|
||||
import base64
|
||||
|
||||
from .runtime.common import enable_multi_thread
|
||||
from .runtime.msg_dispatcher import MsgDispatcher
|
||||
from .runtime.msg_dispatcher_base import MsgDispatcherBase
|
||||
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance
|
||||
|
||||
logger = logging.getLogger('nni.main')
|
||||
|
@ -29,82 +29,50 @@ def main():
|
|||
exp_params = json.loads(exp_params_decode)
|
||||
logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4))
|
||||
|
||||
if exp_params.get('deprecated', {}).get('multiThread'):
|
||||
enable_multi_thread()
|
||||
|
||||
if 'trainingServicePlatform' in exp_params: # config schema is v1
|
||||
from .experiment.config.convert import convert_algo
|
||||
for algo_type in ['tuner', 'assessor', 'advisor']:
|
||||
for algo_type in ['tuner', 'assessor']:
|
||||
if algo_type in exp_params:
|
||||
exp_params[algo_type] = convert_algo(algo_type, exp_params[algo_type])
|
||||
if 'advisor' in exp_params:
|
||||
exp_params['tuner'] = convert_algo('advisor', exp_params['advisor'])
|
||||
|
||||
if exp_params.get('advisor') is not None:
|
||||
# advisor is enabled and starts to run
|
||||
_run_advisor(exp_params)
|
||||
else:
|
||||
# tuner (and assessor) is enabled and starts to run
|
||||
assert exp_params.get('tuner') is not None
|
||||
tuner = _create_tuner(exp_params)
|
||||
assert exp_params.get('tuner') is not None
|
||||
tuner = _create_algo(exp_params['tuner'], 'tuner')
|
||||
|
||||
if isinstance(tuner, MsgDispatcherBase): # is advisor
|
||||
logger.debug(f'Tuner {type(tuner).__name__} is advisor.')
|
||||
if exp_params.get('assessor') is not None:
|
||||
assessor = _create_assessor(exp_params)
|
||||
else:
|
||||
assessor = None
|
||||
dispatcher = MsgDispatcher(tuner, assessor)
|
||||
logger.error('Tuner {type(tuner).__name__} has built-in early stopping logic. Assessor is ignored.')
|
||||
tuner.run()
|
||||
return
|
||||
|
||||
try:
|
||||
dispatcher.run()
|
||||
tuner._on_exit()
|
||||
if assessor is not None:
|
||||
assessor._on_exit()
|
||||
except Exception as exception:
|
||||
logger.exception(exception)
|
||||
tuner._on_error()
|
||||
if assessor is not None:
|
||||
assessor._on_error()
|
||||
raise
|
||||
|
||||
|
||||
def _run_advisor(exp_params):
|
||||
if exp_params.get('advisor').get('name'):
|
||||
dispatcher = create_builtin_class_instance(
|
||||
exp_params['advisor']['name'],
|
||||
exp_params['advisor'].get('classArgs'),
|
||||
'advisors')
|
||||
if exp_params.get('assessor') is not None:
|
||||
assessor = _create_algo(exp_params['assessor'], 'assessor')
|
||||
else:
|
||||
dispatcher = create_customized_class_instance(exp_params.get('advisor'))
|
||||
if dispatcher is None:
|
||||
raise AssertionError('Failed to create Advisor instance')
|
||||
assessor = None
|
||||
dispatcher = MsgDispatcher(tuner, assessor)
|
||||
|
||||
try:
|
||||
dispatcher.run()
|
||||
except Exception as exception:
|
||||
logger.exception(exception)
|
||||
tuner._on_exit()
|
||||
if assessor is not None:
|
||||
assessor._on_exit()
|
||||
except Exception:
|
||||
tuner._on_error()
|
||||
if assessor is not None:
|
||||
assessor._on_error()
|
||||
raise
|
||||
|
||||
|
||||
def _create_tuner(exp_params):
|
||||
if exp_params['tuner'].get('name'):
|
||||
tuner = create_builtin_class_instance(
|
||||
exp_params['tuner']['name'],
|
||||
exp_params['tuner'].get('classArgs'),
|
||||
'tuners')
|
||||
def _create_algo(algo_config, algo_type):
|
||||
if algo_config.get('name'):
|
||||
algo = create_builtin_class_instance(algo_config['name'], algo_config.get('classArgs'), algo_type + 's')
|
||||
else:
|
||||
tuner = create_customized_class_instance(exp_params['tuner'])
|
||||
if tuner is None:
|
||||
raise AssertionError('Failed to create Tuner instance')
|
||||
return tuner
|
||||
|
||||
|
||||
def _create_assessor(exp_params):
|
||||
if exp_params['assessor'].get('name'):
|
||||
assessor = create_builtin_class_instance(
|
||||
exp_params['assessor']['name'],
|
||||
exp_params['assessor'].get('classArgs'),
|
||||
'assessors')
|
||||
else:
|
||||
assessor = create_customized_class_instance(exp_params['assessor'])
|
||||
if assessor is None:
|
||||
raise AssertionError('Failed to create Assessor instance')
|
||||
return assessor
|
||||
algo = create_customized_class_instance(algo_config)
|
||||
if algo is None:
|
||||
raise AssertionError(f'Failed to create {algo_type} instance')
|
||||
return algo
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -70,7 +70,7 @@ class BatchTuner(Tuner):
|
|||
]
|
||||
}
|
||||
}
|
||||
config.tuner.name = 'BatchTuner'
|
||||
config.tuner.name = 'Batch'
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -272,8 +272,8 @@ class BOHB(MsgDispatcherBase):
|
|||
|
||||
.. code-block::
|
||||
|
||||
config.advisor.name = 'BOHB'
|
||||
config.advisor.class_args = {
|
||||
config.tuner.name = 'BOHB'
|
||||
config.tuner.class_args = {
|
||||
'optimize_mode': 'maximize',
|
||||
'min_budget': 1,
|
||||
'max_budget': 27,
|
||||
|
|
|
@ -41,7 +41,7 @@ class GPClassArgsValidator(ClassArgsValidator):
|
|||
|
||||
class GPTuner(Tuner):
|
||||
"""
|
||||
GPTuner is a Bayesian Optimization method where Gaussian Process
|
||||
GP tuner is a Bayesian Optimization method where Gaussian Process
|
||||
is used for modeling loss functions.
|
||||
|
||||
Bayesian optimization works by constructing a posterior distribution of functions
|
||||
|
@ -50,7 +50,7 @@ class GPTuner(Tuner):
|
|||
and the algorithm becomes more certain of which regions in parameter space
|
||||
are worth exploring and which are not.
|
||||
|
||||
GPTuner is designed to minimize/maximize the number of steps required to find
|
||||
GP tuner is designed to minimize/maximize the number of steps required to find
|
||||
a combination of parameters that are close to the optimal combination.
|
||||
To do so, this method uses a proxy optimization problem (finding the maximum of
|
||||
the acquisition function) that, albeit still a hard problem, is cheaper
|
||||
|
@ -70,7 +70,7 @@ class GPTuner(Tuner):
|
|||
|
||||
.. code-block::
|
||||
|
||||
config.tuner.name = 'GPTuner'
|
||||
config.tuner.name = 'GP'
|
||||
config.tuner.class_args = {
|
||||
'optimize_mode': 'maximize',
|
||||
'utility': 'ei',
|
||||
|
|
|
@ -284,8 +284,8 @@ class Hyperband(MsgDispatcherBase):
|
|||
|
||||
.. code-block::
|
||||
|
||||
config.advisor.name = 'Hyperband'
|
||||
config.advisor.class_args = {
|
||||
config.tuner.name = 'Hyperband'
|
||||
config.tuner.class_args = {
|
||||
'optimize_mode': 'maximize',
|
||||
'R': 60,
|
||||
'eta': 3
|
||||
|
|
|
@ -81,7 +81,7 @@ class MetisTuner(Tuner):
|
|||
|
||||
.. code-block::
|
||||
|
||||
config.tuner.name = 'MetisTuner'
|
||||
config.tuner.name = 'Metis'
|
||||
config.tuner.class_args = {
|
||||
'optimize_mode': 'maximize'
|
||||
}
|
||||
|
|
|
@ -179,14 +179,14 @@ class PBTTuner(Tuner):
|
|||
|
||||
.. image:: ../../img/pbt.jpg
|
||||
|
||||
PBTTuner initializes a population with several trials (i.e., ``population_size``).
|
||||
PBT tuner initializes a population with several trials (i.e., ``population_size``).
|
||||
There are four steps in the above figure, each trial only runs by one step. How long is one step is controlled by trial code,
|
||||
e.g., one epoch. When a trial starts, it loads a checkpoint specified by PBTTuner and continues to run one step,
|
||||
then saves checkpoint to a directory specified by PBTTuner and exits.
|
||||
e.g., one epoch. When a trial starts, it loads a checkpoint specified by PBT tuner and continues to run one step,
|
||||
then saves checkpoint to a directory specified by PBT tuner and exits.
|
||||
The trials in a population run steps synchronously, that is, after all the trials finish the ``i``-th step,
|
||||
the ``(i+1)``-th step can be started. Exploitation and exploration of PBT are executed between two consecutive steps.
|
||||
|
||||
Two important steps to follow if you are trying to use PBTTuner:
|
||||
Two important steps to follow if you are trying to use PBT tuner:
|
||||
|
||||
1. **Provide checkpoint directory**. Since some trials need to load other trial's checkpoint,
|
||||
users should provide a directory (i.e., ``all_checkpoint_dir``) which is accessible by every trial.
|
||||
|
@ -196,7 +196,7 @@ class PBTTuner(Tuner):
|
|||
to provide a directory in a shared storage, such as NFS, Azure storage.
|
||||
|
||||
2. **Modify your trial code**. Before running a step, a trial needs to load a checkpoint,
|
||||
the checkpoint directory is specified in hyper-parameter configuration generated by PBTTuner,
|
||||
the checkpoint directory is specified in hyper-parameter configuration generated by PBT tuner,
|
||||
i.e., ``params['load_checkpoint_dir']``. Similarly, the directory for saving checkpoint is also included in the configuration,
|
||||
i.e., ``params['save_checkpoint_dir']``. Here, ``all_checkpoint_dir`` is base folder of ``load_checkpoint_dir``
|
||||
and ``save_checkpoint_dir`` whose format is ``all_checkpoint_dir/<population-id>/<step>``.
|
||||
|
@ -238,12 +238,12 @@ class PBTTuner(Tuner):
|
|||
|
||||
Examples
|
||||
--------
|
||||
Below is an example of PBTTuner configuration in experiment config file.
|
||||
Below is an example of PBT tuner configuration in experiment config file.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
tuner:
|
||||
name: PBTTuner
|
||||
name: PBT
|
||||
classArgs:
|
||||
optimize_mode: maximize
|
||||
all_checkpoint_dir: /the/path/to/store/checkpoints
|
||||
|
@ -251,7 +251,7 @@ class PBTTuner(Tuner):
|
|||
|
||||
Notes
|
||||
-----
|
||||
Assessor is not allowed if PBTTuner is used.
|
||||
Assessor is not allowed if PBT tuner is used.
|
||||
"""
|
||||
|
||||
def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factor=0.2,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Config classes for tuner/assessor/advisor algorithms.
|
||||
Config classes for tuner and assessor algorithms.
|
||||
|
||||
Use ``AlgorithmConfig`` to specify a built-in algorithm;
|
||||
use ``CustomAlgorithmConfig`` to specify a custom algorithm.
|
||||
|
|
|
@ -113,6 +113,11 @@ class ExperimentConfig(ConfigBase):
|
|||
if algo is not None and algo.name == '_none_':
|
||||
setattr(self, algo_type, None)
|
||||
|
||||
if self.advisor is not None:
|
||||
assert self.tuner is None, '"advisor" is deprecated. You should only set "tuner".'
|
||||
self.tuner = self.advisor
|
||||
self.advisor = None
|
||||
|
||||
super()._canonicalize([self])
|
||||
|
||||
if self.search_space_file is not None:
|
||||
|
@ -161,9 +166,8 @@ class ExperimentConfig(ConfigBase):
|
|||
|
||||
utils.validate_gpu_indices(self.tuner_gpu_indices)
|
||||
|
||||
tuner_cnt = (self.tuner is not None) + (self.advisor is not None)
|
||||
if tuner_cnt != 1:
|
||||
raise ValueError('ExperimentConfig: tuner and advisor must be set one')
|
||||
if self.tuner is None:
|
||||
raise ValueError('ExperimentConfig: tuner must be set')
|
||||
|
||||
def _load_search_space_file(search_space_path):
|
||||
# FIXME
|
||||
|
|
|
@ -1,12 +1,3 @@
|
|||
advisors:
|
||||
- builtinName: Hyperband
|
||||
classArgsValidator: nni.algorithms.hpo.hyperband_advisor.HyperbandClassArgsValidator
|
||||
className: nni.algorithms.hpo.hyperband_advisor.Hyperband
|
||||
source: nni
|
||||
- builtinName: BOHB
|
||||
classArgsValidator: nni.algorithms.hpo.bohb_advisor.BOHBClassArgsValidator
|
||||
className: nni.algorithms.hpo.bohb_advisor.BOHB
|
||||
source: nni
|
||||
assessors:
|
||||
- builtinName: Medianstop
|
||||
classArgsValidator: nni.algorithms.hpo.medianstop_assessor.MedianstopClassArgsValidator
|
||||
|
@ -17,7 +8,8 @@ assessors:
|
|||
className: nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor
|
||||
source: nni
|
||||
tuners:
|
||||
- builtinName: PPOTuner
|
||||
- alias: PPOTuner
|
||||
builtinName: PPO
|
||||
classArgsValidator: nni.algorithms.hpo.ppo_tuner.PPOClassArgsValidator
|
||||
className: nni.algorithms.hpo.ppo_tuner.PPOTuner
|
||||
source: nni
|
||||
|
@ -49,7 +41,8 @@ tuners:
|
|||
className: nni.algorithms.hpo.evolution_tuner.EvolutionTuner
|
||||
source: nni
|
||||
- acceptClassArgs: false
|
||||
builtinName: BatchTuner
|
||||
alias: BatchTuner
|
||||
builtinName: Batch
|
||||
className: nni.algorithms.hpo.batch_tuner.BatchTuner
|
||||
source: nni
|
||||
- acceptClassArgs: false
|
||||
|
@ -60,15 +53,18 @@ tuners:
|
|||
classArgsValidator: nni.algorithms.hpo.networkmorphism_tuner.NetworkMorphismClassArgsValidator
|
||||
className: nni.algorithms.hpo.networkmorphism_tuner.NetworkMorphismTuner
|
||||
source: nni
|
||||
- builtinName: MetisTuner
|
||||
- alias: MetisTuner
|
||||
builtinName: Metis
|
||||
classArgsValidator: nni.algorithms.hpo.metis_tuner.MetisClassArgsValidator
|
||||
className: nni.algorithms.hpo.metis_tuner.MetisTuner
|
||||
source: nni
|
||||
- builtinName: GPTuner
|
||||
- alias: GPTuner
|
||||
builtinName: GP
|
||||
classArgsValidator: nni.algorithms.hpo.gp_tuner.GPClassArgsValidator
|
||||
className: nni.algorithms.hpo.gp_tuner.GPTuner
|
||||
source: nni
|
||||
- builtinName: PBTTuner
|
||||
- alias: PBTTuner
|
||||
builtinName: PBT
|
||||
classArgsValidator: nni.algorithms.hpo.pbt_tuner.PBTClassArgsValidator
|
||||
className: nni.algorithms.hpo.pbt_tuner.PBTTuner
|
||||
source: nni
|
||||
|
@ -76,7 +72,18 @@ tuners:
|
|||
classArgsValidator: nni.algorithms.hpo.regularized_evolution_tuner.EvolutionClassArgsValidator
|
||||
className: nni.algorithms.hpo.regularized_evolution_tuner.RegularizedEvolutionTuner
|
||||
source: nni
|
||||
- builtinName: DNGOTuner
|
||||
- alias: DNGOTuner
|
||||
builtinName: DNGO
|
||||
classArgsValidator: nni.algorithms.hpo.dngo_tuner.DNGOClassArgsValidator
|
||||
className: nni.algorithms.hpo.dngo_tuner.DNGOTuner
|
||||
source: nni
|
||||
- builtinName: Hyperband
|
||||
classArgsValidator: nni.algorithms.hpo.hyperband_advisor.HyperbandClassArgsValidator
|
||||
className: nni.algorithms.hpo.hyperband_advisor.Hyperband
|
||||
isAdvisor: true
|
||||
source: nni
|
||||
- builtinName: BOHB
|
||||
classArgsValidator: nni.algorithms.hpo.bohb_advisor.BOHBClassArgsValidator
|
||||
className: nni.algorithms.hpo.bohb_advisor.BOHB
|
||||
isAdvisor: true
|
||||
source: nni
|
||||
|
|
|
@ -6,11 +6,7 @@ import logging
|
|||
import os
|
||||
|
||||
from schema import And, Optional, Or, Regex, Schema, SchemaError
|
||||
from nni.tools.package_utils.tuner_factory import (
|
||||
create_validator_instance,
|
||||
get_all_builtin_names,
|
||||
get_registered_algo_meta,
|
||||
)
|
||||
from nni.tools.package_utils.tuner_factory import create_validator_instance
|
||||
|
||||
from .common_utils import get_yml_content, print_warning
|
||||
from .constants import SCHEMA_PATH_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_TYPE_ERROR
|
||||
|
@ -73,16 +69,13 @@ class AlgoSchema:
|
|||
}
|
||||
self.builtin_name_schema = {}
|
||||
for k, n in self.builtin_keys.items():
|
||||
self.builtin_name_schema[k] = {Optional(n): setChoice(n, *get_all_builtin_names(k+'s'))}
|
||||
self.builtin_name_schema[k] = {Optional(n): setType(n, str)}
|
||||
|
||||
self.customized_keys = set(['codeDir', 'classFileName', 'className'])
|
||||
|
||||
def validate_class_args(self, class_args, algo_type, builtin_name):
|
||||
if not builtin_name or not class_args:
|
||||
return
|
||||
meta = get_registered_algo_meta(builtin_name, algo_type+'s')
|
||||
if meta and 'acceptClassArgs' in meta and meta['acceptClassArgs'] == False:
|
||||
raise SchemaError('classArgs is not allowed.')
|
||||
|
||||
logging.getLogger('nni.protocol').setLevel(logging.ERROR) # we know IPC is not there, don't complain
|
||||
validator = create_validator_instance(algo_type+'s', builtin_name)
|
||||
|
|
|
@ -1,38 +1,48 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['AlgoMeta']
|
||||
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
from typing import NamedTuple
|
||||
|
||||
from nni.typehint import Literal
|
||||
|
||||
class AlgoMeta(NamedTuple):
|
||||
name: str
|
||||
class_name: Optional[str]
|
||||
alias: str | None
|
||||
class_name: str | None
|
||||
accept_class_args: bool
|
||||
class_args: Optional[dict]
|
||||
validator_class_name: Optional[str]
|
||||
algo_type: str # 'tuner' | 'assessor' | 'advisor'
|
||||
class_args: dict | None
|
||||
validator_class_name: str | None
|
||||
algo_type: Literal['tuner', 'assessor']
|
||||
is_advisor: bool
|
||||
is_builtin: bool
|
||||
nni_version: Optional[str]
|
||||
nni_version: str | None
|
||||
|
||||
@staticmethod
|
||||
def load(meta: Dict, algo_type: Optional[str] = None) -> 'AlgoMeta':
|
||||
def load(meta: dict, algo_type: Literal['tuner', 'assessor', 'advisor'] | None = None) -> AlgoMeta:
|
||||
if algo_type is None:
|
||||
algo_type = meta['algoType']
|
||||
algo_type = meta['algoType'] # type: ignore
|
||||
return AlgoMeta(
|
||||
name=meta['builtinName'],
|
||||
class_name=meta['className'],
|
||||
accept_class_args=meta.get('acceptClassArgs', True),
|
||||
class_args=meta.get('classArgs'),
|
||||
validator_class_name=meta.get('classArgsValidator'),
|
||||
algo_type=algo_type,
|
||||
is_builtin=(meta.get('source') == 'nni'),
|
||||
nni_version=meta.get('nniVersion')
|
||||
name = meta['builtinName'],
|
||||
alias = meta.get('alias'),
|
||||
class_name = meta['className'],
|
||||
accept_class_args = meta.get('acceptClassArgs', True),
|
||||
class_args = meta.get('classArgs'),
|
||||
validator_class_name = meta.get('classArgsValidator'),
|
||||
algo_type = ('assessor' if algo_type == 'assessor' else 'tuner'),
|
||||
is_advisor = meta.get('isAdvisor', algo_type == 'advisor'),
|
||||
is_builtin = (meta.get('source') == 'nni'),
|
||||
nni_version = meta.get('nniVersion')
|
||||
)
|
||||
|
||||
def dump(self) -> Dict:
|
||||
def dump(self) -> dict:
|
||||
ret = {}
|
||||
ret['builtinName'] = self.name
|
||||
if self.alias is not None:
|
||||
ret['alias'] = self.alias
|
||||
ret['className'] = self.class_name
|
||||
if not self.accept_class_args:
|
||||
ret['acceptClassArgs'] = False
|
||||
|
@ -40,6 +50,8 @@ class AlgoMeta(NamedTuple):
|
|||
ret['classArgs'] = self.class_args
|
||||
if self.validator_class_name is not None:
|
||||
ret['classArgsValidator'] = self.validator_class_name
|
||||
if self.is_advisor:
|
||||
ret['isAdvisor'] = True
|
||||
ret['source'] = 'nni' if self.is_builtin else 'user'
|
||||
if self.nni_version is not None:
|
||||
ret['nniVersion'] = self.nni_version
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
'get_algo_meta',
|
||||
'get_all_algo_meta',
|
||||
|
@ -9,24 +11,26 @@ __all__ = [
|
|||
]
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from nni.runtime.config import get_builtin_config_file, get_config_file
|
||||
from .common import AlgoMeta
|
||||
|
||||
def get_algo_meta(name: AlgoMeta) -> Optional[AlgoMeta]:
|
||||
def get_algo_meta(name: str) -> AlgoMeta | None:
|
||||
"""
|
||||
Get meta information of a built-in or registered algorithm.
|
||||
Return None if not found.
|
||||
"""
|
||||
name = name.lower()
|
||||
for algo in get_all_algo_meta():
|
||||
if algo.name == name:
|
||||
if algo.name.lower() == name:
|
||||
return algo
|
||||
if algo.alias is not None and algo.alias.lower() == name:
|
||||
return algo
|
||||
return None
|
||||
|
||||
def get_all_algo_meta() -> List[AlgoMeta]:
|
||||
def get_all_algo_meta() -> list[AlgoMeta]:
|
||||
"""
|
||||
Get meta information of all built-in and registered algorithms.
|
||||
"""
|
||||
|
@ -64,7 +68,7 @@ def _load_config_file(path):
|
|||
algos = []
|
||||
for algo_type in ['tuner', 'assessor', 'advisor']:
|
||||
for algo in config.get(algo_type + 's', []):
|
||||
algos.append(AlgoMeta.load(algo, algo_type))
|
||||
algos.append(AlgoMeta.load(algo, algo_type)) # type: ignore
|
||||
return algos
|
||||
|
||||
def _save_custom_config(custom_algos):
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
'create_builtin_class_instance',
|
||||
'create_customized_class_instance',
|
||||
|
@ -9,38 +11,23 @@ __all__ = [
|
|||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from nni.typehint import Literal
|
||||
from . import config_manager
|
||||
|
||||
ALGO_TYPES = ['tuners', 'assessors', 'advisors']
|
||||
ALGO_TYPES = ['tuners', 'assessors']
|
||||
|
||||
def get_all_builtin_names(algo_type):
|
||||
"""Get all builtin names of registered algorithms of specified type
|
||||
|
||||
Parameters
|
||||
----------
|
||||
algo_type: str
|
||||
can be one of 'tuners', 'assessors' or 'advisors'
|
||||
|
||||
Returns: list of string
|
||||
-------
|
||||
All builtin names of specified type, for example, if algo_type is 'tuners', returns
|
||||
all builtin tuner names.
|
||||
"""
|
||||
def _get_all_builtin_names(algo_type: Literal['tuners', 'assessors']) -> list[str]:
|
||||
algos = config_manager.get_all_algo_meta()
|
||||
return [meta.name for meta in algos if meta.algo_type == algo_type.rstrip('s')]
|
||||
algos = [meta for meta in algos if meta.algo_type + 's' == algo_type]
|
||||
names = [meta.name for meta in algos] + [meta.alias for meta in algos if meta.alias is not None]
|
||||
return [name.lower() for name in names]
|
||||
|
||||
def get_registered_algo_meta(builtin_name, algo_type=None):
|
||||
def _get_registered_algo_meta(builtin_name: str) -> dict | None:
|
||||
""" Get meta information of registered algorithms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
builtin_name: str
|
||||
builtin name.
|
||||
algo_type: str | None
|
||||
can be one of 'tuners', 'assessors', 'advisors' or None
|
||||
|
||||
Returns: dict | None
|
||||
Returns
|
||||
-------
|
||||
Returns meta information of speicified builtin alogorithms, for example:
|
||||
{
|
||||
|
@ -52,8 +39,6 @@ def get_registered_algo_meta(builtin_name, algo_type=None):
|
|||
algo = config_manager.get_algo_meta(builtin_name)
|
||||
if algo is None:
|
||||
return None
|
||||
if algo_type is not None and algo.algo_type != algo_type.rstrip('s'):
|
||||
return None
|
||||
return algo.dump()
|
||||
|
||||
def parse_full_class_name(full_class_name):
|
||||
|
@ -69,7 +54,7 @@ def get_builtin_module_class_name(algo_type, builtin_name):
|
|||
Parameters
|
||||
----------
|
||||
algo_type: str
|
||||
can be one of 'tuners', 'assessors', 'advisors'
|
||||
can be one of 'tuners', 'assessors'
|
||||
builtin_name: str
|
||||
builtin name.
|
||||
|
||||
|
@ -79,7 +64,7 @@ def get_builtin_module_class_name(algo_type, builtin_name):
|
|||
"""
|
||||
assert algo_type in ALGO_TYPES
|
||||
assert builtin_name is not None
|
||||
meta = get_registered_algo_meta(builtin_name, algo_type)
|
||||
meta = _get_registered_algo_meta(builtin_name)
|
||||
if not meta:
|
||||
return None, None
|
||||
return parse_full_class_name(meta['className'])
|
||||
|
@ -90,7 +75,7 @@ def create_validator_instance(algo_type, builtin_name):
|
|||
Parameters
|
||||
----------
|
||||
algo_type: str
|
||||
can be one of 'tuners', 'assessors', 'advisors'
|
||||
can be one of 'tuners', 'assessors'
|
||||
builtin_name: str
|
||||
builtin name.
|
||||
|
||||
|
@ -101,16 +86,20 @@ def create_validator_instance(algo_type, builtin_name):
|
|||
"""
|
||||
assert algo_type in ALGO_TYPES
|
||||
assert builtin_name is not None
|
||||
meta = get_registered_algo_meta(builtin_name, algo_type)
|
||||
meta = _get_registered_algo_meta(builtin_name)
|
||||
if not meta or 'classArgsValidator' not in meta:
|
||||
return None
|
||||
module_name, class_name = parse_full_class_name(meta['classArgsValidator'])
|
||||
assert module_name is not None
|
||||
class_module = importlib.import_module(module_name)
|
||||
class_constructor = getattr(class_module, class_name)
|
||||
|
||||
return class_constructor()
|
||||
|
||||
def create_builtin_class_instance(builtin_name, input_class_args, algo_type):
|
||||
def create_builtin_class_instance(
|
||||
builtin_name: str,
|
||||
input_class_args: dict,
|
||||
algo_type: Literal['tuners', 'assessors']) -> Any:
|
||||
"""Create instance of builtin algorithms
|
||||
|
||||
Parameters
|
||||
|
@ -120,14 +109,15 @@ def create_builtin_class_instance(builtin_name, input_class_args, algo_type):
|
|||
input_class_args: dict
|
||||
kwargs for builtin class constructor
|
||||
algo_type: str
|
||||
can be one of 'tuners', 'assessors', 'advisors'
|
||||
can be one of 'tuners', 'assessors'
|
||||
|
||||
Returns: object
|
||||
-------
|
||||
Returns builtin class instance.
|
||||
"""
|
||||
assert algo_type in ALGO_TYPES
|
||||
if builtin_name not in get_all_builtin_names(algo_type):
|
||||
builtin_name = builtin_name.lower()
|
||||
if builtin_name not in _get_all_builtin_names(algo_type):
|
||||
raise RuntimeError('Builtin name is not found: {}'.format(builtin_name))
|
||||
|
||||
def parse_algo_meta(algo_meta, input_class_args):
|
||||
|
@ -150,10 +140,11 @@ def create_builtin_class_instance(builtin_name, input_class_args, algo_type):
|
|||
|
||||
return module_name, class_name, class_args
|
||||
|
||||
algo_meta = get_registered_algo_meta(builtin_name, algo_type)
|
||||
algo_meta = _get_registered_algo_meta(builtin_name)
|
||||
module_name, class_name, class_args = parse_algo_meta(algo_meta, input_class_args)
|
||||
assert module_name is not None
|
||||
|
||||
if importlib.util.find_spec(module_name) is None:
|
||||
if importlib.util.find_spec(module_name) is None: # type: ignore
|
||||
raise RuntimeError('Builtin module can not be loaded: {}'.format(module_name))
|
||||
|
||||
class_module = importlib.import_module(module_name)
|
||||
|
|
|
@ -8,7 +8,11 @@
|
|||
"nni/nas",
|
||||
"nni/retiarii",
|
||||
"nni/smartparam.py",
|
||||
"nni/tools"
|
||||
"nni/tools/annotation",
|
||||
"nni/tools/gpu_tool",
|
||||
"nni/tools/jupyter_extension",
|
||||
"nni/tools/nnictl",
|
||||
"nni/tools/trial_tool"
|
||||
],
|
||||
"reportMissingImports": false
|
||||
}
|
||||
|
|
|
@ -13,8 +13,8 @@ logLevel: warning
|
|||
tunerGpuIndices: 0
|
||||
assessor:
|
||||
name: assess
|
||||
advisor:
|
||||
className: Advisor
|
||||
tuner:
|
||||
className: Tuner
|
||||
codeDirectory: .
|
||||
classArgs: {random_seed: 0}
|
||||
trainingService:
|
||||
|
|
|
@ -72,8 +72,8 @@ detailed_canon = {
|
|||
'assessor': {
|
||||
'name': 'assess',
|
||||
},
|
||||
'advisor': {
|
||||
'className': 'Advisor',
|
||||
'tuner': {
|
||||
'className': 'Tuner',
|
||||
'codeDirectory': expand_path('assets'),
|
||||
'classArgs': {'random_seed': 0},
|
||||
},
|
||||
|
|
Загрузка…
Ссылка в новой задаче