зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1082: Save the best model
This commit is contained in:
Родитель
cfd9f57049
Коммит
f582e43893
|
@ -49,6 +49,7 @@ class Evaluation():
|
|||
'''
|
||||
|
||||
self.worker_trainer = req['worker_trainer']
|
||||
save_model = False
|
||||
if metric_logger is None:
|
||||
metric_logger = run.log
|
||||
|
||||
|
@ -72,16 +73,19 @@ class Evaluation():
|
|||
if value['higher_is_better']:
|
||||
if self.metrics[key]['value'] > req[attr]:
|
||||
req[attr] = self.metrics[key]['value']
|
||||
save_model = True
|
||||
else:
|
||||
if self.metrics[key]['value'] < req[attr]:
|
||||
req[attr] = self.metrics[key]['value']
|
||||
|
||||
if mode == 'val':
|
||||
save_model = True
|
||||
|
||||
if save_model and mode == 'val':
|
||||
self.worker_trainer.save(
|
||||
model_path=self.model_path,
|
||||
token=str('best_'+ mode +'_'+key),
|
||||
config=self.config['server_config']
|
||||
)
|
||||
save_model = False
|
||||
|
||||
return req
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче