From 0235072c8eab147893479c4e6183a66707fad848 Mon Sep 17 00:00:00 2001 From: Shital Shah Date: Thu, 23 Apr 2020 10:17:54 -0700 Subject: [PATCH] move apex section under trainer --- archai/common/tester.py | 8 +++++--- archai/common/trainer.py | 3 +++ archai/nas/evaluate.py | 5 +---- archai/nas/search.py | 5 +---- confs/algos/darts.yaml | 14 ++++++++------ 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/archai/common/tester.py b/archai/common/tester.py index 58910ed0..afd7b3be 100644 --- a/archai/common/tester.py +++ b/archai/common/tester.py @@ -10,12 +10,11 @@ from .metrics import Metrics from .config import Config from . import utils, ml_utils from .common import logger, get_device +from archai.common.common import get_apex_utils class Tester(EnforceOverrides): - """Evaluate model on given data - """ - def __init__(self, conf_eval:Config, model:nn.Module)->None: + # TODO: currently we expect that given model and dataloader will already be distributed self._title = conf_eval['title'] self._logger_freq = conf_eval['logger_freq'] conf_lossfn = conf_eval['lossfn'] @@ -57,6 +56,9 @@ class Tester(EnforceOverrides): loss = self._lossfn(logits, y) self._post_step(x, y, logits, loss, steps, self._metrics) + # TODO: we possibly need to sync so all replicas are upto date + get_apex_utils().sync_devices() + logger.popd() self._metrics.post_epoch(None) diff --git a/archai/common/trainer.py b/archai/common/trainer.py index 87b22719..10a52d7f 100644 --- a/archai/common/trainer.py +++ b/archai/common/trainer.py @@ -29,9 +29,12 @@ class Trainer(EnforceOverrides): self._conf_optim = conf_train['optimizer'] self._conf_sched = conf_train['lr_schedule'] conf_validation = conf_train['validation'] + conf_apex = conf_train['apex'] self._validation_freq = 0 if conf_validation is None else conf_validation['freq'] # endregion + get_apex_utils().reset(logger, conf_apex) + self._checkpoint = checkpoint self.model = model diff --git a/archai/nas/evaluate.py b/archai/nas/evaluate.py index 485222c7..c3daf35c 100644 --- a/archai/nas/evaluate.py +++ b/archai/nas/evaluate.py @@ -7,7 +7,7 @@ from torch import nn from archai.common.trainer import Trainer from archai.common.config import Config -from archai.common.common import logger, get_apex_utils +from archai.common.common import logger from archai.datasets import data from archai.nas.model_desc import ModelDesc from archai.nas.cell_builder import CellBuilder @@ -24,11 +24,8 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]): conf_checkpoint = conf_eval['checkpoint'] resume = conf_eval['resume'] conf_train = conf_eval['trainer'] - conf_apex = conf_eval['apex'] # endregion - get_apex_utils().reset(logger, conf_apex) - if cell_builder: cell_builder.register_ops() diff --git a/archai/nas/search.py b/archai/nas/search.py index 2e96a96b..7a7b6fe0 100644 --- a/archai/nas/search.py +++ b/archai/nas/search.py @@ -8,7 +8,7 @@ import tensorwatch as tw from torch.utils.data.dataloader import DataLoader import yaml -from archai.common.common import logger, get_apex_utils +from archai.common.common import logger from archai.common.checkpoint import CheckPoint from archai.common.config import Config from .cell_builder import CellBuilder @@ -108,7 +108,6 @@ class Search: self.search_iters = conf_search['search_iters'] self.pareto_enabled = conf_pareto['enabled'] pareto_summary_filename = conf_pareto['summary_filename'] - conf_apex = conf_search['apex'] # endregion self.cell_builder = cell_builder @@ -117,8 +116,6 @@ class Search: self._parito_filepath = utils.full_path(pareto_summary_filename) self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) - get_apex_utils().reset(logger, conf_apex) - logger.info({'pareto_enabled': self.pareto_enabled, 'base_reductions': self.base_reductions, 'base_cells': self.base_cells, diff --git a/confs/algos/darts.yaml b/confs/algos/darts.yaml index 9fec7793..79e8491c 100644 --- a/confs/algos/darts.yaml +++ b/confs/algos/darts.yaml @@ -50,8 +50,6 @@ nas: metric_filename: '$expdir/eval_train_metrics.yaml' model_filename: '$expdir/model.pt' # file to which trained model will be saved data_parallel: False - apex: - _copy: 'common/apex' checkpoint: _copy: 'common/checkpoint' resume: '_copy: common/resume' @@ -88,7 +86,8 @@ nas: dataset: _copy: '/dataset' trainer: - apex: False + apex: + _copy: 'common/apex' aux_weight: '_copy: nas/eval/model_desc/aux_weight' drop_path_prob: 0.2 # probability that given edge will be dropped grad_clip: 5.0 # grads above this value is clipped @@ -122,8 +121,6 @@ nas: data_parallel: False checkpoint: _copy: 'common/checkpoint' - apex: - _copy: 'common/apex' resume: '_copy: common/resume' search_iters: 1 full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized @@ -132,6 +129,8 @@ nas: seed_train: trainer: _copy: 'nas/eval/trainer' + apex: + enabled: False title: 'seed_train' epochs: 0 # number of epochs model will be trained before search aux_weight: 0.0 @@ -143,6 +142,8 @@ nas: post_train: trainer: _copy: 'nas/eval/trainer' + apex: + enabled: False title: 'post_train' epochs: 0 # number of epochs model will be trained after search aux_weight: 0.0 @@ -193,7 +194,8 @@ nas: dataset: _copy: '/dataset' trainer: - apex: False + apex: + _copy: 'common/apex' aux_weight: '_copy: nas/search/model_desc/aux_weight' drop_path_prob: 0.0 # probability that given edge will be dropped grad_clip: 5.0 # grads above this value is clipped