зеркало из https://github.com/microsoft/archai.git
robustify the error case a bit more so search jobs can continue when a small number of training jobs fail. (#234)
* remove old notebook * store onnx latency allow aml partial training with no snapdragon mode * fix docker file now that aml branch is merged. * fix bug in reset add notebook * add link to notebook. * Add an on_start_iteration callback so that user can track which models came from which iterations. * fix conda file. * new version * robustify the error case a bit more so search jobs can continue when a small number of training jobs fail. * re-use onnx latency numbers.
This commit is contained in:
Родитель
5410e8bdd1
Коммит
4a3cf62a77
|
@ -15,19 +15,88 @@ class JobCompletionMonitor:
|
|||
""" This helper class uses the ArchaiStore to monitor the status of some long running
|
||||
training operations and the status of the Azure ML pipeline those jobs are running in
|
||||
and waits for them to finish (either successfully or with a failure)"""
|
||||
def __init__(self, store : ArchaiStore, ml_client : MLClient, metric_keys: List[str], pipeline_id=None, timeout=3600):
|
||||
def __init__(self, store : ArchaiStore, ml_client : MLClient, metric_keys: List[str], pipeline_id=None, timeout=3600, throw_on_failure_rate=0.1):
|
||||
"""
|
||||
Initialize a JobCompletionMonitor instance.
|
||||
:param store: an instance of ArchaiStore to monitor the status of some long running training operations
|
||||
:param ml_client: an instance of MLClient to check the status of the Azure ML pipeline those jobs are running in
|
||||
:param metric_keys: a list of column names to monitor and return from the Azure table.
|
||||
:param pipeline_id: (optional) the ID of the Azure ML pipeline to monitor, if not provided we can get this from the ArchaiStore.
|
||||
:param timeout: (optional) the timeout in seconds
|
||||
:param throw_on_failure_rate: (optional) what ratio of jobs failed (between 0 and 1) should result in raising an an exception.
|
||||
Zero means throw exception on any failure.
|
||||
This is handy if you want to allow the search to continue even when a small percentage of jobs fails.
|
||||
Default is 0.1, or 10% or more of jobs failed will raise an exception.
|
||||
"""
|
||||
self.store = store
|
||||
self.ml_client = ml_client
|
||||
self.timeout = timeout
|
||||
self.pipeline_id = pipeline_id
|
||||
self.metric_keys = metric_keys
|
||||
self.throw_on_failure_rate = throw_on_failure_rate
|
||||
|
||||
def _check_entity_status(self, waiting, completed):
|
||||
failed = 0
|
||||
for i in range(len(waiting) - 1, -1, -1):
|
||||
id = waiting[i]
|
||||
e = self.store.get_status(id)
|
||||
if self.pipeline_id is None and 'pipeline_id' in e:
|
||||
self.pipeline_id = e['pipeline_id']
|
||||
if e is not None and 'status' in e and (e['status'] == 'complete' or e['status'] == 'failed'):
|
||||
del waiting[i]
|
||||
completed[id] = e
|
||||
if e['status'] == 'failed':
|
||||
error = e['error']
|
||||
print(f'Training job {id} failed with error: {error}')
|
||||
failed += 1
|
||||
else:
|
||||
if len(self.metric_keys) > 0 and self.metric_keys[0] in e:
|
||||
key = self.metric_keys[0]
|
||||
metric = e[key]
|
||||
print(f'Training job {id} completed with {key} = {metric}')
|
||||
else:
|
||||
print(f'Training job {id} completed')
|
||||
return failed
|
||||
|
||||
def _get_model_results(self, model_ids, completed):
|
||||
# stitch together the models.json file from our status table.
|
||||
print('Top model results: ')
|
||||
models = []
|
||||
interesting_columns = self.metric_keys + ['status', 'error', 'epochs']
|
||||
for id in model_ids:
|
||||
row = {'id': id}
|
||||
e = completed[id] if id in completed else {}
|
||||
for key in interesting_columns:
|
||||
if key in e:
|
||||
row[key] = e[key]
|
||||
models += [row]
|
||||
|
||||
return {
|
||||
'models': models
|
||||
}
|
||||
|
||||
def _cancel_waiting_list(self, waiting, pipeline_status):
|
||||
# cancel any remaining jobs in the waiting list by marking an error status on the entity
|
||||
for i in range(len(waiting) - 1, -1, -1):
|
||||
id = waiting[i]
|
||||
del waiting[i]
|
||||
e = self.store.get_status(id)
|
||||
if 'error' not in e:
|
||||
e['error'] = f'Pipeline {pipeline_status}'
|
||||
if 'status' not in e or e['status'] != 'complete':
|
||||
e['status'] = pipeline_status.lower()
|
||||
self.store.merge_status_entity(e)
|
||||
|
||||
def _get_pipeline_status(self):
|
||||
# try and get the status of the Azure ML pipeline, it returns strings like
|
||||
# 'Completed', 'Failed', 'Running', 'Preparing', 'Canceled' and so on.
|
||||
try:
|
||||
if self.pipeline_id is not None:
|
||||
train_job = self.ml_client.jobs.get(self.pipeline_id)
|
||||
if train_job is not None:
|
||||
return train_job.status
|
||||
except Exception as e:
|
||||
print(f'Error getting pipeline status for pipeline {self.pipeline_id}: {e}')
|
||||
|
||||
def wait(self, model_ids: List[str]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
|
@ -42,54 +111,19 @@ class JobCompletionMonitor:
|
|||
failed = 0
|
||||
|
||||
while len(waiting) > 0:
|
||||
for i in range(len(waiting) - 1, -1, -1):
|
||||
id = waiting[i]
|
||||
e = self.store.get_status(id)
|
||||
if self.pipeline_id is None and 'pipeline_id' in e:
|
||||
self.pipeline_id = e['pipeline_id']
|
||||
if e is not None and 'status' in e and (e['status'] == 'complete' or e['status'] == 'failed'):
|
||||
del waiting[i]
|
||||
completed[id] = e
|
||||
if e['status'] == 'failed':
|
||||
error = e['error']
|
||||
print(f'Training job {id} failed with error: {error}')
|
||||
failed += 1
|
||||
self.store.update_status_entity(e)
|
||||
else:
|
||||
if len(self.metric_keys) > 0 and self.metric_keys[0] in e:
|
||||
key = self.metric_keys[0]
|
||||
metric = e[key]
|
||||
print(f'Training job {id} completed with {key} = {metric}')
|
||||
else:
|
||||
print(f'Training job {id} completed')
|
||||
|
||||
failed += self._check_entity_status(waiting, completed)
|
||||
if len(waiting) == 0:
|
||||
break
|
||||
|
||||
# check the overall pipeline status just in case training jobs failed to even start.
|
||||
pipeline_status = None
|
||||
try:
|
||||
if self.pipeline_id is not None:
|
||||
train_job = self.ml_client.jobs.get(self.pipeline_id)
|
||||
if train_job is not None:
|
||||
pipeline_status = train_job.status
|
||||
except Exception as e:
|
||||
print(f'Error getting pipeline status for pipeline {self.pipeline_id}: {e}')
|
||||
|
||||
pipeline_status = self._get_pipeline_status()
|
||||
if pipeline_status is not None:
|
||||
if pipeline_status == 'Completed':
|
||||
# ok, all jobs are done, which means if we still have waiting tasks then they failed to
|
||||
# even start.
|
||||
break
|
||||
self._cancel_waiting_list(waiting, 'failed to start')
|
||||
elif pipeline_status == 'Failed' or pipeline_status == 'Canceled':
|
||||
for id in waiting:
|
||||
e = self.store.get_status(id)
|
||||
if 'error' not in e:
|
||||
e['error'] = f'Pipeline {pipeline_status}'
|
||||
if 'status' not in e or e['status'] != 'complete':
|
||||
e['status'] = pipeline_status.lower()
|
||||
self.store.update_status_entity(e)
|
||||
raise Exception('Partial Training Pipeline failed')
|
||||
self._cancel_waiting_list(waiting, pipeline_status)
|
||||
|
||||
if len(waiting) > 0:
|
||||
if time.time() > self.timeout + start:
|
||||
|
@ -104,24 +138,11 @@ class JobCompletionMonitor:
|
|||
else:
|
||||
raise Exception('Partial Training Pipeline failed to start')
|
||||
|
||||
if failed == len(completed):
|
||||
raise Exception('Partial Training Pipeline failed all jobs')
|
||||
failure_rate = float(failed) / float(len(model_ids))
|
||||
if failure_rate > self.throw_on_failure_rate:
|
||||
raise Exception(f'Partial Training Pipeline failure rate {failure_rate} exceeds allowed threshold of {self.throw_on_failure_rate}')
|
||||
|
||||
# stitch together the models.json file from our status table.
|
||||
print('Top model results: ')
|
||||
models = []
|
||||
interesting_columns = self.metric_keys + ['status', 'error', 'epochs']
|
||||
for id in model_ids:
|
||||
row = {'id': id}
|
||||
e = completed[id] if id in completed else {}
|
||||
for key in interesting_columns:
|
||||
if key in e:
|
||||
row[key] = e[key]
|
||||
models += [row]
|
||||
|
||||
results = {
|
||||
'models': models
|
||||
}
|
||||
results = self._get_model_results(model_ids, completed)
|
||||
|
||||
timespan = time.strftime('%H:%M:%S', time.gmtime(time.time() - start))
|
||||
print(f'Training: Distributed training completed in {timespan} ')
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"args":[
|
||||
"--search_config", "output/confs/aml_search.yaml"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
@ -129,7 +129,7 @@ def main(output_dir: Path, experiment_name: str, seed: int):
|
|||
ml_client,
|
||||
image="mcr.microsoft.com/azureml/openmpi3.1.2-ubuntu18.04:latest",
|
||||
conda_file="conda.yaml",
|
||||
version='1.0.22')
|
||||
version='1.0.25')
|
||||
environment_name = f"{archai_job_env.name}:{archai_job_env.version}"
|
||||
|
||||
# Register the datastore with AML
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -54,6 +54,7 @@ class AmlPartialTrainingEvaluator(AsyncModelEvaluator):
|
|||
self.store = configure_store(aml_config)
|
||||
self.results = []
|
||||
self.metric_key = self.config['training'].get('metric_key', 'val_iou')
|
||||
self.failure_rate = 0.25
|
||||
|
||||
@overrides
|
||||
def send(self, arch: ArchaiModel, budget: Optional[float] = None) -> None:
|
||||
|
@ -61,7 +62,6 @@ class AmlPartialTrainingEvaluator(AsyncModelEvaluator):
|
|||
model_id = get_valid_arch_id(arch)
|
||||
e = self.store.get_status(model_id)
|
||||
if self.metric_key in e and e[self.metric_key]:
|
||||
|
||||
metric = float(e[self.metric_key])
|
||||
self.results += [{
|
||||
'id': model_id,
|
||||
|
@ -102,7 +102,7 @@ class AmlPartialTrainingEvaluator(AsyncModelEvaluator):
|
|||
|
||||
# wait for all the parallel training jobs to finish
|
||||
keys = [self.metric_key]
|
||||
monitor = JobCompletionMonitor(self.store, self.ml_client, keys, job_id, self.timeout)
|
||||
monitor = JobCompletionMonitor(self.store, self.ml_client, keys, job_id, self.timeout, throw_on_failure_rate=self.failure_rate)
|
||||
models = monitor.wait(model_names)['models']
|
||||
for m in models:
|
||||
id = m['id']
|
||||
|
@ -135,7 +135,10 @@ class AmlPartialTrainingEvaluator(AsyncModelEvaluator):
|
|||
# not so good.
|
||||
metrics = []
|
||||
for m in results['models']:
|
||||
metric = m[self.metric_key]
|
||||
if self.metric_key in m:
|
||||
metric = m[self.metric_key]
|
||||
else:
|
||||
metric = None
|
||||
metrics += [metric]
|
||||
|
||||
self.models = [] # reset for next run.
|
||||
|
|
|
@ -40,12 +40,21 @@ class AvgOnnxLatencyEvaluator(AvgOnnxLatency):
|
|||
|
||||
@overrides
|
||||
def evaluate(self, model: ArchaiModel, budget: Optional[float] = None) -> float:
|
||||
archid = f'id_{model.archid}'
|
||||
if self.store is not None:
|
||||
e = self.store.get_status(archid)
|
||||
if 'iteration' not in e or e['iteration'] != self.iteration:
|
||||
e['iteration'] = self.iteration
|
||||
self.store.merge_status_entity(e)
|
||||
if self.metric_key in e:
|
||||
# Use the cached value and skip the more expensive Onnx evaluation.
|
||||
# This also ensures maximum re-use of previous training jobs.
|
||||
return e[self.metric_key]
|
||||
|
||||
result = super(AvgOnnxLatencyEvaluator, self).evaluate(model, budget)
|
||||
if self.store is not None:
|
||||
archid = f'id_{model.archid}'
|
||||
e = self.store.get_status(archid)
|
||||
e['status'] = 'complete'
|
||||
e['iteration'] = self.iteration
|
||||
e[self.metric_key] = result
|
||||
self.store.merge_status_entity(e)
|
||||
return result
|
||||
|
|
|
@ -79,4 +79,4 @@ def copy_code_folder(src_dir, target_dir):
|
|||
|
||||
def get_valid_arch_id(arch: ArchaiModel):
|
||||
# bug in azure ml sdk requires blob store folder names not begin with digits, so we prefix with 'id_'
|
||||
return f'id_{arch.archid}'
|
||||
return f'id_{arch.archid}'
|
||||
|
|
|
@ -91,7 +91,10 @@ aml:
|
|||
- matplotlib
|
||||
- mldesigner
|
||||
- mlflow
|
||||
- tensorwatch
|
||||
- torch
|
||||
- torchvision
|
||||
- torchaudio
|
||||
- transformers==4.27.4
|
||||
- xformers
|
||||
- archai[dev] @ git+https://github.com/microsoft/archai.git
|
|
@ -131,4 +131,4 @@ def main():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче