зеркало из https://github.com/microsoft/nni.git
[retiarii] refactor of nas experiment (#4841)
This commit is contained in:
Родитель
c80bda297e
Коммит
2fc4724771
|
@ -54,6 +54,11 @@ class ConfigBase:
|
|||
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
|
||||
If a config object is created with constructor, the base path will be current working directory.
|
||||
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
|
||||
|
||||
.. attention::
|
||||
|
||||
All the classes that inherit ``ConfigBase`` are not allowed to use ``from __future__ import annotations``,
|
||||
because ``ConfigBase`` uses ``typeguard`` to perform runtime check and it does not support lazy annotations.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
|
|
@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase):
|
|||
# currently I have only seen one issue of this kind
|
||||
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
utils.validate_gpu_indices(self.tuner_gpu_indices)
|
||||
if type(self).__name__ != 'RetiariiExeConfig':
|
||||
utils.validate_gpu_indices(self.tuner_gpu_indices)
|
||||
|
||||
if self.tuner is None:
|
||||
raise ValueError('ExperimentConfig: tuner must be set')
|
||||
if self.tuner is None:
|
||||
raise ValueError('ExperimentConfig: tuner must be set')
|
||||
|
||||
def _load_search_space_file(search_space_path):
|
||||
# FIXME
|
||||
|
|
|
@ -84,20 +84,9 @@ class Experiment:
|
|||
else:
|
||||
self.config = config_or_platform
|
||||
|
||||
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
|
||||
"""
|
||||
Start the experiment in background.
|
||||
|
||||
This method will raise exception on failure.
|
||||
If it returns, the experiment should have been successfully started.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
port
|
||||
The port of web UI.
|
||||
debug
|
||||
Whether to start in debug mode.
|
||||
"""
|
||||
def _start_impl(self, port: int, debug: bool, run_mode: RunMode,
|
||||
tuner_command_channel: str | None,
|
||||
tags: list[str] = []) -> ExperimentConfig:
|
||||
assert self.config is not None
|
||||
if run_mode is not RunMode.Detach:
|
||||
atexit.register(self.stop)
|
||||
|
@ -111,7 +100,8 @@ class Experiment:
|
|||
log_level = 'debug' if (debug or config.log_level == 'trace') else config.log_level
|
||||
start_experiment_logging(self.id, log_file, cast(str, log_level))
|
||||
|
||||
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix)
|
||||
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode,
|
||||
self.url_prefix, tuner_command_channel, tags)
|
||||
assert self._proc is not None
|
||||
|
||||
self.port = port # port will be None if start up failed
|
||||
|
@ -124,12 +114,27 @@ class Experiment:
|
|||
ips = [f'http://{ip}:{port}' for ip in ips if ip]
|
||||
msg = 'Web portal URLs: ${CYAN}' + ' '.join(ips)
|
||||
_logger.info(msg)
|
||||
return config
|
||||
|
||||
def stop(self) -> None:
|
||||
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
|
||||
"""
|
||||
Stop the experiment.
|
||||
Start the experiment in background.
|
||||
|
||||
This method will raise exception on failure.
|
||||
If it returns, the experiment should have been successfully started.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
port
|
||||
The port of web UI.
|
||||
debug
|
||||
Whether to start in debug mode.
|
||||
run_mode
|
||||
Running the experiment in foreground or background
|
||||
"""
|
||||
_logger.info('Stopping experiment, please wait...')
|
||||
self._start_impl(port, debug, run_mode, None, [])
|
||||
|
||||
def _stop_impl(self) -> None:
|
||||
atexit.unregister(self.stop)
|
||||
|
||||
stop_experiment_logging(self.id)
|
||||
|
@ -144,8 +149,24 @@ class Experiment:
|
|||
self.id = None # type: ignore
|
||||
self.port = None
|
||||
self._proc = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop the experiment.
|
||||
"""
|
||||
_logger.info('Stopping experiment, please wait...')
|
||||
self._stop_impl()
|
||||
_logger.info('Experiment stopped')
|
||||
|
||||
def _wait_completion(self) -> bool:
|
||||
while True:
|
||||
status = self.get_status()
|
||||
if status == 'DONE' or status == 'STOPPED':
|
||||
return True
|
||||
if status == 'ERROR':
|
||||
return False
|
||||
time.sleep(10)
|
||||
|
||||
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
|
||||
"""
|
||||
Run the experiment.
|
||||
|
@ -159,13 +180,7 @@ class Experiment:
|
|||
self.start(port, debug)
|
||||
if wait_completion:
|
||||
try:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
status = self.get_status()
|
||||
if status == 'DONE' or status == 'STOPPED':
|
||||
return True
|
||||
if status == 'ERROR':
|
||||
return False
|
||||
self._wait_completion()
|
||||
except KeyboardInterrupt:
|
||||
_logger.warning('KeyboardInterrupt detected')
|
||||
self.stop()
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import time
|
||||
import warnings
|
||||
from typing import Iterable
|
||||
|
||||
from ..graph import Model, ModelStatus
|
||||
|
@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
|
|||
|
||||
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
|
||||
global _execution_engine
|
||||
if _execution_engine is None:
|
||||
_execution_engine = engine
|
||||
else:
|
||||
raise RuntimeError('Execution engine is already set. '
|
||||
'You should avoid instantiating RetiariiExperiment twice in one process. '
|
||||
'If you are running in a Jupyter notebook, please restart the kernel.')
|
||||
if _execution_engine is not None:
|
||||
warnings.warn('Execution engine is already set. '
|
||||
'You should avoid instantiating RetiariiExperiment twice in one process. '
|
||||
'If you are running in a Jupyter notebook, please restart the kernel.',
|
||||
RuntimeWarning)
|
||||
_execution_engine = engine
|
||||
|
||||
|
||||
def get_execution_engine() -> AbstractExecutionEngine:
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from typing import Any, Dict, Iterable, List
|
||||
|
||||
from nni.experiment import rest
|
||||
|
||||
from .interface import AbstractExecutionEngine, AbstractGraphListener
|
||||
from .utils import get_mutation_summary
|
||||
from .. import codegen, utils
|
||||
|
@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine):
|
|||
Resource management is implemented in this class.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None:
|
||||
"""
|
||||
Upon initialization, advisor callbacks need to be registered.
|
||||
Advisor will call the callbacks when the corresponding event has been triggered.
|
||||
Base execution engine will get those callbacks and broadcast them to graph listener.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rest_port
|
||||
The port of the experiment's rest server
|
||||
rest_url_prefix
|
||||
The url prefix of the experiment's rest entry
|
||||
"""
|
||||
self.port = rest_port
|
||||
self.url_prefix = rest_url_prefix
|
||||
|
||||
self._listeners: List[AbstractGraphListener] = []
|
||||
|
||||
# register advisor callbacks
|
||||
|
@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
|
|||
return self.resources
|
||||
|
||||
def budget_exhausted(self) -> bool:
|
||||
advisor = get_advisor()
|
||||
return advisor.stopping
|
||||
resp = rest.get(self.port, '/check-status', self.url_prefix)
|
||||
return resp['status'] == 'DONE'
|
||||
|
||||
@classmethod
|
||||
def pack_model_data(cls, model: Model) -> Any:
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import threading
|
||||
from typing import Iterable, List, Dict, Tuple
|
||||
from typing import Iterable, List, Dict, Tuple, cast
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nni.common.device import GPUDevice, Device
|
||||
from nni.experiment.config.training_services import RemoteConfig
|
||||
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
|
||||
from .. import codegen, utils
|
||||
from ..graph import Model, ModelStatus, MetricData, Node
|
||||
|
@ -31,7 +34,6 @@ class TrialSubmission:
|
|||
placement: Dict[Node, Device]
|
||||
grouped_models: List[Model]
|
||||
|
||||
|
||||
class CGOExecutionEngine(AbstractExecutionEngine):
|
||||
"""
|
||||
The execution engine with Cross-Graph Optimization (CGO).
|
||||
|
@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
devices : List[Device]
|
||||
Available devices for execution.
|
||||
max_concurrency : int
|
||||
training_service
|
||||
The remote training service config.
|
||||
max_concurrency
|
||||
The maximum number of trials to run concurrently.
|
||||
batch_waiting_time: int
|
||||
batch_waiting_time
|
||||
Seconds to wait for each batch of trial submission.
|
||||
The trials within one batch could apply cross-graph optimization.
|
||||
rest_port
|
||||
The port of the experiment's rest server
|
||||
rest_url_prefix
|
||||
The url prefix of the experiment's rest entry
|
||||
"""
|
||||
|
||||
def __init__(self, devices: List[Device] = None,
|
||||
def __init__(self, training_service: RemoteConfig,
|
||||
max_concurrency: int = None,
|
||||
batch_waiting_time: int = 60,
|
||||
rest_port: int | None = None,
|
||||
rest_url_prefix: str | None = None
|
||||
) -> None:
|
||||
self.port = rest_port
|
||||
self.url_prefix = rest_url_prefix
|
||||
|
||||
self._listeners: List[AbstractGraphListener] = []
|
||||
self._running_models: Dict[int, Model] = dict()
|
||||
self.logical_plan_counter = 0
|
||||
self.available_devices: List[Device] = []
|
||||
self.max_concurrency: int = max_concurrency
|
||||
|
||||
devices = self._construct_devices(training_service)
|
||||
for device in devices:
|
||||
self.available_devices.append(device)
|
||||
self.all_devices = self.available_devices.copy()
|
||||
|
@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine):
|
|||
self._consumer_thread = threading.Thread(target=self._consume_models)
|
||||
self._consumer_thread.start()
|
||||
|
||||
def _construct_devices(self, training_service):
|
||||
devices = []
|
||||
if hasattr(training_service, 'machine_list'):
|
||||
for machine in cast(RemoteConfig, training_service).machine_list:
|
||||
assert machine.gpu_indices is not None, \
|
||||
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
|
||||
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
|
||||
for gpu_idx in machine.gpu_indices:
|
||||
devices.append(GPUDevice(machine.host, gpu_idx))
|
||||
return devices
|
||||
|
||||
def join(self):
|
||||
self._stopped = True
|
||||
self._consumer_thread.join()
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .experiment_config import *
|
||||
from .engine_config import *
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
|
||||
from nni.experiment.config.base import ConfigBase
|
||||
|
||||
__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig',
|
||||
'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig']
|
||||
|
||||
@dataclass(init=False)
|
||||
class ExecutionEngineConfig(ConfigBase):
|
||||
name: str
|
||||
|
||||
@dataclass(init=False)
|
||||
class PyEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'py'
|
||||
|
||||
@dataclass(init=False)
|
||||
class OneshotEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'oneshot'
|
||||
|
||||
@dataclass(init=False)
|
||||
class BaseEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'base'
|
||||
# input used in GraphConverterWithShape. Currently support shape tuple only.
|
||||
dummy_input: Optional[List[int]] = None
|
||||
|
||||
@dataclass(init=False)
|
||||
class CgoEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'cgo'
|
||||
max_concurrency_cgo: Optional[int] = None
|
||||
batch_waiting_time: Optional[int] = None
|
||||
# input used in GraphConverterWithShape. Currently support shape tuple only.
|
||||
dummy_input: Optional[List[int]] = None
|
||||
|
||||
@dataclass(init=False)
|
||||
class BenchmarkEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'benchmark'
|
||||
benchmark: Optional[str] = None
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
|
||||
from nni.experiment.config import utils, ExperimentConfig
|
||||
|
||||
from .engine_config import ExecutionEngineConfig
|
||||
|
||||
__all__ = ['RetiariiExeConfig']
|
||||
|
||||
def execution_engine_config_factory(engine_name):
|
||||
# FIXME: may move this function to experiment utils in future
|
||||
cls = _get_ee_config_class(engine_name)
|
||||
if cls is None:
|
||||
raise ValueError(f'Invalid execution engine name: {engine_name}')
|
||||
return cls()
|
||||
|
||||
def _get_ee_config_class(engine_name):
|
||||
for cls in ExecutionEngineConfig.__subclasses__():
|
||||
if cls.name == engine_name:
|
||||
return cls
|
||||
return None
|
||||
|
||||
@dataclass(init=False)
|
||||
class RetiariiExeConfig(ExperimentConfig):
|
||||
# FIXME: refactor this class to inherit from a new common base class with HPO config
|
||||
search_space: Any = ''
|
||||
trial_code_directory: utils.PathLike = '.'
|
||||
trial_command: str = '_reserved'
|
||||
# new config field for NAS
|
||||
execution_engine: Union[str, ExecutionEngineConfig]
|
||||
|
||||
def __init__(self, training_service_platform: Union[str, None] = None,
|
||||
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
|
||||
**kwargs):
|
||||
super().__init__(training_service_platform, **kwargs)
|
||||
self.execution_engine = execution_engine
|
||||
|
||||
def _canonicalize(self, _parents):
|
||||
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
|
||||
if self.search_space != '':
|
||||
raise ValueError(msg.format('search_space', self.search_space))
|
||||
# TODO: maybe we should also allow users to specify trial_code_directory
|
||||
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
|
||||
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
|
||||
if self.trial_command != '_reserved' and \
|
||||
not self.trial_command.startswith('python3 -m nni.retiarii.trial_entry '):
|
||||
raise ValueError(msg.format('trial_command', self.trial_command))
|
||||
|
||||
if isinstance(self.execution_engine, str):
|
||||
self.execution_engine = execution_engine_config_factory(self.execution_engine)
|
||||
if self.execution_engine.name in ('py', 'base', 'cgo'):
|
||||
# TODO: replace python3 with more elegant approach
|
||||
# maybe use sys.executable rendered in trial side (e.g., trial_runner)
|
||||
self.trial_command = 'python3 -m nni.retiarii.trial_entry ' + self.execution_engine.name
|
||||
|
||||
super()._canonicalize([self])
|
|
@ -1,32 +1,25 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import atexit
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from subprocess import Popen
|
||||
from threading import Thread
|
||||
from typing import Any, List, Optional, Union, cast
|
||||
from typing import Any, List, Union, cast
|
||||
|
||||
import colorama
|
||||
import psutil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import nni.runtime.log
|
||||
from nni.common.device import GPUDevice
|
||||
from nni.experiment import Experiment, RunMode, launcher, management, rest
|
||||
from nni.experiment.config import utils
|
||||
from nni.experiment.config.base import ConfigBase
|
||||
from nni.experiment.config.training_service import TrainingServiceConfig
|
||||
from nni.experiment import Experiment, RunMode
|
||||
from nni.experiment.config.training_services import RemoteConfig
|
||||
from nni.runtime.tuner_command_channel import TunerCommandChannel
|
||||
from nni.tools.nnictl.command_utils import kill_command
|
||||
|
||||
from .config import (
|
||||
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
|
||||
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
|
||||
)
|
||||
from ..codegen import model_to_pytorch_script
|
||||
from ..converter import convert_to_graph
|
||||
from ..converter.graph_gen import GraphConverterWithShape
|
||||
|
@ -46,79 +39,7 @@ from ..strategy.utils import dry_run_for_formatted_search_space
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment']
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class RetiariiExeConfig(ConfigBase):
|
||||
experiment_name: Optional[str] = None
|
||||
search_space: Any = '' # TODO: remove
|
||||
trial_command: str = '_reserved'
|
||||
trial_code_directory: utils.PathLike = '.'
|
||||
trial_concurrency: int
|
||||
trial_gpu_number: int = 0
|
||||
devices: Optional[List[Union[str, GPUDevice]]] = None
|
||||
max_experiment_duration: Optional[str] = None
|
||||
max_trial_number: Optional[int] = None
|
||||
max_concurrency_cgo: Optional[int] = None
|
||||
batch_waiting_time: Optional[int] = None
|
||||
nni_manager_ip: Optional[str] = None
|
||||
debug: bool = False
|
||||
log_level: str = 'info'
|
||||
experiment_working_directory: utils.PathLike = '~/nni-experiments'
|
||||
# remove configuration of tuner/assessor/advisor
|
||||
training_service: TrainingServiceConfig
|
||||
execution_engine: str = 'py'
|
||||
|
||||
# input used in GraphConverterWithShape. Currently support shape tuple only.
|
||||
dummy_input: Optional[List[int]] = None
|
||||
|
||||
# input used for benchmark engine.
|
||||
benchmark: Optional[str] = None
|
||||
|
||||
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if training_service_platform is not None:
|
||||
assert 'training_service' not in kwargs
|
||||
self.training_service = utils.training_service_config_factory(platform=training_service_platform)
|
||||
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry py'
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
fixed_attrs = {'search_space': '',
|
||||
'trial_command': '_reserved'}
|
||||
if key in fixed_attrs and fixed_attrs[key] != value:
|
||||
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
|
||||
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
|
||||
if key == 'trial_code_directory' and not (str(value) == '.' or os.path.isabs(value)):
|
||||
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
|
||||
if key == 'execution_engine':
|
||||
assert value in ['base', 'py', 'cgo', 'benchmark', 'oneshot'], f'The specified execution engine "{value}" is not supported.'
|
||||
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
|
||||
self.__dict__[key] = value
|
||||
|
||||
def validate(self, initialized_tuner: bool = False) -> None:
|
||||
super().validate()
|
||||
|
||||
@property
|
||||
def _canonical_rules(self):
|
||||
return _canonical_rules
|
||||
|
||||
@property
|
||||
def _validation_rules(self):
|
||||
return _validation_rules
|
||||
|
||||
|
||||
_canonical_rules = {
|
||||
}
|
||||
|
||||
_validation_rules = {
|
||||
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
|
||||
'trial_concurrency': lambda value: value > 0,
|
||||
'trial_gpu_number': lambda value: value >= 0,
|
||||
'max_trial_number': lambda value: value > 0,
|
||||
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
|
||||
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
|
||||
}
|
||||
__all__ = ['RetiariiExperiment']
|
||||
|
||||
|
||||
def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
|
||||
|
@ -252,9 +173,14 @@ class RetiariiExperiment(Experiment):
|
|||
... final_model = Net()
|
||||
"""
|
||||
|
||||
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
|
||||
applied_mutators: List[Mutator] = cast(List[Mutator], None), strategy: BaseStrategy = cast(BaseStrategy, None),
|
||||
def __init__(self, base_model: nn.Module,
|
||||
evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
|
||||
applied_mutators: List[Mutator] = cast(List[Mutator], None),
|
||||
strategy: BaseStrategy = cast(BaseStrategy, None),
|
||||
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
|
||||
super().__init__(None)
|
||||
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
|
||||
|
||||
if trainer is not None:
|
||||
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
|
||||
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
|
||||
|
@ -263,25 +189,13 @@ class RetiariiExperiment(Experiment):
|
|||
if evaluator is None:
|
||||
raise ValueError('Evaluator should not be none.')
|
||||
|
||||
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
|
||||
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
|
||||
self.port: Optional[int] = None
|
||||
|
||||
self.base_model = base_model
|
||||
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
|
||||
self.applied_mutators = applied_mutators
|
||||
self.strategy = strategy
|
||||
|
||||
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
|
||||
if not isinstance(strategy, OneShotStrategy):
|
||||
# FIXME: Dispatcher should not be created this early.
|
||||
self._dispatcher = RetiariiAdvisor('_placeholder_')
|
||||
else:
|
||||
self._dispatcher = cast(RetiariiAdvisor, None)
|
||||
self._dispatcher_thread: Optional[Thread] = None
|
||||
self._proc: Optional[Popen] = None
|
||||
|
||||
self.url_prefix = None
|
||||
self._dispatcher = None
|
||||
self._dispatcher_thread = None
|
||||
|
||||
# check for sanity
|
||||
if not is_model_wrapped(base_model):
|
||||
|
@ -290,11 +204,12 @@ class RetiariiExperiment(Experiment):
|
|||
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
|
||||
RuntimeWarning)
|
||||
|
||||
def _start_strategy(self):
|
||||
def _run_strategy(self, config: RetiariiExeConfig):
|
||||
base_model_ir, self.applied_mutators = preprocess_model(
|
||||
self.base_model, self.evaluator, self.applied_mutators,
|
||||
full_ir=self.config.execution_engine not in ['py', 'benchmark'],
|
||||
dummy_input=self.config.dummy_input
|
||||
full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
|
||||
dummy_input=config.execution_engine.dummy_input
|
||||
if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
|
||||
)
|
||||
|
||||
_logger.info('Start strategy...')
|
||||
|
@ -303,102 +218,49 @@ class RetiariiExperiment(Experiment):
|
|||
self.strategy.run(base_model_ir, self.applied_mutators)
|
||||
_logger.info('Strategy exit')
|
||||
# TODO: find out a proper way to show no more trial message on WebUI
|
||||
# self._dispatcher.mark_experiment_as_ending()
|
||||
|
||||
def start(self, port: int = 8080, debug: bool = False) -> None:
|
||||
"""
|
||||
Start the experiment in background.
|
||||
This method will raise exception on failure.
|
||||
If it returns, the experiment should have been successfully started.
|
||||
Parameters
|
||||
----------
|
||||
port
|
||||
The port of web UI.
|
||||
debug
|
||||
Whether to start in debug mode.
|
||||
"""
|
||||
atexit.register(self.stop)
|
||||
|
||||
self.config = self.config.canonical_copy()
|
||||
|
||||
# we will probably need a execution engine factory to make this clean and elegant
|
||||
if self.config.execution_engine == 'base':
|
||||
def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
|
||||
#TODO: we will probably need a execution engine factory to make this clean and elegant
|
||||
if isinstance(config.execution_engine, BaseEngineConfig):
|
||||
from ..execution.base import BaseExecutionEngine
|
||||
engine = BaseExecutionEngine()
|
||||
elif self.config.execution_engine == 'cgo':
|
||||
engine = BaseExecutionEngine(self.port, self.url_prefix)
|
||||
elif isinstance(config.execution_engine, CgoEngineConfig):
|
||||
from ..execution.cgo_engine import CGOExecutionEngine
|
||||
|
||||
assert self.config.training_service.platform == 'remote', \
|
||||
assert not isinstance(config.training_service, list) \
|
||||
and config.training_service.platform == 'remote', \
|
||||
"CGO execution engine currently only supports remote training service"
|
||||
assert self.config.batch_waiting_time is not None and self.config.max_concurrency_cgo is not None
|
||||
devices = self._construct_devices()
|
||||
engine = CGOExecutionEngine(devices,
|
||||
max_concurrency=self.config.max_concurrency_cgo,
|
||||
batch_waiting_time=self.config.batch_waiting_time)
|
||||
elif self.config.execution_engine == 'py':
|
||||
assert config.execution_engine.batch_waiting_time is not None \
|
||||
and config.execution_engine.max_concurrency_cgo is not None
|
||||
engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service),
|
||||
max_concurrency=config.execution_engine.max_concurrency_cgo,
|
||||
batch_waiting_time=config.execution_engine.batch_waiting_time,
|
||||
rest_port=self.port,
|
||||
rest_url_prefix=self.url_prefix)
|
||||
elif isinstance(config.execution_engine, PyEngineConfig):
|
||||
from ..execution.python import PurePythonExecutionEngine
|
||||
engine = PurePythonExecutionEngine()
|
||||
elif self.config.execution_engine == 'benchmark':
|
||||
engine = PurePythonExecutionEngine(self.port, self.url_prefix)
|
||||
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
|
||||
from ..execution.benchmark import BenchmarkExecutionEngine
|
||||
assert self.config.benchmark is not None, '"benchmark" must be set when benchmark execution engine is used.'
|
||||
engine = BenchmarkExecutionEngine(self.config.benchmark)
|
||||
assert config.execution_engine.benchmark is not None, \
|
||||
'"benchmark" must be set when benchmark execution engine is used.'
|
||||
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
|
||||
else:
|
||||
raise ValueError(f'Unsupported engine type: {self.config.execution_engine}')
|
||||
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
|
||||
set_execution_engine(engine)
|
||||
|
||||
self.id = management.generate_experiment_id()
|
||||
def start(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
By design, the only different between `start` and `run` is that `start` is asynchronous,
|
||||
while `run` waits the experiment to complete. RetiariiExperiment always waits the experiment
|
||||
to complete as strategy runs in foreground.
|
||||
"""
|
||||
raise NotImplementedError('RetiariiExperiment is not supposed to provide `start` method')
|
||||
|
||||
log_file = Path(self.config.experiment_working_directory, self.id, 'log', 'experiment.log')
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
log_level = 'debug' if (debug or self.config.log_level == 'trace') else self.config.log_level
|
||||
nni.runtime.log.start_experiment_logging(self.id, log_file, cast(str, log_level))
|
||||
|
||||
ws_url = f'ws://localhost:{port}/tuner'
|
||||
self._proc = launcher.start_experiment('create', self.id, self.config, port, debug, # type: ignore
|
||||
RunMode.Background, None, ws_url, ['retiarii'])
|
||||
assert self._proc is not None
|
||||
|
||||
self.port = port # port will be None if start up failed
|
||||
|
||||
# dispatcher must be launched after pipe initialized
|
||||
# the logic to launch dispatcher in background should be refactored into dispatcher api
|
||||
self._dispatcher = self._create_dispatcher()
|
||||
if self._dispatcher is not None:
|
||||
self._dispatcher._channel = TunerCommandChannel(ws_url)
|
||||
self._dispatcher_thread = Thread(target=self._dispatcher.run)
|
||||
self._dispatcher_thread.start()
|
||||
|
||||
ips = [self.config.nni_manager_ip]
|
||||
for interfaces in psutil.net_if_addrs().values():
|
||||
for interface in interfaces:
|
||||
if interface.family == socket.AF_INET:
|
||||
ips.append(interface.address)
|
||||
ips = [f'http://{ip}:{port}' for ip in ips if ip]
|
||||
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
|
||||
_logger.info(msg)
|
||||
|
||||
exp_status_checker = Thread(target=self._check_exp_status)
|
||||
exp_status_checker.start()
|
||||
self._start_strategy()
|
||||
# TODO: the experiment should be completed, when strategy exits and there is no running job
|
||||
_logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...')
|
||||
exp_status_checker.join()
|
||||
|
||||
def _construct_devices(self):
|
||||
devices = []
|
||||
if hasattr(self.config.training_service, 'machine_list'):
|
||||
for machine in cast(RemoteConfig, self.config.training_service).machine_list:
|
||||
assert machine.gpu_indices is not None, \
|
||||
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
|
||||
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
|
||||
for gpu_idx in machine.gpu_indices:
|
||||
devices.append(GPUDevice(machine.host, gpu_idx))
|
||||
return devices
|
||||
|
||||
def _create_dispatcher(self):
|
||||
return self._dispatcher
|
||||
|
||||
def run(self, config: Optional[RetiariiExeConfig] = None, port: int = 8080, debug: bool = False) -> None:
|
||||
def run(self,
|
||||
config: RetiariiExeConfig | None = None,
|
||||
port: int = 8080,
|
||||
debug: bool = False) -> None:
|
||||
"""
|
||||
Run the experiment.
|
||||
This function will block until experiment finish or error.
|
||||
|
@ -410,75 +272,47 @@ class RetiariiExperiment(Experiment):
|
|||
# 'In case you want to stick to the old implementation, '
|
||||
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
|
||||
self.evaluator.fit()
|
||||
return
|
||||
|
||||
if config is None:
|
||||
warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, '
|
||||
'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning)
|
||||
config = RetiariiExeConfig()
|
||||
config.execution_engine = 'oneshot'
|
||||
self.config = RetiariiExeConfig()
|
||||
self.config.execution_engine = OneshotEngineConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
if config.execution_engine == 'oneshot':
|
||||
if isinstance(self.config.execution_engine, OneshotEngineConfig) \
|
||||
or (isinstance(self.config.execution_engine, str) and self.config.execution_engine == 'oneshot'):
|
||||
# this is hacky, will be refactored when oneshot can run on training services
|
||||
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True)
|
||||
self.strategy.run(base_model_ir, self.applied_mutators)
|
||||
else:
|
||||
assert config is not None, 'You are using classic search mode, config cannot be None!'
|
||||
self.config = config
|
||||
self.start(port, debug)
|
||||
|
||||
def _check_exp_status(self) -> bool:
|
||||
"""
|
||||
Run the experiment.
|
||||
This function will block until experiment finish or error.
|
||||
Return `True` when experiment done; or return `False` when experiment failed.
|
||||
"""
|
||||
assert self._proc is not None
|
||||
try:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
# this if is to deal with the situation that
|
||||
# nnimanager is cleaned up by ctrl+c first
|
||||
if self._proc.poll() is None:
|
||||
status = self.get_status()
|
||||
else:
|
||||
return False
|
||||
if status == 'DONE' or status == 'STOPPED':
|
||||
return True
|
||||
if status == 'ERROR':
|
||||
return False
|
||||
except KeyboardInterrupt:
|
||||
_logger.warning('KeyboardInterrupt detected')
|
||||
finally:
|
||||
self.stop()
|
||||
raise RuntimeError('Check experiment status failed.')
|
||||
ws_url = f'ws://localhost:{port}/tuner'
|
||||
canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
|
||||
canonicalized_config = cast(RetiariiExeConfig, canonicalized_config)
|
||||
self._dispatcher = RetiariiAdvisor(ws_url)
|
||||
self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True)
|
||||
self._dispatcher_thread.start()
|
||||
# FIXME: engine cannot be created twice
|
||||
self._create_execution_engine(canonicalized_config)
|
||||
try:
|
||||
self._run_strategy(canonicalized_config)
|
||||
# FIXME: move this logic to strategy with a new API provided by execution engine
|
||||
self._wait_completion()
|
||||
except KeyboardInterrupt:
|
||||
_logger.warning('KeyboardInterrupt detected')
|
||||
self.stop()
|
||||
_logger.info('Search process is done, the experiment is still alive, `stop()` can terminate the experiment.')
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop background experiment.
|
||||
"""
|
||||
_logger.info('Stopping experiment, please wait...')
|
||||
atexit.unregister(self.stop)
|
||||
|
||||
# stop strategy first
|
||||
if self._dispatcher_thread is not None:
|
||||
self._dispatcher.stopping = True
|
||||
self._dispatcher_thread.join(timeout=1)
|
||||
|
||||
if self.id is not None:
|
||||
nni.runtime.log.stop_experiment_logging(self.id)
|
||||
if self._proc is not None:
|
||||
try:
|
||||
# this if is to deal with the situation that
|
||||
# nnimanager is cleaned up by ctrl+c first
|
||||
if self._proc.poll() is None:
|
||||
rest.delete(self.port, '/experiment')
|
||||
except Exception as e:
|
||||
_logger.exception(e)
|
||||
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
|
||||
kill_command(self._proc.pid)
|
||||
|
||||
self.id = cast(str, None)
|
||||
self.port = cast(int, None)
|
||||
self._proc = None
|
||||
self._stop_impl()
|
||||
if self._dispatcher_thread:
|
||||
self._dispatcher_thread.join()
|
||||
self._dispatcher = cast(RetiariiAdvisor, None)
|
||||
self._dispatcher_thread = None
|
||||
_logger.info('Experiment stopped')
|
||||
|
@ -502,8 +336,11 @@ class RetiariiExperiment(Experiment):
|
|||
If ``code``, the python code of model will be returned.
|
||||
If ``dict``, the mutation history will be returned.
|
||||
"""
|
||||
# TODO: the base class may also need this method
|
||||
if formatter == 'code':
|
||||
assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.'
|
||||
config = self.config.canonical_copy()
|
||||
assert not isinstance(config.execution_engine, PyEngineConfig), \
|
||||
'You should use `dict` formatter when using Python execution engine.'
|
||||
if isinstance(self.evaluator, BaseOneShotTrainer):
|
||||
assert top_k == 1, 'Only support top_k is 1 for now.'
|
||||
return self.evaluator.export()
|
||||
|
@ -520,9 +357,3 @@ class RetiariiExperiment(Experiment):
|
|||
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
|
||||
elif formatter == 'dict':
|
||||
return [get_mutation_dict(model) for model in all_models[:top_k]]
|
||||
|
||||
def retrain_model(self, model):
|
||||
"""
|
||||
this function retrains the exported model, and test it to output test accuracy
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor':
|
|||
|
||||
def register_advisor(advisor: 'RetiariiAdvisor'):
|
||||
global _advisor
|
||||
assert _advisor is None
|
||||
if _advisor is not None:
|
||||
warnings.warn('Advisor is already set.'
|
||||
'You should avoid instantiating RetiariiExperiment twice in one proces.'
|
||||
'If you are running in a Jupyter notebook, please restart the kernel.')
|
||||
_advisor = advisor
|
||||
|
||||
|
||||
|
|
|
@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True
|
|||
|
||||
|
||||
class MsgDispatcherBase(Recoverable):
|
||||
"""This is where tuners and assessors are not defined yet.
|
||||
"""
|
||||
This is where tuners and assessors are not defined yet.
|
||||
Inherits this class to make your own advisor.
|
||||
|
||||
.. note::
|
||||
|
||||
The class inheriting MsgDispatcherBase should be instantiated
|
||||
after nnimanager (rest server) is started, so that the object
|
||||
is ready to use right after its instantiation.
|
||||
"""
|
||||
|
||||
def __init__(self, command_channel_url=None):
|
||||
|
@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable):
|
|||
if command_channel_url is None:
|
||||
command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
|
||||
self._channel = TunerCommandChannel(command_channel_url)
|
||||
# NOTE: `connect()` should be put in __init__. First, this `connect()` affects nnimanager's
|
||||
# starting process, without `connect()` nnimanager is blocked in `dispatcher.init()`.
|
||||
# Second, nas experiment uses a thread to execute `run()` of this class, thus, there is
|
||||
# no way to know when the websocket between nnimanager and dispatcher is built. The following
|
||||
# logic may crash is websocket is not built. One example is updating search space. If updating
|
||||
# search space too soon, as the websocket has not been built, the rest api of updating search
|
||||
# space will timeout.
|
||||
# FIXME: this is making unittest happy
|
||||
if not command_channel_url.startswith('ws://_unittest_'):
|
||||
self._channel.connect()
|
||||
self.default_command_queue = Queue()
|
||||
self.assessor_command_queue = Queue()
|
||||
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
|
||||
|
@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable):
|
|||
"""
|
||||
_logger.info('Dispatcher started')
|
||||
|
||||
self._channel.connect()
|
||||
self.default_worker.start()
|
||||
self.assessor_worker.start()
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import warnings
|
|||
|
||||
import torch
|
||||
import torch.nn as torch_nn
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
import torch.nn.functional as F
|
||||
|
||||
import sys
|
||||
|
|
|
@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
|
|||
from nni.retiarii import serialize
|
||||
from base_mnasnet import MNASNet
|
||||
from nni.experiment import RemoteMachineConfig
|
||||
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
|
||||
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig, CgoEngineConfig
|
||||
from nni.retiarii.strategy import TPEStrategy
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
@ -59,8 +59,6 @@ if __name__ == '__main__':
|
|||
exp_config.max_trial_number = 10
|
||||
exp_config.trial_gpu_number = 1
|
||||
exp_config.training_service.reuse_mode = True
|
||||
exp_config.max_concurrency_cgo = 3
|
||||
exp_config.batch_waiting_time = 0
|
||||
|
||||
rm_conf = RemoteMachineConfig()
|
||||
rm_conf.host = '127.0.0.1'
|
||||
|
@ -73,6 +71,6 @@ if __name__ == '__main__':
|
|||
rm_conf.max_trial_number_per_gpu = 3
|
||||
|
||||
exp_config.training_service.machine_list = [rm_conf]
|
||||
exp_config.execution_engine = 'cgo'
|
||||
exp_config.execution_engine = CgoEngineConfig(max_concurrency_cgo = 3, batch_waiting_time = 0)
|
||||
|
||||
exp.run(exp_config, 8099)
|
||||
exp.run(exp_config, 8099)
|
||||
|
|
|
@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
|
|||
from pathlib import Path
|
||||
|
||||
import nni
|
||||
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
|
||||
import nni.runtime.platform.test
|
||||
from nni.runtime.tuner_command_channel import legacy as protocol
|
||||
import json
|
||||
|
@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase):
|
|||
opt = DedupInputOptimizer()
|
||||
opt.convert(lp)
|
||||
|
||||
advisor = RetiariiAdvisor('ws://_placeholder_')
|
||||
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
|
||||
advisor._channel = protocol.LegacyCommandChannel()
|
||||
advisor.default_worker.start()
|
||||
advisor.assessor_worker.start()
|
||||
|
||||
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)]
|
||||
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
|
||||
remote = RemoteConfig(machine_list=[])
|
||||
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
|
||||
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
|
||||
|
||||
phy_models = cgo._assemble(lp)
|
||||
self.assertTrue(len(phy_models) == 1)
|
||||
|
@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase):
|
|||
opt = DedupInputOptimizer()
|
||||
opt.convert(lp)
|
||||
|
||||
advisor = RetiariiAdvisor('ws://_placeholder_')
|
||||
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
|
||||
advisor._channel = protocol.LegacyCommandChannel()
|
||||
advisor.default_worker.start()
|
||||
advisor.assessor_worker.start()
|
||||
|
||||
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)]
|
||||
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
|
||||
remote = RemoteConfig(machine_list=[])
|
||||
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1]))
|
||||
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
|
||||
|
||||
phy_models = cgo._assemble(lp)
|
||||
self.assertTrue(len(phy_models) == 2)
|
||||
|
@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase):
|
|||
|
||||
models = _load_mnist(2)
|
||||
|
||||
advisor = RetiariiAdvisor('ws://_placeholder_')
|
||||
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
|
||||
advisor._channel = protocol.LegacyCommandChannel()
|
||||
advisor.default_worker.start()
|
||||
advisor.assessor_worker.start()
|
||||
|
||||
cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1),
|
||||
GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0)
|
||||
remote = RemoteConfig(machine_list=[])
|
||||
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
|
||||
cgo_engine = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
|
||||
set_execution_engine(cgo_engine)
|
||||
submit_models(*models)
|
||||
time.sleep(3)
|
||||
|
|
|
@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase):
|
|||
def test_base_execution_engine(self):
|
||||
nni.retiarii.integration_api._advisor = None
|
||||
nni.retiarii.execution.api._execution_engine = None
|
||||
advisor = RetiariiAdvisor('ws://_placeholder_')
|
||||
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
|
||||
advisor._channel = LegacyCommandChannel()
|
||||
advisor.default_worker.start()
|
||||
advisor.assessor_worker.start()
|
||||
|
@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase):
|
|||
def test_py_execution_engine(self):
|
||||
nni.retiarii.integration_api._advisor = None
|
||||
nni.retiarii.execution.api._execution_engine = None
|
||||
advisor = RetiariiAdvisor('ws://_placeholder_')
|
||||
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
|
||||
advisor._channel = LegacyCommandChannel()
|
||||
advisor.default_worker.start()
|
||||
advisor.assessor_worker.start()
|
||||
|
|
|
@ -57,7 +57,7 @@ class AssessorTestCase(TestCase):
|
|||
_restore_io()
|
||||
|
||||
assessor = NaiveAssessor()
|
||||
dispatcher = MsgDispatcher('ws://_placeholder_', None, assessor)
|
||||
dispatcher = MsgDispatcher('ws://_unittest_placeholder_', None, assessor)
|
||||
dispatcher._channel = LegacyCommandChannel()
|
||||
msg_dispatcher_base._worker_fast_exit_on_terminate = False
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase):
|
|||
_restore_io()
|
||||
|
||||
tuner = NaiveTuner()
|
||||
dispatcher = MsgDispatcher('ws://_placeholder_', tuner)
|
||||
dispatcher = MsgDispatcher('ws://_unittest_placeholder_', tuner)
|
||||
dispatcher._channel = LegacyCommandChannel()
|
||||
msg_dispatcher_base._worker_fast_exit_on_terminate = False
|
||||
|
||||
|
|
|
@ -303,8 +303,11 @@ class NNIManager implements Manager {
|
|||
}
|
||||
|
||||
this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
|
||||
// NOTE: this sending TERMINATE should be out of the if clause,
|
||||
// because when python dispatcher is started before nnimanager
|
||||
// this.dispatcherPid would not have a valid value (i.e., not >0).
|
||||
this.dispatcher.sendCommand(TERMINATE);
|
||||
if (this.dispatcherPid > 0) {
|
||||
this.dispatcher.sendCommand(TERMINATE);
|
||||
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
|
||||
for (let i: number = 0; i < 30; i++) {
|
||||
if (!await isAlive(this.dispatcherPid)) {
|
||||
|
|
Загрузка…
Ссылка в новой задаче