зеркало из https://github.com/microsoft/nni.git
New trial runner and import trial command channel (#5398)
This commit is contained in:
Родитель
99f9c71b51
Коммит
dd4e5909c3
|
@ -0,0 +1,198 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
Command = str
|
||||||
|
|
||||||
|
|
||||||
|
class _PrefixAdapter(logging.LoggerAdapter):
|
||||||
|
def process(self, msg, kwargs):
|
||||||
|
if isinstance(self.extra, dict) and 'prefix' in self.extra:
|
||||||
|
return f"{self.extra['prefix']} {msg}", kwargs
|
||||||
|
return msg, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class FileChannel:
|
||||||
|
"""
|
||||||
|
Command channel based on access to the same directory.
|
||||||
|
|
||||||
|
The directory can be mounted at a different place for the other side,
|
||||||
|
as long as it's available to both of them.
|
||||||
|
|
||||||
|
Both side must have read and write access to the directory as well as the files inside.
|
||||||
|
If the directory doesn't exist. They must have the privilege to create it.
|
||||||
|
|
||||||
|
:class:`FileChannel` is stateful. It (at least) has a state to mark which messages (i.e. files)
|
||||||
|
that has been read. Recovering the channel from faults might lose that state and consume duplicated messages.
|
||||||
|
Thus the reader side needs the state of "current reading progress" to be persistent.
|
||||||
|
Therefore a sender can broadcast to multiple receivers via a :class:`FileChannel`,
|
||||||
|
but it can only be listening to one channel in current implementation.
|
||||||
|
|
||||||
|
All the files written by the peer are in the URL, starting with ``<peer_name>.``.
|
||||||
|
Synchronization can leverage this glob pattern.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
url
|
||||||
|
A directory on the current file system.
|
||||||
|
local_peer
|
||||||
|
Join as which peer. IDs are used to identify self, with no other limitations.
|
||||||
|
Possible values might be "server", "client", "1", "2", ...
|
||||||
|
Peer name can't contain ``.``.
|
||||||
|
remote_peer
|
||||||
|
The peer name that is connected to.
|
||||||
|
This only matters in :meth:`receive`.
|
||||||
|
|
||||||
|
Warnings
|
||||||
|
--------
|
||||||
|
The channel is not thread-safe. Behavior is undefined when two threads / processes belonging to the same peer
|
||||||
|
try to access the channel at the same time. Try to use the channel ONLY on rank 0 whenever possible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, url: str | Path, local_peer: str, remote_peer: str):
|
||||||
|
self._url: Path = Path(url)
|
||||||
|
self._local_peer = local_peer
|
||||||
|
self._remote_peer = remote_peer
|
||||||
|
|
||||||
|
assert '.' not in self._local_peer
|
||||||
|
assert '.' not in self._remote_peer
|
||||||
|
|
||||||
|
self._logger = _PrefixAdapter(
|
||||||
|
logging.getLogger(__name__),
|
||||||
|
{'prefix': f'(file channel {local_peer} -> {remote_peer})'}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._file_capacity: int = 100000000 # 1e8
|
||||||
|
self._line_limit_per_file: int = 100
|
||||||
|
|
||||||
|
# For write. Next to write is 1.
|
||||||
|
self._write_progress: int = 1
|
||||||
|
self._recover_write_state()
|
||||||
|
|
||||||
|
# For read. Has already read 0.
|
||||||
|
self._read_progress: int = 0
|
||||||
|
self._recover_read_state()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self._url}, {self._local_peer}, {self._remote_peer})'
|
||||||
|
|
||||||
|
def send(self, command: Command) -> None:
|
||||||
|
"""Send a command.
|
||||||
|
|
||||||
|
Returns immediately without checking whether the command is received successfully.
|
||||||
|
|
||||||
|
If the send (itself) is unsuccessful (e.g., due to the command is invalid),
|
||||||
|
the error is logged and ignored.
|
||||||
|
"""
|
||||||
|
if not isinstance(command, str):
|
||||||
|
self._logger.error('Sent command must be str, found %s, ignore: %s', type(command), command)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._url.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# Find a room for this message
|
||||||
|
if self._write_progress % self._file_capacity >= self._line_limit_per_file:
|
||||||
|
self._logger.debug('File full. Need a new file: %d', self._write_progress)
|
||||||
|
# 2300100 -> 2400001
|
||||||
|
self._write_progress = (self._write_progress // self._file_capacity + 1) * self._file_capacity + 1
|
||||||
|
|
||||||
|
filename = self._format_filename(self._local_peer, self._write_progress // self._file_capacity)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with filename.open('a') as f:
|
||||||
|
f.write('%016d\t' % self._write_progress + command + '\n')
|
||||||
|
f.flush()
|
||||||
|
self._logger.debug('Sent command: %s', command)
|
||||||
|
self._write_progress += 1
|
||||||
|
except:
|
||||||
|
self._logger.exception('Write to file failed: %s', filename)
|
||||||
|
|
||||||
|
def receive(self, non_blocking: bool = False) -> Command | None:
|
||||||
|
"""Receive a command.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
non_blocking
|
||||||
|
If ``True``, return immediately if no command is received.
|
||||||
|
Otherwise, block until a command comes.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
# Find a new message from two places.
|
||||||
|
|
||||||
|
# 1. Check whether there is a message from the file corresponding to current progress.
|
||||||
|
current_filename = self._format_filename(self._remote_peer, self._read_progress // self._file_capacity)
|
||||||
|
content = self._receive_from_file(current_filename)
|
||||||
|
if content is not None:
|
||||||
|
return content
|
||||||
|
|
||||||
|
# 2. Check whether there is a message from the next file.
|
||||||
|
next_filename = self._format_filename(self._remote_peer, self._read_progress // self._file_capacity + 1)
|
||||||
|
content = self._receive_from_file(next_filename)
|
||||||
|
if content is not None:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if non_blocking:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._logger.debug('Nothing received. Try again later.')
|
||||||
|
time.sleep(1.)
|
||||||
|
|
||||||
|
def _format_filename(self, peer_name: str, file_index: int) -> Path:
|
||||||
|
assert peer_name in [self._local_peer, self._remote_peer]
|
||||||
|
return self._url / f'{peer_name}.{file_index:08d}'
|
||||||
|
|
||||||
|
def _recover_write_state(self) -> None:
|
||||||
|
while True:
|
||||||
|
path = self._format_filename(self._local_peer, self._write_progress // self._file_capacity)
|
||||||
|
if path.exists():
|
||||||
|
# Regardless of whether it's full or not.
|
||||||
|
self._write_progress += self._file_capacity
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if self._write_progress > 1:
|
||||||
|
self._logger.info('Write progress is recovered to be: %d', self._write_progress)
|
||||||
|
|
||||||
|
def _recover_read_state(self) -> None:
|
||||||
|
path = self._url / f'{self._local_peer}.read'
|
||||||
|
if not path.exists():
|
||||||
|
self._logger.debug('Reading state does not exist. Nothing to recover.')
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
with path.open() as f:
|
||||||
|
self._read_progress = int(f.read())
|
||||||
|
self._logger.info('Read progress is recovered to be: %d', self._read_progress)
|
||||||
|
except:
|
||||||
|
self._logger.exception('Reading state appears to be corrupted: %s', path)
|
||||||
|
|
||||||
|
def _save_read_state(self) -> None:
|
||||||
|
try:
|
||||||
|
self._url.mkdir(exist_ok=True, parents=True)
|
||||||
|
with (self._url / f'{self._local_peer}.read').open('w') as f:
|
||||||
|
f.write(str(self._read_progress))
|
||||||
|
self._logger.debug('Read progress successfully updated: %d', self._read_progress)
|
||||||
|
except:
|
||||||
|
self._logger.exception('Reading state fails to dump: %d', self._read_progress)
|
||||||
|
|
||||||
|
def _receive_from_file(self, file: Path) -> str | None:
|
||||||
|
if not file.exists():
|
||||||
|
self._logger.debug('%s does not exist yet.', file)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with file.open() as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
id, content = line.split('\t', 1) # pylint: disable=redefined-builtin
|
||||||
|
if int(id) > self._read_progress:
|
||||||
|
content = content.rstrip('\n')
|
||||||
|
self._logger.debug('Received command: %s', content)
|
||||||
|
self._read_progress = int(id)
|
||||||
|
self._save_read_state()
|
||||||
|
return content
|
||||||
|
except:
|
||||||
|
self._logger.exception('File appears to be corrupted: %s', file)
|
||||||
|
return None
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
import nni
|
||||||
|
from nni.typehint import ParameterRecord, TrialMetric
|
||||||
|
from nni.runtime.env_vars import trial_env_vars
|
||||||
|
from nni.runtime.trial_command_channel import TrialCommandChannel
|
||||||
|
|
||||||
|
from .trial_runner import TrialServerHandler
|
||||||
|
from .typehint import MetricCommand
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TrialClient(TrialCommandChannel):
|
||||||
|
"""The client side of :class:`TrialServer`."""
|
||||||
|
|
||||||
|
def __init__(self, url: str | None = None, trial_id: str | None = None) -> None:
|
||||||
|
if url is not None:
|
||||||
|
self._url = url
|
||||||
|
else:
|
||||||
|
self._url = TrialServerHandler.ADDRESS
|
||||||
|
if trial_id is not None:
|
||||||
|
self._trial_id = trial_id
|
||||||
|
else:
|
||||||
|
self._trial_id = trial_env_vars.NNI_TRIAL_JOB_ID
|
||||||
|
|
||||||
|
def receive_parameter(self) -> ParameterRecord | None:
|
||||||
|
response = requests.get(TrialServerHandler.ADDRESS + '/parameter/' + self._trial_id)
|
||||||
|
if response.status_code != 200:
|
||||||
|
_logger.error('Failed to receive parameter: %s', response)
|
||||||
|
return None
|
||||||
|
parameter = response.json()['parameter']
|
||||||
|
if not parameter:
|
||||||
|
_logger.error('Received empty parameter: \'%s\'', parameter)
|
||||||
|
return None
|
||||||
|
if not isinstance(parameter, str):
|
||||||
|
_logger.error('Received invalid parameter: \'%s\'', parameter)
|
||||||
|
return None
|
||||||
|
return nni.load(parameter) # Unpack the parameter generated by tuner.
|
||||||
|
|
||||||
|
def send_metric(
|
||||||
|
self,
|
||||||
|
type: Literal['PERIODICAL', 'FINAL'], # pylint: disable=redefined-builtin
|
||||||
|
parameter_id: int | None,
|
||||||
|
trial_job_id: str,
|
||||||
|
sequence: int,
|
||||||
|
value: TrialMetric
|
||||||
|
) -> None:
|
||||||
|
metric = {
|
||||||
|
'parameter_id': parameter_id,
|
||||||
|
'trial_job_id': trial_job_id,
|
||||||
|
'type': type,
|
||||||
|
'sequence': sequence,
|
||||||
|
'value': nni.dump(value), # Pack the metric value, which will be unpacked by tuner.
|
||||||
|
}
|
||||||
|
command = MetricCommand(command_type='metric', id=trial_job_id, metric=nni.dump(metric))
|
||||||
|
response = requests.post(self._url + '/metric', json=command)
|
||||||
|
if response.status_code != 200:
|
||||||
|
_logger.error('Failed to send metric: %s', response)
|
|
@ -0,0 +1,434 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from functools import partial
|
||||||
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Sequence, Callable
|
||||||
|
|
||||||
|
from .file_channel import FileChannel
|
||||||
|
from .typehint import (
|
||||||
|
CreateCommand, KillCommand, MetricCommand, TrialStatusCommand, WakeUpCommand, ReportAwakeCommand,
|
||||||
|
Trial, Status, typed_dict_validation
|
||||||
|
)
|
||||||
|
from .utils import graceful_kill, add_handler
|
||||||
|
|
||||||
|
_logger = logging.getLogger('nni_amlt.trial_runner')
|
||||||
|
|
||||||
|
|
||||||
|
class TrialServerHandler(BaseHTTPRequestHandler):
|
||||||
|
"""A server for trial to get parameters and report metrics to the trial runner."""
|
||||||
|
|
||||||
|
PORT = 36378
|
||||||
|
ADDRESS = 'http://localhost:36378'
|
||||||
|
|
||||||
|
def __init__(self, trials: Sequence[Trial], on_metric: Callable[[MetricCommand], None], *args, **kwargs):
|
||||||
|
self.trials = trials
|
||||||
|
self.on_metric = on_metric
|
||||||
|
# Must be before super.init. The handler will start to handle requests within super.init.
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _set_headers(self):
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def _send_bad_request(self):
|
||||||
|
self.send_response(400)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def _send_not_found(self):
|
||||||
|
self.send_response(404)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def _send_ok(self):
|
||||||
|
self.send_response(200)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def do_HEAD(self):
|
||||||
|
self._set_headers()
|
||||||
|
|
||||||
|
def do_GET(self):
|
||||||
|
"""GET request must be requesting parameters."""
|
||||||
|
if not self.path.startswith('/parameter/'):
|
||||||
|
_logger.error('Invalid path for HTTP GET: %s', self.path)
|
||||||
|
self._send_bad_request()
|
||||||
|
return
|
||||||
|
|
||||||
|
trial_id = self.path.split('/')[-1]
|
||||||
|
|
||||||
|
for trial in self.trials:
|
||||||
|
if trial['id'] == trial_id:
|
||||||
|
self._set_headers()
|
||||||
|
self.wfile.write(json.dumps(trial).encode())
|
||||||
|
return
|
||||||
|
|
||||||
|
_logger.error('Trial ID %s not found in parameters', trial_id)
|
||||||
|
self._send_not_found()
|
||||||
|
return
|
||||||
|
|
||||||
|
def do_POST(self):
|
||||||
|
"""POST request must be sending results."""
|
||||||
|
if self.path != '/metric':
|
||||||
|
_logger.error('Invalid path for HTTP POST: %s', self.path)
|
||||||
|
self._send_bad_request()
|
||||||
|
return
|
||||||
|
|
||||||
|
content_type = self.headers.get_content_type()
|
||||||
|
# refuse to receive non-json content
|
||||||
|
if content_type != 'application/json':
|
||||||
|
self._send_bad_request()
|
||||||
|
return
|
||||||
|
|
||||||
|
content_length = int(self.headers.get('content-length'))
|
||||||
|
message = json.loads(self.rfile.read(content_length))
|
||||||
|
|
||||||
|
if not typed_dict_validation(MetricCommand, message):
|
||||||
|
_logger.error('Invalid message: %s', message)
|
||||||
|
self._send_bad_request()
|
||||||
|
return
|
||||||
|
|
||||||
|
self.on_metric(message)
|
||||||
|
self._send_ok()
|
||||||
|
|
||||||
|
|
||||||
|
class TrialRunner:
|
||||||
|
"""
|
||||||
|
Runner to process incoming trial commands.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
channel
|
||||||
|
Channel to communicate with the management end.
|
||||||
|
The runner only uses the "send" direction of the channel.
|
||||||
|
runner_dir
|
||||||
|
Directory for runner to save logs, save/restore checkpoints.
|
||||||
|
Usually **unshared** between multiple ranks (nodes).
|
||||||
|
trial_output_dir
|
||||||
|
Directory for trials to save their output files.
|
||||||
|
Subdirectory with trial IDs will be created inside.
|
||||||
|
Usually **shared** between multiple ranks.
|
||||||
|
trial_log_dir
|
||||||
|
Path to where trial log is stored.
|
||||||
|
Usually **unshared** between ranks.
|
||||||
|
log_buffer_size
|
||||||
|
Buffer size of trial stdout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel: FileChannel, runner_dir: Path,
|
||||||
|
trial_output_dir: Path, trial_log_dir: Path,
|
||||||
|
log_buffer_size: int) -> None:
|
||||||
|
self._channel = channel
|
||||||
|
self._runner_dir = runner_dir
|
||||||
|
self._trial_output_dir = trial_output_dir
|
||||||
|
self._trial_log_dir = trial_log_dir
|
||||||
|
self._log_buffer_size = log_buffer_size
|
||||||
|
|
||||||
|
self._processing_trials: deque[Trial] = deque() # including the current running one.
|
||||||
|
self._running_process: subprocess.Popen | None = None
|
||||||
|
|
||||||
|
if self._checkpoint_path.exists():
|
||||||
|
self.load_checkpoint()
|
||||||
|
|
||||||
|
self._server_thread = threading.Thread(target=self._server, daemon=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _checkpoint_path(self) -> Path:
|
||||||
|
return self._runner_dir / 'trial_runner.json'
|
||||||
|
|
||||||
|
def _server(self) -> None:
|
||||||
|
server_address = ('', TrialServerHandler.PORT)
|
||||||
|
httpd = HTTPServer(server_address, partial(TrialServerHandler, self._processing_trials, self._on_metric))
|
||||||
|
httpd.serve_forever()
|
||||||
|
|
||||||
|
def _on_metric(self, command: MetricCommand) -> None:
|
||||||
|
self._channel.send(json.dumps(command))
|
||||||
|
|
||||||
|
def load_checkpoint(self) -> None:
|
||||||
|
try:
|
||||||
|
with self._checkpoint_path.open() as f:
|
||||||
|
checkpoint_data = json.load(f)
|
||||||
|
self._processing_trials = deque()
|
||||||
|
for t in checkpoint_data['queued_trials']:
|
||||||
|
if typed_dict_validation(Trial, t):
|
||||||
|
self._processing_trials.append(t)
|
||||||
|
else:
|
||||||
|
_logger.error('Ignored when loading checkpoint as it appears not a valid trial: %s', t)
|
||||||
|
if isinstance(checkpoint_data['parameters'], dict):
|
||||||
|
self._parameters = checkpoint_data['parameters']
|
||||||
|
_logger.info('Checkpoint loaded. Processing trials: %s', self._processing_trials)
|
||||||
|
except:
|
||||||
|
_logger.exception('Checkpoint loaded failed: %s', self._checkpoint_path)
|
||||||
|
|
||||||
|
self._refresh_queue()
|
||||||
|
|
||||||
|
def save_checkpoint(self) -> None:
|
||||||
|
try:
|
||||||
|
checkpoint_data = {
|
||||||
|
'queued_trials': list(self._processing_trials),
|
||||||
|
'parameters': self._parameters
|
||||||
|
}
|
||||||
|
self._checkpoint_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
with self._checkpoint_path.open('w') as f:
|
||||||
|
json.dump(checkpoint_data, f)
|
||||||
|
except:
|
||||||
|
_logger.exception('Checkpoint saved failed: %s', self._checkpoint_path)
|
||||||
|
|
||||||
|
def check_status(self) -> list[Trial]:
|
||||||
|
"""
|
||||||
|
Check the status of the runner and return processing trials (including running + pending).
|
||||||
|
|
||||||
|
The caller should be responsible for :meth:`check_status` regularly.
|
||||||
|
Otherwise the trials in the queue won't be auto-processed.
|
||||||
|
"""
|
||||||
|
# Check the status of current running trial.
|
||||||
|
self._refresh_queue()
|
||||||
|
# List running and pending trials.
|
||||||
|
return list(self._processing_trials)
|
||||||
|
|
||||||
|
def create_trial(self, trial: Trial) -> None:
|
||||||
|
"""Submit a trial for running.
|
||||||
|
|
||||||
|
Returns instantly.
|
||||||
|
"""
|
||||||
|
self._processing_trials.append(trial)
|
||||||
|
self._refresh_queue()
|
||||||
|
|
||||||
|
def kill_trial(self, id: str) -> None: # pylint: disable=redefined-builtin
|
||||||
|
"""Kill a trial.
|
||||||
|
|
||||||
|
Currently must be the running trial.
|
||||||
|
"""
|
||||||
|
if len(self._processing_trials) > 0 and self._running_process is not None:
|
||||||
|
trial = self._processing_trials[0]
|
||||||
|
if trial['id'] == id:
|
||||||
|
graceful_kill(self._running_process)
|
||||||
|
returncode = self._running_process.returncode
|
||||||
|
_logger.info('Process %s is killed with exit code: %s', self._running_process, returncode)
|
||||||
|
self._processing_trials.popleft()
|
||||||
|
self._emit_status_change(trial['id'], 'interrupted')
|
||||||
|
self._running_process = None
|
||||||
|
|
||||||
|
# Run the next trial if any.
|
||||||
|
self._refresh_queue()
|
||||||
|
return
|
||||||
|
|
||||||
|
_logger.warning('Trial %s is not running. Cannot kill it.', id)
|
||||||
|
|
||||||
|
def send_heartbeat(self) -> float:
|
||||||
|
"""Send a heartbeat to the other side."""
|
||||||
|
current_time = time.time()
|
||||||
|
command = ReportAwakeCommand(
|
||||||
|
command_type='awake',
|
||||||
|
time=current_time,
|
||||||
|
idle=not self._processing_trials
|
||||||
|
)
|
||||||
|
self._channel.send(json.dumps(command))
|
||||||
|
return current_time
|
||||||
|
|
||||||
|
def _refresh_queue(self) -> None:
|
||||||
|
if not self._processing_trials:
|
||||||
|
_logger.debug('No trials. Nothing to refresh.')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Move the queue. See if the upfront trial is completed,
|
||||||
|
# and whether the next trial should be run.
|
||||||
|
if self._running_process is not None:
|
||||||
|
if self._running_process.poll() is not None:
|
||||||
|
returncode = self._running_process.returncode
|
||||||
|
_logger.info('Process %s return with exit code: %s', self._running_process, returncode)
|
||||||
|
if returncode == 0:
|
||||||
|
status: Status = 'succeeded'
|
||||||
|
else:
|
||||||
|
status: Status = 'failed'
|
||||||
|
trial = self._processing_trials.popleft()
|
||||||
|
self._emit_status_change(trial['id'], status)
|
||||||
|
self._running_process = None
|
||||||
|
|
||||||
|
# Run a new trial.
|
||||||
|
if len(self._processing_trials) > 0 and self._running_process is None:
|
||||||
|
trial = self._processing_trials[0]
|
||||||
|
_logger.info('Running: %s', trial['command'])
|
||||||
|
self._running_process = subprocess.Popen(
|
||||||
|
trial['command'],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
bufsize=self._log_buffer_size,
|
||||||
|
shell=True
|
||||||
|
)
|
||||||
|
self._start_stdout_logging(self._trial_log_dir / (trial['id'] + '.txt'))
|
||||||
|
self._emit_status_change(trial['id'], 'running')
|
||||||
|
|
||||||
|
def _environ(self, trial: Trial) -> dict:
|
||||||
|
"""Generate environment variables for a trial."""
|
||||||
|
environ_base = dict(os.environ)
|
||||||
|
|
||||||
|
output_dir = str(self._trial_output_dir / trial['id'])
|
||||||
|
nni_environ = dict(
|
||||||
|
NNI_PLATFORM='amlt',
|
||||||
|
NNI_EXP_ID=trial['experiment'],
|
||||||
|
NNI_TRIAL_JOB_ID=trial['id'],
|
||||||
|
NNI_SYS_DIR=output_dir,
|
||||||
|
NNI_OUTPUT_DIR=output_dir,
|
||||||
|
NNI_TRIAL_SEQ_ID=trial['sequence'],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**environ_base,
|
||||||
|
**nni_environ
|
||||||
|
}
|
||||||
|
|
||||||
|
def _start_stdout_logging(self, file: Path) -> None:
|
||||||
|
if self._running_process is None:
|
||||||
|
_logger.error('No running process to start logging.')
|
||||||
|
return
|
||||||
|
|
||||||
|
def _tee(infile, log_file: Path) -> None:
|
||||||
|
log_file.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
with infile, open(log_file, 'ab') as f:
|
||||||
|
for line in iter(infile.readline, b''):
|
||||||
|
f.write(line)
|
||||||
|
sys.stdout.buffer.write(line)
|
||||||
|
# Did not flush here.
|
||||||
|
|
||||||
|
file.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
t = threading.Thread(target=_tee, args=(self._running_process.stdout, file), daemon=True)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
def _emit_status_change(self, trial_id: str, status: Status) -> None:
|
||||||
|
command = TrialStatusCommand(
|
||||||
|
command_type='status',
|
||||||
|
id=trial_id,
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
_logger.debug('Emit status change: %s', command)
|
||||||
|
self._channel.send(json.dumps(command))
|
||||||
|
|
||||||
|
|
||||||
|
def trial_runner_loop(
|
||||||
|
channel: str | Path,
|
||||||
|
out: str | Path,
|
||||||
|
rank: int,
|
||||||
|
interval: float,
|
||||||
|
patience: float,
|
||||||
|
log_buffer_size: int
|
||||||
|
) -> None:
|
||||||
|
output_dir = Path(out)
|
||||||
|
runner_dir = output_dir / f'trial_runner_{rank}'
|
||||||
|
trial_log_dir = output_dir / f'logs_{rank}'
|
||||||
|
runner_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# Init logger if not inited.
|
||||||
|
add_handler(_logger, runner_dir / f'trial_runner.log')
|
||||||
|
|
||||||
|
_logger.info('Trial runner started.')
|
||||||
|
|
||||||
|
_logger.info('Saving trial runner states to: %s', runner_dir)
|
||||||
|
|
||||||
|
file_channel = FileChannel(channel, f'worker-{rank}', 'manager')
|
||||||
|
_logger.info('Using channel %s to communicate with NNI manager', file_channel)
|
||||||
|
|
||||||
|
log_buffer_size = log_buffer_size
|
||||||
|
_logger.info('Buffer size for trial stodut: %d', log_buffer_size)
|
||||||
|
|
||||||
|
trial_runner = TrialRunner(file_channel, runner_dir, output_dir, trial_log_dir, log_buffer_size)
|
||||||
|
|
||||||
|
last_good = time.time()
|
||||||
|
last_heartbeat = time.time()
|
||||||
|
heartbeat_interval = interval
|
||||||
|
|
||||||
|
trial_runner.send_heartbeat()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if trial_runner.check_status():
|
||||||
|
_logger.info('Trial runner has running trials. Be patient.')
|
||||||
|
last_good = time.time()
|
||||||
|
|
||||||
|
trial_runner.save_checkpoint()
|
||||||
|
|
||||||
|
# Receive a control command from manager side.
|
||||||
|
command = file_channel.receive(non_blocking=True)
|
||||||
|
|
||||||
|
if command is not None:
|
||||||
|
try:
|
||||||
|
command = json.loads(command)
|
||||||
|
except:
|
||||||
|
_logger.exception('Command decode error. Skip this command: %s', command)
|
||||||
|
command = None
|
||||||
|
|
||||||
|
if command is not None:
|
||||||
|
if not isinstance(command, dict) or 'command_type' not in command:
|
||||||
|
_logger.error('Invalid command: %s', command)
|
||||||
|
else:
|
||||||
|
command_type = command['command_type']
|
||||||
|
if command_type == 'create' and typed_dict_validation(CreateCommand, command):
|
||||||
|
trial_runner.create_trial(command['trial'])
|
||||||
|
elif command_type == 'kill' and typed_dict_validation(KillCommand, command):
|
||||||
|
trial_runner.kill_trial(command['id'])
|
||||||
|
elif command_type == 'wakeup' and typed_dict_validation(WakeUpCommand, command):
|
||||||
|
last_heartbeat = trial_runner.send_heartbeat()
|
||||||
|
else:
|
||||||
|
_logger.error('Unsupported command: %s', command)
|
||||||
|
|
||||||
|
trial_runner.save_checkpoint()
|
||||||
|
|
||||||
|
# Reset heartbeat interval to communicate more frequently
|
||||||
|
heartbeat_interval = interval
|
||||||
|
|
||||||
|
# No sleep. Continue to next command.
|
||||||
|
|
||||||
|
else:
|
||||||
|
elapsed = time.time() - last_good
|
||||||
|
_logger.info('No command received. Patience: %f / %f', elapsed, patience)
|
||||||
|
|
||||||
|
if elapsed > patience:
|
||||||
|
_logger.warning('No command received for too long. Quit the runner.')
|
||||||
|
break
|
||||||
|
|
||||||
|
if time.time() - last_heartbeat > heartbeat_interval:
|
||||||
|
last_heartbeat = trial_runner.send_heartbeat()
|
||||||
|
# Exponentially increase heartbeat interval
|
||||||
|
heartbeat_interval = heartbeat_interval * 1.5
|
||||||
|
|
||||||
|
time.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Amulet training service trial runner')
|
||||||
|
parser.add_argument('channel', type=str, help='The path where file channel is mounted (in cluster container)')
|
||||||
|
parser.add_argument('out', type=str, default=None,
|
||||||
|
help='Checkpoint directory of the trial runner. If specified, trial runner will try to find its checkpoint.')
|
||||||
|
parser.add_argument('--rank', type=int, default=None,
|
||||||
|
help='Rank of trial runner. Meaningful for distributed training. '
|
||||||
|
'If not set, will try to read from environment variable `RANK`.')
|
||||||
|
parser.add_argument('--interval', type=float, default=60.,
|
||||||
|
help='Interval (seconds) between two polls of the channel')
|
||||||
|
parser.add_argument('--heartbeat-max-interval', type=float, default=600.,
|
||||||
|
help='Max interval (seconds) between two heartbeats. '
|
||||||
|
'Heartbeat is used to tell the manager that the runner is still alive. '
|
||||||
|
'The initial heartbeat interval is `interval`. '
|
||||||
|
'It will be exponentially increased until it reaches this value if no message from manager is received.')
|
||||||
|
parser.add_argument('--patience', type=float, default=1800.,
|
||||||
|
help='Number of seconds without any updates or running trials before the runner shutdowns')
|
||||||
|
parser.add_argument('--log-buffer-size', type=int, default=0,
|
||||||
|
help='Buffer size for trial stdout. See bufsize in `subprocess.Popen`.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.rank is None:
|
||||||
|
args.rank = int(os.environ.get('RANK', 0))
|
||||||
|
trial_runner_loop(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,101 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import TypedDict, Any, Type, TypeVar, Optional, get_origin
|
||||||
|
from typing_extensions import Literal, TypeGuard
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
def typed_dict_validation(typ: Type[T], instance: Any) -> TypeGuard[T]:
|
||||||
|
# https://stackoverflow.com/questions/66665336/validate-python-typeddict-at-runtime
|
||||||
|
if not isinstance(instance, dict):
|
||||||
|
_logger.error('Validation failed for %s. Instance is not a dict: %s', typ, type(instance))
|
||||||
|
return False
|
||||||
|
|
||||||
|
for property_name, property_type in typ.__annotations__.items():
|
||||||
|
if property_name not in instance:
|
||||||
|
# Check for missing keys
|
||||||
|
_logger.error('Validation failed for %s. Missing key: %s', typ, property_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
value = instance[property_name]
|
||||||
|
if property_type in (int, float, bool, str):
|
||||||
|
# Check for type equality
|
||||||
|
if not isinstance(value, property_type):
|
||||||
|
_logger.error('Validation failed for %s. Wrong type: %s. Expected %s, got %s',
|
||||||
|
typ, property_name, property_type, type(value))
|
||||||
|
return False
|
||||||
|
|
||||||
|
elif get_origin(property_type) == Literal:
|
||||||
|
# Check literal.
|
||||||
|
if value not in property_type.__args__:
|
||||||
|
_logger.error('Validation failed for %s. Expect literal to be one of %s, got %s',
|
||||||
|
typ, property_type.__args__, value)
|
||||||
|
return False
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Assuming a nested typed dict.
|
||||||
|
result = typed_dict_validation(property_type, value)
|
||||||
|
if result is False:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class Trial(TypedDict):
|
||||||
|
id: str
|
||||||
|
sequence: int
|
||||||
|
experiment: str
|
||||||
|
command: str
|
||||||
|
parameter: Optional[str] # Serialized JSON string.
|
||||||
|
# time_limit: float
|
||||||
|
|
||||||
|
# Command types are as few as possible.
|
||||||
|
# The implementation also tries to avoid unnecessary dependencies,
|
||||||
|
# to increase the robustness.
|
||||||
|
|
||||||
|
|
||||||
|
UpstreamCommandType = Literal['create', 'kill', 'wakeup'] # manager -> worker
|
||||||
|
DownstreamCommandType = Literal['metric', 'status', 'awake'] # worker -> manager
|
||||||
|
Status = Literal['waiting', 'running', 'succeeded', 'failed', 'interrupted']
|
||||||
|
|
||||||
|
|
||||||
|
class CreateCommand(TypedDict):
|
||||||
|
command_type: Literal['create']
|
||||||
|
trial: Trial
|
||||||
|
|
||||||
|
|
||||||
|
class KillCommand(TypedDict):
|
||||||
|
command_type: Literal['kill']
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
|
class MetricCommand(TypedDict):
|
||||||
|
command_type: Literal['metric']
|
||||||
|
id: str
|
||||||
|
metric: str # Serialized JSON string.
|
||||||
|
|
||||||
|
|
||||||
|
class TrialStatusCommand(TypedDict):
|
||||||
|
command_type: Literal['status']
|
||||||
|
id: str
|
||||||
|
status: Status
|
||||||
|
|
||||||
|
|
||||||
|
class WakeUpCommand(TypedDict):
|
||||||
|
# Request the worker to report its status (more frequently).
|
||||||
|
command_type: Literal['wakeup']
|
||||||
|
|
||||||
|
|
||||||
|
class ReportAwakeCommand(TypedDict):
|
||||||
|
# The only way to report that the worker is alive (and idle or occupied).
|
||||||
|
command_type: Literal['awake']
|
||||||
|
time: float
|
||||||
|
# NOTE: time here is only for verbose.
|
||||||
|
# It should be avoided from usage because the cluster might have a different time from local.
|
||||||
|
idle: bool
|
|
@ -0,0 +1,138 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from subprocess import Popen, PIPE
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
_logger_init: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def graceful_kill(popen: Popen) -> int | None:
|
||||||
|
for retry in [1., 5., 20., 60.]:
|
||||||
|
_logger.info('Gracefully terminating %s...', popen)
|
||||||
|
|
||||||
|
if retry > 10:
|
||||||
|
_logger.info('Use "terminate" instead of "interrupt".')
|
||||||
|
popen.terminate()
|
||||||
|
else:
|
||||||
|
popen.send_signal(signal.SIGINT)
|
||||||
|
|
||||||
|
time.sleep(1.) # Wait for the kill to take effect.
|
||||||
|
|
||||||
|
retcode = popen.poll()
|
||||||
|
if retcode is not None:
|
||||||
|
return retcode
|
||||||
|
|
||||||
|
_logger.warning('%s still alive. Retry to kill in %d seconds.', popen, retry)
|
||||||
|
time.sleep(retry)
|
||||||
|
|
||||||
|
_logger.warning('Force kill process %s...', popen)
|
||||||
|
time.sleep(10.) # Wait for the kill
|
||||||
|
retcode = popen.poll()
|
||||||
|
if retcode is not None:
|
||||||
|
return retcode
|
||||||
|
|
||||||
|
_logger.error('Failed to kill process %s.', popen)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run_subprocess(command: list[str], log_file: Path, timeout: float | None = None) -> tuple[int, str, str]:
|
||||||
|
if timeout:
|
||||||
|
_logger.info('Running command with timeout %f seconds: %s', timeout, command)
|
||||||
|
else:
|
||||||
|
_logger.info('Running command: %s', command)
|
||||||
|
_logger.info('Output saved to: %s', log_file)
|
||||||
|
|
||||||
|
stdout, stderr = '', ''
|
||||||
|
file_handle = None
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
file_handle = log_file.open('w')
|
||||||
|
file_handle.write(f'Command: {command}')
|
||||||
|
proc = Popen(
|
||||||
|
command,
|
||||||
|
stdout=PIPE,
|
||||||
|
stderr=PIPE,
|
||||||
|
encoding='utf-8',
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
out, err = proc.communicate(timeout=1)
|
||||||
|
if out:
|
||||||
|
sys.stdout.write(out)
|
||||||
|
sys.stdout.flush()
|
||||||
|
file_handle.write(out)
|
||||||
|
stdout += out
|
||||||
|
if err:
|
||||||
|
sys.stderr.write(err)
|
||||||
|
sys.stderr.flush()
|
||||||
|
file_handle.write(err)
|
||||||
|
stderr += err
|
||||||
|
file_handle.flush()
|
||||||
|
|
||||||
|
# See if the process has terminated
|
||||||
|
if proc.poll() is not None:
|
||||||
|
returncode = proc.returncode
|
||||||
|
if returncode != 0:
|
||||||
|
_logger.error('Command failed with return code %d: %s', returncode, command)
|
||||||
|
else:
|
||||||
|
_logger.info('Command finished with return code %d: %s', returncode, command)
|
||||||
|
return returncode, stdout, stderr
|
||||||
|
|
||||||
|
# See if we timed out
|
||||||
|
if timeout is not None and time.time() - start_time > timeout:
|
||||||
|
_logger.warning('Command timed out (%f seconds): %s', timeout, command)
|
||||||
|
returncode = graceful_kill(proc)
|
||||||
|
if returncode is None:
|
||||||
|
_logger.error('Return code is still none after attempting to kill it. The process (%d) may be stuck.', proc.pid)
|
||||||
|
returncode = 1
|
||||||
|
return returncode, stdout, stderr
|
||||||
|
finally:
|
||||||
|
if file_handle is not None:
|
||||||
|
file_handle.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_logger() -> None:
|
||||||
|
"""
|
||||||
|
Initialize the logger. Log to stdout by default.
|
||||||
|
"""
|
||||||
|
global _logger_init
|
||||||
|
if _logger_init:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger = logging.getLogger('nni_amlt')
|
||||||
|
logger.setLevel(level=logging.INFO)
|
||||||
|
add_handler(logger)
|
||||||
|
|
||||||
|
_logger_init = True
|
||||||
|
|
||||||
|
|
||||||
|
def add_handler(logger: logging.Logger, file: Path | None = None) -> logging.Handler:
|
||||||
|
"""
|
||||||
|
Add a logging handler.
|
||||||
|
If ``file`` is specified, log to file.
|
||||||
|
Otherwise, add a handler to stdout.
|
||||||
|
"""
|
||||||
|
fmt = '[%(asctime)s] %(levelname)s (%(threadName)s:%(name)s) %(message)s'
|
||||||
|
datefmt = '%Y-%m-%d %H:%M:%S'
|
||||||
|
|
||||||
|
formatter = logging.Formatter(fmt, datefmt)
|
||||||
|
if file is None:
|
||||||
|
# Log to stdout.
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
else:
|
||||||
|
handler = logging.FileHandler(file)
|
||||||
|
handler.setLevel(level=logging.DEBUG) # Print all the logs.
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
return handler
|
|
@ -44,7 +44,15 @@ def set_default_trial_command_channel(channel: Optional[TrialCommandChannel] = N
|
||||||
assert dispatcher_env_vars.SDK_PROCESS != 'dispatcher'
|
assert dispatcher_env_vars.SDK_PROCESS != 'dispatcher'
|
||||||
|
|
||||||
channel_url = trial_env_vars.NNI_TRIAL_COMMAND_CHANNEL
|
channel_url = trial_env_vars.NNI_TRIAL_COMMAND_CHANNEL
|
||||||
if channel_url:
|
if isinstance(channel_url, str) and channel_url.startswith('import://'):
|
||||||
|
_, channel_class_name = channel_url.split('://', 1)
|
||||||
|
module_name, class_name = channel_class_name.rsplit('.', 1)
|
||||||
|
module = __import__(module_name)
|
||||||
|
channel_class = getattr(module, class_name)
|
||||||
|
_channel = channel_class()
|
||||||
|
if not isinstance(_channel, TrialCommandChannel):
|
||||||
|
raise TypeError(f'{_channel} is not an instance of TrialCommandChannel')
|
||||||
|
elif channel_url:
|
||||||
from .v3 import TrialCommandChannelV3
|
from .v3 import TrialCommandChannelV3
|
||||||
_channel = TrialCommandChannelV3(channel_url)
|
_channel = TrialCommandChannelV3(channel_url)
|
||||||
elif trial_env_vars.NNI_PLATFORM is None or trial_env_vars.NNI_PLATFORM == 'unittest':
|
elif trial_env_vars.NNI_PLATFORM is None or trial_env_vars.NNI_PLATFORM == 'unittest':
|
||||||
|
|
Загрузка…
Ссылка в новой задаче