[retiarii] refactor of nas experiment (#4841)

This commit is contained in:
QuanluZhang 2022-05-24 09:18:45 +08:00 коммит произвёл GitHub
Родитель c80bda297e
Коммит 2fc4724771
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
20 изменённых файлов: 345 добавлений и 323 удалений

Просмотреть файл

@ -54,6 +54,11 @@ class ConfigBase:
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
If a config object is created with constructor, the base path will be current working directory.
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
.. attention::
All the classes that inherit ``ConfigBase`` are not allowed to use ``from __future__ import annotations``,
because ``ConfigBase`` uses ``typeguard`` to perform runtime check and it does not support lazy annotations.
"""
def __init__(self, **kwargs):

Просмотреть файл

@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase):
# currently I have only seen one issue of this kind
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
utils.validate_gpu_indices(self.tuner_gpu_indices)
if type(self).__name__ != 'RetiariiExeConfig':
utils.validate_gpu_indices(self.tuner_gpu_indices)
if self.tuner is None:
raise ValueError('ExperimentConfig: tuner must be set')
if self.tuner is None:
raise ValueError('ExperimentConfig: tuner must be set')
def _load_search_space_file(search_space_path):
# FIXME

Просмотреть файл

@ -84,20 +84,9 @@ class Experiment:
else:
self.config = config_or_platform
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
def _start_impl(self, port: int, debug: bool, run_mode: RunMode,
tuner_command_channel: str | None,
tags: list[str] = []) -> ExperimentConfig:
assert self.config is not None
if run_mode is not RunMode.Detach:
atexit.register(self.stop)
@ -111,7 +100,8 @@ class Experiment:
log_level = 'debug' if (debug or config.log_level == 'trace') else config.log_level
start_experiment_logging(self.id, log_file, cast(str, log_level))
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix)
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode,
self.url_prefix, tuner_command_channel, tags)
assert self._proc is not None
self.port = port # port will be None if start up failed
@ -124,12 +114,27 @@ class Experiment:
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web portal URLs: ${CYAN}' + ' '.join(ips)
_logger.info(msg)
return config
def stop(self) -> None:
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
"""
Stop the experiment.
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
run_mode
Running the experiment in foreground or background
"""
_logger.info('Stopping experiment, please wait...')
self._start_impl(port, debug, run_mode, None, [])
def _stop_impl(self) -> None:
atexit.unregister(self.stop)
stop_experiment_logging(self.id)
@ -144,8 +149,24 @@ class Experiment:
self.id = None # type: ignore
self.port = None
self._proc = None
def stop(self) -> None:
"""
Stop the experiment.
"""
_logger.info('Stopping experiment, please wait...')
self._stop_impl()
_logger.info('Experiment stopped')
def _wait_completion(self) -> bool:
while True:
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
time.sleep(10)
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
"""
Run the experiment.
@ -159,13 +180,7 @@ class Experiment:
self.start(port, debug)
if wait_completion:
try:
while True:
time.sleep(10)
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
self._wait_completion()
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
self.stop()

Просмотреть файл

@ -2,6 +2,7 @@
# Licensed under the MIT license.
import time
import warnings
from typing import Iterable
from ..graph import Model, ModelStatus
@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine
if _execution_engine is None:
_execution_engine = engine
else:
raise RuntimeError('Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.')
if _execution_engine is not None:
warnings.warn('Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.',
RuntimeWarning)
_execution_engine = engine
def get_execution_engine() -> AbstractExecutionEngine:

Просмотреть файл

@ -1,12 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import os
import random
import string
from typing import Any, Dict, Iterable, List
from nni.experiment import rest
from .interface import AbstractExecutionEngine, AbstractGraphListener
from .utils import get_mutation_summary
from .. import codegen, utils
@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Resource management is implemented in this class.
"""
def __init__(self) -> None:
def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None:
"""
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
Parameters
----------
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
# register advisor callbacks
@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
return self.resources
def budget_exhausted(self) -> bool:
advisor = get_advisor()
return advisor.stopping
resp = rest.get(self.port, '/check-status', self.url_prefix)
return resp['status'] == 'DONE'
@classmethod
def pack_model_data(cls, model: Model) -> Any:

Просмотреть файл

@ -1,16 +1,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import os
import random
import string
import time
import threading
from typing import Iterable, List, Dict, Tuple
from typing import Iterable, List, Dict, Tuple, cast
from dataclasses import dataclass
from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Node
@ -31,7 +34,6 @@ class TrialSubmission:
placement: Dict[Node, Device]
grouped_models: List[Model]
class CGOExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with Cross-Graph Optimization (CGO).
@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine):
Parameters
----------
devices : List[Device]
Available devices for execution.
max_concurrency : int
training_service
The remote training service config.
max_concurrency
The maximum number of trials to run concurrently.
batch_waiting_time: int
batch_waiting_time
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
def __init__(self, devices: List[Device] = None,
def __init__(self, training_service: RemoteConfig,
max_concurrency: int = None,
batch_waiting_time: int = 60,
rest_port: int | None = None,
rest_url_prefix: str | None = None
) -> None:
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency
devices = self._construct_devices(training_service)
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine):
self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def join(self):
self._stopped = True
self._consumer_thread.join()

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .experiment_config import *
from .engine_config import *

Просмотреть файл

@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Optional, List
from nni.experiment.config.base import ConfigBase
__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig',
'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig']
@dataclass(init=False)
class ExecutionEngineConfig(ConfigBase):
name: str
@dataclass(init=False)
class PyEngineConfig(ExecutionEngineConfig):
name: str = 'py'
@dataclass(init=False)
class OneshotEngineConfig(ExecutionEngineConfig):
name: str = 'oneshot'
@dataclass(init=False)
class BaseEngineConfig(ExecutionEngineConfig):
name: str = 'base'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class CgoEngineConfig(ExecutionEngineConfig):
name: str = 'cgo'
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class BenchmarkEngineConfig(ExecutionEngineConfig):
name: str = 'benchmark'
benchmark: Optional[str] = None

Просмотреть файл

@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from dataclasses import dataclass
from typing import Any, Union
from nni.experiment.config import utils, ExperimentConfig
from .engine_config import ExecutionEngineConfig
__all__ = ['RetiariiExeConfig']
def execution_engine_config_factory(engine_name):
# FIXME: may move this function to experiment utils in future
cls = _get_ee_config_class(engine_name)
if cls is None:
raise ValueError(f'Invalid execution engine name: {engine_name}')
return cls()
def _get_ee_config_class(engine_name):
for cls in ExecutionEngineConfig.__subclasses__():
if cls.name == engine_name:
return cls
return None
@dataclass(init=False)
class RetiariiExeConfig(ExperimentConfig):
# FIXME: refactor this class to inherit from a new common base class with HPO config
search_space: Any = ''
trial_code_directory: utils.PathLike = '.'
trial_command: str = '_reserved'
# new config field for NAS
execution_engine: Union[str, ExecutionEngineConfig]
def __init__(self, training_service_platform: Union[str, None] = None,
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs):
super().__init__(training_service_platform, **kwargs)
self.execution_engine = execution_engine
def _canonicalize(self, _parents):
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space))
# TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
if self.trial_command != '_reserved' and \
not self.trial_command.startswith('python3 -m nni.retiarii.trial_entry '):
raise ValueError(msg.format('trial_command', self.trial_command))
if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine)
if self.execution_engine.name in ('py', 'base', 'cgo'):
# TODO: replace python3 with more elegant approach
# maybe use sys.executable rendered in trial side (e.g., trial_runner)
self.trial_command = 'python3 -m nni.retiarii.trial_entry ' + self.execution_engine.name
super()._canonicalize([self])

Просмотреть файл

