зеркало из https://github.com/microsoft/nni.git
Tuner Command Channel (#5364)
This commit is contained in:
Родитель
4f92f30a0f
Коммит
0dce45a845
|
@ -5,5 +5,5 @@
|
|||
Low level APIs for algorithms to communicate with NNI manager.
|
||||
"""
|
||||
|
||||
from .command_type import CommandType
|
||||
from .command_type import CommandType, TunerIncomingCommand
|
||||
from .channel import TunerCommandChannel
|
||||
|
|
|
@ -10,13 +10,26 @@ from __future__ import annotations
|
|||
__all__ = ['TunerCommandChannel']
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Event
|
||||
from typing import Any, Callable
|
||||
|
||||
from .command_type import CommandType
|
||||
from nni.common.serializer import dump, load, PayloadTooLarge
|
||||
from nni.common.version import version_dump
|
||||
from nni.typehint import Parameters
|
||||
|
||||
from .command_type import (
|
||||
CommandType, TunerIncomingCommand,
|
||||
Initialize, RequestTrialJobs, UpdateSearchSpace, ReportMetricData, TrialEnd, Terminate
|
||||
)
|
||||
from .websocket import WebSocket
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
TunerCommandCallback = Callable[[TunerIncomingCommand], None]
|
||||
|
||||
class TunerCommandChannel:
|
||||
"""
|
||||
A channel to communicate with NNI manager.
|
||||
|
@ -44,17 +57,213 @@ class TunerCommandChannel:
|
|||
self._channel = WebSocket(url)
|
||||
self._retry_intervals = [0, 1, 10]
|
||||
|
||||
self._callbacks: dict[CommandType, list[Callable[..., None]]] = defaultdict(list)
|
||||
|
||||
def connect(self) -> None:
|
||||
self._channel.connect()
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self._channel.disconnect()
|
||||
|
||||
# TODO: Define semantic command class like `KillTrialJob(trial_id='abc')`.
|
||||
# def send(self, command: Command) -> None:
|
||||
# ...
|
||||
# def receive(self) -> Command | None:
|
||||
# ...
|
||||
def listen(self, stop_event: Event) -> None:
|
||||
"""Listen for incoming commands.
|
||||
|
||||
Call :meth:`receive` in a loop and call ``callback`` for each command,
|
||||
until ``stop_event`` is set, or a Terminate command is received.
|
||||
All commands will go into callback, including Terminate command.
|
||||
|
||||
It usually runs in a separate thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback
|
||||
A callback function that takes a :class:`TunerIncomingCommand` as argument.
|
||||
It's not expected to return anything.
|
||||
stop_event
|
||||
A threading event that can be used to stop the loop.
|
||||
"""
|
||||
while not stop_event.is_set():
|
||||
received = self.receive()
|
||||
for callback in self._callbacks[received.command_type]:
|
||||
callback(received)
|
||||
|
||||
# Two ways to stop the loop:
|
||||
# 1. The received command is a Terminate command, which is triggered by a NNI manager stop.
|
||||
# 2. The stop_event is set from another thread (possibly main thread), which could be an engine shutdown.
|
||||
if received.command_type == CommandType.Terminate:
|
||||
_logger.debug('Received command type is terminate. Stop listening.')
|
||||
stop_event.set()
|
||||
|
||||
# NOTE: The semantic commands are only partial for the convenience of NAS implementation.
|
||||
# Send commands are broken into different functions and signatures.
|
||||
# Ideally it should be similar for receive commands, but we can't predict which command will appear in receive.
|
||||
|
||||
def send_initialized(self) -> None:
|
||||
"""Send an initialized command to NNI manager."""
|
||||
self._send(CommandType.Initialized, '')
|
||||
|
||||
def send_trial(
|
||||
self,
|
||||
parameter_id: int,
|
||||
parameters: Parameters,
|
||||
parameter_source: str = 'algorithm',
|
||||
parameter_index: int = 0,
|
||||
placement_constraint: dict[str, Any] | None = None, # TODO: Define PlacementConstraint class.
|
||||
):
|
||||
"""
|
||||
Send a new trial job to NNI manager.
|
||||
|
||||
Without multi-phase in mind, one parameter = one trial.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parameter_id
|
||||
The ID of the current parameter.
|
||||
It's used by whoever calls the :meth:`send_trial` function to identify the parameters.
|
||||
In most cases, they are non-negative integers starting from 0.
|
||||
parameters
|
||||
The parameters.
|
||||
parameter_source
|
||||
The source of the parameters. ``algorithm`` means the parameters are generated by the algorithm.
|
||||
It should be left as default in most cases.
|
||||
parameter_index
|
||||
The index of the parameters. This is previously used in multi-phase, but now it's only kept for compatibility reasons.
|
||||
placement_constraint
|
||||
The placement constraint of the created trial job.
|
||||
"""
|
||||
trial_dict = {
|
||||
'parameter_id': parameter_id,
|
||||
'parameters': parameters,
|
||||
'parameter_source': parameter_source,
|
||||
'parameter_index': parameter_index,
|
||||
'version_info': version_dump()
|
||||
}
|
||||
if placement_constraint is not None:
|
||||
_validate_placement_constraint(placement_constraint)
|
||||
trial_dict['placement_constraint'] = placement_constraint
|
||||
|
||||
try:
|
||||
send_payload = dump(trial_dict, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
|
||||
except PayloadTooLarge:
|
||||
raise ValueError(
|
||||
'Serialization failed when trying to dump the model because payload too large (larger than 64 KB). '
|
||||
'This is usually caused by pickling large objects (like datasets) by mistake. '
|
||||
'See the full error traceback for details and https://nni.readthedocs.io/en/stable/NAS/Serialization.html '
|
||||
'for how to resolve such issue. '
|
||||
)
|
||||
|
||||
self._send(CommandType.NewTrialJob, send_payload)
|
||||
|
||||
def send_no_more_trial_jobs(self) -> None:
|
||||
"""Tell NNI manager that there are no more trial jobs to send for now."""
|
||||
self._send(CommandType.NoMoreTrialJobs, '')
|
||||
|
||||
def receive(self) -> TunerIncomingCommand:
|
||||
"""Receives a command from NNI manager."""
|
||||
command_type, data = self._receive()
|
||||
if data:
|
||||
data = load(data)
|
||||
|
||||
# NOTE: Only handles the commands that are used by NAS.
|
||||
# It uses somewhat hacky way to convert the data received from NNI manager
|
||||
# to a semantic command.
|
||||
if command_type is None:
|
||||
# This shouldn't happen. Only for robustness.
|
||||
_logger.warning('Received command is empty. Terminating...')
|
||||
return Terminate()
|
||||
elif command_type == CommandType.Terminate:
|
||||
return Terminate()
|
||||
elif command_type == CommandType.Initialize:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f'Initialize command data must be a dict, but got {type(data)}')
|
||||
return Initialize(data)
|
||||
elif command_type == CommandType.RequestTrialJobs:
|
||||
if not isinstance(data, int):
|
||||
raise TypeError(f'RequestTrialJobs command data must be an integer, but got {type(data)}')
|
||||
return RequestTrialJobs(data)
|
||||
elif command_type == CommandType.UpdateSearchSpace:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f'UpdateSearchSpace command data must be a dict, but got {type(data)}')
|
||||
return UpdateSearchSpace(data)
|
||||
elif command_type == CommandType.ReportMetricData:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f'ReportMetricData command data must be a dict, but got {type(data)}')
|
||||
if 'value' in data:
|
||||
data['value'] = load(data['value'])
|
||||
return ReportMetricData(**data)
|
||||
elif command_type == CommandType.TrialEnd:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f'TrialEnd command data must be a dict, but got {type(data)}')
|
||||
# For some reason, only one parameter (I guess the first one) shows up in the data.
|
||||
# But a trial technically is associated with multiple parameters.
|
||||
parameter_id = load(data['hyper_params'])['parameter_id']
|
||||
return TrialEnd(
|
||||
trial_job_id=data['trial_job_id'],
|
||||
parameter_ids=[parameter_id],
|
||||
event=data['event']
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown command type: {command_type}')
|
||||
|
||||
def on_terminate(self, callback: Callable[[Terminate], None]) -> None:
|
||||
"""Register a callback for Terminate command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback
|
||||
A callback function that takes a :class:`Terminate` as argument.
|
||||
"""
|
||||
self._callbacks[Terminate.command_type].append(callback)
|
||||
|
||||
def on_initialize(self, callback: Callable[[Initialize], None]) -> None:
|
||||
"""Register a callback for Initialize command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback
|
||||
A callback function that takes a :class:`Initialize` as argument.
|
||||
"""
|
||||
self._callbacks[Initialize.command_type].append(callback)
|
||||
|
||||
def on_request_trial_jobs(self, callback: Callable[[RequestTrialJobs], None]) -> None:
|
||||
"""Register a callback for RequestTrialJobs command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback
|
||||
A callback function that takes a :class:`RequestTrialJobs` as argument.
|
||||
"""
|
||||
self._callbacks[RequestTrialJobs.command_type].append(callback)
|
||||
|
||||
def on_update_search_space(self, callback: Callable[[UpdateSearchSpace], None]) -> None:
|
||||
"""Register a callback for UpdateSearchSpace command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback
|
||||
A callback function that takes a :class:`UpdateSearchSpace` as argument.
|
||||
"""
|
||||
self._callbacks[UpdateSearchSpace.command_type].append(callback)
|
||||
|
||||
def on_report_metric_data(self, callback: Callable[[ReportMetricData], None]) -> None:
|
||||
"""Register a callback for ReportMetricData command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback
|
||||
A callback function that takes a :class:`ReportMetricData` as argument.
|
||||
"""
|
||||
self._callbacks[ReportMetricData.command_type].append(callback)
|
||||
|
||||
def on_trial_end(self, callback: Callable[[TrialEnd], None]) -> None:
|
||||
"""Register a callback for TrialEnd command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback
|
||||
A callback function that takes a :class:`TrialEnd` as argument.
|
||||
"""
|
||||
self._callbacks[TrialEnd.command_type].append(callback)
|
||||
|
||||
def _send(self, command_type: CommandType, data: str) -> None:
|
||||
command = command_type.value.decode() + data
|
||||
|
@ -112,3 +321,29 @@ class TunerCommandChannel:
|
|||
return command
|
||||
_logger.error('Failed to reconnect.')
|
||||
raise RuntimeError('Connection lost')
|
||||
|
||||
|
||||
def _validate_placement_constraint(placement_constraint):
|
||||
# Currently only for CGO.
|
||||
if placement_constraint is None:
|
||||
raise ValueError('placement_constraint is None')
|
||||
if not 'type' in placement_constraint:
|
||||
raise ValueError('placement_constraint must have `type`')
|
||||
if not 'gpus' in placement_constraint:
|
||||
raise ValueError('placement_constraint must have `gpus`')
|
||||
if placement_constraint['type'] not in ['None', 'GPUNumber', 'Device']:
|
||||
raise ValueError('placement_constraint.type must be either `None`,. `GPUNumber` or `Device`')
|
||||
if placement_constraint['type'] == 'None' and len(placement_constraint['gpus']) > 0:
|
||||
raise ValueError('placement_constraint.gpus must be an empty list when type == None')
|
||||
if placement_constraint['type'] == 'GPUNumber':
|
||||
if len(placement_constraint['gpus']) != 1:
|
||||
raise ValueError('placement_constraint.gpus currently only support one host when type == GPUNumber')
|
||||
for e in placement_constraint['gpus']:
|
||||
if not isinstance(e, int):
|
||||
raise ValueError('placement_constraint.gpus must be a list of number when type == GPUNumber')
|
||||
if placement_constraint['type'] == 'Device':
|
||||
for e in placement_constraint['gpus']:
|
||||
if not isinstance(e, tuple):
|
||||
raise ValueError('placement_constraint.gpus must be a list of tuple when type == Device')
|
||||
if not (len(e) == 2 and isinstance(e[0], str) and isinstance(e[1], int)):
|
||||
raise ValueError('placement_constraint.gpus`s tuple must be (str, int)')
|
||||
|
|
|
@ -2,6 +2,11 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, ClassVar
|
||||
|
||||
from nni.typehint import TrialMetric, SearchSpace
|
||||
from nni.utils import MetricType
|
||||
|
||||
class CommandType(Enum):
|
||||
# in
|
||||
|
@ -22,3 +27,45 @@ class CommandType(Enum):
|
|||
NoMoreTrialJobs = b'NO'
|
||||
KillTrialJob = b'KI'
|
||||
Error = b'ER'
|
||||
|
||||
class TunerIncomingCommand:
|
||||
# For type checking.
|
||||
command_type: ClassVar[CommandType]
|
||||
|
||||
# Only necessary commands to make NAS work.
|
||||
|
||||
@dataclass
|
||||
class Initialize(TunerIncomingCommand):
|
||||
command_type: ClassVar[CommandType] = CommandType.Initialize
|
||||
search_space: SearchSpace
|
||||
|
||||
@dataclass
|
||||
class RequestTrialJobs(TunerIncomingCommand):
|
||||
command_type: ClassVar[CommandType] = CommandType.RequestTrialJobs
|
||||
count: int
|
||||
|
||||
@dataclass
|
||||
class UpdateSearchSpace(TunerIncomingCommand):
|
||||
command_type: ClassVar[CommandType] = CommandType.UpdateSearchSpace
|
||||
search_space: SearchSpace
|
||||
|
||||
@dataclass
|
||||
class ReportMetricData(TunerIncomingCommand):
|
||||
command_type: ClassVar[CommandType] = CommandType.ReportMetricData
|
||||
parameter_id: int # Parameter ID.
|
||||
type: MetricType # Request parameter, periodical, or final.
|
||||
sequence: int # Sequence number of the metric.
|
||||
value: Optional[TrialMetric] = None # The metric value. When type is NOT request parameter.
|
||||
trial_job_id: Optional[str] = None # Only available when type is request parameter.
|
||||
parameter_index: Optional[int] = None # Only available when type is request parameter.
|
||||
|
||||
@dataclass
|
||||
class TrialEnd(TunerIncomingCommand):
|
||||
command_type: ClassVar[CommandType] = CommandType.TrialEnd
|
||||
trial_job_id: str # The trial job id.
|
||||
parameter_ids: List[int] # All parameter ids of the trial job.
|
||||
event: str # The job's state
|
||||
|
||||
@dataclass
|
||||
class Terminate(TunerIncomingCommand):
|
||||
command_type: ClassVar[CommandType] = CommandType.Terminate
|
||||
|
|
Загрузка…
Ссылка в новой задаче