From 1777368ad803c26157eb8e010c20a2aa0103d1d6 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Mon, 6 Mar 2023 15:25:51 +0800 Subject: [PATCH] Bug fix for trial runner and etc. (#5422) --- nni/contrib/training_service/trial_client.py | 7 ++- nni/contrib/training_service/trial_runner.py | 44 ++++++++++++------- nni/contrib/training_service/typehint.py | 4 +- nni/runtime/trial_command_channel/__init__.py | 8 ++-- 4 files changed, 40 insertions(+), 23 deletions(-) diff --git a/nni/contrib/training_service/trial_client.py b/nni/contrib/training_service/trial_client.py index 77abe660e..e5e611cec 100644 --- a/nni/contrib/training_service/trial_client.py +++ b/nni/contrib/training_service/trial_client.py @@ -33,13 +33,13 @@ class TrialClient(TrialCommandChannel): self._trial_id = trial_env_vars.NNI_TRIAL_JOB_ID def receive_parameter(self) -> ParameterRecord | None: - response = requests.get(TrialServerHandler.ADDRESS + '/parameter/' + self._trial_id) + response = requests.get(self._url + '/parameter/' + self._trial_id) if response.status_code != 200: _logger.error('Failed to receive parameter: %s', response) return None parameter = response.json()['parameter'] if not parameter: - _logger.error('Received empty parameter: \'%s\'', parameter) + _logger.warning('Received empty parameter: \'%s\'', parameter) return None if not isinstance(parameter, str): _logger.error('Received invalid parameter: \'%s\'', parameter) @@ -54,6 +54,9 @@ class TrialClient(TrialCommandChannel): sequence: int, value: TrialMetric ) -> None: + if trial_job_id != self._trial_id: + _logger.warning('Trial job id does not match: %s vs. %s. Metric will be ignored.', trial_job_id, self._trial_id) + return metric = { 'parameter_id': parameter_id, 'trial_job_id': trial_job_id, diff --git a/nni/contrib/training_service/trial_runner.py b/nni/contrib/training_service/trial_runner.py index f7d09be66..54a1bc4ee 100644 --- a/nni/contrib/training_service/trial_runner.py +++ b/nni/contrib/training_service/trial_runner.py @@ -3,6 +3,7 @@ from __future__ import annotations +import atexit import argparse import json import logging @@ -141,16 +142,29 @@ class TrialRunner: if self._checkpoint_path.exists(): self.load_checkpoint() - self._server_thread = threading.Thread(target=self._server, daemon=True) + self._server = self._server_start() + atexit.register(self._server_stop) @property def _checkpoint_path(self) -> Path: return self._runner_dir / 'trial_runner.json' - def _server(self) -> None: + def _server_start(self) -> HTTPServer: + _logger.info('Starting trial server at %s.', TrialServerHandler.ADDRESS) + atexit.register(self._server_stop) server_address = ('', TrialServerHandler.PORT) httpd = HTTPServer(server_address, partial(TrialServerHandler, self._processing_trials, self._on_metric)) - httpd.serve_forever() + + def _start() -> None: + httpd.serve_forever() + + threading.Thread(target=_start, daemon=True).start() + return httpd + + def _server_stop(self) -> None: + _logger.info('Stopping trial server.') + atexit.unregister(self._server_stop) + self._server.shutdown() def _on_metric(self, command: MetricCommand) -> None: self._channel.send(json.dumps(command)) @@ -165,11 +179,9 @@ class TrialRunner: self._processing_trials.append(t) else: _logger.error('Ignored when loading checkpoint as it appears not a valid trial: %s', t) - if isinstance(checkpoint_data['parameters'], dict): - self._parameters = checkpoint_data['parameters'] _logger.info('Checkpoint loaded. Processing trials: %s', self._processing_trials) except: - _logger.exception('Checkpoint loaded failed: %s', self._checkpoint_path) + _logger.exception('Checkpoint loading failed: %s', self._checkpoint_path) self._refresh_queue() @@ -177,13 +189,12 @@ class TrialRunner: try: checkpoint_data = { 'queued_trials': list(self._processing_trials), - 'parameters': self._parameters } self._checkpoint_path.parent.mkdir(exist_ok=True, parents=True) with self._checkpoint_path.open('w') as f: json.dump(checkpoint_data, f) except: - _logger.exception('Checkpoint saved failed: %s', self._checkpoint_path) + _logger.exception('Checkpoint saving failed: %s', self._checkpoint_path) def check_status(self) -> list[Trial]: """ @@ -253,6 +264,7 @@ class TrialRunner: else: status: Status = 'failed' trial = self._processing_trials.popleft() + _logger.info('Trial %s ended with status: %s', trial['id'], status) self._emit_status_change(trial['id'], status) self._running_process = None @@ -265,7 +277,8 @@ class TrialRunner: stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=self._log_buffer_size, - shell=True + shell=True, + env=self._environ(trial) ) self._start_stdout_logging(self._trial_log_dir / (trial['id'] + '.txt')) self._emit_status_change(trial['id'], 'running') @@ -279,9 +292,10 @@ class TrialRunner: NNI_PLATFORM='amlt', NNI_EXP_ID=trial['experiment'], NNI_TRIAL_JOB_ID=trial['id'], - NNI_SYS_DIR=output_dir, - NNI_OUTPUT_DIR=output_dir, - NNI_TRIAL_SEQ_ID=trial['sequence'], + NNI_SYS_DIR=str(output_dir), + NNI_OUTPUT_DIR=str(output_dir), + NNI_TRIAL_SEQ_ID=str(trial['sequence']), + NNI_TRIAL_COMMAND_CHANNEL='import://nni_amlt.trial_client.TrialClient' ) return { @@ -337,10 +351,10 @@ def trial_runner_loop( _logger.info('Saving trial runner states to: %s', runner_dir) file_channel = FileChannel(channel, f'worker-{rank}', 'manager') - _logger.info('Using channel %s to communicate with NNI manager', file_channel) + _logger.info('Using channel %s to communicate with NNI manager.', file_channel) log_buffer_size = log_buffer_size - _logger.info('Buffer size for trial stodut: %d', log_buffer_size) + _logger.info('Buffer size for trial stdout: %d', log_buffer_size) trial_runner = TrialRunner(file_channel, runner_dir, output_dir, trial_log_dir, log_buffer_size) @@ -390,7 +404,7 @@ def trial_runner_loop( else: elapsed = time.time() - last_good - _logger.info('No command received. Patience: %f / %f', elapsed, patience) + _logger.info('No command received. Since last receiving: %f seconds (%f maximum).', elapsed, patience) if elapsed > patience: _logger.warning('No command received for too long. Quit the runner.') diff --git a/nni/contrib/training_service/typehint.py b/nni/contrib/training_service/typehint.py index d0f3512da..b35d2aedb 100644 --- a/nni/contrib/training_service/typehint.py +++ b/nni/contrib/training_service/typehint.py @@ -3,7 +3,7 @@ import logging -from typing import TypedDict, Any, Type, TypeVar, Optional, get_origin +from typing import TypedDict, Any, Type, TypeVar, get_origin from typing_extensions import Literal, TypeGuard _logger = logging.getLogger(__name__) @@ -52,7 +52,7 @@ class Trial(TypedDict): sequence: int experiment: str command: str - parameter: Optional[str] # Serialized JSON string. + parameter: str # Serialized JSON string. If empty, the trial will receive no parameter. # time_limit: float # Command types are as few as possible. diff --git a/nni/runtime/trial_command_channel/__init__.py b/nni/runtime/trial_command_channel/__init__.py index 97b84027b..936b5141a 100644 --- a/nni/runtime/trial_command_channel/__init__.py +++ b/nni/runtime/trial_command_channel/__init__.py @@ -46,10 +46,10 @@ def set_default_trial_command_channel(channel: Optional[TrialCommandChannel] = N channel_url = trial_env_vars.NNI_TRIAL_COMMAND_CHANNEL if isinstance(channel_url, str) and channel_url.startswith('import://'): _, channel_class_name = channel_url.split('://', 1) - module_name, class_name = channel_class_name.rsplit('.', 1) - module = __import__(module_name) - channel_class = getattr(module, class_name) - _channel = channel_class() + path, identifier = channel_class_name.rsplit('.', 1) + module = __import__(path, globals(), locals(), [identifier]) + class_ = getattr(module, identifier) + _channel = class_() if not isinstance(_channel, TrialCommandChannel): raise TypeError(f'{_channel} is not an instance of TrialCommandChannel') elif channel_url: