This commit is contained in:
Shital Shah 2020-04-23 13:06:09 -07:00
Родитель 0235072c8e
Коммит d861b5f2a6
12 изменённых файлов: 198 добавлений и 174 удалений

Просмотреть файл

@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer
from archai.common import utils, ml_utils from archai.common import utils, ml_utils
from archai.nas.model import Model from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint from archai.common.checkpoint import CheckPoint
from archai.common.common import logger, get_device from archai.common.common import logger
from .bilevel_optimizer import BilevelOptimizer from .bilevel_optimizer import BilevelOptimizer
class BilevelArchTrainer(ArchTrainer): class BilevelArchTrainer(ArchTrainer):
@ -36,10 +36,11 @@ class BilevelArchTrainer(ArchTrainer):
assert val_dl is not None assert val_dl is not None
w_momentum = self._conf_w_optim['momentum'] w_momentum = self._conf_w_optim['momentum']
w_decay = self._conf_w_optim['decay'] w_decay = self._conf_w_optim['decay']
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device()) lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum, self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum,
w_decay, self.model, lossfn) w_decay, self.model, lossfn,
self.get_device())
@overrides @overrides
def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None: def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
@ -71,7 +72,7 @@ class BilevelArchTrainer(ArchTrainer):
self._valid_iter = iter(self._val_dl) self._valid_iter = iter(self._val_dl)
x_val, y_val = next(self._valid_iter) x_val, y_val = next(self._valid_iter)
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True) x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True)
# update alphas # update alphas
self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer()) self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer())

Просмотреть файл

@ -9,11 +9,11 @@ from torch.optim.optimizer import Optimizer
from archai.common.config import Config from archai.common.config import Config
from archai.common import utils, ml_utils from archai.common import utils, ml_utils
from archai.nas.model import Model from archai.nas.model import Model
from archai.common.common import logger, get_device from archai.common.common import logger
class BilevelOptimizer: class BilevelOptimizer:
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float, def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
model: Model, lossfn: _Loss) -> None: model: Model, lossfn: _Loss, device) -> None:
self._w_momentum = w_momentum # momentum for w self._w_momentum = w_momentum # momentum for w
self._w_weight_decay = w_decay # weight decay for w self._w_weight_decay = w_decay # weight decay for w
self._lossfn = lossfn self._lossfn = lossfn
@ -22,7 +22,7 @@ class BilevelOptimizer:
# create a copy of model which we will use # create a copy of model which we will use
# to compute grads for alphas without disturbing # to compute grads for alphas without disturbing
# original weights # original weights
self._vmodel = copy.deepcopy(model).to(get_device()) self._vmodel = copy.deepcopy(model).to(device)
# this is the optimizer to optimize alphas parameter # this is the optimizer to optimize alphas parameter
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas()) self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())

Просмотреть файл

@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer
from archai.common import utils, ml_utils from archai.common import utils, ml_utils
from archai.nas.model import Model from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint from archai.common.checkpoint import CheckPoint
from archai.common.common import logger, get_device from archai.common.common import logger
class XnasArchTrainer(ArchTrainer): class XnasArchTrainer(ArchTrainer):
@ -39,7 +39,7 @@ class XnasArchTrainer(ArchTrainer):
# optimizers, schedulers needs to be recreated for each fit call # optimizers, schedulers needs to be recreated for each fit call
# as they have state # as they have state
assert val_dl is not None assert val_dl is not None
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device()) lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn) self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn)
@ -74,7 +74,7 @@ class XnasArchTrainer(ArchTrainer):
self._valid_iter = iter(self._val_dl) self._valid_iter = iter(self._val_dl)
x_val, y_val = next(self._valid_iter) x_val, y_val = next(self._valid_iter)
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True) x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True)
# update alphas # update alphas
self._xnas_optim.step(x, y, x_val, y_val) self._xnas_optim.step(x, y, x_val, y_val)

Просмотреть файл

