зеркало из https://github.com/microsoft/archai.git
move apex section under trainer
This commit is contained in:
Родитель
f13b652e35
Коммит
0235072c8e
|
@ -10,12 +10,11 @@ from .metrics import Metrics
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from . import utils, ml_utils
|
from . import utils, ml_utils
|
||||||
from .common import logger, get_device
|
from .common import logger, get_device
|
||||||
|
from archai.common.common import get_apex_utils
|
||||||
|
|
||||||
class Tester(EnforceOverrides):
|
class Tester(EnforceOverrides):
|
||||||
"""Evaluate model on given data
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, conf_eval:Config, model:nn.Module)->None:
|
def __init__(self, conf_eval:Config, model:nn.Module)->None:
|
||||||
|
# TODO: currently we expect that given model and dataloader will already be distributed
|
||||||
self._title = conf_eval['title']
|
self._title = conf_eval['title']
|
||||||
self._logger_freq = conf_eval['logger_freq']
|
self._logger_freq = conf_eval['logger_freq']
|
||||||
conf_lossfn = conf_eval['lossfn']
|
conf_lossfn = conf_eval['lossfn']
|
||||||
|
@ -57,6 +56,9 @@ class Tester(EnforceOverrides):
|
||||||
loss = self._lossfn(logits, y)
|
loss = self._lossfn(logits, y)
|
||||||
self._post_step(x, y, logits, loss, steps, self._metrics)
|
self._post_step(x, y, logits, loss, steps, self._metrics)
|
||||||
|
|
||||||
|
# TODO: we possibly need to sync so all replicas are upto date
|
||||||
|
get_apex_utils().sync_devices()
|
||||||
|
|
||||||
logger.popd()
|
logger.popd()
|
||||||
self._metrics.post_epoch(None)
|
self._metrics.post_epoch(None)
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,12 @@ class Trainer(EnforceOverrides):
|
||||||
self._conf_optim = conf_train['optimizer']
|
self._conf_optim = conf_train['optimizer']
|
||||||
self._conf_sched = conf_train['lr_schedule']
|
self._conf_sched = conf_train['lr_schedule']
|
||||||
conf_validation = conf_train['validation']
|
conf_validation = conf_train['validation']
|
||||||
|
conf_apex = conf_train['apex']
|
||||||
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
|
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
get_apex_utils().reset(logger, conf_apex)
|
||||||
|
|
||||||
self._checkpoint = checkpoint
|
self._checkpoint = checkpoint
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torch import nn
|
||||||
|
|
||||||
from archai.common.trainer import Trainer
|
from archai.common.trainer import Trainer
|
||||||
from archai.common.config import Config
|
from archai.common.config import Config
|
||||||
from archai.common.common import logger, get_apex_utils
|
from archai.common.common import logger
|
||||||
from archai.datasets import data
|
from archai.datasets import data
|
||||||
from archai.nas.model_desc import ModelDesc
|
from archai.nas.model_desc import ModelDesc
|
||||||
from archai.nas.cell_builder import CellBuilder
|
from archai.nas.cell_builder import CellBuilder
|
||||||
|
@ -24,11 +24,8 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
|
||||||
conf_checkpoint = conf_eval['checkpoint']
|
conf_checkpoint = conf_eval['checkpoint']
|
||||||
resume = conf_eval['resume']
|
resume = conf_eval['resume']
|
||||||
conf_train = conf_eval['trainer']
|
conf_train = conf_eval['trainer']
|
||||||
conf_apex = conf_eval['apex']
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
get_apex_utils().reset(logger, conf_apex)
|
|
||||||
|
|
||||||
if cell_builder:
|
if cell_builder:
|
||||||
cell_builder.register_ops()
|
cell_builder.register_ops()
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import tensorwatch as tw
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from archai.common.common import logger, get_apex_utils
|
from archai.common.common import logger
|
||||||
from archai.common.checkpoint import CheckPoint
|
from archai.common.checkpoint import CheckPoint
|
||||||
from archai.common.config import Config
|
from archai.common.config import Config
|
||||||
from .cell_builder import CellBuilder
|
from .cell_builder import CellBuilder
|
||||||
|
@ -108,7 +108,6 @@ class Search:
|
||||||
self.search_iters = conf_search['search_iters']
|
self.search_iters = conf_search['search_iters']
|
||||||
self.pareto_enabled = conf_pareto['enabled']
|
self.pareto_enabled = conf_pareto['enabled']
|
||||||
pareto_summary_filename = conf_pareto['summary_filename']
|
pareto_summary_filename = conf_pareto['summary_filename']
|
||||||
conf_apex = conf_search['apex']
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
self.cell_builder = cell_builder
|
self.cell_builder = cell_builder
|
||||||
|
@ -117,8 +116,6 @@ class Search:
|
||||||
self._parito_filepath = utils.full_path(pareto_summary_filename)
|
self._parito_filepath = utils.full_path(pareto_summary_filename)
|
||||||
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
|
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
|
||||||
|
|
||||||
get_apex_utils().reset(logger, conf_apex)
|
|
||||||
|
|
||||||
logger.info({'pareto_enabled': self.pareto_enabled,
|
logger.info({'pareto_enabled': self.pareto_enabled,
|
||||||
'base_reductions': self.base_reductions,
|
'base_reductions': self.base_reductions,
|
||||||
'base_cells': self.base_cells,
|
'base_cells': self.base_cells,
|
||||||
|
|
|
@ -50,8 +50,6 @@ nas:
|
||||||
metric_filename: '$expdir/eval_train_metrics.yaml'
|
metric_filename: '$expdir/eval_train_metrics.yaml'
|
||||||
model_filename: '$expdir/model.pt' # file to which trained model will be saved
|
model_filename: '$expdir/model.pt' # file to which trained model will be saved
|
||||||
data_parallel: False
|
data_parallel: False
|
||||||
apex:
|
|
||||||
_copy: 'common/apex'
|
|
||||||
checkpoint:
|
checkpoint:
|
||||||
_copy: 'common/checkpoint'
|
_copy: 'common/checkpoint'
|
||||||
resume: '_copy: common/resume'
|
resume: '_copy: common/resume'
|
||||||
|
@ -88,7 +86,8 @@ nas:
|
||||||
dataset:
|
dataset:
|
||||||
_copy: '/dataset'
|
_copy: '/dataset'
|
||||||
trainer:
|
trainer:
|
||||||
apex: False
|
apex:
|
||||||
|
_copy: 'common/apex'
|
||||||
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
|
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
|
||||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||||
grad_clip: 5.0 # grads above this value is clipped
|
grad_clip: 5.0 # grads above this value is clipped
|
||||||
|
@ -122,8 +121,6 @@ nas:
|
||||||
data_parallel: False
|
data_parallel: False
|
||||||
checkpoint:
|
checkpoint:
|
||||||
_copy: 'common/checkpoint'
|
_copy: 'common/checkpoint'
|
||||||
apex:
|
|
||||||
_copy: 'common/apex'
|
|
||||||
resume: '_copy: common/resume'
|
resume: '_copy: common/resume'
|
||||||
search_iters: 1
|
search_iters: 1
|
||||||
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
|
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
|
||||||
|
@ -132,6 +129,8 @@ nas:
|
||||||
seed_train:
|
seed_train:
|
||||||
trainer:
|
trainer:
|
||||||
_copy: 'nas/eval/trainer'
|
_copy: 'nas/eval/trainer'
|
||||||
|
apex:
|
||||||
|
enabled: False
|
||||||
title: 'seed_train'
|
title: 'seed_train'
|
||||||
epochs: 0 # number of epochs model will be trained before search
|
epochs: 0 # number of epochs model will be trained before search
|
||||||
aux_weight: 0.0
|
aux_weight: 0.0
|
||||||
|
@ -143,6 +142,8 @@ nas:
|
||||||
post_train:
|
post_train:
|
||||||
trainer:
|
trainer:
|
||||||
_copy: 'nas/eval/trainer'
|
_copy: 'nas/eval/trainer'
|
||||||
|
apex:
|
||||||
|
enabled: False
|
||||||
title: 'post_train'
|
title: 'post_train'
|
||||||
epochs: 0 # number of epochs model will be trained after search
|
epochs: 0 # number of epochs model will be trained after search
|
||||||
aux_weight: 0.0
|
aux_weight: 0.0
|
||||||
|
@ -193,7 +194,8 @@ nas:
|
||||||
dataset:
|
dataset:
|
||||||
_copy: '/dataset'
|
_copy: '/dataset'
|
||||||
trainer:
|
trainer:
|
||||||
apex: False
|
apex:
|
||||||
|
_copy: 'common/apex'
|
||||||
aux_weight: '_copy: nas/search/model_desc/aux_weight'
|
aux_weight: '_copy: nas/search/model_desc/aux_weight'
|
||||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||||
grad_clip: 5.0 # grads above this value is clipped
|
grad_clip: 5.0 # grads above this value is clipped
|
||||||
|
|
Загрузка…
Ссылка в новой задаче