robustify the error case a bit more so search jobs can continue when a small number of training jobs fail. (#234)

* remove old notebook

* store onnx latency
allow aml partial training with no snapdragon mode

* fix docker file now that aml branch is merged.

* fix bug in reset
add notebook

* add link to notebook.

* Add an on_start_iteration callback so that user can track which models came from which iterations.

* fix conda file.

* new version

* robustify the error case a bit more so search jobs can continue when a small number of training jobs fail.

* re-use onnx latency numbers.
This commit is contained in:
Chris Lovett 2023-04-24 17:53:25 -07:00 коммит произвёл GitHub
Родитель 5410e8bdd1
Коммит 4a3cf62a77
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 110 добавлений и 75 удалений

Просмотреть файл

@ -15,19 +15,88 @@ class JobCompletionMonitor:
""" This helper class uses the ArchaiStore to monitor the status of some long running
training operations and the status of the Azure ML pipeline those jobs are running in
and waits for them to finish (either successfully or with a failure)"""
def __init__(self, store : ArchaiStore, ml_client : MLClient, metric_keys: List[str], pipeline_id=None, timeout=3600):
def __init__(self, store : ArchaiStore, ml_client : MLClient, metric_keys: List[str], pipeline_id=None, timeout=3600, throw_on_failure_rate=0.1):
"""
Initialize a JobCompletionMonitor instance.
:param store: an instance of ArchaiStore to monitor the status of some long running training operations
:param ml_client: an instance of MLClient to check the status of the Azure ML pipeline those jobs are running in
:param metric_keys: a list of column names to monitor and return from the Azure table.
:param pipeline_id: (optional) the ID of the Azure ML pipeline to monitor, if not provided we can get this from the ArchaiStore.
:param timeout: (optional) the timeout in seconds
:param throw_on_failure_rate: (optional) what ratio of jobs failed (between 0 and 1) should result in raising an an exception.
Zero means throw exception on any failure.
This is handy if you want to allow the search to continue even when a small percentage of jobs fails.
Default is 0.1, or 10% or more of jobs failed will raise an exception.
"""
self.store = store
self.ml_client = ml_client
self.timeout = timeout
self.pipeline_id = pipeline_id
self.metric_keys = metric_keys
self.throw_on_failure_rate = throw_on_failure_rate
def _check_entity_status(self, waiting, completed):
failed = 0
for i in range(len(waiting) - 1, -1, -1):
id = waiting[i]
e = self.store.get_status(id)
if self.pipeline_id is None and 'pipeline_id' in e:
self.pipeline_id = e['pipeline_id']
if e is not None and 'status' in e and (e['status'] == 'complete' or e['status'] == 'failed'):
del waiting[i]
completed[id] = e
if e['status'] == 'failed':
error = e['error']
print(f'Training job {id} failed with error: {error}')
failed += 1
else:
if len(self.metric_keys) > 0 and self.metric_keys[0] in e:
key = self.metric_keys[0]
metric = e[key]
print(f'Training job {id} completed with {key} = {metric}')
else:
print(f'Training job {id} completed')
return failed
def _get_model_results(self, model_ids, completed):
# stitch together the models.json file from our status table.
print('Top model results: ')
models = []
interesting_columns = self.metric_keys + ['status', 'error', 'epochs']
for id in model_ids:
row = {'id': id}
e = completed[id] if id in completed else {}
for key in interesting_columns:
if key in e:
row[key] = e[key]
models += [row]
return {
'models': models
}
def _cancel_waiting_list(self, waiting, pipeline_status):
# cancel any remaining jobs in the waiting list by marking an error status on the entity
for i in range(len(waiting) - 1, -1, -1):
id = waiting[i]
del waiting[i]
e = self.store.get_status(id)
if 'error' not in e:
e['error'] = f'Pipeline {pipeline_status}'
if 'status' not in e or e['status'] != 'complete':
e['status'] = pipeline_status.lower()
self.store.merge_status_entity(e)
def _get_pipeline_status(self):
# try and get the status of the Azure ML pipeline, it returns strings like
# 'Completed', 'Failed', 'Running', 'Preparing', 'Canceled' and so on.
try:
if self.pipeline_id is not None:
train_job = self.ml_client.jobs.get(self.pipeline_id)
if train_job is not None:
return train_job.status
except Exception as e:
print(f'Error getting pipeline status for pipeline {self.pipeline_id}: {e}')
def wait(self, model_ids: List[str]) -> List[Dict[str, str]]:
"""
@ -42,54 +111,19 @@ class JobCompletionMonitor:
failed = 0
while len(waiting) > 0:
for i in range(len(waiting) - 1, -1, -1):
id = waiting[i]
e = self.store.get_status(id)
if self.pipeline_id is None and 'pipeline_id' in e:
self.pipeline_id = e['pipeline_id']
if e is not None and 'status' in e and (e['status'] == 'complete' or e['status'] == 'failed'):
del waiting[i]
completed[id] = e
if e['status'] == 'failed':
error = e['error']
print(f'Training job {id} failed with error: {error}')
failed += 1
self.store.update_status_entity(e)
else:
if len(self.metric_keys) > 0 and self.metric_keys[0] in e:
key = self.metric_keys[0]
metric = e[key]
print(f'Training job {id} completed with {key} = {metric}')
else:
print(f'Training job {id} completed')
failed += self._check_entity_status(waiting, completed)
if len(waiting) == 0:
break
# check the overall pipeline status just in case training jobs failed to even start.
pipeline_status = None
try:
if self.pipeline_id is not None:
train_job = self.ml_client.jobs.get(self.pipeline_id)
if train_job is not None:
pipeline_status = train_job.status
except Exception as e:
print(f'Error getting pipeline status for pipeline {self.pipeline_id}: {e}')
pipeline_status = self._get_pipeline_status()
if pipeline_status is not None:
if pipeline_status == 'Completed':
# ok, all jobs are done, which means if we still have waiting tasks then they failed to
# even start.
break
self._cancel_waiting_list(waiting, 'failed to start')
elif pipeline_status == 'Failed' or pipeline_status == 'Canceled':
for id in waiting:
e = self.store.get_status(id)
if 'error' not in e:
e['error'] = f'Pipeline {pipeline_status}'
if 'status' not in e or e['status'] != 'complete':
e['status'] = pipeline_status.lower()
self.store.update_status_entity(e)
raise Exception('Partial Training Pipeline failed')
self._cancel_waiting_list(waiting, pipeline_status)
if len(waiting) > 0:
if time.time() > self.timeout + start:
@ -104,24 +138,11 @@ class JobCompletionMonitor:
else:
raise Exception('Partial Training Pipeline failed to start')
if failed == len(completed):
raise Exception('Partial Training Pipeline failed all jobs')
failure_rate = float(failed) / float(len(model_ids))
if failure_rate > self.throw_on_failure_rate:
raise Exception(f'Partial Training Pipeline failure rate {failure_rate} exceeds allowed threshold of {self.throw_on_failure_rate}')
# stitch together the models.json file from our status table.
print('Top model results: ')
models = []
interesting_columns = self.metric_keys + ['status', 'error', 'epochs']
for id in model_ids:
row = {'id': id}
e = completed[id] if id in completed else {}
for key in interesting_columns:
if key in e:
row[key] = e[key]
models += [row]
results = {
'models': models
}
results = self._get_model_results(model_ids, completed)
timespan = time.strftime('%H:%M:%S', time.gmtime(time.time() - start))
print(f'Training: Distributed training completed in {timespan} ')

1
tasks/face_segmentation/.vscode/launch.json поставляемый
Просмотреть файл

@ -12,7 +12,6 @@
"console": "integratedTerminal",
"justMyCode": true,
"args":[
"--search_config", "output/confs/aml_search.yaml"
]
}
]

Просмотреть файл

@ -129,7 +129,7 @@ def main(output_dir: Path, experiment_name: str, seed: int):
ml_client,
image="mcr.microsoft.com/azureml/openmpi3.1.2-ubuntu18.04:latest",
conda_file="conda.yaml",
version='1.0.22')
version='1.0.25')
environment_name = f"{archai_job_env.name}:{archai_job_env.version}"
# Register the datastore with AML

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -54,6 +54,7 @@ class AmlPartialTrainingEvaluator(AsyncModelEvaluator):
self.store = configure_store(aml_config)
self.results = []
self.metric_key = self.config['training'].get('metric_key', 'val_iou')
self.failure_rate = 0.25
@overrides
def send(self, arch: ArchaiModel, budget: Optional[float] = None) -> None:
@ -61,7 +62,6 @@ class AmlPartialTrainingEvaluator(AsyncModelEvaluator):
model_id = get_valid_arch_id(arch)
e = self.store.get_status(model_id)
if self.metric_key in e and e[self.metric_key]:
metric = float(e[self.metric_key])
self.results += [{
'id': model_id,
@ -102,7 +102,7 @@ class AmlPartialTrainingEvaluator(AsyncModelEvaluator):
# wait for all the parallel training jobs to finish
keys = [self.metric_key]
monitor = JobCompletionMonitor(self.store, self.ml_client, keys, job_id, self.timeout)
monitor = JobCompletionMonitor(self.store, self.ml_client, keys, job_id, self.timeout, throw_on_failure_rate=self.failure_rate)
models = monitor.wait(model_names)['models']
for m in models:
id = m['id']
@ -135,7 +135,10 @@ class AmlPartialTrainingEvaluator(AsyncModelEvaluator):
# not so good.
metrics = []
for m in results['models']:
metric = m[self.metric_key]
if self.metric_key in m:
metric = m[self.metric_key]
else:
metric = None
metrics += [metric]
self.models = [] # reset for next run.

Просмотреть файл

@ -40,12 +40,21 @@ class AvgOnnxLatencyEvaluator(AvgOnnxLatency):
@overrides
def evaluate(self, model: ArchaiModel, budget: Optional[float] = None) -> float:
archid = f'id_{model.archid}'
if self.store is not None:
e = self.store.get_status(archid)
if 'iteration' not in e or e['iteration'] != self.iteration:
e['iteration'] = self.iteration
self.store.merge_status_entity(e)
if self.metric_key in e:
# Use the cached value and skip the more expensive Onnx evaluation.
# This also ensures maximum re-use of previous training jobs.
return e[self.metric_key]
result = super(AvgOnnxLatencyEvaluator, self).evaluate(model, budget)
if self.store is not None:
archid = f'id_{model.archid}'
e = self.store.get_status(archid)
e['status'] = 'complete'
e['iteration'] = self.iteration
e[self.metric_key] = result
self.store.merge_status_entity(e)
return result

Просмотреть файл

@ -79,4 +79,4 @@ def copy_code_folder(src_dir, target_dir):
def get_valid_arch_id(arch: ArchaiModel):
# bug in azure ml sdk requires blob store folder names not begin with digits, so we prefix with 'id_'
return f'id_{arch.archid}'
return f'id_{arch.archid}'

Просмотреть файл

@ -91,7 +91,10 @@ aml:
- matplotlib
- mldesigner
- mlflow
- tensorwatch
- torch
- torchvision
- torchaudio
- transformers==4.27.4
- xformers
- archai[dev] @ git+https://github.com/microsoft/archai.git

Просмотреть файл

@ -131,4 +131,4 @@ def main():
if __name__ == '__main__':
main()
main()