This commit is contained in:
Shital Shah 2020-04-23 13:06:09 -07:00
Родитель 0235072c8e
Коммит d861b5f2a6
12 изменённых файлов: 198 добавлений и 174 удалений

Просмотреть файл

@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer
from archai.common import utils, ml_utils
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint
from archai.common.common import logger, get_device
from archai.common.common import logger
from .bilevel_optimizer import BilevelOptimizer
class BilevelArchTrainer(ArchTrainer):
@ -36,10 +36,11 @@ class BilevelArchTrainer(ArchTrainer):
assert val_dl is not None
w_momentum = self._conf_w_optim['momentum']
w_decay = self._conf_w_optim['decay']
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device())
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum,
w_decay, self.model, lossfn)
w_decay, self.model, lossfn,
self.get_device())
@overrides
def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
@ -71,7 +72,7 @@ class BilevelArchTrainer(ArchTrainer):
self._valid_iter = iter(self._val_dl)
x_val, y_val = next(self._valid_iter)
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True)
x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True)
# update alphas
self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer())

Просмотреть файл

@ -9,11 +9,11 @@ from torch.optim.optimizer import Optimizer
from archai.common.config import Config
from archai.common import utils, ml_utils
from archai.nas.model import Model
from archai.common.common import logger, get_device
from archai.common.common import logger
class BilevelOptimizer:
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
model: Model, lossfn: _Loss) -> None:
model: Model, lossfn: _Loss, device) -> None:
self._w_momentum = w_momentum # momentum for w
self._w_weight_decay = w_decay # weight decay for w
self._lossfn = lossfn
@ -22,7 +22,7 @@ class BilevelOptimizer:
# create a copy of model which we will use
# to compute grads for alphas without disturbing
# original weights
self._vmodel = copy.deepcopy(model).to(get_device())
self._vmodel = copy.deepcopy(model).to(device)
# this is the optimizer to optimize alphas parameter
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())

Просмотреть файл

@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer
from archai.common import utils, ml_utils
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint
from archai.common.common import logger, get_device
from archai.common.common import logger
class XnasArchTrainer(ArchTrainer):
@ -39,7 +39,7 @@ class XnasArchTrainer(ArchTrainer):
# optimizers, schedulers needs to be recreated for each fit call
# as they have state
assert val_dl is not None
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device())
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn)
@ -74,7 +74,7 @@ class XnasArchTrainer(ArchTrainer):
self._valid_iter = iter(self._val_dl)
x_val, y_val = next(self._valid_iter)
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True)
x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True)
# update alphas
self._xnas_optim.step(x, y, x_val, y_val)

Просмотреть файл