@ -1,32 +1,25 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import atexit
from __future__ import annotations
import logging
import os
import socket
import time
import warnings
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from threading import Thread
from typing import Any, List, Optional, Union, cast
from typing import Any, List, Union, cast
import colorama
import psutil
import torch
import torch.nn as nn
import nni.runtime.log
from nni.common.device import GPUDevice
from nni.experiment import Experiment, RunMode, launcher, management, rest
from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment import Experiment, RunMode
from nni.experiment.config.training_services import RemoteConfig
from nni.runtime.tuner_command_channel import TunerCommandChannel
from nni.tools.nnictl.command_utils import kill_command
from .config import (
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
)
from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape
@ -46,79 +39,7 @@ from ..strategy.utils import dry_run_for_formatted_search_space
_logger = logging.getLogger(__name__)
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment']
@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove
trial_command: str = '_reserved'
trial_code_directory: utils.PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
devices: Optional[List[Union[str, GPUDevice]]] = None
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
nni_manager_ip: Optional[str] = None
debug: bool = False
log_level: str = 'info'
experiment_working_directory: utils.PathLike = '~/nni-experiments'
# remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig
execution_engine: str = 'py'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
# input used for benchmark engine.
benchmark: Optional[str] = None
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = utils.training_service_config_factory(platform=training_service_platform)
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry py'
def __setattr__(self, key, value):
fixed_attrs = {'search_space': '',
'trial_command': '_reserved'}
if key in fixed_attrs and fixed_attrs[key] != value:
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
if key == 'trial_code_directory' and not (str(value) == '.' or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
if key == 'execution_engine':
assert value in ['base', 'py', 'cgo', 'benchmark', 'oneshot'], f'The specified execution engine "{value}" is not supported.'
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
self.__dict__[key] = value
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
@property
def _canonical_rules(self):
return _canonical_rules
@property
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
}
_validation_rules = {
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
'trial_concurrency': lambda value: value > 0,
'trial_gpu_number': lambda value: value >= 0,
'max_trial_number': lambda value: value > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
__all__ = ['RetiariiExperiment']
def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
@ -252,9 +173,14 @@ class RetiariiExperiment(Experiment):
... final_model = Net()
"""
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None), strategy: BaseStrategy = cast(BaseStrategy, None),
def __init__(self, base_model: nn.Module,
evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None),
strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
super().__init__(None)
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
@ -263,25 +189,13 @@ class RetiariiExperiment(Experiment):
if evaluator is None:
raise ValueError('Evaluator should not be none.')
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
self.port: Optional[int] = None
self.base_model = base_model
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
self.applied_mutators = applied_mutators
self.strategy = strategy
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
if not isinstance(strategy, OneShotStrategy):
# FIXME: Dispatcher should not be created this early.
self._dispatcher = RetiariiAdvisor('_placeholder_')
else:
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self.url_prefix = None
self._dispatcher = None
self._dispatcher_thread = None
# check for sanity
if not is_model_wrapped(base_model):
@ -290,11 +204,12 @@ class RetiariiExperiment(Experiment):
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
def _start_strategy(self):
def _run_strategy(self, config: RetiariiExeConfig):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.evaluator, self.applied_mutators,
full_ir=self.config.execution_engine not in ['py', 'benchmark'],
dummy_input=self.config.dummy_input
full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
dummy_input=config.execution_engine.dummy_input
if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
)
_logger.info('Start strategy...')
@ -303,102 +218,49 @@ class RetiariiExperiment(Experiment):
self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI
# self._dispatcher.mark_experiment_as_ending()
def start(self, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
atexit.register(self.stop)
self.config = self.config.canonical_copy()
# we will probably need a execution engine factory to make this clean and elegant
if self.config.execution_engine == 'base':
def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
#TODO: we will probably need a execution engine factory to make this clean and elegant
if isinstance(config.execution_engine, BaseEngineConfig):
from ..execution.base import BaseExecutionEngine
engine = BaseExecutionEngine()
elif self.config.execution_engine == 'cgo':
engine = BaseExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig):
from ..execution.cgo_engine import CGOExecutionEngine
assert self.config.training_service.platform == 'remote', \
assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert self.config.batch_waiting_time is not None and self.config.max_concurrency_cgo is not None
devices = self._construct_devices()
engine = CGOExecutionEngine(devices,
max_concurrency=self.config.max_concurrency_cgo,
batch_waiting_time=self.config.batch_waiting_time)
elif self.config.execution_engine == 'py':
assert config.execution_engine.batch_waiting_time is not None \
and config.execution_engine.max_concurrency_cgo is not None
engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service),
max_concurrency=config.execution_engine.max_concurrency_cgo,
batch_waiting_time=config.execution_engine.batch_waiting_time,
rest_port=self.port,
rest_url_prefix=self.url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine()
elif self.config.execution_engine == 'benchmark':
engine = PurePythonExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from ..execution.benchmark import BenchmarkExecutionEngine
assert self.config.benchmark is not None, '"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(self.config.benchmark)
assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
else:
raise ValueError(f'Unsupported engine type: {self.config.execution_engine}')
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
set_execution_engine(engine)
self.id = management.generate_experiment_id()
def start(self, *args, **kwargs) -> None:
"""
By design, the only different between `start` and `run` is that `start` is asynchronous,
while `run` waits the experiment to complete. RetiariiExperiment always waits the experiment
to complete as strategy runs in foreground.
"""
raise NotImplementedError('RetiariiExperiment is not supposed to provide `start` method')
log_file = Path(self.config.experiment_working_directory, self.id, 'log', 'experiment.log')
log_file.parent.mkdir(parents=True, exist_ok=True)
log_level = 'debug' if (debug or self.config.log_level == 'trace') else self.config.log_level
nni.runtime.log.start_experiment_logging(self.id, log_file, cast(str, log_level))
ws_url = f'ws://localhost:{port}/tuner'
self._proc = launcher.start_experiment('create', self.id, self.config, port, debug, # type: ignore
RunMode.Background, None, ws_url, ['retiarii'])
assert self._proc is not None
self.port = port # port will be None if start up failed
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
self._dispatcher = self._create_dispatcher()
if self._dispatcher is not None:
self._dispatcher._channel = TunerCommandChannel(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start()
ips = [self.config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values():
for interface in interfaces:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg)
exp_status_checker = Thread(target=self._check_exp_status)
exp_status_checker.start()
self._start_strategy()
# TODO: the experiment should be completed, when strategy exits and there is no running job
_logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...')
exp_status_checker.join()
def _construct_devices(self):
devices = []
if hasattr(self.config.training_service, 'machine_list'):
for machine in cast(RemoteConfig, self.config.training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def _create_dispatcher(self):
return self._dispatcher
def run(self, config: Optional[RetiariiExeConfig] = None, port: int = 8080, debug: bool = False) -> None:
def run(self,
config: RetiariiExeConfig | None = None,
port: int = 8080,
debug: bool = False) -> None:
"""
Run the experiment.
This function will block until experiment finish or error.
@ -410,75 +272,47 @@ class RetiariiExperiment(Experiment):
# 'In case you want to stick to the old implementation, '
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self.evaluator.fit()
return
if config is None:
warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, '
'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning)
config = RetiariiExeConfig()
config.execution_engine = 'oneshot'
self.config = RetiariiExeConfig()
self.config.execution_engine = OneshotEngineConfig()
else:
self.config = config
if config.execution_engine == 'oneshot':
if isinstance(self.config.execution_engine, OneshotEngineConfig) \
or (isinstance(self.config.execution_engine, str) and self.config.execution_engine == 'oneshot'):
# this is hacky, will be refactored when oneshot can run on training services
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True)
self.strategy.run(base_model_ir, self.applied_mutators)
else:
assert config is not None, 'You are using classic search mode, config cannot be None!'
self.config = config
self.start(port, debug)
def _check_exp_status(self) -> bool:
"""
Run the experiment.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
assert self._proc is not None
try:
while True:
time.sleep(10)
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if self._proc.poll() is None:
status = self.get_status()
else:
return False
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
raise RuntimeError('Check experiment status failed.')
ws_url = f'ws://localhost:{port}/tuner'
canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
canonicalized_config = cast(RetiariiExeConfig, canonicalized_config)
self._dispatcher = RetiariiAdvisor(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True)
self._dispatcher_thread.start()
# FIXME: engine cannot be created twice
self._create_execution_engine(canonicalized_config)
try:
self._run_strategy(canonicalized_config)
# FIXME: move this logic to strategy with a new API provided by execution engine
self._wait_completion()
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
self.stop()
_logger.info('Search process is done, the experiment is still alive, `stop()` can terminate the experiment.')
def stop(self) -> None:
"""
Stop background experiment.
"""
_logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop)
# stop strategy first
if self._dispatcher_thread is not None:
self._dispatcher.stopping = True
self._dispatcher_thread.join(timeout=1)
if self.id is not None:
nni.runtime.log.stop_experiment_logging(self.id)
if self._proc is not None:
try:
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if self._proc.poll() is None:
rest.delete(self.port, '/experiment')
except Exception as e:
_logger.exception(e)
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid)
self.id = cast(str, None)
self.port = cast(int, None)
self._proc = None
self._stop_impl()
if self._dispatcher_thread:
self._dispatcher_thread.join()
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None
_logger.info('Experiment stopped')
@ -502,8 +336,11 @@ class RetiariiExperiment(Experiment):
If ``code``, the python code of model will be returned.
If ``dict``, the mutation history will be returned.
"""
# TODO: the base class may also need this method
if formatter == 'code':
assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.'
config = self.config.canonical_copy()
assert not isinstance(config.execution_engine, PyEngineConfig), \
'You should use `dict` formatter when using Python execution engine.'
if isinstance(self.evaluator, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.evaluator.export()
@ -520,9 +357,3 @@ class RetiariiExperiment(Experiment):
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]]
def retrain_model(self, model):
"""
this function retrains the exported model, and test it to output test accuracy
"""
raise NotImplementedError