@ -15,23 +15,7 @@ from archai.common import ml_utils, utils
from archai.common.ordereddict_logger import OrderedDictLogger from archai.common.ordereddict_logger import OrderedDictLogger
class ApexUtils: class ApexUtils:
def __init__(self)->None: def __init__(self, apex_config:Config, logger:Optional[OrderedDictLogger])->None:
self._amp = self._ddp = None
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
self.gpu_ids = [] # use all gpus
self._mixed_prec_enabled = False
self._distributed_enabled = False
self._set_ranks()
def reset(self, logger:OrderedDictLogger, apex_config:Config)->None:
# reset allows to configure differently for search or eval modes
# to avoid circular references= with common, logger is passed from outside
self.logger = logger
# region conf vars # region conf vars
self._enabled = apex_config['enabled'] # global switch to disable anything apex self._enabled = apex_config['enabled'] # global switch to disable anything apex
self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode
@ -47,42 +31,53 @@ class ApexUtils:
conf_gpu_ids = apex_config['gpus'] conf_gpu_ids = apex_config['gpus']
# endregion # endregion
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i] # to avoid circular references= with common, logger is passed from outside
self._amp, self._ddp = None, None self.logger = logger
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 # which GPU to use, we will use only 1 GPU
#logger.info({'apex_config': apex_config.to_dict()}) self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
logger.info({'torch.distributed.is_available': dist.is_available()}) 'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
# defaults for non-distributed mode
self._amp, self._ddp = None, None
self.world_size = 1
self.local_rank = self.global_rank = 0
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
# which GPU to use, we will use only 1 GPU per process to avoid complications with apex
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0
#_log_info({'apex_config': apex_config.to_dict()})
self._log_info({'torch.distributed.is_available': dist.is_available()})
if dist.is_available(): if dist.is_available():
logger.info({'gloo_available': dist.is_gloo_available(), self._log_info({'gloo_available': dist.is_gloo_available(),
'mpi_available': dist.is_mpi_available(), 'mpi_available': dist.is_mpi_available(),
'nccl_available': dist.is_nccl_available()}) 'nccl_available': dist.is_nccl_available()})
if self._enabled: if self.is_mixed():
if self._mixed_prec_enabled: # init enable mixed precision
# init enable mixed precision assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
assert cudnn.enabled, "Amp requires cudnn backend to be enabled." from apex import amp
from apex import amp self._amp = amp
self._amp = amp
# enable distributed processing # enable distributed processing
if self._distributed_enabled: if self.is_dist():
from apex import parallel from apex import parallel
self._ddp = parallel self._ddp = parallel
assert dist.is_available() # distributed module is available assert dist.is_available() # distributed module is available
assert dist.is_nccl_available() assert dist.is_nccl_available()
if not dist.is_initialized(): if not dist.is_initialized():
dist.init_process_group(backend='nccl', init_method='env://') dist.init_process_group(backend='nccl', init_method='env://')
assert dist.is_initialized() assert dist.is_initialized()
self._set_ranks() self._set_ranks()
assert dist.get_world_size() == self.world_size assert dist.get_world_size() == self.world_size
assert dist.get_rank() == self.global_rank assert dist.get_rank() == self.global_rank
else: else:
assert self.world_size == 1 assert self.world_size == 1
assert self.local_rank == 0 assert self.local_rank == 0
assert self.global_rank == 0 assert self.global_rank == 0
assert self.world_size >= 1 assert self.world_size >= 1
assert not self._min_world_size or self.world_size >= self._min_world_size assert not self._min_world_size or self.world_size >= self._min_world_size
@ -94,9 +89,9 @@ class ApexUtils:
self.device = torch.device('cuda', self._gpu) self.device = torch.device('cuda', self._gpu)
self._setup_gpus(seed, detect_anomaly) self._setup_gpus(seed, detect_anomaly)
logger.info({'amp_available': self._amp is not None, self._log_info({'amp_available': self._amp is not None,
'distributed_available': self._ddp is not None}) 'distributed_available': self._ddp is not None})
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False, self._log_info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
'world_size': self.world_size, 'world_size': self.world_size,
'gpu': self._gpu, 'gpu_ids':self.gpu_ids, 'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
'local_rank': self.local_rank}) 'local_rank': self.local_rank})
@ -106,10 +101,10 @@ class ApexUtils:
utils.setup_cuda(seed, self.local_rank) utils.setup_cuda(seed, self.local_rank)
torch.autograd.set_detect_anomaly(detect_anomaly) torch.autograd.set_detect_anomaly(detect_anomaly)
self.logger.info({'set_detect_anomaly': detect_anomaly, self._log_info({'set_detect_anomaly': detect_anomaly,
'is_anomaly_enabled': torch.is_anomaly_enabled()}) 'is_anomaly_enabled': torch.is_anomaly_enabled()})
self.logger.info({'gpu_names': utils.cuda_device_names(), self._log_info({'gpu_names': utils.cuda_device_names(),
'gpu_count': torch.cuda.device_count(), 'gpu_count': torch.cuda.device_count(),
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES'] 'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet', if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
@ -118,8 +113,8 @@ class ApexUtils:
'cudnn.deterministic': cudnn.deterministic, 'cudnn.deterministic': cudnn.deterministic,
'cudnn.version': cudnn.version() 'cudnn.version': cudnn.version()
}) })
self.logger.info({'memory': str(psutil.virtual_memory())}) self._log_info({'memory': str(psutil.virtual_memory())})
self.logger.info({'CPUs': str(psutil.cpu_count())}) self._log_info({'CPUs': str(psutil.cpu_count())})
# gpu_usage = os.popen( # gpu_usage = os.popen(
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader' # 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
@ -127,7 +122,7 @@ class ApexUtils:
# for i, line in enumerate(gpu_usage): # for i, line in enumerate(gpu_usage):
# vals = line.split(',') # vals = line.split(',')
# if len(vals) == 2: # if len(vals) == 2:
# logger.info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1])) # _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
def _set_ranks(self)->None: def _set_ranks(self)->None:
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
@ -153,18 +148,22 @@ class ApexUtils:
self._gpu = self.gpu_ids[self.local_rank] self._gpu = self.gpu_ids[self.local_rank]
def is_mixed(self)->bool: def is_mixed(self)->bool:
return self._mixed_prec_enabled return self._enabled and self._mixed_prec_enabled
def is_dist(self)->bool: def is_dist(self)->bool:
return self._distributed_enabled return self._enabled and self._distributed_enabled
def is_master(self)->bool: def is_master(self)->bool:
return self.global_rank == 0 return self.global_rank == 0
def _log_info(self, d:dict)->None:
if self.logger is not None:
self.logger.info(d)
def sync_devices(self)->None: def sync_devices(self)->None:
if self._distributed_enabled: if self.is_dist():
torch.cuda.synchronize(self.device) torch.cuda.synchronize(self.device)
def reduce(self, val, op='mean'): def reduce(self, val, op='mean'):
if self._distributed_enabled: if self.is_dist():
if not isinstance(val, Tensor): if not isinstance(val, Tensor):
rt = torch.tensor(val).to(self.device) rt = torch.tensor(val).to(self.device)
converted = True converted = True
@ -184,7 +183,7 @@ class ApexUtils:
return val return val
def backward(self, loss:torch.Tensor, optim:Optimizer)->None: def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
if self._mixed_prec_enabled: if self.is_mixed():
with self._amp.scale_loss(loss, optim) as scaled_loss: with self._amp.scale_loss(loss, optim) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
@ -193,26 +192,26 @@ class ApexUtils:
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\ def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
->Tuple[nn.Module, Optimizer]: ->Tuple[nn.Module, Optimizer]:
# conver BNs to sync BNs in distributed mode # conver BNs to sync BNs in distributed mode
if self._distributed_enabled and self._sync_bn: if self.is_dist() and self._sync_bn:
model = self._ddp.convert_syncbn_model(model) model = self._ddp.convert_syncbn_model(model)
self.logger.info({'BNs_converted': True}) self._log_info({'BNs_converted': True})
model = model.to(self.device) model = model.to(self.device)
if self._mixed_prec_enabled: if self.is_mixed():
# scale LR # scale LR
if self._scale_lr: if self._scale_lr:
lr = ml_utils.get_optim_lr(optim) lr = ml_utils.get_optim_lr(optim)
scaled_lr = lr * self.world_size / float(batch_size) scaled_lr = lr * self.world_size / float(batch_size)
ml_utils.set_optim_lr(optim, scaled_lr) ml_utils.set_optim_lr(optim, scaled_lr)
self.logger.info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr}) self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
model, optim = self._amp.initialize( model, optim = self._amp.initialize(
model, optim, opt_level=self._opt_level, model, optim, opt_level=self._opt_level,
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
) )
if self._distributed_enabled: if self.is_dist():
# By default, apex.parallel.DistributedDataParallel overlaps communication with # By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass. # computation in the backward pass.
# delay_allreduce delays all communication to the end of the backward pass. # delay_allreduce delays all communication to the end of the backward pass.
@ -222,19 +221,19 @@ class ApexUtils:
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None: def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
if clip > 0.0: if clip > 0.0:
if self._mixed_prec_enabled: if self.is_mixed():
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip) nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
else: else:
nn.utils.clip_grad_norm_(model.parameters(), clip) nn.utils.clip_grad_norm_(model.parameters(), clip)
def state_dict(self): def state_dict(self):
if self._mixed_prec_enabled: if self.is_mixed():
return self._amp.state_dict() return self._amp.state_dict()
else: else:
return None return None
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
if self._mixed_prec_enabled: if self.is_mixed():
self._amp.load_state_dict() self._amp.load_state_dict()

