This commit is contained in:
Yuge Zhang 2023-02-27 10:35:37 +08:00 коммит произвёл GitHub
Родитель 4f92f30a0f
Коммит 0dce45a845
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 289 добавлений и 7 удалений

Просмотреть файл

@ -5,5 +5,5 @@
Low level APIs for algorithms to communicate with NNI manager.
"""
from .command_type import CommandType
from .command_type import CommandType, TunerIncomingCommand
from .channel import TunerCommandChannel

Просмотреть файл

@ -10,13 +10,26 @@ from __future__ import annotations
__all__ = ['TunerCommandChannel']
import logging
import os
import time
from collections import defaultdict
from threading import Event
from typing import Any, Callable
from .command_type import CommandType
from nni.common.serializer import dump, load, PayloadTooLarge
from nni.common.version import version_dump
from nni.typehint import Parameters
from .command_type import (
CommandType, TunerIncomingCommand,
Initialize, RequestTrialJobs, UpdateSearchSpace, ReportMetricData, TrialEnd, Terminate
)
from .websocket import WebSocket
_logger = logging.getLogger(__name__)
TunerCommandCallback = Callable[[TunerIncomingCommand], None]
class TunerCommandChannel:
"""
A channel to communicate with NNI manager.
@ -44,17 +57,213 @@ class TunerCommandChannel:
self._channel = WebSocket(url)
self._retry_intervals = [0, 1, 10]
self._callbacks: dict[CommandType, list[Callable[..., None]]] = defaultdict(list)
def connect(self) -> None:
self._channel.connect()
def disconnect(self) -> None:
self._channel.disconnect()
# TODO: Define semantic command class like `KillTrialJob(trial_id='abc')`.
# def send(self, command: Command) -> None:
# ...
# def receive(self) -> Command | None:
# ...
def listen(self, stop_event: Event) -> None:
"""Listen for incoming commands.
Call :meth:`receive` in a loop and call ``callback`` for each command,
until ``stop_event`` is set, or a Terminate command is received.
All commands will go into callback, including Terminate command.
It usually runs in a separate thread.
Parameters
----------
callback
A callback function that takes a :class:`TunerIncomingCommand` as argument.
It's not expected to return anything.
stop_event
A threading event that can be used to stop the loop.
"""
while not stop_event.is_set():
received = self.receive()
for callback in self._callbacks[received.command_type]:
callback(received)
# Two ways to stop the loop:
# 1. The received command is a Terminate command, which is triggered by a NNI manager stop.
# 2. The stop_event is set from another thread (possibly main thread), which could be an engine shutdown.
if received.command_type == CommandType.Terminate:
_logger.debug('Received command type is terminate. Stop listening.')
stop_event.set()
# NOTE: The semantic commands are only partial for the convenience of NAS implementation.
# Send commands are broken into different functions and signatures.
# Ideally it should be similar for receive commands, but we can't predict which command will appear in receive.
def send_initialized(self) -> None:
"""Send an initialized command to NNI manager."""
self._send(CommandType.Initialized, '')
def send_trial(
self,
parameter_id: int,
parameters: Parameters,
parameter_source: str = 'algorithm',
parameter_index: int = 0,
placement_constraint: dict[str, Any] | None = None, # TODO: Define PlacementConstraint class.
):
"""
Send a new trial job to NNI manager.
Without multi-phase in mind, one parameter = one trial.
Parameters
----------
parameter_id
The ID of the current parameter.
It's used by whoever calls the :meth:`send_trial` function to identify the parameters.
In most cases, they are non-negative integers starting from 0.
parameters
The parameters.
parameter_source
The source of the parameters. ``algorithm`` means the parameters are generated by the algorithm.
It should be left as default in most cases.
parameter_index
The index of the parameters. This is previously used in multi-phase, but now it's only kept for compatibility reasons.
placement_constraint
The placement constraint of the created trial job.
"""
trial_dict = {
'parameter_id': parameter_id,
'parameters': parameters,
'parameter_source': parameter_source,
'parameter_index': parameter_index,
'version_info': version_dump()
}
if placement_constraint is not None:
_validate_placement_constraint(placement_constraint)
trial_dict['placement_constraint'] = placement_constraint
try:
send_payload = dump(trial_dict, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
except PayloadTooLarge:
raise ValueError(
'Serialization failed when trying to dump the model because payload too large (larger than 64 KB). '
'This is usually caused by pickling large objects (like datasets) by mistake. '
'See the full error traceback for details and https://nni.readthedocs.io/en/stable/NAS/Serialization.html '
'for how to resolve such issue. '
)
self._send(CommandType.NewTrialJob, send_payload)
def send_no_more_trial_jobs(self) -> None:
"""Tell NNI manager that there are no more trial jobs to send for now."""
self._send(CommandType.NoMoreTrialJobs, '')
def receive(self) -> TunerIncomingCommand:
"""Receives a command from NNI manager."""
command_type, data = self._receive()
if data:
data = load(data)
# NOTE: Only handles the commands that are used by NAS.
# It uses somewhat hacky way to convert the data received from NNI manager
# to a semantic command.
if command_type is None:
# This shouldn't happen. Only for robustness.
_logger.warning('Received command is empty. Terminating...')
return Terminate()
elif command_type == CommandType.Terminate:
return Terminate()
elif command_type == CommandType.Initialize:
if not isinstance(data, dict):
raise TypeError(f'Initialize command data must be a dict, but got {type(data)}')
return Initialize(data)
elif command_type == CommandType.RequestTrialJobs:
if not isinstance(data, int):
raise TypeError(f'RequestTrialJobs command data must be an integer, but got {type(data)}')
return RequestTrialJobs(data)
elif command_type == CommandType.UpdateSearchSpace:
if not isinstance(data, dict):
raise TypeError(f'UpdateSearchSpace command data must be a dict, but got {type(data)}')
return UpdateSearchSpace(data)
elif command_type == CommandType.ReportMetricData:
if not isinstance(data, dict):
raise TypeError(f'ReportMetricData command data must be a dict, but got {type(data)}')
if 'value' in data:
data['value'] = load(data['value'])
return ReportMetricData(**data)
elif command_type == CommandType.TrialEnd:
if not isinstance(data, dict):
raise TypeError(f'TrialEnd command data must be a dict, but got {type(data)}')
# For some reason, only one parameter (I guess the first one) shows up in the data.
# But a trial technically is associated with multiple parameters.
parameter_id = load(data['hyper_params'])['parameter_id']
return TrialEnd(
trial_job_id=data['trial_job_id'],
parameter_ids=[parameter_id],
event=data['event']
)
else:
raise ValueError(f'Unknown command type: {command_type}')
def on_terminate(self, callback: Callable[[Terminate], None]) -> None:
"""Register a callback for Terminate command.
Parameters
----------
callback
A callback function that takes a :class:`Terminate` as argument.
"""
self._callbacks[Terminate.command_type].append(callback)
def on_initialize(self, callback: Callable[[Initialize], None]) -> None:
"""Register a callback for Initialize command.
Parameters
----------
callback
A callback function that takes a :class:`Initialize` as argument.
"""
self._callbacks[Initialize.command_type].append(callback)
def on_request_trial_jobs(self, callback: Callable[[RequestTrialJobs], None]) -> None:
"""Register a callback for RequestTrialJobs command.
Parameters
----------
callback
A callback function that takes a :class:`RequestTrialJobs` as argument.
"""
self._callbacks[RequestTrialJobs.command_type].append(callback)
def on_update_search_space(self, callback: Callable[[UpdateSearchSpace], None]) -> None:
"""Register a callback for UpdateSearchSpace command.
Parameters
----------
callback
A callback function that takes a :class:`UpdateSearchSpace` as argument.
"""
self._callbacks[UpdateSearchSpace.command_type].append(callback)
def on_report_metric_data(self, callback: Callable[[ReportMetricData], None]) -> None:
"""Register a callback for ReportMetricData command.
Parameters
----------
callback
A callback function that takes a :class:`ReportMetricData` as argument.
"""
self._callbacks[ReportMetricData.command_type].append(callback)
def on_trial_end(self, callback: Callable[[TrialEnd], None]) -> None:
"""Register a callback for TrialEnd command.
Parameters
----------
callback
A callback function that takes a :class:`TrialEnd` as argument.
"""
self._callbacks[TrialEnd.command_type].append(callback)
def _send(self, command_type: CommandType, data: str) -> None:
command = command_type.value.decode() + data
@ -112,3 +321,29 @@ class TunerCommandChannel:
return command
_logger.error('Failed to reconnect.')
raise RuntimeError('Connection lost')
def _validate_placement_constraint(placement_constraint):
# Currently only for CGO.
if placement_constraint is None:
raise ValueError('placement_constraint is None')
if not 'type' in placement_constraint:
raise ValueError('placement_constraint must have `type`')
if not 'gpus' in placement_constraint:
raise ValueError('placement_constraint must have `gpus`')
if placement_constraint['type'] not in ['None', 'GPUNumber', 'Device']:
raise ValueError('placement_constraint.type must be either `None`,. `GPUNumber` or `Device`')
if placement_constraint['type'] == 'None' and len(placement_constraint['gpus']) > 0:
raise ValueError('placement_constraint.gpus must be an empty list when type == None')
if placement_constraint['type'] == 'GPUNumber':
if len(placement_constraint['gpus']) != 1:
raise ValueError('placement_constraint.gpus currently only support one host when type == GPUNumber')
for e in placement_constraint['gpus']:
if not isinstance(e, int):
raise ValueError('placement_constraint.gpus must be a list of number when type == GPUNumber')
if placement_constraint['type'] == 'Device':
for e in placement_constraint['gpus']:
if not isinstance(e, tuple):
raise ValueError('placement_constraint.gpus must be a list of tuple when type == Device')
if not (len(e) == 2 and isinstance(e[0], str) and isinstance(e[1], int)):
raise ValueError('placement_constraint.gpus`s tuple must be (str, int)')

Просмотреть файл

@ -2,6 +2,11 @@
# Licensed under the MIT license.
from enum import Enum
from dataclasses import dataclass
from typing import List, Optional, ClassVar
from nni.typehint import TrialMetric, SearchSpace
from nni.utils import MetricType
class CommandType(Enum):
# in
@ -22,3 +27,45 @@ class CommandType(Enum):
NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI'
Error = b'ER'
class TunerIncomingCommand:
# For type checking.
command_type: ClassVar[CommandType]
# Only necessary commands to make NAS work.
@dataclass
class Initialize(TunerIncomingCommand):
command_type: ClassVar[CommandType] = CommandType.Initialize
search_space: SearchSpace
@dataclass
class RequestTrialJobs(TunerIncomingCommand):
command_type: ClassVar[CommandType] = CommandType.RequestTrialJobs
count: int
@dataclass
class UpdateSearchSpace(TunerIncomingCommand):
command_type: ClassVar[CommandType] = CommandType.UpdateSearchSpace
search_space: SearchSpace
@dataclass
class ReportMetricData(TunerIncomingCommand):
command_type: ClassVar[CommandType] = CommandType.ReportMetricData
parameter_id: int # Parameter ID.
type: MetricType # Request parameter, periodical, or final.
sequence: int # Sequence number of the metric.
value: Optional[TrialMetric] = None # The metric value. When type is NOT request parameter.
trial_job_id: Optional[str] = None # Only available when type is request parameter.
parameter_index: Optional[int] = None # Only available when type is request parameter.
@dataclass
class TrialEnd(TunerIncomingCommand):
command_type: ClassVar[CommandType] = CommandType.TrialEnd
trial_job_id: str # The trial job id.
parameter_ids: List[int] # All parameter ids of the trial job.
event: str # The job's state
@dataclass
class Terminate(TunerIncomingCommand):
command_type: ClassVar[CommandType] = CommandType.Terminate