зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1082: Save the best model
This commit is contained in:
Родитель
cfd9f57049
Коммит
f582e43893
|
@ -49,6 +49,7 @@ class Evaluation():
|
||||||
'''
|
'''
|
||||||
|
|
||||||
self.worker_trainer = req['worker_trainer']
|
self.worker_trainer = req['worker_trainer']
|
||||||
|
save_model = False
|
||||||
if metric_logger is None:
|
if metric_logger is None:
|
||||||
metric_logger = run.log
|
metric_logger = run.log
|
||||||
|
|
||||||
|
@ -72,16 +73,19 @@ class Evaluation():
|
||||||
if value['higher_is_better']:
|
if value['higher_is_better']:
|
||||||
if self.metrics[key]['value'] > req[attr]:
|
if self.metrics[key]['value'] > req[attr]:
|
||||||
req[attr] = self.metrics[key]['value']
|
req[attr] = self.metrics[key]['value']
|
||||||
|
save_model = True
|
||||||
else:
|
else:
|
||||||
if self.metrics[key]['value'] < req[attr]:
|
if self.metrics[key]['value'] < req[attr]:
|
||||||
req[attr] = self.metrics[key]['value']
|
req[attr] = self.metrics[key]['value']
|
||||||
|
save_model = True
|
||||||
|
|
||||||
if mode == 'val':
|
if save_model and mode == 'val':
|
||||||
self.worker_trainer.save(
|
self.worker_trainer.save(
|
||||||
model_path=self.model_path,
|
model_path=self.model_path,
|
||||||
token=str('best_'+ mode +'_'+key),
|
token=str('best_'+ mode +'_'+key),
|
||||||
config=self.config['server_config']
|
config=self.config['server_config']
|
||||||
)
|
)
|
||||||
|
save_model = False
|
||||||
|
|
||||||
return req
|
return req
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче