зеркало из https://github.com/microsoft/archai.git
move apex section under trainer
This commit is contained in:
Родитель
f13b652e35
Коммит
0235072c8e
|
@ -10,12 +10,11 @@ from .metrics import Metrics
|
|||
from .config import Config
|
||||
from . import utils, ml_utils
|
||||
from .common import logger, get_device
|
||||
from archai.common.common import get_apex_utils
|
||||
|
||||
class Tester(EnforceOverrides):
|
||||
"""Evaluate model on given data
|
||||
"""
|
||||
|
||||
def __init__(self, conf_eval:Config, model:nn.Module)->None:
|
||||
# TODO: currently we expect that given model and dataloader will already be distributed
|
||||
self._title = conf_eval['title']
|
||||
self._logger_freq = conf_eval['logger_freq']
|
||||
conf_lossfn = conf_eval['lossfn']
|
||||
|
@ -57,6 +56,9 @@ class Tester(EnforceOverrides):
|
|||
loss = self._lossfn(logits, y)
|
||||
self._post_step(x, y, logits, loss, steps, self._metrics)
|
||||
|
||||
# TODO: we possibly need to sync so all replicas are upto date
|
||||
get_apex_utils().sync_devices()
|
||||
|
||||
logger.popd()
|
||||
self._metrics.post_epoch(None)
|
||||
|
||||
|
|
|
@ -29,9 +29,12 @@ class Trainer(EnforceOverrides):
|
|||
self._conf_optim = conf_train['optimizer']
|
||||
self._conf_sched = conf_train['lr_schedule']
|
||||
conf_validation = conf_train['validation']
|
||||
conf_apex = conf_train['apex']
|
||||
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
|
||||
# endregion
|
||||
|
||||
get_apex_utils().reset(logger, conf_apex)
|
||||
|
||||
self._checkpoint = checkpoint
|
||||
self.model = model
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch import nn
|
|||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger, get_apex_utils
|
||||
from archai.common.common import logger
|
||||
from archai.datasets import data
|
||||
from archai.nas.model_desc import ModelDesc
|
||||
from archai.nas.cell_builder import CellBuilder
|
||||
|
@ -24,11 +24,8 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
|
|||
conf_checkpoint = conf_eval['checkpoint']
|
||||
resume = conf_eval['resume']
|
||||
conf_train = conf_eval['trainer']
|
||||
conf_apex = conf_eval['apex']
|
||||
# endregion
|
||||
|
||||
get_apex_utils().reset(logger, conf_apex)
|
||||
|
||||
if cell_builder:
|
||||
cell_builder.register_ops()
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import tensorwatch as tw
|
|||
from torch.utils.data.dataloader import DataLoader
|
||||
import yaml
|
||||
|
||||
from archai.common.common import logger, get_apex_utils
|
||||
from archai.common.common import logger
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.config import Config
|
||||
from .cell_builder import CellBuilder
|
||||
|
@ -108,7 +108,6 @@ class Search:
|
|||
self.search_iters = conf_search['search_iters']
|
||||
self.pareto_enabled = conf_pareto['enabled']
|
||||
pareto_summary_filename = conf_pareto['summary_filename']
|
||||
conf_apex = conf_search['apex']
|
||||
# endregion
|
||||
|
||||
self.cell_builder = cell_builder
|
||||
|
@ -117,8 +116,6 @@ class Search:
|
|||
self._parito_filepath = utils.full_path(pareto_summary_filename)
|
||||
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
|
||||
|
||||
get_apex_utils().reset(logger, conf_apex)
|
||||
|
||||
logger.info({'pareto_enabled': self.pareto_enabled,
|
||||
'base_reductions': self.base_reductions,
|
||||
'base_cells': self.base_cells,
|
||||
|
|
|
@ -50,8 +50,6 @@ nas:
|
|||
metric_filename: '$expdir/eval_train_metrics.yaml'
|
||||
model_filename: '$expdir/model.pt' # file to which trained model will be saved
|
||||
data_parallel: False
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
checkpoint:
|
||||
_copy: 'common/checkpoint'
|
||||
resume: '_copy: common/resume'
|
||||
|
@ -88,7 +86,8 @@ nas:
|
|||
dataset:
|
||||
_copy: '/dataset'
|
||||
trainer:
|
||||
apex: False
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
|
||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
|
@ -122,8 +121,6 @@ nas:
|
|||
data_parallel: False
|
||||
checkpoint:
|
||||
_copy: 'common/checkpoint'
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
resume: '_copy: common/resume'
|
||||
search_iters: 1
|
||||
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
|
||||
|
@ -132,6 +129,8 @@ nas:
|
|||
seed_train:
|
||||
trainer:
|
||||
_copy: 'nas/eval/trainer'
|
||||
apex:
|
||||
enabled: False
|
||||
title: 'seed_train'
|
||||
epochs: 0 # number of epochs model will be trained before search
|
||||
aux_weight: 0.0
|
||||
|
@ -143,6 +142,8 @@ nas:
|
|||
post_train:
|
||||
trainer:
|
||||
_copy: 'nas/eval/trainer'
|
||||
apex:
|
||||
enabled: False
|
||||
title: 'post_train'
|
||||
epochs: 0 # number of epochs model will be trained after search
|
||||
aux_weight: 0.0
|
||||
|
@ -193,7 +194,8 @@ nas:
|
|||
dataset:
|
||||
_copy: '/dataset'
|
||||
trainer:
|
||||
apex: False
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
aux_weight: '_copy: nas/search/model_desc/aux_weight'
|
||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
|
|
Загрузка…
Ссылка в новой задаче