Просмотреть файл

@ -30,42 +30,11 @@ class SummaryWriterDummy:
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter] SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter]
logger = OrderedDictLogger(None, None) logger = OrderedDictLogger(None, None)
_tb_writer: SummaryWriterAny = None _tb_writer: SummaryWriterAny = None
_apex_utils = ApexUtils()
_atexit_reg = False # is hook for atexit registered? _atexit_reg = False # is hook for atexit registered?
def get_conf()->Config: def get_conf()->Config:
return Config.get() return Config.get()
def get_device():
global _apex_utils
return _apex_utils.device
def get_apex_utils()->ApexUtils:
global _apex_utils
assert _apex_utils
return _apex_utils
def is_dist()->bool:
global _apex_utils
return _apex_utils.is_dist()
def reduce_min(val):
global _apex_utils
return _apex_utils.reduce(val, op='min')
def reduce_max(val):
global _apex_utils
return _apex_utils.reduce(val, op='max')
def reduce_sum(val):
global _apex_utils
return _apex_utils.reduce(val, op='sum')
def reduce_mean(val):
global _apex_utils
return _apex_utils.reduce(val, op='mean')
def is_dist()->bool:
global _apex_utils
return _apex_utils.is_dist()
def get_conf_common()->Config: def get_conf_common()->Config:
return get_conf()['common'] return get_conf()['common']
@ -138,14 +107,18 @@ def common_init(config_filepath: Optional[str]=None,
logger.info({'expdir': expdir, logger.info({'expdir': expdir,
'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir}) 'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})
# create a[ex to know distributed processing paramters
conf_apex = get_conf_common()['apex']
apex = ApexUtils(conf_apex, None)
# create global logger # create global logger
_setup_logger() _setup_logger(apex)
# create info file for current system # create info file for current system
_create_sysinfo(conf) _create_sysinfo(conf)
# setup tensorboard # setup tensorboard
global _tb_writer global _tb_writer
_tb_writer = _create_tb_writer(get_apex_utils().is_master()) _tb_writer = _create_tb_writer(apex.is_master())
# create hooks to execute code when script exits # create hooks to execute code when script exits
global _atexit_reg global _atexit_reg
@ -216,17 +189,18 @@ def _setup_dirs()->Optional[str]:
os.environ['distdir'] = conf_common['distdir'] = distdir os.environ['distdir'] = conf_common['distdir'] = distdir
def _setup_logger(): def _setup_logger(apex:ApexUtils):
global logger global logger
logger.close() # close any previous instances logger.close() # close any previous instances
conf_common = get_conf_common() conf_common = get_conf_common()
expdir = conf_common['expdir'] expdir = conf_common['expdir']
distdir = conf_common['distdir'] distdir = conf_common['distdir']
global_rank = get_apex_utils().global_rank
global_rank = apex.global_rank
# file where logger would log messages # file where logger would log messages
if get_apex_utils().is_master(): if apex.is_master():
sys_log_filepath = utils.full_path(os.path.join(expdir, 'logs.log')) sys_log_filepath = utils.full_path(os.path.join(expdir, 'logs.log'))
logs_yaml_filepath = utils.full_path(os.path.join(expdir, 'logs.yaml')) logs_yaml_filepath = utils.full_path(os.path.join(expdir, 'logs.yaml'))
experiment_name = get_experiment_name() experiment_name = get_experiment_name()

Просмотреть файл

@ -11,7 +11,8 @@ from torch import Tensor
import yaml import yaml
from . import utils, ml_utils from . import utils, ml_utils
from .common import logger, get_tb_writer, is_dist, reduce_mean, reduce_sum, reduce_min, reduce_max from .common import logger, get_tb_writer
from .apex_utils import ApexUtils
class Metrics: class Metrics:
@ -29,7 +30,7 @@ class Metrics:
best we have seen for each epoch. best we have seen for each epoch.
""" """
def __init__(self, title:str, logger_freq:int=50) -> None: def __init__(self, title:str, apex:Optional[ApexUtils], logger_freq:int=50) -> None:
"""Create the metrics object to maintain epoch stats """Create the metrics object to maintain epoch stats
Arguments: Arguments:
@ -39,6 +40,7 @@ class Metrics:
""" """
self.logger_freq = logger_freq self.logger_freq = logger_freq
self.title = title self.title = title
self._apex = apex
self._reset_run() self._reset_run()
def _reset_run(self)->None: def _reset_run(self)->None:
@ -59,27 +61,27 @@ class Metrics:
logger.info({'epoch':self.run_metrics.epoch_time_avg(), logger.info({'epoch':self.run_metrics.epoch_time_avg(),
'step': self.run_metrics.step_time_avg(), 'step': self.run_metrics.step_time_avg(),
'run': self.run_metrics.duration()}) 'run': self.run_metrics.duration()})
if is_dist(): if self.is_dist():
logger.info({'dist_epoch_sum': reduce_sum(self.run_metrics.epoch_time_avg()), logger.info({'dist_epoch_sum': self.reduce_sum(self.run_metrics.epoch_time_avg()),
'dist_step': reduce_mean(self.run_metrics.step_time_avg()), 'dist_step': self.reduce_mean(self.run_metrics.step_time_avg()),
'dist_run_sum': reduce_sum(self.run_metrics.duration())}) 'dist_run_sum': self.reduce_sum(self.run_metrics.duration())})
best_train, best_val = self.run_metrics.best_epoch() best_train, best_val = self.run_metrics.best_epoch()
with logger.pushd('best_train'): with logger.pushd('best_train'):
logger.info({'epoch': best_train.index, logger.info({'epoch': best_train.index,
'top1': best_train.top1.avg}) 'top1': best_train.top1.avg})
if is_dist(): if self.is_dist():
logger.info({'dist_epoch': reduce_mean(best_train.index), logger.info({'dist_epoch': self.reduce_mean(best_train.index),
'dist_top1': reduce_mean(best_train.top1.avg)}) 'dist_top1': self.reduce_mean(best_train.top1.avg)})
if best_val: if best_val:
with logger.pushd('best_val'): with logger.pushd('best_val'):
logger.info({'epoch': best_val.index, logger.info({'epoch': best_val.index,
'top1': best_val.val_metrics.top1.avg}) 'top1': best_val.val_metrics.top1.avg})
if is_dist(): if self.is_dist():
logger.info({'dist_epoch': reduce_mean(best_val.index), logger.info({'dist_epoch': self.reduce_mean(best_val.index),
'dist_top1': reduce_mean(best_val.val_metrics.top1.avg)}) 'dist_top1': self.reduce_mean(best_val.val_metrics.top1.avg)})
def pre_step(self, x: Tensor, y: Tensor): def pre_step(self, x: Tensor, y: Tensor):
self.run_metrics.cur_epoch().pre_step() self.run_metrics.cur_epoch().pre_step()
@ -102,11 +104,11 @@ class Metrics:
'loss': epoch.loss.avg, 'loss': epoch.loss.avg,
'step_time': epoch.step_time.last}) 'step_time': epoch.step_time.last})
if is_dist(): if self.is_dist():
logger.info({'dist_top1': reduce_mean(epoch.top1.avg), logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg),
'dist_top5': reduce_mean(epoch.top5.avg), 'dist_top5': self.reduce_mean(epoch.top5.avg),
'dist_loss': reduce_mean(epoch.loss.avg), 'dist_loss': self.reduce_mean(epoch.loss.avg),
'dist_step_time': reduce_mean(epoch.step_time.last)}) 'dist_step_time': self.reduce_mean(epoch.step_time.last)})
# NOTE: Tensorboard step-level logging is removed as it becomes exponentially expensive on Azure blobs # NOTE: Tensorboard step-level logging is removed as it becomes exponentially expensive on Azure blobs
@ -143,24 +145,24 @@ class Metrics:
'duration': epoch.duration(), 'duration': epoch.duration(),
'step_time': epoch.step_time.avg, 'step_time': epoch.step_time.avg,
'end_lr': lr}) 'end_lr': lr})
if is_dist(): if self.is_dist():
logger.info({'dist_top1': reduce_mean(epoch.top1.avg), logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg),
'dist_top5': reduce_mean(epoch.top5.avg), 'dist_top5': self.reduce_mean(epoch.top5.avg),
'dist_loss': reduce_mean(epoch.loss.avg), 'dist_loss': self.reduce_mean(epoch.loss.avg),
'dist_duration': reduce_mean(epoch.duration()), 'dist_duration': self.reduce_mean(epoch.duration()),
'dist_step_time': reduce_mean(epoch.step_time.avg), 'dist_step_time': self.reduce_mean(epoch.step_time.avg),
'dist_end_lr': reduce_mean(lr)}) 'dist_end_lr': self.reduce_mean(lr)})
if test_epoch: if test_epoch:
with logger.pushd('val'): with logger.pushd('val'):
logger.info({'top1': test_epoch.top1.avg, logger.info({'top1': test_epoch.top1.avg,
'top5': test_epoch.top5.avg, 'top5': test_epoch.top5.avg,
'loss': test_epoch.loss.avg, 'loss': test_epoch.loss.avg,
'duration': epoch.duration()}) 'duration': epoch.duration()})
if is_dist(): if self.is_dist():
logger.info({'dist_top1': reduce_mean(test_epoch.top1.avg), logger.info({'dist_top1': self.reduce_mean(test_epoch.top1.avg),
'dist_top5': reduce_mean(test_epoch.top5.avg), 'dist_top5': self.reduce_mean(test_epoch.top5.avg),
'dist_loss': reduce_mean(test_epoch.loss.avg), 'dist_loss': self.reduce_mean(test_epoch.loss.avg),
'dist_duration': reduce_mean(test_epoch.duration())}) 'dist_duration': self.reduce_mean(test_epoch.duration())})
# writer = get_tb_writer() # writer = get_tb_writer()
# writer.add_scalar(f'{self._tb_path}/train_epochs/loss', # writer.add_scalar(f'{self._tb_path}/train_epochs/loss',
@ -181,9 +183,14 @@ class Metrics:
return utils.state_dict(self) return utils.state_dict(self)
def load_state_dict(self, state_dict:dict)->None: def load_state_dict(self, state_dict:dict)->None:
# simply convert current object to dictionary
utils.load_state_dict(self, state_dict) utils.load_state_dict(self, state_dict)
def __getstate__(self):
state = self.__dict__.copy()
del state['_apex'] # cannot serialize this
return state
# no need to define __setstate__ because _apex should be set from constructor
def save(self, filepath:str)->Optional[str]: def save(self, filepath:str)->Optional[str]:
if filepath: if filepath:
filepath = utils.full_path(filepath) filepath = utils.full_path(filepath)
@ -197,6 +204,27 @@ class Metrics:
def cur_epoch(self)->'EpochMetrics': def cur_epoch(self)->'EpochMetrics':
return self.run_metrics.cur_epoch() return self.run_metrics.cur_epoch()
def reduce_min(self, val):
if not self._apex:
return val
return self._apex.reduce(val, op='min')
def reduce_max(self, val):
if not self._apex:
return val
return self._apex.reduce(val, op='max')
def reduce_sum(self, val):
if not self._apex:
return val
return self._apex.reduce(val, op='sum')
def reduce_mean(self, val):
if not self._apex:
return val
return self._apex.reduce(val, op='mean')
def is_dist(self)->bool:
if not self._apex:
return False
return self._apex.is_dist()
class Accumulator: class Accumulator:
# TODO: replace this with Metrics class # TODO: replace this with Metrics class

Просмотреть файл

@ -9,18 +9,18 @@ from overrides import EnforceOverrides
from .metrics import Metrics from .metrics import Metrics
from .config import Config from .config import Config
from . import utils, ml_utils from . import utils, ml_utils
from .common import logger, get_device from .common import logger
from archai.common.common import get_apex_utils from archai.common.apex_utils import ApexUtils
class Tester(EnforceOverrides): class Tester(EnforceOverrides):
def __init__(self, conf_eval:Config, model:nn.Module)->None: def __init__(self, conf_eval:Config, model:nn.Module, apex:ApexUtils)->None:
# TODO: currently we expect that given model and dataloader will already be distributed
self._title = conf_eval['title'] self._title = conf_eval['title']
self._logger_freq = conf_eval['logger_freq'] self._logger_freq = conf_eval['logger_freq']
conf_lossfn = conf_eval['lossfn'] conf_lossfn = conf_eval['lossfn']
self._apex = apex
self.model = model self.model = model
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(get_device()) self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(apex.device)
self._metrics = None self._metrics = None
def test(self, test_dl: DataLoader)->Metrics: def test(self, test_dl: DataLoader)->Metrics:
@ -43,7 +43,7 @@ class Tester(EnforceOverrides):
with torch.no_grad(), logger.pushd('steps'): with torch.no_grad(), logger.pushd('steps'):
for step, (x, y) in enumerate(test_dl): for step, (x, y) in enumerate(test_dl):
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True) x, y = x.to(self._apex.device, non_blocking=True), y.to(self._apex.device, non_blocking=True)
assert not self.model.training # derived class might alter the mode assert not self.model.training # derived class might alter the mode
logger.pushd(step) logger.pushd(step)
@ -57,7 +57,7 @@ class Tester(EnforceOverrides):
self._post_step(x, y, logits, loss, steps, self._metrics) self._post_step(x, y, logits, loss, steps, self._metrics)
# TODO: we possibly need to sync so all replicas are upto date # TODO: we possibly need to sync so all replicas are upto date
get_apex_utils().sync_devices() self._apex.sync_devices()
logger.popd() logger.popd()
self._metrics.post_epoch(None) self._metrics.post_epoch(None)
@ -87,5 +87,5 @@ class Tester(EnforceOverrides):
metrics.post_step(x, y, logits, loss, steps) metrics.post_step(x, y, logits, loss, steps)
def _create_metrics(self)->Metrics: def _create_metrics(self)->Metrics:
return Metrics(self._title, logger_freq=self._logger_freq) return Metrics(self._title, self._apex, logger_freq=self._logger_freq)