Просмотреть файл

@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor':
def register_advisor(advisor: 'RetiariiAdvisor'):
global _advisor
assert _advisor is None
if _advisor is not None:
warnings.warn('Advisor is already set.'
'You should avoid instantiating RetiariiExperiment twice in one proces.'
'If you are running in a Jupyter notebook, please restart the kernel.')
_advisor = advisor

Просмотреть файл

@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable):
"""This is where tuners and assessors are not defined yet.
"""
This is where tuners and assessors are not defined yet.
Inherits this class to make your own advisor.
.. note::
The class inheriting MsgDispatcherBase should be instantiated
after nnimanager (rest server) is started, so that the object
is ready to use right after its instantiation.
"""
def __init__(self, command_channel_url=None):
@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable):
if command_channel_url is None:
command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
self._channel = TunerCommandChannel(command_channel_url)
# NOTE: `connect()` should be put in __init__. First, this `connect()` affects nnimanager's
# starting process, without `connect()` nnimanager is blocked in `dispatcher.init()`.
# Second, nas experiment uses a thread to execute `run()` of this class, thus, there is
# no way to know when the websocket between nnimanager and dispatcher is built. The following
# logic may crash is websocket is not built. One example is updating search space. If updating
# search space too soon, as the websocket has not been built, the rest api of updating search
# space will timeout.
# FIXME: this is making unittest happy
if not command_channel_url.startswith('ws://_unittest_'):
self._channel.connect()
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable):
"""
_logger.info('Dispatcher started')
self._channel.connect()
self.default_worker.start()
self.assessor_worker.start()

