From d861b5f2a6f522fabe7d6350f04e77b428ef371b Mon Sep 17 00:00:00 2001 From: Shital Shah Date: Thu, 23 Apr 2020 13:06:09 -0700 Subject: [PATCH] Move apex to trainer --- archai/algos/darts/bilevel_arch_trainer.py | 9 +- archai/algos/darts/bilevel_optimizer.py | 6 +- archai/algos/xnas/xnas_arch_trainer.py | 6 +- archai/common/apex_utils.py | 129 ++++++++++----------- archai/common/common.py | 46 ++------ archai/common/metrics.py | 88 +++++++++----- archai/common/tester.py | 16 +-- archai/common/trainer.py | 30 ++--- archai/common/utils.py | 6 +- confs/algos/darts.yaml | 5 +- scripts/misc/yaml_playground.py | 27 ++++- scripts/perf/resnet_test.py | 4 +- 12 files changed, 198 insertions(+), 174 deletions(-) diff --git a/archai/algos/darts/bilevel_arch_trainer.py b/archai/algos/darts/bilevel_arch_trainer.py index f81d1376..3ac7f902 100644 --- a/archai/algos/darts/bilevel_arch_trainer.py +++ b/archai/algos/darts/bilevel_arch_trainer.py @@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer from archai.common import utils, ml_utils from archai.nas.model import Model from archai.common.checkpoint import CheckPoint -from archai.common.common import logger, get_device +from archai.common.common import logger from .bilevel_optimizer import BilevelOptimizer class BilevelArchTrainer(ArchTrainer): @@ -36,10 +36,11 @@ class BilevelArchTrainer(ArchTrainer): assert val_dl is not None w_momentum = self._conf_w_optim['momentum'] w_decay = self._conf_w_optim['decay'] - lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device()) + lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device()) self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum, - w_decay, self.model, lossfn) + w_decay, self.model, lossfn, + self.get_device()) @overrides def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None: @@ -71,7 +72,7 @@ class BilevelArchTrainer(ArchTrainer): self._valid_iter = iter(self._val_dl) x_val, y_val = next(self._valid_iter) - x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True) + x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True) # update alphas self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer()) diff --git a/archai/algos/darts/bilevel_optimizer.py b/archai/algos/darts/bilevel_optimizer.py index b9ede32e..fd0ea5a7 100644 --- a/archai/algos/darts/bilevel_optimizer.py +++ b/archai/algos/darts/bilevel_optimizer.py @@ -9,11 +9,11 @@ from torch.optim.optimizer import Optimizer from archai.common.config import Config from archai.common import utils, ml_utils from archai.nas.model import Model -from archai.common.common import logger, get_device +from archai.common.common import logger class BilevelOptimizer: def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float, - model: Model, lossfn: _Loss) -> None: + model: Model, lossfn: _Loss, device) -> None: self._w_momentum = w_momentum # momentum for w self._w_weight_decay = w_decay # weight decay for w self._lossfn = lossfn @@ -22,7 +22,7 @@ class BilevelOptimizer: # create a copy of model which we will use # to compute grads for alphas without disturbing # original weights - self._vmodel = copy.deepcopy(model).to(get_device()) + self._vmodel = copy.deepcopy(model).to(device) # this is the optimizer to optimize alphas parameter self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas()) diff --git a/archai/algos/xnas/xnas_arch_trainer.py b/archai/algos/xnas/xnas_arch_trainer.py index 5d2df886..e29e54bd 100644 --- a/archai/algos/xnas/xnas_arch_trainer.py +++ b/archai/algos/xnas/xnas_arch_trainer.py @@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer from archai.common import utils, ml_utils from archai.nas.model import Model from archai.common.checkpoint import CheckPoint -from archai.common.common import logger, get_device +from archai.common.common import logger class XnasArchTrainer(ArchTrainer): @@ -39,7 +39,7 @@ class XnasArchTrainer(ArchTrainer): # optimizers, schedulers needs to be recreated for each fit call # as they have state assert val_dl is not None - lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device()) + lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device()) self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn) @@ -74,7 +74,7 @@ class XnasArchTrainer(ArchTrainer): self._valid_iter = iter(self._val_dl) x_val, y_val = next(self._valid_iter) - x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True) + x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True) # update alphas self._xnas_optim.step(x, y, x_val, y_val) diff --git a/archai/common/apex_utils.py b/archai/common/apex_utils.py index 32cd9a23..382e7e5c 100644 --- a/archai/common/apex_utils.py +++ b/archai/common/apex_utils.py @@ -15,23 +15,7 @@ from archai.common import ml_utils, utils from archai.common.ordereddict_logger import OrderedDictLogger class ApexUtils: - def __init__(self)->None: - self._amp = self._ddp = None - self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM, - 'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX} - self.gpu_ids = [] # use all gpus - - self._mixed_prec_enabled = False - self._distributed_enabled = False - - self._set_ranks() - - def reset(self, logger:OrderedDictLogger, apex_config:Config)->None: - # reset allows to configure differently for search or eval modes - - # to avoid circular references= with common, logger is passed from outside - self.logger = logger - + def __init__(self, apex_config:Config, logger:Optional[OrderedDictLogger])->None: # region conf vars self._enabled = apex_config['enabled'] # global switch to disable anything apex self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode @@ -47,42 +31,53 @@ class ApexUtils: conf_gpu_ids = apex_config['gpus'] # endregion - self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i] - self._amp, self._ddp = None, None - self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 # which GPU to use, we will use only 1 GPU + # to avoid circular references= with common, logger is passed from outside + self.logger = logger - #logger.info({'apex_config': apex_config.to_dict()}) - logger.info({'torch.distributed.is_available': dist.is_available()}) + self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM, + 'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX} + + # defaults for non-distributed mode + self._amp, self._ddp = None, None + self.world_size = 1 + self.local_rank = self.global_rank = 0 + + self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i] + + # which GPU to use, we will use only 1 GPU per process to avoid complications with apex + self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 + + #_log_info({'apex_config': apex_config.to_dict()}) + self._log_info({'torch.distributed.is_available': dist.is_available()}) if dist.is_available(): - logger.info({'gloo_available': dist.is_gloo_available(), + self._log_info({'gloo_available': dist.is_gloo_available(), 'mpi_available': dist.is_mpi_available(), 'nccl_available': dist.is_nccl_available()}) - if self._enabled: - if self._mixed_prec_enabled: - # init enable mixed precision - assert cudnn.enabled, "Amp requires cudnn backend to be enabled." - from apex import amp - self._amp = amp + if self.is_mixed(): + # init enable mixed precision + assert cudnn.enabled, "Amp requires cudnn backend to be enabled." + from apex import amp + self._amp = amp - # enable distributed processing - if self._distributed_enabled: - from apex import parallel - self._ddp = parallel + # enable distributed processing + if self.is_dist(): + from apex import parallel + self._ddp = parallel - assert dist.is_available() # distributed module is available - assert dist.is_nccl_available() - if not dist.is_initialized(): - dist.init_process_group(backend='nccl', init_method='env://') - assert dist.is_initialized() + assert dist.is_available() # distributed module is available + assert dist.is_nccl_available() + if not dist.is_initialized(): + dist.init_process_group(backend='nccl', init_method='env://') + assert dist.is_initialized() - self._set_ranks() - assert dist.get_world_size() == self.world_size - assert dist.get_rank() == self.global_rank - else: - assert self.world_size == 1 - assert self.local_rank == 0 - assert self.global_rank == 0 + self._set_ranks() + assert dist.get_world_size() == self.world_size + assert dist.get_rank() == self.global_rank + else: + assert self.world_size == 1 + assert self.local_rank == 0 + assert self.global_rank == 0 assert self.world_size >= 1 assert not self._min_world_size or self.world_size >= self._min_world_size @@ -94,9 +89,9 @@ class ApexUtils: self.device = torch.device('cuda', self._gpu) self._setup_gpus(seed, detect_anomaly) - logger.info({'amp_available': self._amp is not None, + self._log_info({'amp_available': self._amp is not None, 'distributed_available': self._ddp is not None}) - logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False, + self._log_info({'dist_initialized': dist.is_initialized() if dist.is_available() else False, 'world_size': self.world_size, 'gpu': self._gpu, 'gpu_ids':self.gpu_ids, 'local_rank': self.local_rank}) @@ -106,10 +101,10 @@ class ApexUtils: utils.setup_cuda(seed, self.local_rank) torch.autograd.set_detect_anomaly(detect_anomaly) - self.logger.info({'set_detect_anomaly': detect_anomaly, + self._log_info({'set_detect_anomaly': detect_anomaly, 'is_anomaly_enabled': torch.is_anomaly_enabled()}) - self.logger.info({'gpu_names': utils.cuda_device_names(), + self._log_info({'gpu_names': utils.cuda_device_names(), 'gpu_count': torch.cuda.device_count(), 'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet', @@ -118,8 +113,8 @@ class ApexUtils: 'cudnn.deterministic': cudnn.deterministic, 'cudnn.version': cudnn.version() }) - self.logger.info({'memory': str(psutil.virtual_memory())}) - self.logger.info({'CPUs': str(psutil.cpu_count())}) + self._log_info({'memory': str(psutil.virtual_memory())}) + self._log_info({'CPUs': str(psutil.cpu_count())}) # gpu_usage = os.popen( # 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader' @@ -127,7 +122,7 @@ class ApexUtils: # for i, line in enumerate(gpu_usage): # vals = line.split(',') # if len(vals) == 2: - # logger.info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1])) + # _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1])) def _set_ranks(self)->None: if 'WORLD_SIZE' in os.environ: @@ -153,18 +148,22 @@ class ApexUtils: self._gpu = self.gpu_ids[self.local_rank] def is_mixed(self)->bool: - return self._mixed_prec_enabled + return self._enabled and self._mixed_prec_enabled def is_dist(self)->bool: - return self._distributed_enabled + return self._enabled and self._distributed_enabled def is_master(self)->bool: return self.global_rank == 0 + def _log_info(self, d:dict)->None: + if self.logger is not None: + self.logger.info(d) + def sync_devices(self)->None: - if self._distributed_enabled: + if self.is_dist(): torch.cuda.synchronize(self.device) def reduce(self, val, op='mean'): - if self._distributed_enabled: + if self.is_dist(): if not isinstance(val, Tensor): rt = torch.tensor(val).to(self.device) converted = True @@ -184,7 +183,7 @@ class ApexUtils: return val def backward(self, loss:torch.Tensor, optim:Optimizer)->None: - if self._mixed_prec_enabled: + if self.is_mixed(): with self._amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: @@ -193,26 +192,26 @@ class ApexUtils: def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\ ->Tuple[nn.Module, Optimizer]: # conver BNs to sync BNs in distributed mode - if self._distributed_enabled and self._sync_bn: + if self.is_dist() and self._sync_bn: model = self._ddp.convert_syncbn_model(model) - self.logger.info({'BNs_converted': True}) + self._log_info({'BNs_converted': True}) model = model.to(self.device) - if self._mixed_prec_enabled: + if self.is_mixed(): # scale LR if self._scale_lr: lr = ml_utils.get_optim_lr(optim) scaled_lr = lr * self.world_size / float(batch_size) ml_utils.set_optim_lr(optim, scaled_lr) - self.logger.info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr}) + self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr}) model, optim = self._amp.initialize( model, optim, opt_level=self._opt_level, keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale ) - if self._distributed_enabled: + if self.is_dist(): # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # delay_allreduce delays all communication to the end of the backward pass. @@ -222,19 +221,19 @@ class ApexUtils: def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None: if clip > 0.0: - if self._mixed_prec_enabled: + if self.is_mixed(): nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip) else: nn.utils.clip_grad_norm_(model.parameters(), clip) def state_dict(self): - if self._mixed_prec_enabled: + if self.is_mixed(): return self._amp.state_dict() else: return None def load_state_dict(self, state_dict): - if self._mixed_prec_enabled: + if self.is_mixed(): self._amp.load_state_dict() diff --git a/archai/common/common.py b/archai/common/common.py index f5ba331d..8d70d9fc 100644 --- a/archai/common/common.py +++ b/archai/common/common.py @@ -30,42 +30,11 @@ class SummaryWriterDummy: SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter] logger = OrderedDictLogger(None, None) _tb_writer: SummaryWriterAny = None -_apex_utils = ApexUtils() _atexit_reg = False # is hook for atexit registered? def get_conf()->Config: return Config.get() -def get_device(): - global _apex_utils - return _apex_utils.device - -def get_apex_utils()->ApexUtils: - global _apex_utils - assert _apex_utils - return _apex_utils - -def is_dist()->bool: - global _apex_utils - return _apex_utils.is_dist() - -def reduce_min(val): - global _apex_utils - return _apex_utils.reduce(val, op='min') -def reduce_max(val): - global _apex_utils - return _apex_utils.reduce(val, op='max') -def reduce_sum(val): - global _apex_utils - return _apex_utils.reduce(val, op='sum') -def reduce_mean(val): - global _apex_utils - return _apex_utils.reduce(val, op='mean') - -def is_dist()->bool: - global _apex_utils - return _apex_utils.is_dist() - def get_conf_common()->Config: return get_conf()['common'] @@ -138,14 +107,18 @@ def common_init(config_filepath: Optional[str]=None, logger.info({'expdir': expdir, 'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir}) + # create a[ex to know distributed processing paramters + conf_apex = get_conf_common()['apex'] + apex = ApexUtils(conf_apex, None) + # create global logger - _setup_logger() + _setup_logger(apex) # create info file for current system _create_sysinfo(conf) # setup tensorboard global _tb_writer - _tb_writer = _create_tb_writer(get_apex_utils().is_master()) + _tb_writer = _create_tb_writer(apex.is_master()) # create hooks to execute code when script exits global _atexit_reg @@ -216,17 +189,18 @@ def _setup_dirs()->Optional[str]: os.environ['distdir'] = conf_common['distdir'] = distdir -def _setup_logger(): +def _setup_logger(apex:ApexUtils): global logger logger.close() # close any previous instances conf_common = get_conf_common() expdir = conf_common['expdir'] distdir = conf_common['distdir'] - global_rank = get_apex_utils().global_rank + + global_rank = apex.global_rank # file where logger would log messages - if get_apex_utils().is_master(): + if apex.is_master(): sys_log_filepath = utils.full_path(os.path.join(expdir, 'logs.log')) logs_yaml_filepath = utils.full_path(os.path.join(expdir, 'logs.yaml')) experiment_name = get_experiment_name() diff --git a/archai/common/metrics.py b/archai/common/metrics.py index 712e1988..75bca302 100644 --- a/archai/common/metrics.py +++ b/archai/common/metrics.py @@ -11,7 +11,8 @@ from torch import Tensor import yaml from . import utils, ml_utils -from .common import logger, get_tb_writer, is_dist, reduce_mean, reduce_sum, reduce_min, reduce_max +from .common import logger, get_tb_writer +from .apex_utils import ApexUtils class Metrics: @@ -29,7 +30,7 @@ class Metrics: best we have seen for each epoch. """ - def __init__(self, title:str, logger_freq:int=50) -> None: + def __init__(self, title:str, apex:Optional[ApexUtils], logger_freq:int=50) -> None: """Create the metrics object to maintain epoch stats Arguments: @@ -39,6 +40,7 @@ class Metrics: """ self.logger_freq = logger_freq self.title = title + self._apex = apex self._reset_run() def _reset_run(self)->None: @@ -59,27 +61,27 @@ class Metrics: logger.info({'epoch':self.run_metrics.epoch_time_avg(), 'step': self.run_metrics.step_time_avg(), 'run': self.run_metrics.duration()}) - if is_dist(): - logger.info({'dist_epoch_sum': reduce_sum(self.run_metrics.epoch_time_avg()), - 'dist_step': reduce_mean(self.run_metrics.step_time_avg()), - 'dist_run_sum': reduce_sum(self.run_metrics.duration())}) + if self.is_dist(): + logger.info({'dist_epoch_sum': self.reduce_sum(self.run_metrics.epoch_time_avg()), + 'dist_step': self.reduce_mean(self.run_metrics.step_time_avg()), + 'dist_run_sum': self.reduce_sum(self.run_metrics.duration())}) best_train, best_val = self.run_metrics.best_epoch() with logger.pushd('best_train'): logger.info({'epoch': best_train.index, 'top1': best_train.top1.avg}) - if is_dist(): - logger.info({'dist_epoch': reduce_mean(best_train.index), - 'dist_top1': reduce_mean(best_train.top1.avg)}) + if self.is_dist(): + logger.info({'dist_epoch': self.reduce_mean(best_train.index), + 'dist_top1': self.reduce_mean(best_train.top1.avg)}) if best_val: with logger.pushd('best_val'): logger.info({'epoch': best_val.index, 'top1': best_val.val_metrics.top1.avg}) - if is_dist(): - logger.info({'dist_epoch': reduce_mean(best_val.index), - 'dist_top1': reduce_mean(best_val.val_metrics.top1.avg)}) + if self.is_dist(): + logger.info({'dist_epoch': self.reduce_mean(best_val.index), + 'dist_top1': self.reduce_mean(best_val.val_metrics.top1.avg)}) def pre_step(self, x: Tensor, y: Tensor): self.run_metrics.cur_epoch().pre_step() @@ -102,11 +104,11 @@ class Metrics: 'loss': epoch.loss.avg, 'step_time': epoch.step_time.last}) - if is_dist(): - logger.info({'dist_top1': reduce_mean(epoch.top1.avg), - 'dist_top5': reduce_mean(epoch.top5.avg), - 'dist_loss': reduce_mean(epoch.loss.avg), - 'dist_step_time': reduce_mean(epoch.step_time.last)}) + if self.is_dist(): + logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg), + 'dist_top5': self.reduce_mean(epoch.top5.avg), + 'dist_loss': self.reduce_mean(epoch.loss.avg), + 'dist_step_time': self.reduce_mean(epoch.step_time.last)}) # NOTE: Tensorboard step-level logging is removed as it becomes exponentially expensive on Azure blobs @@ -143,24 +145,24 @@ class Metrics: 'duration': epoch.duration(), 'step_time': epoch.step_time.avg, 'end_lr': lr}) - if is_dist(): - logger.info({'dist_top1': reduce_mean(epoch.top1.avg), - 'dist_top5': reduce_mean(epoch.top5.avg), - 'dist_loss': reduce_mean(epoch.loss.avg), - 'dist_duration': reduce_mean(epoch.duration()), - 'dist_step_time': reduce_mean(epoch.step_time.avg), - 'dist_end_lr': reduce_mean(lr)}) + if self.is_dist(): + logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg), + 'dist_top5': self.reduce_mean(epoch.top5.avg), + 'dist_loss': self.reduce_mean(epoch.loss.avg), + 'dist_duration': self.reduce_mean(epoch.duration()), + 'dist_step_time': self.reduce_mean(epoch.step_time.avg), + 'dist_end_lr': self.reduce_mean(lr)}) if test_epoch: with logger.pushd('val'): logger.info({'top1': test_epoch.top1.avg, 'top5': test_epoch.top5.avg, 'loss': test_epoch.loss.avg, 'duration': epoch.duration()}) - if is_dist(): - logger.info({'dist_top1': reduce_mean(test_epoch.top1.avg), - 'dist_top5': reduce_mean(test_epoch.top5.avg), - 'dist_loss': reduce_mean(test_epoch.loss.avg), - 'dist_duration': reduce_mean(test_epoch.duration())}) + if self.is_dist(): + logger.info({'dist_top1': self.reduce_mean(test_epoch.top1.avg), + 'dist_top5': self.reduce_mean(test_epoch.top5.avg), + 'dist_loss': self.reduce_mean(test_epoch.loss.avg), + 'dist_duration': self.reduce_mean(test_epoch.duration())}) # writer = get_tb_writer() # writer.add_scalar(f'{self._tb_path}/train_epochs/loss', @@ -181,9 +183,14 @@ class Metrics: return utils.state_dict(self) def load_state_dict(self, state_dict:dict)->None: - # simply convert current object to dictionary utils.load_state_dict(self, state_dict) + def __getstate__(self): + state = self.__dict__.copy() + del state['_apex'] # cannot serialize this + return state + # no need to define __setstate__ because _apex should be set from constructor + def save(self, filepath:str)->Optional[str]: if filepath: filepath = utils.full_path(filepath) @@ -197,6 +204,27 @@ class Metrics: def cur_epoch(self)->'EpochMetrics': return self.run_metrics.cur_epoch() + def reduce_min(self, val): + if not self._apex: + return val + return self._apex.reduce(val, op='min') + def reduce_max(self, val): + if not self._apex: + return val + return self._apex.reduce(val, op='max') + def reduce_sum(self, val): + if not self._apex: + return val + return self._apex.reduce(val, op='sum') + def reduce_mean(self, val): + if not self._apex: + return val + return self._apex.reduce(val, op='mean') + def is_dist(self)->bool: + if not self._apex: + return False + return self._apex.is_dist() + class Accumulator: # TODO: replace this with Metrics class diff --git a/archai/common/tester.py b/archai/common/tester.py index afd7b3be..4afb93b8 100644 --- a/archai/common/tester.py +++ b/archai/common/tester.py @@ -9,18 +9,18 @@ from overrides import EnforceOverrides from .metrics import Metrics from .config import Config from . import utils, ml_utils -from .common import logger, get_device -from archai.common.common import get_apex_utils +from .common import logger +from archai.common.apex_utils import ApexUtils class Tester(EnforceOverrides): - def __init__(self, conf_eval:Config, model:nn.Module)->None: - # TODO: currently we expect that given model and dataloader will already be distributed + def __init__(self, conf_eval:Config, model:nn.Module, apex:ApexUtils)->None: self._title = conf_eval['title'] self._logger_freq = conf_eval['logger_freq'] conf_lossfn = conf_eval['lossfn'] + self._apex = apex self.model = model - self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(get_device()) + self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(apex.device) self._metrics = None def test(self, test_dl: DataLoader)->Metrics: @@ -43,7 +43,7 @@ class Tester(EnforceOverrides): with torch.no_grad(), logger.pushd('steps'): for step, (x, y) in enumerate(test_dl): - x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True) + x, y = x.to(self._apex.device, non_blocking=True), y.to(self._apex.device, non_blocking=True) assert not self.model.training # derived class might alter the mode logger.pushd(step) @@ -57,7 +57,7 @@ class Tester(EnforceOverrides): self._post_step(x, y, logits, loss, steps, self._metrics) # TODO: we possibly need to sync so all replicas are upto date - get_apex_utils().sync_devices() + self._apex.sync_devices() logger.popd() self._metrics.post_epoch(None) @@ -87,5 +87,5 @@ class Tester(EnforceOverrides): metrics.post_step(x, y, logits, loss, steps) def _create_metrics(self)->Metrics: - return Metrics(self._title, logger_freq=self._logger_freq) + return Metrics(self._title, self._apex, logger_freq=self._logger_freq) diff --git a/archai/common/trainer.py b/archai/common/trainer.py index 10a52d7f..98ea4091 100644 --- a/archai/common/trainer.py +++ b/archai/common/trainer.py @@ -11,8 +11,9 @@ from .metrics import Metrics from .tester import Tester from .config import Config from . import utils, ml_utils -from ..common.common import logger, get_device, get_apex_utils +from ..common.common import logger from ..common.checkpoint import CheckPoint +from ..common.apex_utils import ApexUtils class Trainer(EnforceOverrides): @@ -33,13 +34,13 @@ class Trainer(EnforceOverrides): self._validation_freq = 0 if conf_validation is None else conf_validation['freq'] # endregion - get_apex_utils().reset(logger, conf_apex) + self._apex = ApexUtils(conf_apex, logger) self._checkpoint = checkpoint self.model = model self._lossfn = ml_utils.get_lossfn(conf_lossfn) - self._tester = Tester(conf_validation, model) \ + self._tester = Tester(conf_validation, model, self._apex) \ if conf_validation else None self._metrics:Optional[Metrics] = None @@ -52,7 +53,7 @@ class Trainer(EnforceOverrides): def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics: logger.pushd(self._title) - self._metrics = Metrics(self._title, logger_freq=self._logger_freq) + self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq) # optimizers, schedulers needs to be recreated for each fit call # as they have state specific to each run @@ -60,10 +61,10 @@ class Trainer(EnforceOverrides): # create scheduler for optim before applying amp self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl)) # before checkpoint restore, convert to amp - self.model, self._optim = get_apex_utils().to_amp(self.model, optim, + self.model, self._optim = self._apex.to_amp(self.model, optim, batch_size=train_dl.batch_size) - self._lossfn = self._lossfn.to(get_device()) + self._lossfn = self._lossfn.to(self.get_device()) self.pre_fit(train_dl, val_dl) @@ -158,7 +159,7 @@ class Trainer(EnforceOverrides): self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr']) # checkpoint if enabled with given freq or if this is the last epoch - if self._checkpoint is not None and get_apex_utils().is_master() and \ + if self._checkpoint is not None and self._apex.is_master() and \ self._checkpoint.freq > 0 and (self._metrics.epochs() % self._checkpoint.freq == 0 or \ self._metrics.epochs() >= self._epochs): self._checkpoint.new() @@ -173,6 +174,9 @@ class Trainer(EnforceOverrides): self._metrics.post_step(x, y, logits, loss, steps) ######################### hooks ######################### + def get_device(self): + return self._apex.device + def restore_checkpoint(self)->None: state = self._checkpoint['trainer'] last_epoch = state['last_epoch'] @@ -180,7 +184,7 @@ class Trainer(EnforceOverrides): self._metrics.load_state_dict(state['metrics']) assert self._metrics.epochs() == last_epoch+1 - get_apex_utils().load_state_dict(state['amp']) + self._apex.load_state_dict(state['amp']) self.model.load_state_dict(state['model']) self._optim.load_state_dict(state['optim']) if self._sched: @@ -198,7 +202,7 @@ class Trainer(EnforceOverrides): 'model': self.model.state_dict(), 'optim': self._optim.state_dict(), 'sched': self._sched.state_dict() if self._sched else None, - 'amp': get_apex_utils().state_dict() + 'amp': self._apex.state_dict() } self._checkpoint['trainer'] = state @@ -208,7 +212,7 @@ class Trainer(EnforceOverrides): logger.pushd('steps') for step, (x, y) in enumerate(train_dl): - x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True) + x, y = x.to(self.get_device(), non_blocking=True), y.to(self.get_device(), non_blocking=True) logger.pushd(step) assert self.model.training # derived class might alter the mode @@ -226,15 +230,15 @@ class Trainer(EnforceOverrides): loss = self.compute_loss(self._lossfn, x, y, logits, self._aux_weight, aux_logits) - get_apex_utils().backward(loss, self._optim) + self._apex.backward(loss, self._optim) # TODO: original darts clips alphas as well but pt.darts doesn't - get_apex_utils().clip_grad(self._grad_clip, self.model, self._optim) + self._apex.clip_grad(self._grad_clip, self.model, self._optim) self._optim.step() # TODO: we possibly need to sync so all replicas are upto date - get_apex_utils().sync_devices() + self._apex.sync_devices() if self._sched and not self._sched_on_epoch: self._sched.step() diff --git a/archai/common/utils.py b/archai/common/utils.py index 5582255d..58f8694e 100644 --- a/archai/common/utils.py +++ b/archai/common/utils.py @@ -55,7 +55,7 @@ def deep_update(d:MutableMapping, u:Mapping, map_type:Type[MutableMapping]=dict) return d def state_dict(val)->Mapping: - assert hasattr(val, '__dict__'), 'val must be object with __dict__' + assert hasattr(val, '__dict__'), 'val must be object with __dict__ otherwise it cannot be loaded back in load_state_dict' # Can't do below because val has state_dict() which calls utils.state_dict # if has_method(val, 'state_dict'): @@ -76,8 +76,8 @@ def load_state_dict(val:Any, state_dict:Mapping)->None: assert s is not None, 'state_dict must contain yaml key' obj = yaml.load(s, Loader=yaml.Loader) - for k in val.__dict__.keys(): - setattr(val, k, getattr(obj, k)) + for k, v in obj.__dict__.items(): + setattr(val, k, v) def deep_comp(o1:Any, o2:Any)->bool: # NOTE: dict don't have __dict__ diff --git a/confs/algos/darts.yaml b/confs/algos/darts.yaml index 79e8491c..93fbbecd 100644 --- a/confs/algos/darts.yaml +++ b/confs/algos/darts.yaml @@ -19,8 +19,8 @@ common: redis: null apex: # this is overriden in search and eval individually enabled: False # global switch to disable anything apex - distributed_enabled: False # enable/disable distributed mode - mixed_prec_enabled: False # switch to disable amp mixed precision + distributed_enabled: True # enable/disable distributed mode + mixed_prec_enabled: True # switch to disable amp mixed precision gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs opt_level: 'O2' # optimization level for mixed precision bn_fp32: True # keep BN in fp32 @@ -88,6 +88,7 @@ nas: trainer: apex: _copy: 'common/apex' + enabled: True aux_weight: '_copy: nas/eval/model_desc/aux_weight' drop_path_prob: 0.2 # probability that given edge will be dropped grad_clip: 5.0 # grads above this value is clipped diff --git a/scripts/misc/yaml_playground.py b/scripts/misc/yaml_playground.py index 15ce4862..eed50f37 100644 --- a/scripts/misc/yaml_playground.py +++ b/scripts/misc/yaml_playground.py @@ -1,10 +1,27 @@ +from collections import UserDict import yaml +from typing import Iterator -y = """ -a: .NaN +class A(object): + def __init__(self): + self.hidden = 42 + self.visible = 5 -""" + def __getstate__(self): + state = self.__dict__.copy() + del state['hidden'] # cannot serialize this + return state -d=yaml.load(y, Loader=yaml.Loader) +a = A() +d = yaml.dump(a) print(d) -print(type( d['a'])) + + +# y = """ +# a: .NaN + +# """ + +# d=yaml.load(y, Loader=yaml.Loader) +# print(d) +# print(type( d['a'])) diff --git a/scripts/perf/resnet_test.py b/scripts/perf/resnet_test.py index 34b3d1a2..a9a36791 100644 --- a/scripts/perf/resnet_test.py +++ b/scripts/perf/resnet_test.py @@ -5,7 +5,7 @@ from archai import cifar10_models from archai.common.trainer import Trainer from archai.common.config import Config -from archai.common.common import logger, common_init, get_device +from archai.common.common import logger, common_init from archai.datasets import data def train_test(conf_eval:Config): @@ -24,7 +24,7 @@ def train_test(conf_eval:Config): conf_trainer['aux_weight'] = 0.0 Net = cifar10_models.resnet34 - model = Net().to(get_device()) + model = Net().to(torch.device('cuda', 0)) # get data train_dl, _, test_dl = data.get_data(conf_loader)