@ -15,23 +15,7 @@ from archai.common import ml_utils, utils
from archai.common.ordereddict_logger import OrderedDictLogger
class ApexUtils:
def __init__(self)->None:
self._amp = self._ddp = None
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
self.gpu_ids = [] # use all gpus
self._mixed_prec_enabled = False
self._distributed_enabled = False
self._set_ranks()
def reset(self, logger:OrderedDictLogger, apex_config:Config)->None:
# reset allows to configure differently for search or eval modes
# to avoid circular references= with common, logger is passed from outside
self.logger = logger
def __init__(self, apex_config:Config, logger:Optional[OrderedDictLogger])->None:
# region conf vars
self._enabled = apex_config['enabled'] # global switch to disable anything apex
self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode
@ -47,42 +31,53 @@ class ApexUtils:
conf_gpu_ids = apex_config['gpus']
# endregion
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
self._amp, self._ddp = None, None
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 # which GPU to use, we will use only 1 GPU
# to avoid circular references= with common, logger is passed from outside
self.logger = logger
#logger.info({'apex_config': apex_config.to_dict()})
logger.info({'torch.distributed.is_available': dist.is_available()})
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
# defaults for non-distributed mode
self._amp, self._ddp = None, None
self.world_size = 1
self.local_rank = self.global_rank = 0
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
# which GPU to use, we will use only 1 GPU per process to avoid complications with apex
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0
#_log_info({'apex_config': apex_config.to_dict()})
self._log_info({'torch.distributed.is_available': dist.is_available()})
if dist.is_available():
logger.info({'gloo_available': dist.is_gloo_available(),
self._log_info({'gloo_available': dist.is_gloo_available(),
'mpi_available': dist.is_mpi_available(),
'nccl_available': dist.is_nccl_available()})
if self._enabled:
if self._mixed_prec_enabled:
# init enable mixed precision
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
from apex import amp
self._amp = amp
if self.is_mixed():
# init enable mixed precision
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
from apex import amp
self._amp = amp
# enable distributed processing
if self._distributed_enabled:
from apex import parallel
self._ddp = parallel
# enable distributed processing
if self.is_dist():
from apex import parallel
self._ddp = parallel
assert dist.is_available() # distributed module is available
assert dist.is_nccl_available()
if not dist.is_initialized():
dist.init_process_group(backend='nccl', init_method='env://')
assert dist.is_initialized()
assert dist.is_available() # distributed module is available
assert dist.is_nccl_available()
if not dist.is_initialized():
dist.init_process_group(backend='nccl', init_method='env://')
assert dist.is_initialized()
self._set_ranks()
assert dist.get_world_size() == self.world_size
assert dist.get_rank() == self.global_rank
else:
assert self.world_size == 1
assert self.local_rank == 0
assert self.global_rank == 0
self._set_ranks()
assert dist.get_world_size() == self.world_size
assert dist.get_rank() == self.global_rank
else:
assert self.world_size == 1
assert self.local_rank == 0
assert self.global_rank == 0
assert self.world_size >= 1
assert not self._min_world_size or self.world_size >= self._min_world_size
@ -94,9 +89,9 @@ class ApexUtils:
self.device = torch.device('cuda', self._gpu)
self._setup_gpus(seed, detect_anomaly)
logger.info({'amp_available': self._amp is not None,
self._log_info({'amp_available': self._amp is not None,
'distributed_available': self._ddp is not None})
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
self._log_info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
'world_size': self.world_size,
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
'local_rank': self.local_rank})
@ -106,10 +101,10 @@ class ApexUtils:
utils.setup_cuda(seed, self.local_rank)
torch.autograd.set_detect_anomaly(detect_anomaly)
self.logger.info({'set_detect_anomaly': detect_anomaly,
self._log_info({'set_detect_anomaly': detect_anomaly,
'is_anomaly_enabled': torch.is_anomaly_enabled()})
self.logger.info({'gpu_names': utils.cuda_device_names(),
self._log_info({'gpu_names': utils.cuda_device_names(),
'gpu_count': torch.cuda.device_count(),
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
@ -118,8 +113,8 @@ class ApexUtils:
'cudnn.deterministic': cudnn.deterministic,
'cudnn.version': cudnn.version()
})
self.logger.info({'memory': str(psutil.virtual_memory())})
self.logger.info({'CPUs': str(psutil.cpu_count())})
self._log_info({'memory': str(psutil.virtual_memory())})
self._log_info({'CPUs': str(psutil.cpu_count())})
# gpu_usage = os.popen(
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
@ -127,7 +122,7 @@ class ApexUtils:
# for i, line in enumerate(gpu_usage):
# vals = line.split(',')
# if len(vals) == 2:
# logger.info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
# _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
def _set_ranks(self)->None:
if 'WORLD_SIZE' in os.environ:
@ -153,18 +148,22 @@ class ApexUtils:
self._gpu = self.gpu_ids[self.local_rank]
def is_mixed(self)->bool:
return self._mixed_prec_enabled
return self._enabled and self._mixed_prec_enabled
def is_dist(self)->bool:
return self._distributed_enabled
return self._enabled and self._distributed_enabled
def is_master(self)->bool:
return self.global_rank == 0
def _log_info(self, d:dict)->None:
if self.logger is not None:
self.logger.info(d)
def sync_devices(self)->None:
if self._distributed_enabled:
if self.is_dist():
torch.cuda.synchronize(self.device)
def reduce(self, val, op='mean'):
if self._distributed_enabled:
if self.is_dist():
if not isinstance(val, Tensor):
rt = torch.tensor(val).to(self.device)
converted = True
@ -184,7 +183,7 @@ class ApexUtils:
return val
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
if self._mixed_prec_enabled:
if self.is_mixed():
with self._amp.scale_loss(loss, optim) as scaled_loss:
scaled_loss.backward()
else:
@ -193,26 +192,26 @@ class ApexUtils:
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
->Tuple[nn.Module, Optimizer]:
# conver BNs to sync BNs in distributed mode
if self._distributed_enabled and self._sync_bn:
if self.is_dist() and self._sync_bn:
model = self._ddp.convert_syncbn_model(model)
self.logger.info({'BNs_converted': True})
self._log_info({'BNs_converted': True})
model = model.to(self.device)
if self._mixed_prec_enabled:
if self.is_mixed():
# scale LR
if self._scale_lr:
lr = ml_utils.get_optim_lr(optim)
scaled_lr = lr * self.world_size / float(batch_size)
ml_utils.set_optim_lr(optim, scaled_lr)
self.logger.info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
model, optim = self._amp.initialize(
model, optim, opt_level=self._opt_level,
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
)
if self._distributed_enabled:
if self.is_dist():
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# delay_allreduce delays all communication to the end of the backward pass.
@ -222,19 +221,19 @@ class ApexUtils:
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
if clip > 0.0:
if self._mixed_prec_enabled:
if self.is_mixed():
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
else:
nn.utils.clip_grad_norm_(model.parameters(), clip)
def state_dict(self):
if self._mixed_prec_enabled:
if self.is_mixed():
return self._amp.state_dict()
else:
return None
def load_state_dict(self, state_dict):
if self._mixed_prec_enabled:
if self.is_mixed():
self._amp.load_state_dict()

Просмотреть файл

@ -30,42 +30,11 @@ class SummaryWriterDummy:
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter]
logger = OrderedDictLogger(None, None)
_tb_writer: SummaryWriterAny = None
_apex_utils = ApexUtils()
_atexit_reg = False # is hook for atexit registered?
def get_conf()->Config:
return Config.get()
def get_device():
global _apex_utils
return _apex_utils.device
def get_apex_utils()->ApexUtils:
global _apex_utils
assert _apex_utils
return _apex_utils
def is_dist()->bool:
global _apex_utils
return _apex_utils.is_dist()
def reduce_min(val):
global _apex_utils
return _apex_utils.reduce(val, op='min')
def reduce_max(val):
global _apex_utils
return _apex_utils.reduce(val, op='max')
def reduce_sum(val):
global _apex_utils
return _apex_utils.reduce(val, op='sum')
def reduce_mean(val):
global _apex_utils
return _apex_utils.reduce(val, op='mean')
def is_dist()->bool:
global _apex_utils
return _apex_utils.is_dist()
def get_conf_common()->Config:
return get_conf()['common']
@ -138,14 +107,18 @@ def common_init(config_filepath: Optional[str]=None,
logger.info({'expdir': expdir,
'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})
# create a[ex to know distributed processing paramters
conf_apex = get_conf_common()['apex']
apex = ApexUtils(conf_apex, None)
# create global logger
_setup_logger()
_setup_logger(apex)
# create info file for current system
_create_sysinfo(conf)
# setup tensorboard
global _tb_writer
_tb_writer = _create_tb_writer(get_apex_utils().is_master())
_tb_writer = _create_tb_writer(apex.is_master())
# create hooks to execute code when script exits
global _atexit_reg
@ -216,17 +189,18 @@ def _setup_dirs()->Optional[str]:
os.environ['distdir'] = conf_common['distdir'] = distdir
def _setup_logger():
def _setup_logger(apex:ApexUtils):
global logger
logger.close() # close any previous instances
conf_common = get_conf_common()
expdir = conf_common['expdir']
distdir = conf_common['distdir']
global_rank = get_apex_utils().global_rank
global_rank = apex.global_rank
# file where logger would log messages
if get_apex_utils().is_master():
if apex.is_master():
sys_log_filepath = utils.full_path(os.path.join(expdir, 'logs.log'))
logs_yaml_filepath = utils.full_path(os.path.join(expdir, 'logs.yaml'))
experiment_name = get_experiment_name()

Просмотреть файл

@ -11,7 +11,8 @@ from torch import Tensor
import yaml
from . import utils, ml_utils
from .common import logger, get_tb_writer, is_dist, reduce_mean, reduce_sum, reduce_min, reduce_max
from .common import logger, get_tb_writer
from .apex_utils import ApexUtils
class Metrics:
@ -29,7 +30,7 @@ class Metrics:
best we have seen for each epoch.
"""
def __init__(self, title:str, logger_freq:int=50) -> None:
def __init__(self, title:str, apex:Optional[ApexUtils], logger_freq:int=50) -> None:
"""Create the metrics object to maintain epoch stats
Arguments:
@ -39,6 +40,7 @@ class Metrics:
"""
self.logger_freq = logger_freq
self.title = title
self._apex = apex
self._reset_run()
def _reset_run(self)->None:
@ -59,27 +61,27 @@ class Metrics:
logger.info({'epoch':self.run_metrics.epoch_time_avg(),
'step': self.run_metrics.step_time_avg(),
'run': self.run_metrics.duration()})
if is_dist():
logger.info({'dist_epoch_sum': reduce_sum(self.run_metrics.epoch_time_avg()),
'dist_step': reduce_mean(self.run_metrics.step_time_avg()),
'dist_run_sum': reduce_sum(self.run_metrics.duration())})
if self.is_dist():
logger.info({'dist_epoch_sum': self.reduce_sum(self.run_metrics.epoch_time_avg()),
'dist_step': self.reduce_mean(self.run_metrics.step_time_avg()),
'dist_run_sum': self.reduce_sum(self.run_metrics.duration())})
best_train, best_val = self.run_metrics.best_epoch()
with logger.pushd('best_train'):
logger.info({'epoch': best_train.index,
'top1': best_train.top1.avg})
if is_dist():
logger.info({'dist_epoch': reduce_mean(best_train.index),
'dist_top1': reduce_mean(best_train.top1.avg)})
if self.is_dist():
logger.info({'dist_epoch': self.reduce_mean(best_train.index),
'dist_top1': self.reduce_mean(best_train.top1.avg)})
if best_val:
with logger.pushd('best_val'):
logger.info({'epoch': best_val.index,
'top1': best_val.val_metrics.top1.avg})
if is_dist():
logger.info({'dist_epoch': reduce_mean(best_val.index),
'dist_top1': reduce_mean(best_val.val_metrics.top1.avg)})
if self.is_dist():
logger.info({'dist_epoch': self.reduce_mean(best_val.index),
'dist_top1': self.reduce_mean(best_val.val_metrics.top1.avg)})
def pre_step(self, x: Tensor, y: Tensor):
self.run_metrics.cur_epoch().pre_step()
@ -102,11 +104,11 @@ class Metrics:
'loss': epoch.loss.avg,
'step_time': epoch.step_time.last})
if is_dist():
logger.info({'dist_top1': reduce_mean(epoch.top1.avg),
'dist_top5': reduce_mean(epoch.top5.avg),
'dist_loss': reduce_mean(epoch.loss.avg),
'dist_step_time': reduce_mean(epoch.step_time.last)})
if self.is_dist():
logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg),
'dist_top5': self.reduce_mean(epoch.top5.avg),
'dist_loss': self.reduce_mean(epoch.loss.avg),
'dist_step_time': self.reduce_mean(epoch.step_time.last)})
# NOTE: Tensorboard step-level logging is removed as it becomes exponentially expensive on Azure blobs
@ -143,24 +145,24 @@ class Metrics:
'duration': epoch.duration(),
'step_time': epoch.step_time.avg,
'end_lr': lr})
if is_dist():
logger.info({'dist_top1': reduce_mean(epoch.top1.avg),
'dist_top5': reduce_mean(epoch.top5.avg),
'dist_loss': reduce_mean(epoch.loss.avg),
'dist_duration': reduce_mean(epoch.duration()),
'dist_step_time': reduce_mean(epoch.step_time.avg),
'dist_end_lr': reduce_mean(lr)})
if self.is_dist():
logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg),
'dist_top5': self.reduce_mean(epoch.top5.avg),
'dist_loss': self.reduce_mean(epoch.loss.avg),
'dist_duration': self.reduce_mean(epoch.duration()),
'dist_step_time': self.reduce_mean(epoch.step_time.avg),
'dist_end_lr': self.reduce_mean(lr)})
if test_epoch:
with logger.pushd('val'):
logger.info({'top1': test_epoch.top1.avg,
'top5': test_epoch.top5.avg,
'loss': test_epoch.loss.avg,
'duration': epoch.duration()})
if is_dist():
logger.info({'dist_top1': reduce_mean(test_epoch.top1.avg),
'dist_top5': reduce_mean(test_epoch.top5.avg),
'dist_loss': reduce_mean(test_epoch.loss.avg),
'dist_duration': reduce_mean(test_epoch.duration())})
if self.is_dist():
logger.info({'dist_top1': self.reduce_mean(test_epoch.top1.avg),
'dist_top5': self.reduce_mean(test_epoch.top5.avg),
'dist_loss': self.reduce_mean(test_epoch.loss.avg),
'dist_duration': self.reduce_mean(test_epoch.duration())})
# writer = get_tb_writer()
# writer.add_scalar(f'{self._tb_path}/train_epochs/loss',
@ -181,9 +183,14 @@ class Metrics:
return utils.state_dict(self)
def load_state_dict(self, state_dict:dict)->None:
# simply convert current object to dictionary
utils.load_state_dict(self, state_dict)
def __getstate__(self):
state = self.__dict__.copy()
del state['_apex'] # cannot serialize this
return state
# no need to define __setstate__ because _apex should be set from constructor
def save(self, filepath:str)->Optional[str]:
if filepath:
filepath = utils.full_path(filepath)
@ -197,6 +204,27 @@ class Metrics:
def cur_epoch(self)->'EpochMetrics':
return self.run_metrics.cur_epoch()
def reduce_min(self, val):
if not self._apex:
return val
return self._apex.reduce(val, op='min')
def reduce_max(self, val):
if not self._apex:
return val
return self._apex.reduce(val, op='max')
def reduce_sum(self, val):
if not self._apex:
return val
return self._apex.reduce(val, op='sum')
def reduce_mean(self, val):
if not self._apex:
return val
return self._apex.reduce(val, op='mean')
def is_dist(self)->bool:
if not self._apex:
return False
return self._apex.is_dist()
class Accumulator:
# TODO: replace this with Metrics class

Просмотреть файл

@ -9,18 +9,18 @@ from overrides import EnforceOverrides
from .metrics import Metrics
from .config import Config
from . import utils, ml_utils
from .common import logger, get_device
from archai.common.common import get_apex_utils
from .common import logger
from archai.common.apex_utils import ApexUtils
class Tester(EnforceOverrides):
def __init__(self, conf_eval:Config, model:nn.Module)->None:
# TODO: currently we expect that given model and dataloader will already be distributed
def __init__(self, conf_eval:Config, model:nn.Module, apex:ApexUtils)->None:
self._title = conf_eval['title']
self._logger_freq = conf_eval['logger_freq']
conf_lossfn = conf_eval['lossfn']
self._apex = apex
self.model = model
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(get_device())
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(apex.device)
self._metrics = None
def test(self, test_dl: DataLoader)->Metrics:
@ -43,7 +43,7 @@ class Tester(EnforceOverrides):
with torch.no_grad(), logger.pushd('steps'):
for step, (x, y) in enumerate(test_dl):
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True)
x, y = x.to(self._apex.device, non_blocking=True), y.to(self._apex.device, non_blocking=True)
assert not self.model.training # derived class might alter the mode
logger.pushd(step)
@ -57,7 +57,7 @@ class Tester(EnforceOverrides):
self._post_step(x, y, logits, loss, steps, self._metrics)
# TODO: we possibly need to sync so all replicas are upto date
get_apex_utils().sync_devices()
self._apex.sync_devices()
logger.popd()
self._metrics.post_epoch(None)
@ -87,5 +87,5 @@ class Tester(EnforceOverrides):
metrics.post_step(x, y, logits, loss, steps)
def _create_metrics(self)->Metrics:
return Metrics(self._title, logger_freq=self._logger_freq)
return Metrics(self._title, self._apex, logger_freq=self._logger_freq)

Просмотреть файл

@ -11,8 +11,9 @@ from .metrics import Metrics
from .tester import Tester
from .config import Config
from . import utils, ml_utils
from ..common.common import logger, get_device, get_apex_utils
from ..common.common import logger
from ..common.checkpoint import CheckPoint
from ..common.apex_utils import ApexUtils
class Trainer(EnforceOverrides):
@ -33,13 +34,13 @@ class Trainer(EnforceOverrides):
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
# endregion
get_apex_utils().reset(logger, conf_apex)
self._apex = ApexUtils(conf_apex, logger)
self._checkpoint = checkpoint
self.model = model
self._lossfn = ml_utils.get_lossfn(conf_lossfn)
self._tester = Tester(conf_validation, model) \
self._tester = Tester(conf_validation, model, self._apex) \
if conf_validation else None
self._metrics:Optional[Metrics] = None
@ -52,7 +53,7 @@ class Trainer(EnforceOverrides):
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
logger.pushd(self._title)
self._metrics = Metrics(self._title, logger_freq=self._logger_freq)
self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)
# optimizers, schedulers needs to be recreated for each fit call
# as they have state specific to each run
@ -60,10 +61,10 @@ class Trainer(EnforceOverrides):
# create scheduler for optim before applying amp
self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl))
# before checkpoint restore, convert to amp
self.model, self._optim = get_apex_utils().to_amp(self.model, optim,
self.model, self._optim = self._apex.to_amp(self.model, optim,
batch_size=train_dl.batch_size)
self._lossfn = self._lossfn.to(get_device())
self._lossfn = self._lossfn.to(self.get_device())
self.pre_fit(train_dl, val_dl)
@ -158,7 +159,7 @@ class Trainer(EnforceOverrides):
self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr'])
# checkpoint if enabled with given freq or if this is the last epoch
if self._checkpoint is not None and get_apex_utils().is_master() and \
if self._checkpoint is not None and self._apex.is_master() and \
self._checkpoint.freq > 0 and (self._metrics.epochs() % self._checkpoint.freq == 0 or \
self._metrics.epochs() >= self._epochs):
self._checkpoint.new()
@ -173,6 +174,9 @@ class Trainer(EnforceOverrides):
self._metrics.post_step(x, y, logits, loss, steps)
######################### hooks #########################
def get_device(self):
return self._apex.device
def restore_checkpoint(self)->None:
state = self._checkpoint['trainer']
last_epoch = state['last_epoch']
@ -180,7 +184,7 @@ class Trainer(EnforceOverrides):
self._metrics.load_state_dict(state['metrics'])
assert self._metrics.epochs() == last_epoch+1
get_apex_utils().load_state_dict(state['amp'])
self._apex.load_state_dict(state['amp'])
self.model.load_state_dict(state['model'])
self._optim.load_state_dict(state['optim'])
if self._sched:
@ -198,7 +202,7 @@ class Trainer(EnforceOverrides):
'model': self.model.state_dict(),
'optim': self._optim.state_dict(),
'sched': self._sched.state_dict() if self._sched else None,
'amp': get_apex_utils().state_dict()
'amp': self._apex.state_dict()
}
self._checkpoint['trainer'] = state
@ -208,7 +212,7 @@ class Trainer(EnforceOverrides):
logger.pushd('steps')
for step, (x, y) in enumerate(train_dl):
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True)
x, y = x.to(self.get_device(), non_blocking=True), y.to(self.get_device(), non_blocking=True)
logger.pushd(step)
assert self.model.training # derived class might alter the mode
@ -226,15 +230,15 @@ class Trainer(EnforceOverrides):
loss = self.compute_loss(self._lossfn, x, y, logits,
self._aux_weight, aux_logits)
get_apex_utils().backward(loss, self._optim)
self._apex.backward(loss, self._optim)
# TODO: original darts clips alphas as well but pt.darts doesn't
get_apex_utils().clip_grad(self._grad_clip, self.model, self._optim)
self._apex.clip_grad(self._grad_clip, self.model, self._optim)
self._optim.step()
# TODO: we possibly need to sync so all replicas are upto date
get_apex_utils().sync_devices()
self._apex.sync_devices()
if self._sched and not self._sched_on_epoch:
self._sched.step()

Просмотреть файл

@ -55,7 +55,7 @@ def deep_update(d:MutableMapping, u:Mapping, map_type:Type[MutableMapping]=dict)
return d
def state_dict(val)->Mapping:
assert hasattr(val, '__dict__'), 'val must be object with __dict__'
assert hasattr(val, '__dict__'), 'val must be object with __dict__ otherwise it cannot be loaded back in load_state_dict'
# Can't do below because val has state_dict() which calls utils.state_dict
# if has_method(val, 'state_dict'):
@ -76,8 +76,8 @@ def load_state_dict(val:Any, state_dict:Mapping)->None:
assert s is not None, 'state_dict must contain yaml key'
obj = yaml.load(s, Loader=yaml.Loader)
for k in val.__dict__.keys():
setattr(val, k, getattr(obj, k))
for k, v in obj.__dict__.items():
setattr(val, k, v)
def deep_comp(o1:Any, o2:Any)->bool:
# NOTE: dict don't have __dict__

Просмотреть файл

@ -19,8 +19,8 @@ common:
redis: null
apex: # this is overriden in search and eval individually
enabled: False # global switch to disable anything apex
distributed_enabled: False # enable/disable distributed mode
mixed_prec_enabled: False # switch to disable amp mixed precision
distributed_enabled: True # enable/disable distributed mode
mixed_prec_enabled: True # switch to disable amp mixed precision
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
opt_level: 'O2' # optimization level for mixed precision
bn_fp32: True # keep BN in fp32
@ -88,6 +88,7 @@ nas:
trainer:
apex:
_copy: 'common/apex'
enabled: True
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
drop_path_prob: 0.2 # probability that given edge will be dropped
grad_clip: 5.0 # grads above this value is clipped

Просмотреть файл

@ -1,10 +1,27 @@
from collections import UserDict
import yaml
from typing import Iterator
y = """
a: .NaN
class A(object):
def __init__(self):
self.hidden = 42
self.visible = 5
"""
def __getstate__(self):
state = self.__dict__.copy()
del state['hidden'] # cannot serialize this
return state
d=yaml.load(y, Loader=yaml.Loader)
a = A()
d = yaml.dump(a)
print(d)
print(type( d['a']))
# y = """
# a: .NaN
# """
# d=yaml.load(y, Loader=yaml.Loader)
# print(d)
# print(type( d['a']))

Просмотреть файл

@ -5,7 +5,7 @@ from archai import cifar10_models
from archai.common.trainer import Trainer
from archai.common.config import Config
from archai.common.common import logger, common_init, get_device
from archai.common.common import logger, common_init
from archai.datasets import data
def train_test(conf_eval:Config):
@ -24,7 +24,7 @@ def train_test(conf_eval:Config):
conf_trainer['aux_weight'] = 0.0
Net = cifar10_models.resnet34
model = Net().to(get_device())
model = Net().to(torch.device('cuda', 0))
# get data
train_dl, _, test_dl = data.get_data(conf_loader)