Просмотреть файл

@ -11,8 +11,9 @@ from .metrics import Metrics
from .tester import Tester from .tester import Tester
from .config import Config from .config import Config
from . import utils, ml_utils from . import utils, ml_utils
from ..common.common import logger, get_device, get_apex_utils from ..common.common import logger
from ..common.checkpoint import CheckPoint from ..common.checkpoint import CheckPoint
from ..common.apex_utils import ApexUtils
class Trainer(EnforceOverrides): class Trainer(EnforceOverrides):
@ -33,13 +34,13 @@ class Trainer(EnforceOverrides):
self._validation_freq = 0 if conf_validation is None else conf_validation['freq'] self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
# endregion # endregion
get_apex_utils().reset(logger, conf_apex) self._apex = ApexUtils(conf_apex, logger)
self._checkpoint = checkpoint self._checkpoint = checkpoint
self.model = model self.model = model
self._lossfn = ml_utils.get_lossfn(conf_lossfn) self._lossfn = ml_utils.get_lossfn(conf_lossfn)
self._tester = Tester(conf_validation, model) \ self._tester = Tester(conf_validation, model, self._apex) \
if conf_validation else None if conf_validation else None
self._metrics:Optional[Metrics] = None self._metrics:Optional[Metrics] = None
@ -52,7 +53,7 @@ class Trainer(EnforceOverrides):
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics: def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
logger.pushd(self._title) logger.pushd(self._title)
self._metrics = Metrics(self._title, logger_freq=self._logger_freq) self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)
# optimizers, schedulers needs to be recreated for each fit call # optimizers, schedulers needs to be recreated for each fit call
# as they have state specific to each run # as they have state specific to each run
@ -60,10 +61,10 @@ class Trainer(EnforceOverrides):
# create scheduler for optim before applying amp # create scheduler for optim before applying amp
self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl)) self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl))
# before checkpoint restore, convert to amp # before checkpoint restore, convert to amp
self.model, self._optim = get_apex_utils().to_amp(self.model, optim, self.model, self._optim = self._apex.to_amp(self.model, optim,
batch_size=train_dl.batch_size) batch_size=train_dl.batch_size)
self._lossfn = self._lossfn.to(get_device()) self._lossfn = self._lossfn.to(self.get_device())
self.pre_fit(train_dl, val_dl) self.pre_fit(train_dl, val_dl)
@ -158,7 +159,7 @@ class Trainer(EnforceOverrides):
self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr']) self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr'])
# checkpoint if enabled with given freq or if this is the last epoch # checkpoint if enabled with given freq or if this is the last epoch
if self._checkpoint is not None and get_apex_utils().is_master() and \ if self._checkpoint is not None and self._apex.is_master() and \
self._checkpoint.freq > 0 and (self._metrics.epochs() % self._checkpoint.freq == 0 or \ self._checkpoint.freq > 0 and (self._metrics.epochs() % self._checkpoint.freq == 0 or \
self._metrics.epochs() >= self._epochs): self._metrics.epochs() >= self._epochs):
self._checkpoint.new() self._checkpoint.new()
@ -173,6 +174,9 @@ class Trainer(EnforceOverrides):
self._metrics.post_step(x, y, logits, loss, steps) self._metrics.post_step(x, y, logits, loss, steps)
######################### hooks ######################### ######################### hooks #########################
def get_device(self):
return self._apex.device
def restore_checkpoint(self)->None: def restore_checkpoint(self)->None:
state = self._checkpoint['trainer'] state = self._checkpoint['trainer']
last_epoch = state['last_epoch'] last_epoch = state['last_epoch']
@ -180,7 +184,7 @@ class Trainer(EnforceOverrides):
self._metrics.load_state_dict(state['metrics']) self._metrics.load_state_dict(state['metrics'])
assert self._metrics.epochs() == last_epoch+1 assert self._metrics.epochs() == last_epoch+1
get_apex_utils().load_state_dict(state['amp']) self._apex.load_state_dict(state['amp'])
self.model.load_state_dict(state['model']) self.model.load_state_dict(state['model'])
self._optim.load_state_dict(state['optim']) self._optim.load_state_dict(state['optim'])
if self._sched: if self._sched:
@ -198,7 +202,7 @@ class Trainer(EnforceOverrides):
'model': self.model.state_dict(), 'model': self.model.state_dict(),
'optim': self._optim.state_dict(), 'optim': self._optim.state_dict(),
'sched': self._sched.state_dict() if self._sched else None, 'sched': self._sched.state_dict() if self._sched else None,
'amp': get_apex_utils().state_dict() 'amp': self._apex.state_dict()
} }
self._checkpoint['trainer'] = state self._checkpoint['trainer'] = state
@ -208,7 +212,7 @@ class Trainer(EnforceOverrides):
logger.pushd('steps') logger.pushd('steps')
for step, (x, y) in enumerate(train_dl): for step, (x, y) in enumerate(train_dl):
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True) x, y = x.to(self.get_device(), non_blocking=True), y.to(self.get_device(), non_blocking=True)
logger.pushd(step) logger.pushd(step)
assert self.model.training # derived class might alter the mode assert self.model.training # derived class might alter the mode
@ -226,15 +230,15 @@ class Trainer(EnforceOverrides):
loss = self.compute_loss(self._lossfn, x, y, logits, loss = self.compute_loss(self._lossfn, x, y, logits,
self._aux_weight, aux_logits) self._aux_weight, aux_logits)
get_apex_utils().backward(loss, self._optim) self._apex.backward(loss, self._optim)
# TODO: original darts clips alphas as well but pt.darts doesn't # TODO: original darts clips alphas as well but pt.darts doesn't
get_apex_utils().clip_grad(self._grad_clip, self.model, self._optim) self._apex.clip_grad(self._grad_clip, self.model, self._optim)
self._optim.step() self._optim.step()
# TODO: we possibly need to sync so all replicas are upto date # TODO: we possibly need to sync so all replicas are upto date
get_apex_utils().sync_devices() self._apex.sync_devices()
if self._sched and not self._sched_on_epoch: if self._sched and not self._sched_on_epoch:
self._sched.step() self._sched.step()

Просмотреть файл

@ -55,7 +55,7 @@ def deep_update(d:MutableMapping, u:Mapping, map_type:Type[MutableMapping]=dict)
return d return d
def state_dict(val)->Mapping: def state_dict(val)->Mapping:
assert hasattr(val, '__dict__'), 'val must be object with __dict__' assert hasattr(val, '__dict__'), 'val must be object with __dict__ otherwise it cannot be loaded back in load_state_dict'
# Can't do below because val has state_dict() which calls utils.state_dict # Can't do below because val has state_dict() which calls utils.state_dict
# if has_method(val, 'state_dict'): # if has_method(val, 'state_dict'):
@ -76,8 +76,8 @@ def load_state_dict(val:Any, state_dict:Mapping)->None:
assert s is not None, 'state_dict must contain yaml key' assert s is not None, 'state_dict must contain yaml key'
obj = yaml.load(s, Loader=yaml.Loader) obj = yaml.load(s, Loader=yaml.Loader)
for k in val.__dict__.keys(): for k, v in obj.__dict__.items():
setattr(val, k, getattr(obj, k)) setattr(val, k, v)
def deep_comp(o1:Any, o2:Any)->bool: def deep_comp(o1:Any, o2:Any)->bool:
# NOTE: dict don't have __dict__ # NOTE: dict don't have __dict__

