Bug fix for trial runner and etc. (#5422)

This commit is contained in:
Yuge Zhang 2023-03-06 15:25:51 +08:00 коммит произвёл GitHub
Родитель 33d41faac4
Коммит 1777368ad8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 40 добавлений и 23 удалений

Просмотреть файл

@ -33,13 +33,13 @@ class TrialClient(TrialCommandChannel):
self._trial_id = trial_env_vars.NNI_TRIAL_JOB_ID
def receive_parameter(self) -> ParameterRecord | None:
response = requests.get(TrialServerHandler.ADDRESS + '/parameter/' + self._trial_id)
response = requests.get(self._url + '/parameter/' + self._trial_id)
if response.status_code != 200:
_logger.error('Failed to receive parameter: %s', response)
return None
parameter = response.json()['parameter']
if not parameter:
_logger.error('Received empty parameter: \'%s\'', parameter)
_logger.warning('Received empty parameter: \'%s\'', parameter)
return None
if not isinstance(parameter, str):
_logger.error('Received invalid parameter: \'%s\'', parameter)
@ -54,6 +54,9 @@ class TrialClient(TrialCommandChannel):
sequence: int,
value: TrialMetric
) -> None:
if trial_job_id != self._trial_id:
_logger.warning('Trial job id does not match: %s vs. %s. Metric will be ignored.', trial_job_id, self._trial_id)
return
metric = {
'parameter_id': parameter_id,
'trial_job_id': trial_job_id,

Просмотреть файл

@ -3,6 +3,7 @@
from __future__ import annotations
import atexit
import argparse
import json
import logging
@ -141,16 +142,29 @@ class TrialRunner:
if self._checkpoint_path.exists():
self.load_checkpoint()
self._server_thread = threading.Thread(target=self._server, daemon=True)
self._server = self._server_start()
atexit.register(self._server_stop)
@property
def _checkpoint_path(self) -> Path:
return self._runner_dir / 'trial_runner.json'
def _server(self) -> None:
def _server_start(self) -> HTTPServer:
_logger.info('Starting trial server at %s.', TrialServerHandler.ADDRESS)
atexit.register(self._server_stop)
server_address = ('', TrialServerHandler.PORT)
httpd = HTTPServer(server_address, partial(TrialServerHandler, self._processing_trials, self._on_metric))
httpd.serve_forever()
def _start() -> None:
httpd.serve_forever()
threading.Thread(target=_start, daemon=True).start()
return httpd
def _server_stop(self) -> None:
_logger.info('Stopping trial server.')
atexit.unregister(self._server_stop)
self._server.shutdown()
def _on_metric(self, command: MetricCommand) -> None:
self._channel.send(json.dumps(command))
@ -165,11 +179,9 @@ class TrialRunner:
self._processing_trials.append(t)
else:
_logger.error('Ignored when loading checkpoint as it appears not a valid trial: %s', t)
if isinstance(checkpoint_data['parameters'], dict):
self._parameters = checkpoint_data['parameters']
_logger.info('Checkpoint loaded. Processing trials: %s', self._processing_trials)
except:
_logger.exception('Checkpoint loaded failed: %s', self._checkpoint_path)
_logger.exception('Checkpoint loading failed: %s', self._checkpoint_path)
self._refresh_queue()
@ -177,13 +189,12 @@ class TrialRunner:
try:
checkpoint_data = {
'queued_trials': list(self._processing_trials),
'parameters': self._parameters
}
self._checkpoint_path.parent.mkdir(exist_ok=True, parents=True)
with self._checkpoint_path.open('w') as f:
json.dump(checkpoint_data, f)
except:
_logger.exception('Checkpoint saved failed: %s', self._checkpoint_path)
_logger.exception('Checkpoint saving failed: %s', self._checkpoint_path)
def check_status(self) -> list[Trial]:
"""
@ -253,6 +264,7 @@ class TrialRunner:
else:
status: Status = 'failed'
trial = self._processing_trials.popleft()
_logger.info('Trial %s ended with status: %s', trial['id'], status)
self._emit_status_change(trial['id'], status)
self._running_process = None
@ -265,7 +277,8 @@ class TrialRunner:
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=self._log_buffer_size,
shell=True
shell=True,
env=self._environ(trial)
)
self._start_stdout_logging(self._trial_log_dir / (trial['id'] + '.txt'))
self._emit_status_change(trial['id'], 'running')
@ -279,9 +292,10 @@ class TrialRunner:
NNI_PLATFORM='amlt',
NNI_EXP_ID=trial['experiment'],
NNI_TRIAL_JOB_ID=trial['id'],
NNI_SYS_DIR=output_dir,
NNI_OUTPUT_DIR=output_dir,
NNI_TRIAL_SEQ_ID=trial['sequence'],
NNI_SYS_DIR=str(output_dir),
NNI_OUTPUT_DIR=str(output_dir),
NNI_TRIAL_SEQ_ID=str(trial['sequence']),
NNI_TRIAL_COMMAND_CHANNEL='import://nni_amlt.trial_client.TrialClient'
)
return {
@ -337,10 +351,10 @@ def trial_runner_loop(
_logger.info('Saving trial runner states to: %s', runner_dir)
file_channel = FileChannel(channel, f'worker-{rank}', 'manager')
_logger.info('Using channel %s to communicate with NNI manager', file_channel)
_logger.info('Using channel %s to communicate with NNI manager.', file_channel)
log_buffer_size = log_buffer_size
_logger.info('Buffer size for trial stodut: %d', log_buffer_size)
_logger.info('Buffer size for trial stdout: %d', log_buffer_size)
trial_runner = TrialRunner(file_channel, runner_dir, output_dir, trial_log_dir, log_buffer_size)
@ -390,7 +404,7 @@ def trial_runner_loop(
else:
elapsed = time.time() - last_good
_logger.info('No command received. Patience: %f / %f', elapsed, patience)
_logger.info('No command received. Since last receiving: %f seconds (%f maximum).', elapsed, patience)
if elapsed > patience:
_logger.warning('No command received for too long. Quit the runner.')

Просмотреть файл

@ -3,7 +3,7 @@
import logging
from typing import TypedDict, Any, Type, TypeVar, Optional, get_origin
from typing import TypedDict, Any, Type, TypeVar, get_origin
from typing_extensions import Literal, TypeGuard
_logger = logging.getLogger(__name__)
@ -52,7 +52,7 @@ class Trial(TypedDict):
sequence: int
experiment: str
command: str
parameter: Optional[str] # Serialized JSON string.
parameter: str # Serialized JSON string. If empty, the trial will receive no parameter.
# time_limit: float
# Command types are as few as possible.

Просмотреть файл

@ -46,10 +46,10 @@ def set_default_trial_command_channel(channel: Optional[TrialCommandChannel] = N
channel_url = trial_env_vars.NNI_TRIAL_COMMAND_CHANNEL
if isinstance(channel_url, str) and channel_url.startswith('import://'):
_, channel_class_name = channel_url.split('://', 1)
module_name, class_name = channel_class_name.rsplit('.', 1)
module = __import__(module_name)
channel_class = getattr(module, class_name)
_channel = channel_class()
path, identifier = channel_class_name.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier])
class_ = getattr(module, identifier)
_channel = class_()
if not isinstance(_channel, TrialCommandChannel):
raise TypeError(f'{_channel} is not an instance of TrialCommandChannel')
elif channel_url: