зеркало из https://github.com/microsoft/nni.git
Bug fix for trial runner and etc. (#5422)
This commit is contained in:
Родитель
33d41faac4
Коммит
1777368ad8
|
@ -33,13 +33,13 @@ class TrialClient(TrialCommandChannel):
|
|||
self._trial_id = trial_env_vars.NNI_TRIAL_JOB_ID
|
||||
|
||||
def receive_parameter(self) -> ParameterRecord | None:
|
||||
response = requests.get(TrialServerHandler.ADDRESS + '/parameter/' + self._trial_id)
|
||||
response = requests.get(self._url + '/parameter/' + self._trial_id)
|
||||
if response.status_code != 200:
|
||||
_logger.error('Failed to receive parameter: %s', response)
|
||||
return None
|
||||
parameter = response.json()['parameter']
|
||||
if not parameter:
|
||||
_logger.error('Received empty parameter: \'%s\'', parameter)
|
||||
_logger.warning('Received empty parameter: \'%s\'', parameter)
|
||||
return None
|
||||
if not isinstance(parameter, str):
|
||||
_logger.error('Received invalid parameter: \'%s\'', parameter)
|
||||
|
@ -54,6 +54,9 @@ class TrialClient(TrialCommandChannel):
|
|||
sequence: int,
|
||||
value: TrialMetric
|
||||
) -> None:
|
||||
if trial_job_id != self._trial_id:
|
||||
_logger.warning('Trial job id does not match: %s vs. %s. Metric will be ignored.', trial_job_id, self._trial_id)
|
||||
return
|
||||
metric = {
|
||||
'parameter_id': parameter_id,
|
||||
'trial_job_id': trial_job_id,
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
|
@ -141,16 +142,29 @@ class TrialRunner:
|
|||
if self._checkpoint_path.exists():
|
||||
self.load_checkpoint()
|
||||
|
||||
self._server_thread = threading.Thread(target=self._server, daemon=True)
|
||||
self._server = self._server_start()
|
||||
atexit.register(self._server_stop)
|
||||
|
||||
@property
|
||||
def _checkpoint_path(self) -> Path:
|
||||
return self._runner_dir / 'trial_runner.json'
|
||||
|
||||
def _server(self) -> None:
|
||||
def _server_start(self) -> HTTPServer:
|
||||
_logger.info('Starting trial server at %s.', TrialServerHandler.ADDRESS)
|
||||
atexit.register(self._server_stop)
|
||||
server_address = ('', TrialServerHandler.PORT)
|
||||
httpd = HTTPServer(server_address, partial(TrialServerHandler, self._processing_trials, self._on_metric))
|
||||
httpd.serve_forever()
|
||||
|
||||
def _start() -> None:
|
||||
httpd.serve_forever()
|
||||
|
||||
threading.Thread(target=_start, daemon=True).start()
|
||||
return httpd
|
||||
|
||||
def _server_stop(self) -> None:
|
||||
_logger.info('Stopping trial server.')
|
||||
atexit.unregister(self._server_stop)
|
||||
self._server.shutdown()
|
||||
|
||||
def _on_metric(self, command: MetricCommand) -> None:
|
||||
self._channel.send(json.dumps(command))
|
||||
|
@ -165,11 +179,9 @@ class TrialRunner:
|
|||
self._processing_trials.append(t)
|
||||
else:
|
||||
_logger.error('Ignored when loading checkpoint as it appears not a valid trial: %s', t)
|
||||
if isinstance(checkpoint_data['parameters'], dict):
|
||||
self._parameters = checkpoint_data['parameters']
|
||||
_logger.info('Checkpoint loaded. Processing trials: %s', self._processing_trials)
|
||||
except:
|
||||
_logger.exception('Checkpoint loaded failed: %s', self._checkpoint_path)
|
||||
_logger.exception('Checkpoint loading failed: %s', self._checkpoint_path)
|
||||
|
||||
self._refresh_queue()
|
||||
|
||||
|
@ -177,13 +189,12 @@ class TrialRunner:
|
|||
try:
|
||||
checkpoint_data = {
|
||||
'queued_trials': list(self._processing_trials),
|
||||
'parameters': self._parameters
|
||||
}
|
||||
self._checkpoint_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with self._checkpoint_path.open('w') as f:
|
||||
json.dump(checkpoint_data, f)
|
||||
except:
|
||||
_logger.exception('Checkpoint saved failed: %s', self._checkpoint_path)
|
||||
_logger.exception('Checkpoint saving failed: %s', self._checkpoint_path)
|
||||
|
||||
def check_status(self) -> list[Trial]:
|
||||
"""
|
||||
|
@ -253,6 +264,7 @@ class TrialRunner:
|
|||
else:
|
||||
status: Status = 'failed'
|
||||
trial = self._processing_trials.popleft()
|
||||
_logger.info('Trial %s ended with status: %s', trial['id'], status)
|
||||
self._emit_status_change(trial['id'], status)
|
||||
self._running_process = None
|
||||
|
||||
|
@ -265,7 +277,8 @@ class TrialRunner:
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
bufsize=self._log_buffer_size,
|
||||
shell=True
|
||||
shell=True,
|
||||
env=self._environ(trial)
|
||||
)
|
||||
self._start_stdout_logging(self._trial_log_dir / (trial['id'] + '.txt'))
|
||||
self._emit_status_change(trial['id'], 'running')
|
||||
|
@ -279,9 +292,10 @@ class TrialRunner:
|
|||
NNI_PLATFORM='amlt',
|
||||
NNI_EXP_ID=trial['experiment'],
|
||||
NNI_TRIAL_JOB_ID=trial['id'],
|
||||
NNI_SYS_DIR=output_dir,
|
||||
NNI_OUTPUT_DIR=output_dir,
|
||||
NNI_TRIAL_SEQ_ID=trial['sequence'],
|
||||
NNI_SYS_DIR=str(output_dir),
|
||||
NNI_OUTPUT_DIR=str(output_dir),
|
||||
NNI_TRIAL_SEQ_ID=str(trial['sequence']),
|
||||
NNI_TRIAL_COMMAND_CHANNEL='import://nni_amlt.trial_client.TrialClient'
|
||||
)
|
||||
|
||||
return {
|
||||
|
@ -337,10 +351,10 @@ def trial_runner_loop(
|
|||
_logger.info('Saving trial runner states to: %s', runner_dir)
|
||||
|
||||
file_channel = FileChannel(channel, f'worker-{rank}', 'manager')
|
||||
_logger.info('Using channel %s to communicate with NNI manager', file_channel)
|
||||
_logger.info('Using channel %s to communicate with NNI manager.', file_channel)
|
||||
|
||||
log_buffer_size = log_buffer_size
|
||||
_logger.info('Buffer size for trial stodut: %d', log_buffer_size)
|
||||
_logger.info('Buffer size for trial stdout: %d', log_buffer_size)
|
||||
|
||||
trial_runner = TrialRunner(file_channel, runner_dir, output_dir, trial_log_dir, log_buffer_size)
|
||||
|
||||
|
@ -390,7 +404,7 @@ def trial_runner_loop(
|
|||
|
||||
else:
|
||||
elapsed = time.time() - last_good
|
||||
_logger.info('No command received. Patience: %f / %f', elapsed, patience)
|
||||
_logger.info('No command received. Since last receiving: %f seconds (%f maximum).', elapsed, patience)
|
||||
|
||||
if elapsed > patience:
|
||||
_logger.warning('No command received for too long. Quit the runner.')
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import logging
|
||||
|
||||
from typing import TypedDict, Any, Type, TypeVar, Optional, get_origin
|
||||
from typing import TypedDict, Any, Type, TypeVar, get_origin
|
||||
from typing_extensions import Literal, TypeGuard
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
@ -52,7 +52,7 @@ class Trial(TypedDict):
|
|||
sequence: int
|
||||
experiment: str
|
||||
command: str
|
||||
parameter: Optional[str] # Serialized JSON string.
|
||||
parameter: str # Serialized JSON string. If empty, the trial will receive no parameter.
|
||||
# time_limit: float
|
||||
|
||||
# Command types are as few as possible.
|
||||
|
|
|
@ -46,10 +46,10 @@ def set_default_trial_command_channel(channel: Optional[TrialCommandChannel] = N
|
|||
channel_url = trial_env_vars.NNI_TRIAL_COMMAND_CHANNEL
|
||||
if isinstance(channel_url, str) and channel_url.startswith('import://'):
|
||||
_, channel_class_name = channel_url.split('://', 1)
|
||||
module_name, class_name = channel_class_name.rsplit('.', 1)
|
||||
module = __import__(module_name)
|
||||
channel_class = getattr(module, class_name)
|
||||
_channel = channel_class()
|
||||
path, identifier = channel_class_name.rsplit('.', 1)
|
||||
module = __import__(path, globals(), locals(), [identifier])
|
||||
class_ = getattr(module, identifier)
|
||||
_channel = class_()
|
||||
if not isinstance(_channel, TrialCommandChannel):
|
||||
raise TypeError(f'{_channel} is not an instance of TrialCommandChannel')
|
||||
elif channel_url:
|
||||
|
|
Загрузка…
Ссылка в новой задаче