Просмотреть файл

@ -4,7 +4,6 @@ import warnings
import torch
import torch.nn as torch_nn
from torchvision.models.utils import load_state_dict_from_url
import torch.nn.functional as F
import sys

Просмотреть файл

@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
from nni.retiarii import serialize
from base_mnasnet import MNASNet
from nni.experiment import RemoteMachineConfig
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig, CgoEngineConfig
from nni.retiarii.strategy import TPEStrategy
from torchvision import transforms
from torchvision.datasets import CIFAR10
@ -59,8 +59,6 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10
exp_config.trial_gpu_number = 1
exp_config.training_service.reuse_mode = True
exp_config.max_concurrency_cgo = 3
exp_config.batch_waiting_time = 0
rm_conf = RemoteMachineConfig()
rm_conf.host = '127.0.0.1'
@ -73,6 +71,6 @@ if __name__ == '__main__':
rm_conf.max_trial_number_per_gpu = 3
exp_config.training_service.machine_list = [rm_conf]
exp_config.execution_engine = 'cgo'
exp_config.execution_engine = CgoEngineConfig(max_concurrency_cgo = 3, batch_waiting_time = 0)
exp.run(exp_config, 8099)
exp.run(exp_config, 8099)

Просмотреть файл

@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from pathlib import Path
import nni
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
import nni.runtime.platform.test
from nni.runtime.tuner_command_channel import legacy as protocol
import json
@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 1)
@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 2)
@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase):
models = _load_mnist(2)
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1),
GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0)
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo_engine = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
set_execution_engine(cgo_engine)
submit_models(*models)
time.sleep(3)

Просмотреть файл

@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase):
def test_base_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase):
def test_py_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()

Просмотреть файл

@ -57,7 +57,7 @@ class AssessorTestCase(TestCase):
_restore_io()
assessor = NaiveAssessor()
dispatcher = MsgDispatcher('ws://_placeholder_', None, assessor)
dispatcher = MsgDispatcher('ws://_unittest_placeholder_', None, assessor)
dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False

Просмотреть файл

@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase):
_restore_io()
tuner = NaiveTuner()
dispatcher = MsgDispatcher('ws://_placeholder_', tuner)
dispatcher = MsgDispatcher('ws://_unittest_placeholder_', tuner)
dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False

Просмотреть файл

@ -303,8 +303,11 @@ class NNIManager implements Manager {
}
this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
// NOTE: this sending TERMINATE should be out of the if clause,
// because when python dispatcher is started before nnimanager
// this.dispatcherPid would not have a valid value (i.e., not >0).
this.dispatcher.sendCommand(TERMINATE);
if (this.dispatcherPid > 0) {
this.dispatcher.sendCommand(TERMINATE);
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
for (let i: number = 0; i < 30; i++) {
if (!await isAlive(this.dispatcherPid)) {