move apex section under trainer

This commit is contained in:
Shital Shah 2020-04-23 10:17:54 -07:00
Родитель f13b652e35
Коммит 0235072c8e
5 изменённых файлов: 18 добавлений и 17 удалений

Просмотреть файл

@ -10,12 +10,11 @@ from .metrics import Metrics
from .config import Config
from . import utils, ml_utils
from .common import logger, get_device
from archai.common.common import get_apex_utils
class Tester(EnforceOverrides):
"""Evaluate model on given data
"""
def __init__(self, conf_eval:Config, model:nn.Module)->None:
# TODO: currently we expect that given model and dataloader will already be distributed
self._title = conf_eval['title']
self._logger_freq = conf_eval['logger_freq']
conf_lossfn = conf_eval['lossfn']
@ -57,6 +56,9 @@ class Tester(EnforceOverrides):
loss = self._lossfn(logits, y)
self._post_step(x, y, logits, loss, steps, self._metrics)
# TODO: we possibly need to sync so all replicas are upto date
get_apex_utils().sync_devices()
logger.popd()
self._metrics.post_epoch(None)

Просмотреть файл

@ -29,9 +29,12 @@ class Trainer(EnforceOverrides):
self._conf_optim = conf_train['optimizer']
self._conf_sched = conf_train['lr_schedule']
conf_validation = conf_train['validation']
conf_apex = conf_train['apex']
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
# endregion
get_apex_utils().reset(logger, conf_apex)
self._checkpoint = checkpoint
self.model = model

Просмотреть файл

@ -7,7 +7,7 @@ from torch import nn
from archai.common.trainer import Trainer
from archai.common.config import Config
from archai.common.common import logger, get_apex_utils
from archai.common.common import logger
from archai.datasets import data
from archai.nas.model_desc import ModelDesc
from archai.nas.cell_builder import CellBuilder
@ -24,11 +24,8 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
conf_checkpoint = conf_eval['checkpoint']
resume = conf_eval['resume']
conf_train = conf_eval['trainer']
conf_apex = conf_eval['apex']
# endregion
get_apex_utils().reset(logger, conf_apex)
if cell_builder:
cell_builder.register_ops()

Просмотреть файл

@ -8,7 +8,7 @@ import tensorwatch as tw
from torch.utils.data.dataloader import DataLoader
import yaml
from archai.common.common import logger, get_apex_utils
from archai.common.common import logger
from archai.common.checkpoint import CheckPoint
from archai.common.config import Config
from .cell_builder import CellBuilder
@ -108,7 +108,6 @@ class Search:
self.search_iters = conf_search['search_iters']
self.pareto_enabled = conf_pareto['enabled']
pareto_summary_filename = conf_pareto['summary_filename']
conf_apex = conf_search['apex']
# endregion
self.cell_builder = cell_builder
@ -117,8 +116,6 @@ class Search:
self._parito_filepath = utils.full_path(pareto_summary_filename)
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
get_apex_utils().reset(logger, conf_apex)
logger.info({'pareto_enabled': self.pareto_enabled,
'base_reductions': self.base_reductions,
'base_cells': self.base_cells,

Просмотреть файл

@ -50,8 +50,6 @@ nas:
metric_filename: '$expdir/eval_train_metrics.yaml'
model_filename: '$expdir/model.pt' # file to which trained model will be saved
data_parallel: False
apex:
_copy: 'common/apex'
checkpoint:
_copy: 'common/checkpoint'
resume: '_copy: common/resume'
@ -88,7 +86,8 @@ nas:
dataset:
_copy: '/dataset'
trainer:
apex: False
apex:
_copy: 'common/apex'
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
drop_path_prob: 0.2 # probability that given edge will be dropped
grad_clip: 5.0 # grads above this value is clipped
@ -122,8 +121,6 @@ nas:
data_parallel: False
checkpoint:
_copy: 'common/checkpoint'
apex:
_copy: 'common/apex'
resume: '_copy: common/resume'
search_iters: 1
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
@ -132,6 +129,8 @@ nas:
seed_train:
trainer:
_copy: 'nas/eval/trainer'
apex:
enabled: False
title: 'seed_train'
epochs: 0 # number of epochs model will be trained before search
aux_weight: 0.0
@ -143,6 +142,8 @@ nas:
post_train:
trainer:
_copy: 'nas/eval/trainer'
apex:
enabled: False
title: 'post_train'
epochs: 0 # number of epochs model will be trained after search
aux_weight: 0.0
@ -193,7 +194,8 @@ nas:
dataset:
_copy: '/dataset'
trainer:
apex: False
apex:
_copy: 'common/apex'
aux_weight: '_copy: nas/search/model_desc/aux_weight'
drop_path_prob: 0.0 # probability that given edge will be dropped
grad_clip: 5.0 # grads above this value is clipped