Просмотреть файл

@ -19,8 +19,8 @@ common:
redis: null redis: null
apex: # this is overriden in search and eval individually apex: # this is overriden in search and eval individually
enabled: False # global switch to disable anything apex enabled: False # global switch to disable anything apex
distributed_enabled: False # enable/disable distributed mode distributed_enabled: True # enable/disable distributed mode
mixed_prec_enabled: False # switch to disable amp mixed precision mixed_prec_enabled: True # switch to disable amp mixed precision
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
opt_level: 'O2' # optimization level for mixed precision opt_level: 'O2' # optimization level for mixed precision
bn_fp32: True # keep BN in fp32 bn_fp32: True # keep BN in fp32
@ -88,6 +88,7 @@ nas:
trainer: trainer:
apex: apex:
_copy: 'common/apex' _copy: 'common/apex'
enabled: True
aux_weight: '_copy: nas/eval/model_desc/aux_weight' aux_weight: '_copy: nas/eval/model_desc/aux_weight'
drop_path_prob: 0.2 # probability that given edge will be dropped drop_path_prob: 0.2 # probability that given edge will be dropped
grad_clip: 5.0 # grads above this value is clipped grad_clip: 5.0 # grads above this value is clipped

Просмотреть файл

@ -1,10 +1,27 @@
from collections import UserDict
import yaml import yaml
from typing import Iterator
y = """ class A(object):
a: .NaN def __init__(self):
self.hidden = 42
self.visible = 5
""" def __getstate__(self):
state = self.__dict__.copy()
del state['hidden'] # cannot serialize this
return state
d=yaml.load(y, Loader=yaml.Loader) a = A()
d = yaml.dump(a)
print(d) print(d)
print(type( d['a']))
# y = """
# a: .NaN
# """
# d=yaml.load(y, Loader=yaml.Loader)
# print(d)
# print(type( d['a']))

Просмотреть файл

@ -5,7 +5,7 @@ from archai import cifar10_models
from archai.common.trainer import Trainer from archai.common.trainer import Trainer
from archai.common.config import Config from archai.common.config import Config
from archai.common.common import logger, common_init, get_device from archai.common.common import logger, common_init
from archai.datasets import data from archai.datasets import data
def train_test(conf_eval:Config): def train_test(conf_eval:Config):
@ -24,7 +24,7 @@ def train_test(conf_eval:Config):
conf_trainer['aux_weight'] = 0.0 conf_trainer['aux_weight'] = 0.0
Net = cifar10_models.resnet34 Net = cifar10_models.resnet34
model = Net().to(get_device()) model = Net().to(torch.device('cuda', 0))
# get data # get data
train_dl, _, test_dl = data.get_data(conf_loader) train_dl, _, test_dl = data.get_data(conf_loader)