New trial runner and import trial command channel (#5398)

This commit is contained in:
Yuge Zhang 2023-02-23 14:14:33 +08:00 коммит произвёл GitHub
Родитель 99f9c71b51
Коммит dd4e5909c3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 947 добавлений и 1 удалений

Просмотреть файл

Просмотреть файл

@ -0,0 +1,198 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import time
from pathlib import Path
Command = str
class _PrefixAdapter(logging.LoggerAdapter):
def process(self, msg, kwargs):
if isinstance(self.extra, dict) and 'prefix' in self.extra:
return f"{self.extra['prefix']} {msg}", kwargs
return msg, kwargs
class FileChannel:
"""
Command channel based on access to the same directory.
The directory can be mounted at a different place for the other side,
as long as it's available to both of them.
Both side must have read and write access to the directory as well as the files inside.
If the directory doesn't exist. They must have the privilege to create it.
:class:`FileChannel` is stateful. It (at least) has a state to mark which messages (i.e. files)
that has been read. Recovering the channel from faults might lose that state and consume duplicated messages.
Thus the reader side needs the state of "current reading progress" to be persistent.
Therefore a sender can broadcast to multiple receivers via a :class:`FileChannel`,
but it can only be listening to one channel in current implementation.
All the files written by the peer are in the URL, starting with ``<peer_name>.``.
Synchronization can leverage this glob pattern.
Parameters
----------
url
A directory on the current file system.
local_peer
Join as which peer. IDs are used to identify self, with no other limitations.
Possible values might be "server", "client", "1", "2", ...
Peer name can't contain ``.``.
remote_peer
The peer name that is connected to.
This only matters in :meth:`receive`.
Warnings
--------
The channel is not thread-safe. Behavior is undefined when two threads / processes belonging to the same peer
try to access the channel at the same time. Try to use the channel ONLY on rank 0 whenever possible.
"""
def __init__(self, url: str | Path, local_peer: str, remote_peer: str):
self._url: Path = Path(url)
self._local_peer = local_peer
self._remote_peer = remote_peer
assert '.' not in self._local_peer
assert '.' not in self._remote_peer
self._logger = _PrefixAdapter(
logging.getLogger(__name__),
{'prefix': f'(file channel {local_peer} -> {remote_peer})'}
)
self._file_capacity: int = 100000000 # 1e8
self._line_limit_per_file: int = 100
# For write. Next to write is 1.
self._write_progress: int = 1
self._recover_write_state()
# For read. Has already read 0.
self._read_progress: int = 0
self._recover_read_state()
def __repr__(self):
return f'{self.__class__.__name__}({self._url}, {self._local_peer}, {self._remote_peer})'
def send(self, command: Command) -> None:
"""Send a command.
Returns immediately without checking whether the command is received successfully.
If the send (itself) is unsuccessful (e.g., due to the command is invalid),
the error is logged and ignored.
"""
if not isinstance(command, str):
self._logger.error('Sent command must be str, found %s, ignore: %s', type(command), command)
return
self._url.mkdir(exist_ok=True, parents=True)
# Find a room for this message
if self._write_progress % self._file_capacity >= self._line_limit_per_file:
self._logger.debug('File full. Need a new file: %d', self._write_progress)
# 2300100 -> 2400001
self._write_progress = (self._write_progress // self._file_capacity + 1) * self._file_capacity + 1
filename = self._format_filename(self._local_peer, self._write_progress // self._file_capacity)
try:
with filename.open('a') as f:
f.write('%016d\t' % self._write_progress + command + '\n')
f.flush()
self._logger.debug('Sent command: %s', command)
self._write_progress += 1
except:
self._logger.exception('Write to file failed: %s', filename)
def receive(self, non_blocking: bool = False) -> Command | None:
"""Receive a command.
Parameters
----------
non_blocking
If ``True``, return immediately if no command is received.
Otherwise, block until a command comes.
"""
while True:
# Find a new message from two places.
# 1. Check whether there is a message from the file corresponding to current progress.
current_filename = self._format_filename(self._remote_peer, self._read_progress // self._file_capacity)
content = self._receive_from_file(current_filename)
if content is not None:
return content
# 2. Check whether there is a message from the next file.
next_filename = self._format_filename(self._remote_peer, self._read_progress // self._file_capacity + 1)
content = self._receive_from_file(next_filename)
if content is not None:
return content
if non_blocking:
return None
self._logger.debug('Nothing received. Try again later.')
time.sleep(1.)
def _format_filename(self, peer_name: str, file_index: int) -> Path:
assert peer_name in [self._local_peer, self._remote_peer]
return self._url / f'{peer_name}.{file_index:08d}'
def _recover_write_state(self) -> None:
while True:
path = self._format_filename(self._local_peer, self._write_progress // self._file_capacity)
if path.exists():
# Regardless of whether it's full or not.
self._write_progress += self._file_capacity
else:
break
if self._write_progress > 1:
self._logger.info('Write progress is recovered to be: %d', self._write_progress)
def _recover_read_state(self) -> None:
path = self._url / f'{self._local_peer}.read'
if not path.exists():
self._logger.debug('Reading state does not exist. Nothing to recover.')
else:
try:
with path.open() as f:
self._read_progress = int(f.read())
self._logger.info('Read progress is recovered to be: %d', self._read_progress)
except:
self._logger.exception('Reading state appears to be corrupted: %s', path)
def _save_read_state(self) -> None:
try:
self._url.mkdir(exist_ok=True, parents=True)
with (self._url / f'{self._local_peer}.read').open('w') as f:
f.write(str(self._read_progress))
self._logger.debug('Read progress successfully updated: %d', self._read_progress)
except:
self._logger.exception('Reading state fails to dump: %d', self._read_progress)
def _receive_from_file(self, file: Path) -> str | None:
if not file.exists():
self._logger.debug('%s does not exist yet.', file)
return None
try:
with file.open() as f:
for line in f.readlines():
id, content = line.split('\t', 1) # pylint: disable=redefined-builtin
if int(id) > self._read_progress:
content = content.rstrip('\n')
self._logger.debug('Received command: %s', content)
self._read_progress = int(id)
self._save_read_state()
return content
except:
self._logger.exception('File appears to be corrupted: %s', file)
return None

Просмотреть файл

@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
from typing_extensions import Literal
import requests
import nni
from nni.typehint import ParameterRecord, TrialMetric
from nni.runtime.env_vars import trial_env_vars
from nni.runtime.trial_command_channel import TrialCommandChannel
from .trial_runner import TrialServerHandler
from .typehint import MetricCommand
_logger = logging.getLogger(__name__)
class TrialClient(TrialCommandChannel):
"""The client side of :class:`TrialServer`."""
def __init__(self, url: str | None = None, trial_id: str | None = None) -> None:
if url is not None:
self._url = url
else:
self._url = TrialServerHandler.ADDRESS
if trial_id is not None:
self._trial_id = trial_id
else:
self._trial_id = trial_env_vars.NNI_TRIAL_JOB_ID
def receive_parameter(self) -> ParameterRecord | None:
response = requests.get(TrialServerHandler.ADDRESS + '/parameter/' + self._trial_id)
if response.status_code != 200:
_logger.error('Failed to receive parameter: %s', response)
return None
parameter = response.json()['parameter']
if not parameter:
_logger.error('Received empty parameter: \'%s\'', parameter)
return None
if not isinstance(parameter, str):
_logger.error('Received invalid parameter: \'%s\'', parameter)
return None
return nni.load(parameter) # Unpack the parameter generated by tuner.
def send_metric(
self,
type: Literal['PERIODICAL', 'FINAL'], # pylint: disable=redefined-builtin
parameter_id: int | None,
trial_job_id: str,
sequence: int,
value: TrialMetric
) -> None:
metric = {
'parameter_id': parameter_id,
'trial_job_id': trial_job_id,
'type': type,
'sequence': sequence,
'value': nni.dump(value), # Pack the metric value, which will be unpacked by tuner.
}
command = MetricCommand(command_type='metric', id=trial_job_id, metric=nni.dump(metric))
response = requests.post(self._url + '/metric', json=command)
if response.status_code != 200:
_logger.error('Failed to send metric: %s', response)

Просмотреть файл

@ -0,0 +1,434 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import argparse
import json
import logging
import os
import subprocess
import sys
import threading
import time
from collections import deque
from functools import partial
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Sequence, Callable
from .file_channel import FileChannel
from .typehint import (
CreateCommand, KillCommand, MetricCommand, TrialStatusCommand, WakeUpCommand, ReportAwakeCommand,
Trial, Status, typed_dict_validation
)
from .utils import graceful_kill, add_handler
_logger = logging.getLogger('nni_amlt.trial_runner')
class TrialServerHandler(BaseHTTPRequestHandler):
"""A server for trial to get parameters and report metrics to the trial runner."""
PORT = 36378
ADDRESS = 'http://localhost:36378'
def __init__(self, trials: Sequence[Trial], on_metric: Callable[[MetricCommand], None], *args, **kwargs):
self.trials = trials
self.on_metric = on_metric
# Must be before super.init. The handler will start to handle requests within super.init.
super().__init__(*args, **kwargs)
def _set_headers(self):
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
def _send_bad_request(self):
self.send_response(400)
self.end_headers()
def _send_not_found(self):
self.send_response(404)
self.end_headers()
def _send_ok(self):
self.send_response(200)
self.end_headers()
def do_HEAD(self):
self._set_headers()
def do_GET(self):
"""GET request must be requesting parameters."""
if not self.path.startswith('/parameter/'):
_logger.error('Invalid path for HTTP GET: %s', self.path)
self._send_bad_request()
return
trial_id = self.path.split('/')[-1]
for trial in self.trials:
if trial['id'] == trial_id:
self._set_headers()
self.wfile.write(json.dumps(trial).encode())
return
_logger.error('Trial ID %s not found in parameters', trial_id)
self._send_not_found()
return
def do_POST(self):
"""POST request must be sending results."""
if self.path != '/metric':
_logger.error('Invalid path for HTTP POST: %s', self.path)
self._send_bad_request()
return
content_type = self.headers.get_content_type()
# refuse to receive non-json content
if content_type != 'application/json':
self._send_bad_request()
return
content_length = int(self.headers.get('content-length'))
message = json.loads(self.rfile.read(content_length))
if not typed_dict_validation(MetricCommand, message):
_logger.error('Invalid message: %s', message)
self._send_bad_request()
return
self.on_metric(message)
self._send_ok()
class TrialRunner:
"""
Runner to process incoming trial commands.
Parameters
----------
channel
Channel to communicate with the management end.
The runner only uses the "send" direction of the channel.
runner_dir
Directory for runner to save logs, save/restore checkpoints.
Usually **unshared** between multiple ranks (nodes).
trial_output_dir
Directory for trials to save their output files.
Subdirectory with trial IDs will be created inside.
Usually **shared** between multiple ranks.
trial_log_dir
Path to where trial log is stored.
Usually **unshared** between ranks.
log_buffer_size
Buffer size of trial stdout.
"""
def __init__(self, channel: FileChannel, runner_dir: Path,
trial_output_dir: Path, trial_log_dir: Path,
log_buffer_size: int) -> None:
self._channel = channel
self._runner_dir = runner_dir
self._trial_output_dir = trial_output_dir
self._trial_log_dir = trial_log_dir
self._log_buffer_size = log_buffer_size
self._processing_trials: deque[Trial] = deque() # including the current running one.
self._running_process: subprocess.Popen | None = None
if self._checkpoint_path.exists():
self.load_checkpoint()
self._server_thread = threading.Thread(target=self._server, daemon=True)
@property
def _checkpoint_path(self) -> Path:
return self._runner_dir / 'trial_runner.json'
def _server(self) -> None:
server_address = ('', TrialServerHandler.PORT)
httpd = HTTPServer(server_address, partial(TrialServerHandler, self._processing_trials, self._on_metric))
httpd.serve_forever()
def _on_metric(self, command: MetricCommand) -> None:
self._channel.send(json.dumps(command))
def load_checkpoint(self) -> None:
try:
with self._checkpoint_path.open() as f:
checkpoint_data = json.load(f)
self._processing_trials = deque()
for t in checkpoint_data['queued_trials']:
if typed_dict_validation(Trial, t):
self._processing_trials.append(t)
else:
_logger.error('Ignored when loading checkpoint as it appears not a valid trial: %s', t)
if isinstance(checkpoint_data['parameters'], dict):
self._parameters = checkpoint_data['parameters']
_logger.info('Checkpoint loaded. Processing trials: %s', self._processing_trials)
except:
_logger.exception('Checkpoint loaded failed: %s', self._checkpoint_path)
self._refresh_queue()
def save_checkpoint(self) -> None:
try:
checkpoint_data = {
'queued_trials': list(self._processing_trials),
'parameters': self._parameters
}
self._checkpoint_path.parent.mkdir(exist_ok=True, parents=True)
with self._checkpoint_path.open('w') as f:
json.dump(checkpoint_data, f)
except:
_logger.exception('Checkpoint saved failed: %s', self._checkpoint_path)
def check_status(self) -> list[Trial]:
"""
Check the status of the runner and return processing trials (including running + pending).
The caller should be responsible for :meth:`check_status` regularly.
Otherwise the trials in the queue won't be auto-processed.
"""
# Check the status of current running trial.
self._refresh_queue()
# List running and pending trials.
return list(self._processing_trials)
def create_trial(self, trial: Trial) -> None:
"""Submit a trial for running.
Returns instantly.
"""
self._processing_trials.append(trial)
self._refresh_queue()
def kill_trial(self, id: str) -> None: # pylint: disable=redefined-builtin
"""Kill a trial.
Currently must be the running trial.
"""
if len(self._processing_trials) > 0 and self._running_process is not None:
trial = self._processing_trials[0]
if trial['id'] == id:
graceful_kill(self._running_process)
returncode = self._running_process.returncode
_logger.info('Process %s is killed with exit code: %s', self._running_process, returncode)
self._processing_trials.popleft()
self._emit_status_change(trial['id'], 'interrupted')
self._running_process = None
# Run the next trial if any.
self._refresh_queue()
return
_logger.warning('Trial %s is not running. Cannot kill it.', id)
def send_heartbeat(self) -> float:
"""Send a heartbeat to the other side."""
current_time = time.time()
command = ReportAwakeCommand(
command_type='awake',
time=current_time,
idle=not self._processing_trials
)
self._channel.send(json.dumps(command))
return current_time
def _refresh_queue(self) -> None:
if not self._processing_trials:
_logger.debug('No trials. Nothing to refresh.')
return
# Move the queue. See if the upfront trial is completed,
# and whether the next trial should be run.
if self._running_process is not None:
if self._running_process.poll() is not None:
returncode = self._running_process.returncode
_logger.info('Process %s return with exit code: %s', self._running_process, returncode)
if returncode == 0:
status: Status = 'succeeded'
else:
status: Status = 'failed'
trial = self._processing_trials.popleft()
self._emit_status_change(trial['id'], status)
self._running_process = None
# Run a new trial.
if len(self._processing_trials) > 0 and self._running_process is None:
trial = self._processing_trials[0]
_logger.info('Running: %s', trial['command'])
self._running_process = subprocess.Popen(
trial['command'],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=self._log_buffer_size,
shell=True
)
self._start_stdout_logging(self._trial_log_dir / (trial['id'] + '.txt'))
self._emit_status_change(trial['id'], 'running')
def _environ(self, trial: Trial) -> dict:
"""Generate environment variables for a trial."""
environ_base = dict(os.environ)
output_dir = str(self._trial_output_dir / trial['id'])
nni_environ = dict(
NNI_PLATFORM='amlt',
NNI_EXP_ID=trial['experiment'],
NNI_TRIAL_JOB_ID=trial['id'],
NNI_SYS_DIR=output_dir,
NNI_OUTPUT_DIR=output_dir,
NNI_TRIAL_SEQ_ID=trial['sequence'],
)
return {
**environ_base,
**nni_environ
}
def _start_stdout_logging(self, file: Path) -> None:
if self._running_process is None:
_logger.error('No running process to start logging.')
return
def _tee(infile, log_file: Path) -> None:
log_file.parent.mkdir(exist_ok=True, parents=True)
with infile, open(log_file, 'ab') as f:
for line in iter(infile.readline, b''):
f.write(line)
sys.stdout.buffer.write(line)
# Did not flush here.
file.parent.mkdir(exist_ok=True, parents=True)
t = threading.Thread(target=_tee, args=(self._running_process.stdout, file), daemon=True)
t.start()
def _emit_status_change(self, trial_id: str, status: Status) -> None:
command = TrialStatusCommand(
command_type='status',
id=trial_id,
status=status,
)
_logger.debug('Emit status change: %s', command)
self._channel.send(json.dumps(command))
def trial_runner_loop(
channel: str | Path,
out: str | Path,
rank: int,
interval: float,
patience: float,
log_buffer_size: int
) -> None:
output_dir = Path(out)
runner_dir = output_dir / f'trial_runner_{rank}'
trial_log_dir = output_dir / f'logs_{rank}'
runner_dir.mkdir(exist_ok=True, parents=True)
# Init logger if not inited.
add_handler(_logger, runner_dir / f'trial_runner.log')
_logger.info('Trial runner started.')
_logger.info('Saving trial runner states to: %s', runner_dir)
file_channel = FileChannel(channel, f'worker-{rank}', 'manager')
_logger.info('Using channel %s to communicate with NNI manager', file_channel)
log_buffer_size = log_buffer_size
_logger.info('Buffer size for trial stodut: %d', log_buffer_size)
trial_runner = TrialRunner(file_channel, runner_dir, output_dir, trial_log_dir, log_buffer_size)
last_good = time.time()
last_heartbeat = time.time()
heartbeat_interval = interval
trial_runner.send_heartbeat()
while True:
if trial_runner.check_status():
_logger.info('Trial runner has running trials. Be patient.')
last_good = time.time()
trial_runner.save_checkpoint()
# Receive a control command from manager side.
command = file_channel.receive(non_blocking=True)
if command is not None:
try:
command = json.loads(command)
except:
_logger.exception('Command decode error. Skip this command: %s', command)
command = None
if command is not None:
if not isinstance(command, dict) or 'command_type' not in command:
_logger.error('Invalid command: %s', command)
else:
command_type = command['command_type']
if command_type == 'create' and typed_dict_validation(CreateCommand, command):
trial_runner.create_trial(command['trial'])
elif command_type == 'kill' and typed_dict_validation(KillCommand, command):
trial_runner.kill_trial(command['id'])
elif command_type == 'wakeup' and typed_dict_validation(WakeUpCommand, command):
last_heartbeat = trial_runner.send_heartbeat()
else:
_logger.error('Unsupported command: %s', command)
trial_runner.save_checkpoint()
# Reset heartbeat interval to communicate more frequently
heartbeat_interval = interval
# No sleep. Continue to next command.
else:
elapsed = time.time() - last_good
_logger.info('No command received. Patience: %f / %f', elapsed, patience)
if elapsed > patience:
_logger.warning('No command received for too long. Quit the runner.')
break
if time.time() - last_heartbeat > heartbeat_interval:
last_heartbeat = trial_runner.send_heartbeat()
# Exponentially increase heartbeat interval
heartbeat_interval = heartbeat_interval * 1.5
time.sleep(interval)
def main():
parser = argparse.ArgumentParser(description='Amulet training service trial runner')
parser.add_argument('channel', type=str, help='The path where file channel is mounted (in cluster container)')
parser.add_argument('out', type=str, default=None,
help='Checkpoint directory of the trial runner. If specified, trial runner will try to find its checkpoint.')
parser.add_argument('--rank', type=int, default=None,
help='Rank of trial runner. Meaningful for distributed training. '
'If not set, will try to read from environment variable `RANK`.')
parser.add_argument('--interval', type=float, default=60.,
help='Interval (seconds) between two polls of the channel')
parser.add_argument('--heartbeat-max-interval', type=float, default=600.,
help='Max interval (seconds) between two heartbeats. '
'Heartbeat is used to tell the manager that the runner is still alive. '
'The initial heartbeat interval is `interval`. '
'It will be exponentially increased until it reaches this value if no message from manager is received.')
parser.add_argument('--patience', type=float, default=1800.,
help='Number of seconds without any updates or running trials before the runner shutdowns')
parser.add_argument('--log-buffer-size', type=int, default=0,
help='Buffer size for trial stdout. See bufsize in `subprocess.Popen`.')
args = parser.parse_args()
if args.rank is None:
args.rank = int(os.environ.get('RANK', 0))
trial_runner_loop(**vars(args))
if __name__ == '__main__':
main()

Просмотреть файл

@ -0,0 +1,101 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import TypedDict, Any, Type, TypeVar, Optional, get_origin
from typing_extensions import Literal, TypeGuard
_logger = logging.getLogger(__name__)
T = TypeVar('T')
def typed_dict_validation(typ: Type[T], instance: Any) -> TypeGuard[T]:
# https://stackoverflow.com/questions/66665336/validate-python-typeddict-at-runtime
if not isinstance(instance, dict):
_logger.error('Validation failed for %s. Instance is not a dict: %s', typ, type(instance))
return False
for property_name, property_type in typ.__annotations__.items():
if property_name not in instance:
# Check for missing keys
_logger.error('Validation failed for %s. Missing key: %s', typ, property_name)
return False
value = instance[property_name]
if property_type in (int, float, bool, str):
# Check for type equality
if not isinstance(value, property_type):
_logger.error('Validation failed for %s. Wrong type: %s. Expected %s, got %s',
typ, property_name, property_type, type(value))
return False
elif get_origin(property_type) == Literal:
# Check literal.
if value not in property_type.__args__:
_logger.error('Validation failed for %s. Expect literal to be one of %s, got %s',
typ, property_type.__args__, value)
return False
else:
# Assuming a nested typed dict.
result = typed_dict_validation(property_type, value)
if result is False:
return False
return True
class Trial(TypedDict):
id: str
sequence: int
experiment: str
command: str
parameter: Optional[str] # Serialized JSON string.
# time_limit: float
# Command types are as few as possible.
# The implementation also tries to avoid unnecessary dependencies,
# to increase the robustness.
UpstreamCommandType = Literal['create', 'kill', 'wakeup'] # manager -> worker
DownstreamCommandType = Literal['metric', 'status', 'awake'] # worker -> manager
Status = Literal['waiting', 'running', 'succeeded', 'failed', 'interrupted']
class CreateCommand(TypedDict):
command_type: Literal['create']
trial: Trial
class KillCommand(TypedDict):
command_type: Literal['kill']
id: str
class MetricCommand(TypedDict):
command_type: Literal['metric']
id: str
metric: str # Serialized JSON string.
class TrialStatusCommand(TypedDict):
command_type: Literal['status']
id: str
status: Status
class WakeUpCommand(TypedDict):
# Request the worker to report its status (more frequently).
command_type: Literal['wakeup']
class ReportAwakeCommand(TypedDict):
# The only way to report that the worker is alive (and idle or occupied).
command_type: Literal['awake']
time: float
# NOTE: time here is only for verbose.
# It should be avoided from usage because the cluster might have a different time from local.
idle: bool

Просмотреть файл

@ -0,0 +1,138 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import signal
import sys
import time
from pathlib import Path
from subprocess import Popen, PIPE
_logger = logging.getLogger(__name__)
_logger_init: bool = False
def graceful_kill(popen: Popen) -> int | None:
for retry in [1., 5., 20., 60.]:
_logger.info('Gracefully terminating %s...', popen)
if retry > 10:
_logger.info('Use "terminate" instead of "interrupt".')
popen.terminate()
else:
popen.send_signal(signal.SIGINT)
time.sleep(1.) # Wait for the kill to take effect.
retcode = popen.poll()
if retcode is not None:
return retcode
_logger.warning('%s still alive. Retry to kill in %d seconds.', popen, retry)
time.sleep(retry)
_logger.warning('Force kill process %s...', popen)
time.sleep(10.) # Wait for the kill
retcode = popen.poll()
if retcode is not None:
return retcode
_logger.error('Failed to kill process %s.', popen)
return None
def run_subprocess(command: list[str], log_file: Path, timeout: float | None = None) -> tuple[int, str, str]:
if timeout:
_logger.info('Running command with timeout %f seconds: %s', timeout, command)
else:
_logger.info('Running command: %s', command)
_logger.info('Output saved to: %s', log_file)
stdout, stderr = '', ''
file_handle = None
try:
start_time = time.time()
file_handle = log_file.open('w')
file_handle.write(f'Command: {command}')
proc = Popen(
command,
stdout=PIPE,
stderr=PIPE,
encoding='utf-8',
)
while True:
out, err = proc.communicate(timeout=1)
if out:
sys.stdout.write(out)
sys.stdout.flush()
file_handle.write(out)
stdout += out
if err:
sys.stderr.write(err)
sys.stderr.flush()
file_handle.write(err)
stderr += err
file_handle.flush()
# See if the process has terminated
if proc.poll() is not None:
returncode = proc.returncode
if returncode != 0:
_logger.error('Command failed with return code %d: %s', returncode, command)
else:
_logger.info('Command finished with return code %d: %s', returncode, command)
return returncode, stdout, stderr
# See if we timed out
if timeout is not None and time.time() - start_time > timeout:
_logger.warning('Command timed out (%f seconds): %s', timeout, command)
returncode = graceful_kill(proc)
if returncode is None:
_logger.error('Return code is still none after attempting to kill it. The process (%d) may be stuck.', proc.pid)
returncode = 1
return returncode, stdout, stderr
finally:
if file_handle is not None:
file_handle.close()
def init_logger() -> None:
"""
Initialize the logger. Log to stdout by default.
"""
global _logger_init
if _logger_init:
return
logger = logging.getLogger('nni_amlt')
logger.setLevel(level=logging.INFO)
add_handler(logger)
_logger_init = True
def add_handler(logger: logging.Logger, file: Path | None = None) -> logging.Handler:
"""
Add a logging handler.
If ``file`` is specified, log to file.
Otherwise, add a handler to stdout.
"""
fmt = '[%(asctime)s] %(levelname)s (%(threadName)s:%(name)s) %(message)s'
datefmt = '%Y-%m-%d %H:%M:%S'
formatter = logging.Formatter(fmt, datefmt)
if file is None:
# Log to stdout.
handler = logging.StreamHandler(sys.stdout)
else:
handler = logging.FileHandler(file)
handler.setLevel(level=logging.DEBUG) # Print all the logs.
handler.setFormatter(formatter)
logger.addHandler(handler)
return handler

Просмотреть файл

@ -44,7 +44,15 @@ def set_default_trial_command_channel(channel: Optional[TrialCommandChannel] = N
assert dispatcher_env_vars.SDK_PROCESS != 'dispatcher'
channel_url = trial_env_vars.NNI_TRIAL_COMMAND_CHANNEL
if channel_url:
if isinstance(channel_url, str) and channel_url.startswith('import://'):
_, channel_class_name = channel_url.split('://', 1)
module_name, class_name = channel_class_name.rsplit('.', 1)
module = __import__(module_name)
channel_class = getattr(module, class_name)
_channel = channel_class()
if not isinstance(_channel, TrialCommandChannel):
raise TypeError(f'{_channel} is not an instance of TrialCommandChannel')
elif channel_url:
from .v3 import TrialCommandChannelV3
_channel = TrialCommandChannelV3(channel_url)
elif trial_env_vars.NNI_PLATFORM is None or trial_env_vars.NNI_PLATFORM == 'unittest':