From 2fc472477166c75d5bde95ce1ada3357f0bb24fb Mon Sep 17 00:00:00 2001 From: QuanluZhang Date: Tue, 24 May 2022 09:18:45 +0800 Subject: [PATCH] [retiarii] refactor of nas experiment (#4841) --- nni/experiment/config/base.py | 5 + nni/experiment/config/experiment_config.py | 7 +- nni/experiment/experiment.py | 65 ++-- nni/retiarii/execution/api.py | 13 +- nni/retiarii/execution/base.py | 20 +- nni/retiarii/execution/cgo_engine.py | 38 +- nni/retiarii/experiment/__init__.py | 2 + nni/retiarii/experiment/config/__init__.py | 5 + .../experiment/config/engine_config.py | 41 +++ .../experiment/config/experiment_config.py | 60 +++ nni/retiarii/experiment/pytorch.py | 343 +++++------------- nni/retiarii/integration_api.py | 5 +- nni/runtime/msg_dispatcher_base.py | 20 +- .../retiarii_test/cgo_mnasnet/base_mnasnet.py | 1 - test/retiarii_test/cgo_mnasnet/test.py | 8 +- test/ut/retiarii/test_cgo_engine.py | 22 +- test/ut/retiarii/test_engine.py | 4 +- test/ut/sdk/test_assessor.py | 2 +- test/ut/sdk/test_msg_dispatcher.py | 2 +- ts/nni_manager/core/nnimanager.ts | 5 +- 20 files changed, 345 insertions(+), 323 deletions(-) create mode 100644 nni/retiarii/experiment/config/__init__.py create mode 100644 nni/retiarii/experiment/config/engine_config.py create mode 100644 nni/retiarii/experiment/config/experiment_config.py diff --git a/nni/experiment/config/base.py b/nni/experiment/config/base.py index f3d44e063..ab8b6f061 100644 --- a/nni/experiment/config/base.py +++ b/nni/experiment/config/base.py @@ -54,6 +54,11 @@ class ConfigBase: Config objects will remember where they are loaded; therefore relative paths can be resolved smartly. If a config object is created with constructor, the base path will be current working directory. If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent. + + .. attention:: + + All the classes that inherit ``ConfigBase`` are not allowed to use ``from __future__ import annotations``, + because ``ConfigBase`` uses ``typeguard`` to perform runtime check and it does not support lazy annotations. """ def __init__(self, **kwargs): diff --git a/nni/experiment/config/experiment_config.py b/nni/experiment/config/experiment_config.py index 20216d7c2..64af5e3af 100644 --- a/nni/experiment/config/experiment_config.py +++ b/nni/experiment/config/experiment_config.py @@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase): # currently I have only seen one issue of this kind #Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True) - utils.validate_gpu_indices(self.tuner_gpu_indices) + if type(self).__name__ != 'RetiariiExeConfig': + utils.validate_gpu_indices(self.tuner_gpu_indices) - if self.tuner is None: - raise ValueError('ExperimentConfig: tuner must be set') + if self.tuner is None: + raise ValueError('ExperimentConfig: tuner must be set') def _load_search_space_file(search_space_path): # FIXME diff --git a/nni/experiment/experiment.py b/nni/experiment/experiment.py index 44a5263a1..ee0fa67b8 100644 --- a/nni/experiment/experiment.py +++ b/nni/experiment/experiment.py @@ -84,20 +84,9 @@ class Experiment: else: self.config = config_or_platform - def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None: - """ - Start the experiment in background. - - This method will raise exception on failure. - If it returns, the experiment should have been successfully started. - - Parameters - ---------- - port - The port of web UI. - debug - Whether to start in debug mode. - """ + def _start_impl(self, port: int, debug: bool, run_mode: RunMode, + tuner_command_channel: str | None, + tags: list[str] = []) -> ExperimentConfig: assert self.config is not None if run_mode is not RunMode.Detach: atexit.register(self.stop) @@ -111,7 +100,8 @@ class Experiment: log_level = 'debug' if (debug or config.log_level == 'trace') else config.log_level start_experiment_logging(self.id, log_file, cast(str, log_level)) - self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix) + self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, + self.url_prefix, tuner_command_channel, tags) assert self._proc is not None self.port = port # port will be None if start up failed @@ -124,12 +114,27 @@ class Experiment: ips = [f'http://{ip}:{port}' for ip in ips if ip] msg = 'Web portal URLs: ${CYAN}' + ' '.join(ips) _logger.info(msg) + return config - def stop(self) -> None: + def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None: """ - Stop the experiment. + Start the experiment in background. + + This method will raise exception on failure. + If it returns, the experiment should have been successfully started. + + Parameters + ---------- + port + The port of web UI. + debug + Whether to start in debug mode. + run_mode + Running the experiment in foreground or background """ - _logger.info('Stopping experiment, please wait...') + self._start_impl(port, debug, run_mode, None, []) + + def _stop_impl(self) -> None: atexit.unregister(self.stop) stop_experiment_logging(self.id) @@ -144,8 +149,24 @@ class Experiment: self.id = None # type: ignore self.port = None self._proc = None + + def stop(self) -> None: + """ + Stop the experiment. + """ + _logger.info('Stopping experiment, please wait...') + self._stop_impl() _logger.info('Experiment stopped') + def _wait_completion(self) -> bool: + while True: + status = self.get_status() + if status == 'DONE' or status == 'STOPPED': + return True + if status == 'ERROR': + return False + time.sleep(10) + def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None: """ Run the experiment. @@ -159,13 +180,7 @@ class Experiment: self.start(port, debug) if wait_completion: try: - while True: - time.sleep(10) - status = self.get_status() - if status == 'DONE' or status == 'STOPPED': - return True - if status == 'ERROR': - return False + self._wait_completion() except KeyboardInterrupt: _logger.warning('KeyboardInterrupt detected') self.stop() diff --git a/nni/retiarii/execution/api.py b/nni/retiarii/execution/api.py index d0028e5e7..01c85f81e 100644 --- a/nni/retiarii/execution/api.py +++ b/nni/retiarii/execution/api.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import time +import warnings from typing import Iterable from ..graph import Model, ModelStatus @@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener', def set_execution_engine(engine: AbstractExecutionEngine) -> None: global _execution_engine - if _execution_engine is None: - _execution_engine = engine - else: - raise RuntimeError('Execution engine is already set. ' - 'You should avoid instantiating RetiariiExperiment twice in one process. ' - 'If you are running in a Jupyter notebook, please restart the kernel.') + if _execution_engine is not None: + warnings.warn('Execution engine is already set. ' + 'You should avoid instantiating RetiariiExperiment twice in one process. ' + 'If you are running in a Jupyter notebook, please restart the kernel.', + RuntimeWarning) + _execution_engine = engine def get_execution_engine() -> AbstractExecutionEngine: diff --git a/nni/retiarii/execution/base.py b/nni/retiarii/execution/base.py index d8cda6cc8..a45299065 100644 --- a/nni/retiarii/execution/base.py +++ b/nni/retiarii/execution/base.py @@ -1,12 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import logging import os import random import string from typing import Any, Dict, Iterable, List +from nni.experiment import rest + from .interface import AbstractExecutionEngine, AbstractGraphListener from .utils import get_mutation_summary from .. import codegen, utils @@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine): Resource management is implemented in this class. """ - def __init__(self) -> None: + def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None: """ Upon initialization, advisor callbacks need to be registered. Advisor will call the callbacks when the corresponding event has been triggered. Base execution engine will get those callbacks and broadcast them to graph listener. + + Parameters + ---------- + rest_port + The port of the experiment's rest server + rest_url_prefix + The url prefix of the experiment's rest entry """ + self.port = rest_port + self.url_prefix = rest_url_prefix + self._listeners: List[AbstractGraphListener] = [] # register advisor callbacks @@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine): return self.resources def budget_exhausted(self) -> bool: - advisor = get_advisor() - return advisor.stopping + resp = rest.get(self.port, '/check-status', self.url_prefix) + return resp['status'] == 'DONE' @classmethod def pack_model_data(cls, model: Model) -> Any: diff --git a/nni/retiarii/execution/cgo_engine.py b/nni/retiarii/execution/cgo_engine.py index 4ba11987a..b959c54a4 100644 --- a/nni/retiarii/execution/cgo_engine.py +++ b/nni/retiarii/execution/cgo_engine.py @@ -1,16 +1,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import logging import os import random import string import time import threading -from typing import Iterable, List, Dict, Tuple +from typing import Iterable, List, Dict, Tuple, cast from dataclasses import dataclass from nni.common.device import GPUDevice, Device +from nni.experiment.config.training_services import RemoteConfig from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .. import codegen, utils from ..graph import Model, ModelStatus, MetricData, Node @@ -31,7 +34,6 @@ class TrialSubmission: placement: Dict[Node, Device] grouped_models: List[Model] - class CGOExecutionEngine(AbstractExecutionEngine): """ The execution engine with Cross-Graph Optimization (CGO). @@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine): Parameters ---------- - devices : List[Device] - Available devices for execution. - max_concurrency : int + training_service + The remote training service config. + max_concurrency The maximum number of trials to run concurrently. - batch_waiting_time: int + batch_waiting_time Seconds to wait for each batch of trial submission. The trials within one batch could apply cross-graph optimization. + rest_port + The port of the experiment's rest server + rest_url_prefix + The url prefix of the experiment's rest entry """ - def __init__(self, devices: List[Device] = None, + def __init__(self, training_service: RemoteConfig, max_concurrency: int = None, batch_waiting_time: int = 60, + rest_port: int | None = None, + rest_url_prefix: str | None = None ) -> None: + self.port = rest_port + self.url_prefix = rest_url_prefix + self._listeners: List[AbstractGraphListener] = [] self._running_models: Dict[int, Model] = dict() self.logical_plan_counter = 0 self.available_devices: List[Device] = [] self.max_concurrency: int = max_concurrency + + devices = self._construct_devices(training_service) for device in devices: self.available_devices.append(device) self.all_devices = self.available_devices.copy() @@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine): self._consumer_thread = threading.Thread(target=self._consume_models) self._consumer_thread.start() + def _construct_devices(self, training_service): + devices = [] + if hasattr(training_service, 'machine_list'): + for machine in cast(RemoteConfig, training_service).machine_list: + assert machine.gpu_indices is not None, \ + 'gpu_indices must be set in RemoteMachineConfig for CGO execution engine' + assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list' + for gpu_idx in machine.gpu_indices: + devices.append(GPUDevice(machine.host, gpu_idx)) + return devices + def join(self): self._stopped = True self._consumer_thread.join() diff --git a/nni/retiarii/experiment/__init__.py b/nni/retiarii/experiment/__init__.py index e69de29bb..0eca6426d 100644 --- a/nni/retiarii/experiment/__init__.py +++ b/nni/retiarii/experiment/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. \ No newline at end of file diff --git a/nni/retiarii/experiment/config/__init__.py b/nni/retiarii/experiment/config/__init__.py new file mode 100644 index 000000000..38bc42747 --- /dev/null +++ b/nni/retiarii/experiment/config/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .experiment_config import * +from .engine_config import * \ No newline at end of file diff --git a/nni/retiarii/experiment/config/engine_config.py b/nni/retiarii/experiment/config/engine_config.py new file mode 100644 index 000000000..214714762 --- /dev/null +++ b/nni/retiarii/experiment/config/engine_config.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from typing import Optional, List + +from nni.experiment.config.base import ConfigBase + +__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig', + 'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig'] + +@dataclass(init=False) +class ExecutionEngineConfig(ConfigBase): + name: str + +@dataclass(init=False) +class PyEngineConfig(ExecutionEngineConfig): + name: str = 'py' + +@dataclass(init=False) +class OneshotEngineConfig(ExecutionEngineConfig): + name: str = 'oneshot' + +@dataclass(init=False) +class BaseEngineConfig(ExecutionEngineConfig): + name: str = 'base' + # input used in GraphConverterWithShape. Currently support shape tuple only. + dummy_input: Optional[List[int]] = None + +@dataclass(init=False) +class CgoEngineConfig(ExecutionEngineConfig): + name: str = 'cgo' + max_concurrency_cgo: Optional[int] = None + batch_waiting_time: Optional[int] = None + # input used in GraphConverterWithShape. Currently support shape tuple only. + dummy_input: Optional[List[int]] = None + +@dataclass(init=False) +class BenchmarkEngineConfig(ExecutionEngineConfig): + name: str = 'benchmark' + benchmark: Optional[str] = None \ No newline at end of file diff --git a/nni/retiarii/experiment/config/experiment_config.py b/nni/retiarii/experiment/config/experiment_config.py new file mode 100644 index 000000000..18869e90f --- /dev/null +++ b/nni/retiarii/experiment/config/experiment_config.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from dataclasses import dataclass +from typing import Any, Union + +from nni.experiment.config import utils, ExperimentConfig + +from .engine_config import ExecutionEngineConfig + +__all__ = ['RetiariiExeConfig'] + +def execution_engine_config_factory(engine_name): + # FIXME: may move this function to experiment utils in future + cls = _get_ee_config_class(engine_name) + if cls is None: + raise ValueError(f'Invalid execution engine name: {engine_name}') + return cls() + +def _get_ee_config_class(engine_name): + for cls in ExecutionEngineConfig.__subclasses__(): + if cls.name == engine_name: + return cls + return None + +@dataclass(init=False) +class RetiariiExeConfig(ExperimentConfig): + # FIXME: refactor this class to inherit from a new common base class with HPO config + search_space: Any = '' + trial_code_directory: utils.PathLike = '.' + trial_command: str = '_reserved' + # new config field for NAS + execution_engine: Union[str, ExecutionEngineConfig] + + def __init__(self, training_service_platform: Union[str, None] = None, + execution_engine: Union[str, ExecutionEngineConfig] = 'py', + **kwargs): + super().__init__(training_service_platform, **kwargs) + self.execution_engine = execution_engine + + def _canonicalize(self, _parents): + msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.' + if self.search_space != '': + raise ValueError(msg.format('search_space', self.search_space)) + # TODO: maybe we should also allow users to specify trial_code_directory + if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory): + raise ValueError(msg.format('trial_code_directory', self.trial_code_directory)) + if self.trial_command != '_reserved' and \ + not self.trial_command.startswith('python3 -m nni.retiarii.trial_entry '): + raise ValueError(msg.format('trial_command', self.trial_command)) + + if isinstance(self.execution_engine, str): + self.execution_engine = execution_engine_config_factory(self.execution_engine) + if self.execution_engine.name in ('py', 'base', 'cgo'): + # TODO: replace python3 with more elegant approach + # maybe use sys.executable rendered in trial side (e.g., trial_runner) + self.trial_command = 'python3 -m nni.retiarii.trial_entry ' + self.execution_engine.name + + super()._canonicalize([self]) diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index dcc2712ce..2f81d781b 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -1,32 +1,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import atexit +from __future__ import annotations + import logging -import os -import socket -import time + import warnings -from dataclasses import dataclass -from pathlib import Path -from subprocess import Popen from threading import Thread -from typing import Any, List, Optional, Union, cast +from typing import Any, List, Union, cast import colorama -import psutil + import torch import torch.nn as nn -import nni.runtime.log -from nni.common.device import GPUDevice -from nni.experiment import Experiment, RunMode, launcher, management, rest -from nni.experiment.config import utils -from nni.experiment.config.base import ConfigBase -from nni.experiment.config.training_service import TrainingServiceConfig +from nni.experiment import Experiment, RunMode from nni.experiment.config.training_services import RemoteConfig -from nni.runtime.tuner_command_channel import TunerCommandChannel -from nni.tools.nnictl.command_utils import kill_command +from .config import ( + RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig, + PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig +) from ..codegen import model_to_pytorch_script from ..converter import convert_to_graph from ..converter.graph_gen import GraphConverterWithShape @@ -46,79 +39,7 @@ from ..strategy.utils import dry_run_for_formatted_search_space _logger = logging.getLogger(__name__) -__all__ = ['RetiariiExeConfig', 'RetiariiExperiment'] - - -@dataclass(init=False) -class RetiariiExeConfig(ConfigBase): - experiment_name: Optional[str] = None - search_space: Any = '' # TODO: remove - trial_command: str = '_reserved' - trial_code_directory: utils.PathLike = '.' - trial_concurrency: int - trial_gpu_number: int = 0 - devices: Optional[List[Union[str, GPUDevice]]] = None - max_experiment_duration: Optional[str] = None - max_trial_number: Optional[int] = None - max_concurrency_cgo: Optional[int] = None - batch_waiting_time: Optional[int] = None - nni_manager_ip: Optional[str] = None - debug: bool = False - log_level: str = 'info' - experiment_working_directory: utils.PathLike = '~/nni-experiments' - # remove configuration of tuner/assessor/advisor - training_service: TrainingServiceConfig - execution_engine: str = 'py' - - # input used in GraphConverterWithShape. Currently support shape tuple only. - dummy_input: Optional[List[int]] = None - - # input used for benchmark engine. - benchmark: Optional[str] = None - - def __init__(self, training_service_platform: Optional[str] = None, **kwargs): - super().__init__(**kwargs) - if training_service_platform is not None: - assert 'training_service' not in kwargs - self.training_service = utils.training_service_config_factory(platform=training_service_platform) - self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry py' - - def __setattr__(self, key, value): - fixed_attrs = {'search_space': '', - 'trial_command': '_reserved'} - if key in fixed_attrs and fixed_attrs[key] != value: - raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') - # 'trial_code_directory' is handled differently because the path will be converted to absolute path by us - if key == 'trial_code_directory' and not (str(value) == '.' or os.path.isabs(value)): - raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') - if key == 'execution_engine': - assert value in ['base', 'py', 'cgo', 'benchmark', 'oneshot'], f'The specified execution engine "{value}" is not supported.' - self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value - self.__dict__[key] = value - - def validate(self, initialized_tuner: bool = False) -> None: - super().validate() - - @property - def _canonical_rules(self): - return _canonical_rules - - @property - def _validation_rules(self): - return _validation_rules - - -_canonical_rules = { -} - -_validation_rules = { - 'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'), - 'trial_concurrency': lambda value: value > 0, - 'trial_gpu_number': lambda value: value >= 0, - 'max_trial_number': lambda value: value > 0, - 'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"], - 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') -} +__all__ = ['RetiariiExperiment'] def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False): @@ -252,9 +173,14 @@ class RetiariiExperiment(Experiment): ... final_model = Net() """ - def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None), - applied_mutators: List[Mutator] = cast(List[Mutator], None), strategy: BaseStrategy = cast(BaseStrategy, None), + def __init__(self, base_model: nn.Module, + evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None), + applied_mutators: List[Mutator] = cast(List[Mutator], None), + strategy: BaseStrategy = cast(BaseStrategy, None), trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)): + super().__init__(None) + self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None) + if trainer is not None: warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. ' 'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning) @@ -263,25 +189,13 @@ class RetiariiExperiment(Experiment): if evaluator is None: raise ValueError('Evaluator should not be none.') - # TODO: The current design of init interface of Retiarii experiment needs to be reviewed. - self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None) - self.port: Optional[int] = None - self.base_model = base_model self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator self.applied_mutators = applied_mutators self.strategy = strategy - from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy - if not isinstance(strategy, OneShotStrategy): - # FIXME: Dispatcher should not be created this early. - self._dispatcher = RetiariiAdvisor('_placeholder_') - else: - self._dispatcher = cast(RetiariiAdvisor, None) - self._dispatcher_thread: Optional[Thread] = None - self._proc: Optional[Popen] = None - - self.url_prefix = None + self._dispatcher = None + self._dispatcher_thread = None # check for sanity if not is_model_wrapped(base_model): @@ -290,11 +204,12 @@ class RetiariiExperiment(Experiment): 'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL, RuntimeWarning) - def _start_strategy(self): + def _run_strategy(self, config: RetiariiExeConfig): base_model_ir, self.applied_mutators = preprocess_model( self.base_model, self.evaluator, self.applied_mutators, - full_ir=self.config.execution_engine not in ['py', 'benchmark'], - dummy_input=self.config.dummy_input + full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)), + dummy_input=config.execution_engine.dummy_input + if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None ) _logger.info('Start strategy...') @@ -303,102 +218,49 @@ class RetiariiExperiment(Experiment): self.strategy.run(base_model_ir, self.applied_mutators) _logger.info('Strategy exit') # TODO: find out a proper way to show no more trial message on WebUI - # self._dispatcher.mark_experiment_as_ending() - def start(self, port: int = 8080, debug: bool = False) -> None: - """ - Start the experiment in background. - This method will raise exception on failure. - If it returns, the experiment should have been successfully started. - Parameters - ---------- - port - The port of web UI. - debug - Whether to start in debug mode. - """ - atexit.register(self.stop) - - self.config = self.config.canonical_copy() - - # we will probably need a execution engine factory to make this clean and elegant - if self.config.execution_engine == 'base': + def _create_execution_engine(self, config: RetiariiExeConfig) -> None: + #TODO: we will probably need a execution engine factory to make this clean and elegant + if isinstance(config.execution_engine, BaseEngineConfig): from ..execution.base import BaseExecutionEngine - engine = BaseExecutionEngine() - elif self.config.execution_engine == 'cgo': + engine = BaseExecutionEngine(self.port, self.url_prefix) + elif isinstance(config.execution_engine, CgoEngineConfig): from ..execution.cgo_engine import CGOExecutionEngine - assert self.config.training_service.platform == 'remote', \ + assert not isinstance(config.training_service, list) \ + and config.training_service.platform == 'remote', \ "CGO execution engine currently only supports remote training service" - assert self.config.batch_waiting_time is not None and self.config.max_concurrency_cgo is not None - devices = self._construct_devices() - engine = CGOExecutionEngine(devices, - max_concurrency=self.config.max_concurrency_cgo, - batch_waiting_time=self.config.batch_waiting_time) - elif self.config.execution_engine == 'py': + assert config.execution_engine.batch_waiting_time is not None \ + and config.execution_engine.max_concurrency_cgo is not None + engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service), + max_concurrency=config.execution_engine.max_concurrency_cgo, + batch_waiting_time=config.execution_engine.batch_waiting_time, + rest_port=self.port, + rest_url_prefix=self.url_prefix) + elif isinstance(config.execution_engine, PyEngineConfig): from ..execution.python import PurePythonExecutionEngine - engine = PurePythonExecutionEngine() - elif self.config.execution_engine == 'benchmark': + engine = PurePythonExecutionEngine(self.port, self.url_prefix) + elif isinstance(config.execution_engine, BenchmarkEngineConfig): from ..execution.benchmark import BenchmarkExecutionEngine - assert self.config.benchmark is not None, '"benchmark" must be set when benchmark execution engine is used.' - engine = BenchmarkExecutionEngine(self.config.benchmark) + assert config.execution_engine.benchmark is not None, \ + '"benchmark" must be set when benchmark execution engine is used.' + engine = BenchmarkExecutionEngine(config.execution_engine.benchmark) else: - raise ValueError(f'Unsupported engine type: {self.config.execution_engine}') + raise ValueError(f'Unsupported engine type: {config.execution_engine}') set_execution_engine(engine) - self.id = management.generate_experiment_id() + def start(self, *args, **kwargs) -> None: + """ + By design, the only different between `start` and `run` is that `start` is asynchronous, + while `run` waits the experiment to complete. RetiariiExperiment always waits the experiment + to complete as strategy runs in foreground. + """ + raise NotImplementedError('RetiariiExperiment is not supposed to provide `start` method') - log_file = Path(self.config.experiment_working_directory, self.id, 'log', 'experiment.log') - log_file.parent.mkdir(parents=True, exist_ok=True) - log_level = 'debug' if (debug or self.config.log_level == 'trace') else self.config.log_level - nni.runtime.log.start_experiment_logging(self.id, log_file, cast(str, log_level)) - - ws_url = f'ws://localhost:{port}/tuner' - self._proc = launcher.start_experiment('create', self.id, self.config, port, debug, # type: ignore - RunMode.Background, None, ws_url, ['retiarii']) - assert self._proc is not None - - self.port = port # port will be None if start up failed - - # dispatcher must be launched after pipe initialized - # the logic to launch dispatcher in background should be refactored into dispatcher api - self._dispatcher = self._create_dispatcher() - if self._dispatcher is not None: - self._dispatcher._channel = TunerCommandChannel(ws_url) - self._dispatcher_thread = Thread(target=self._dispatcher.run) - self._dispatcher_thread.start() - - ips = [self.config.nni_manager_ip] - for interfaces in psutil.net_if_addrs().values(): - for interface in interfaces: - if interface.family == socket.AF_INET: - ips.append(interface.address) - ips = [f'http://{ip}:{port}' for ip in ips if ip] - msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL - _logger.info(msg) - - exp_status_checker = Thread(target=self._check_exp_status) - exp_status_checker.start() - self._start_strategy() - # TODO: the experiment should be completed, when strategy exits and there is no running job - _logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...') - exp_status_checker.join() - - def _construct_devices(self): - devices = [] - if hasattr(self.config.training_service, 'machine_list'): - for machine in cast(RemoteConfig, self.config.training_service).machine_list: - assert machine.gpu_indices is not None, \ - 'gpu_indices must be set in RemoteMachineConfig for CGO execution engine' - assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list' - for gpu_idx in machine.gpu_indices: - devices.append(GPUDevice(machine.host, gpu_idx)) - return devices - - def _create_dispatcher(self): - return self._dispatcher - - def run(self, config: Optional[RetiariiExeConfig] = None, port: int = 8080, debug: bool = False) -> None: + def run(self, + config: RetiariiExeConfig | None = None, + port: int = 8080, + debug: bool = False) -> None: """ Run the experiment. This function will block until experiment finish or error. @@ -410,75 +272,47 @@ class RetiariiExperiment(Experiment): # 'In case you want to stick to the old implementation, ' # 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning) self.evaluator.fit() + return if config is None: warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, ' 'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning) - config = RetiariiExeConfig() - config.execution_engine = 'oneshot' + self.config = RetiariiExeConfig() + self.config.execution_engine = OneshotEngineConfig() + else: + self.config = config - if config.execution_engine == 'oneshot': + if isinstance(self.config.execution_engine, OneshotEngineConfig) \ + or (isinstance(self.config.execution_engine, str) and self.config.execution_engine == 'oneshot'): + # this is hacky, will be refactored when oneshot can run on training services base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True) self.strategy.run(base_model_ir, self.applied_mutators) else: - assert config is not None, 'You are using classic search mode, config cannot be None!' - self.config = config - self.start(port, debug) - - def _check_exp_status(self) -> bool: - """ - Run the experiment. - This function will block until experiment finish or error. - Return `True` when experiment done; or return `False` when experiment failed. - """ - assert self._proc is not None - try: - while True: - time.sleep(10) - # this if is to deal with the situation that - # nnimanager is cleaned up by ctrl+c first - if self._proc.poll() is None: - status = self.get_status() - else: - return False - if status == 'DONE' or status == 'STOPPED': - return True - if status == 'ERROR': - return False - except KeyboardInterrupt: - _logger.warning('KeyboardInterrupt detected') - finally: - self.stop() - raise RuntimeError('Check experiment status failed.') + ws_url = f'ws://localhost:{port}/tuner' + canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii']) + canonicalized_config = cast(RetiariiExeConfig, canonicalized_config) + self._dispatcher = RetiariiAdvisor(ws_url) + self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True) + self._dispatcher_thread.start() + # FIXME: engine cannot be created twice + self._create_execution_engine(canonicalized_config) + try: + self._run_strategy(canonicalized_config) + # FIXME: move this logic to strategy with a new API provided by execution engine + self._wait_completion() + except KeyboardInterrupt: + _logger.warning('KeyboardInterrupt detected') + self.stop() + _logger.info('Search process is done, the experiment is still alive, `stop()` can terminate the experiment.') def stop(self) -> None: """ Stop background experiment. """ _logger.info('Stopping experiment, please wait...') - atexit.unregister(self.stop) - - # stop strategy first - if self._dispatcher_thread is not None: - self._dispatcher.stopping = True - self._dispatcher_thread.join(timeout=1) - - if self.id is not None: - nni.runtime.log.stop_experiment_logging(self.id) - if self._proc is not None: - try: - # this if is to deal with the situation that - # nnimanager is cleaned up by ctrl+c first - if self._proc.poll() is None: - rest.delete(self.port, '/experiment') - except Exception as e: - _logger.exception(e) - _logger.warning('Cannot gracefully stop experiment, killing NNI process...') - kill_command(self._proc.pid) - - self.id = cast(str, None) - self.port = cast(int, None) - self._proc = None + self._stop_impl() + if self._dispatcher_thread: + self._dispatcher_thread.join() self._dispatcher = cast(RetiariiAdvisor, None) self._dispatcher_thread = None _logger.info('Experiment stopped') @@ -502,8 +336,11 @@ class RetiariiExperiment(Experiment): If ``code``, the python code of model will be returned. If ``dict``, the mutation history will be returned. """ + # TODO: the base class may also need this method if formatter == 'code': - assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.' + config = self.config.canonical_copy() + assert not isinstance(config.execution_engine, PyEngineConfig), \ + 'You should use `dict` formatter when using Python execution engine.' if isinstance(self.evaluator, BaseOneShotTrainer): assert top_k == 1, 'Only support top_k is 1 for now.' return self.evaluator.export() @@ -520,9 +357,3 @@ class RetiariiExperiment(Experiment): return [model_to_pytorch_script(model) for model in all_models[:top_k]] elif formatter == 'dict': return [get_mutation_dict(model) for model in all_models[:top_k]] - - def retrain_model(self, model): - """ - this function retrains the exported model, and test it to output test accuracy - """ - raise NotImplementedError diff --git a/nni/retiarii/integration_api.py b/nni/retiarii/integration_api.py index dfc77bdc2..643758ec2 100644 --- a/nni/retiarii/integration_api.py +++ b/nni/retiarii/integration_api.py @@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor': def register_advisor(advisor: 'RetiariiAdvisor'): global _advisor - assert _advisor is None + if _advisor is not None: + warnings.warn('Advisor is already set.' + 'You should avoid instantiating RetiariiExperiment twice in one proces.' + 'If you are running in a Jupyter notebook, please restart the kernel.') _advisor = advisor diff --git a/nni/runtime/msg_dispatcher_base.py b/nni/runtime/msg_dispatcher_base.py index 83831d812..6873ff240 100644 --- a/nni/runtime/msg_dispatcher_base.py +++ b/nni/runtime/msg_dispatcher_base.py @@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True class MsgDispatcherBase(Recoverable): - """This is where tuners and assessors are not defined yet. + """ + This is where tuners and assessors are not defined yet. Inherits this class to make your own advisor. + + .. note:: + + The class inheriting MsgDispatcherBase should be instantiated + after nnimanager (rest server) is started, so that the object + is ready to use right after its instantiation. """ def __init__(self, command_channel_url=None): @@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable): if command_channel_url is None: command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL self._channel = TunerCommandChannel(command_channel_url) + # NOTE: `connect()` should be put in __init__. First, this `connect()` affects nnimanager's + # starting process, without `connect()` nnimanager is blocked in `dispatcher.init()`. + # Second, nas experiment uses a thread to execute `run()` of this class, thus, there is + # no way to know when the websocket between nnimanager and dispatcher is built. The following + # logic may crash is websocket is not built. One example is updating search space. If updating + # search space too soon, as the websocket has not been built, the rest api of updating search + # space will timeout. + # FIXME: this is making unittest happy + if not command_channel_url.startswith('ws://_unittest_'): + self._channel.connect() self.default_command_queue = Queue() self.assessor_command_queue = Queue() self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,)) @@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable): """ _logger.info('Dispatcher started') - self._channel.connect() self.default_worker.start() self.assessor_worker.start() diff --git a/test/retiarii_test/cgo_mnasnet/base_mnasnet.py b/test/retiarii_test/cgo_mnasnet/base_mnasnet.py index 3e76d0bf7..3cbb7f6c0 100644 --- a/test/retiarii_test/cgo_mnasnet/base_mnasnet.py +++ b/test/retiarii_test/cgo_mnasnet/base_mnasnet.py @@ -4,7 +4,6 @@ import warnings import torch import torch.nn as torch_nn -from torchvision.models.utils import load_state_dict_from_url import torch.nn.functional as F import sys diff --git a/test/retiarii_test/cgo_mnasnet/test.py b/test/retiarii_test/cgo_mnasnet/test.py index eac4956f3..651591d51 100644 --- a/test/retiarii_test/cgo_mnasnet/test.py +++ b/test/retiarii_test/cgo_mnasnet/test.py @@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo from nni.retiarii import serialize from base_mnasnet import MNASNet from nni.experiment import RemoteMachineConfig -from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig +from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig, CgoEngineConfig from nni.retiarii.strategy import TPEStrategy from torchvision import transforms from torchvision.datasets import CIFAR10 @@ -59,8 +59,6 @@ if __name__ == '__main__': exp_config.max_trial_number = 10 exp_config.trial_gpu_number = 1 exp_config.training_service.reuse_mode = True - exp_config.max_concurrency_cgo = 3 - exp_config.batch_waiting_time = 0 rm_conf = RemoteMachineConfig() rm_conf.host = '127.0.0.1' @@ -73,6 +71,6 @@ if __name__ == '__main__': rm_conf.max_trial_number_per_gpu = 3 exp_config.training_service.machine_list = [rm_conf] - exp_config.execution_engine = 'cgo' + exp_config.execution_engine = CgoEngineConfig(max_concurrency_cgo = 3, batch_waiting_time = 0) - exp.run(exp_config, 8099) \ No newline at end of file + exp.run(exp_config, 8099) diff --git a/test/ut/retiarii/test_cgo_engine.py b/test/ut/retiarii/test_cgo_engine.py index 67dde0938..5a3605eb2 100644 --- a/test/ut/retiarii/test_cgo_engine.py +++ b/test/ut/retiarii/test_cgo_engine.py @@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything from pathlib import Path import nni +from nni.experiment.config import RemoteConfig, RemoteMachineConfig import nni.runtime.platform.test from nni.runtime.tuner_command_channel import legacy as protocol import json @@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase): opt = DedupInputOptimizer() opt.convert(lp) - advisor = RetiariiAdvisor('ws://_placeholder_') + advisor = RetiariiAdvisor('ws://_unittest_placeholder_') advisor._channel = protocol.LegacyCommandChannel() advisor.default_worker.start() advisor.assessor_worker.start() - available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)] - cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0) + remote = RemoteConfig(machine_list=[]) + remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3])) + cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0) phy_models = cgo._assemble(lp) self.assertTrue(len(phy_models) == 1) @@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase): opt = DedupInputOptimizer() opt.convert(lp) - advisor = RetiariiAdvisor('ws://_placeholder_') + advisor = RetiariiAdvisor('ws://_unittest_placeholder_') advisor._channel = protocol.LegacyCommandChannel() advisor.default_worker.start() advisor.assessor_worker.start() - available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)] - cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0) + remote = RemoteConfig(machine_list=[]) + remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1])) + cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0) phy_models = cgo._assemble(lp) self.assertTrue(len(phy_models) == 2) @@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase): models = _load_mnist(2) - advisor = RetiariiAdvisor('ws://_placeholder_') + advisor = RetiariiAdvisor('ws://_unittest_placeholder_') advisor._channel = protocol.LegacyCommandChannel() advisor.default_worker.start() advisor.assessor_worker.start() - cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1), - GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0) + remote = RemoteConfig(machine_list=[]) + remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3])) + cgo_engine = CGOExecutionEngine(training_service=remote, batch_waiting_time=0) set_execution_engine(cgo_engine) submit_models(*models) time.sleep(3) diff --git a/test/ut/retiarii/test_engine.py b/test/ut/retiarii/test_engine.py index 8e8f050c1..c8cd760b8 100644 --- a/test/ut/retiarii/test_engine.py +++ b/test/ut/retiarii/test_engine.py @@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase): def test_base_execution_engine(self): nni.retiarii.integration_api._advisor = None nni.retiarii.execution.api._execution_engine = None - advisor = RetiariiAdvisor('ws://_placeholder_') + advisor = RetiariiAdvisor('ws://_unittest_placeholder_') advisor._channel = LegacyCommandChannel() advisor.default_worker.start() advisor.assessor_worker.start() @@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase): def test_py_execution_engine(self): nni.retiarii.integration_api._advisor = None nni.retiarii.execution.api._execution_engine = None - advisor = RetiariiAdvisor('ws://_placeholder_') + advisor = RetiariiAdvisor('ws://_unittest_placeholder_') advisor._channel = LegacyCommandChannel() advisor.default_worker.start() advisor.assessor_worker.start() diff --git a/test/ut/sdk/test_assessor.py b/test/ut/sdk/test_assessor.py index 0d5e07802..48c2c0332 100644 --- a/test/ut/sdk/test_assessor.py +++ b/test/ut/sdk/test_assessor.py @@ -57,7 +57,7 @@ class AssessorTestCase(TestCase): _restore_io() assessor = NaiveAssessor() - dispatcher = MsgDispatcher('ws://_placeholder_', None, assessor) + dispatcher = MsgDispatcher('ws://_unittest_placeholder_', None, assessor) dispatcher._channel = LegacyCommandChannel() msg_dispatcher_base._worker_fast_exit_on_terminate = False diff --git a/test/ut/sdk/test_msg_dispatcher.py b/test/ut/sdk/test_msg_dispatcher.py index 356308501..643d4d9b7 100644 --- a/test/ut/sdk/test_msg_dispatcher.py +++ b/test/ut/sdk/test_msg_dispatcher.py @@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase): _restore_io() tuner = NaiveTuner() - dispatcher = MsgDispatcher('ws://_placeholder_', tuner) + dispatcher = MsgDispatcher('ws://_unittest_placeholder_', tuner) dispatcher._channel = LegacyCommandChannel() msg_dispatcher_base._worker_fast_exit_on_terminate = False diff --git a/ts/nni_manager/core/nnimanager.ts b/ts/nni_manager/core/nnimanager.ts index 7ad5a0130..8bce93e86 100644 --- a/ts/nni_manager/core/nnimanager.ts +++ b/ts/nni_manager/core/nnimanager.ts @@ -303,8 +303,11 @@ class NNIManager implements Manager { } this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener); + // NOTE: this sending TERMINATE should be out of the if clause, + // because when python dispatcher is started before nnimanager + // this.dispatcherPid would not have a valid value (i.e., not >0). + this.dispatcher.sendCommand(TERMINATE); if (this.dispatcherPid > 0) { - this.dispatcher.sendCommand(TERMINATE); // gracefully terminate tuner and assessor here, wait at most 30 seconds. for (let i: number = 0; i < 30; i++) { if (!await isAlive(this.dispatcherPid)) {