зеркало из https://github.com/microsoft/archai.git
Move apex to trainer
This commit is contained in:
Родитель
0235072c8e
Коммит
d861b5f2a6
|
@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer
|
|||
from archai.common import utils, ml_utils
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.common import logger, get_device
|
||||
from archai.common.common import logger
|
||||
from .bilevel_optimizer import BilevelOptimizer
|
||||
|
||||
class BilevelArchTrainer(ArchTrainer):
|
||||
|
@ -36,10 +36,11 @@ class BilevelArchTrainer(ArchTrainer):
|
|||
assert val_dl is not None
|
||||
w_momentum = self._conf_w_optim['momentum']
|
||||
w_decay = self._conf_w_optim['decay']
|
||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device())
|
||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
|
||||
|
||||
self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum,
|
||||
w_decay, self.model, lossfn)
|
||||
w_decay, self.model, lossfn,
|
||||
self.get_device())
|
||||
|
||||
@overrides
|
||||
def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
|
||||
|
@ -71,7 +72,7 @@ class BilevelArchTrainer(ArchTrainer):
|
|||
self._valid_iter = iter(self._val_dl)
|
||||
x_val, y_val = next(self._valid_iter)
|
||||
|
||||
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True)
|
||||
x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True)
|
||||
|
||||
# update alphas
|
||||
self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer())
|
||||
|
|
|
@ -9,11 +9,11 @@ from torch.optim.optimizer import Optimizer
|
|||
from archai.common.config import Config
|
||||
from archai.common import utils, ml_utils
|
||||
from archai.nas.model import Model
|
||||
from archai.common.common import logger, get_device
|
||||
from archai.common.common import logger
|
||||
|
||||
class BilevelOptimizer:
|
||||
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
|
||||
model: Model, lossfn: _Loss) -> None:
|
||||
model: Model, lossfn: _Loss, device) -> None:
|
||||
self._w_momentum = w_momentum # momentum for w
|
||||
self._w_weight_decay = w_decay # weight decay for w
|
||||
self._lossfn = lossfn
|
||||
|
@ -22,7 +22,7 @@ class BilevelOptimizer:
|
|||
# create a copy of model which we will use
|
||||
# to compute grads for alphas without disturbing
|
||||
# original weights
|
||||
self._vmodel = copy.deepcopy(model).to(get_device())
|
||||
self._vmodel = copy.deepcopy(model).to(device)
|
||||
# this is the optimizer to optimize alphas parameter
|
||||
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer
|
|||
from archai.common import utils, ml_utils
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.common import logger, get_device
|
||||
from archai.common.common import logger
|
||||
|
||||
|
||||
class XnasArchTrainer(ArchTrainer):
|
||||
|
@ -39,7 +39,7 @@ class XnasArchTrainer(ArchTrainer):
|
|||
# optimizers, schedulers needs to be recreated for each fit call
|
||||
# as they have state
|
||||
assert val_dl is not None
|
||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device())
|
||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
|
||||
|
||||
self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn)
|
||||
|
||||
|
@ -74,7 +74,7 @@ class XnasArchTrainer(ArchTrainer):
|
|||
self._valid_iter = iter(self._val_dl)
|
||||
x_val, y_val = next(self._valid_iter)
|
||||
|
||||
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True)
|
||||
x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True)
|
||||
|
||||
# update alphas
|
||||
self._xnas_optim.step(x, y, x_val, y_val)
|
||||
|
|
|
@ -15,23 +15,7 @@ from archai.common import ml_utils, utils
|
|||
from archai.common.ordereddict_logger import OrderedDictLogger
|
||||
|
||||
class ApexUtils:
|
||||
def __init__(self)->None:
|
||||
self._amp = self._ddp = None
|
||||
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
|
||||
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
|
||||
self.gpu_ids = [] # use all gpus
|
||||
|
||||
self._mixed_prec_enabled = False
|
||||
self._distributed_enabled = False
|
||||
|
||||
self._set_ranks()
|
||||
|
||||
def reset(self, logger:OrderedDictLogger, apex_config:Config)->None:
|
||||
# reset allows to configure differently for search or eval modes
|
||||
|
||||
# to avoid circular references= with common, logger is passed from outside
|
||||
self.logger = logger
|
||||
|
||||
def __init__(self, apex_config:Config, logger:Optional[OrderedDictLogger])->None:
|
||||
# region conf vars
|
||||
self._enabled = apex_config['enabled'] # global switch to disable anything apex
|
||||
self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode
|
||||
|
@ -47,42 +31,53 @@ class ApexUtils:
|
|||
conf_gpu_ids = apex_config['gpus']
|
||||
# endregion
|
||||
|
||||
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
|
||||
self._amp, self._ddp = None, None
|
||||
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 # which GPU to use, we will use only 1 GPU
|
||||
# to avoid circular references= with common, logger is passed from outside
|
||||
self.logger = logger
|
||||
|
||||
#logger.info({'apex_config': apex_config.to_dict()})
|
||||
logger.info({'torch.distributed.is_available': dist.is_available()})
|
||||
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
|
||||
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
|
||||
|
||||
# defaults for non-distributed mode
|
||||
self._amp, self._ddp = None, None
|
||||
self.world_size = 1
|
||||
self.local_rank = self.global_rank = 0
|
||||
|
||||
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
|
||||
|
||||
# which GPU to use, we will use only 1 GPU per process to avoid complications with apex
|
||||
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0
|
||||
|
||||
#_log_info({'apex_config': apex_config.to_dict()})
|
||||
self._log_info({'torch.distributed.is_available': dist.is_available()})
|
||||
if dist.is_available():
|
||||
logger.info({'gloo_available': dist.is_gloo_available(),
|
||||
self._log_info({'gloo_available': dist.is_gloo_available(),
|
||||
'mpi_available': dist.is_mpi_available(),
|
||||
'nccl_available': dist.is_nccl_available()})
|
||||
|
||||
if self._enabled:
|
||||
if self._mixed_prec_enabled:
|
||||
# init enable mixed precision
|
||||
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
|
||||
from apex import amp
|
||||
self._amp = amp
|
||||
if self.is_mixed():
|
||||
# init enable mixed precision
|
||||
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
|
||||
from apex import amp
|
||||
self._amp = amp
|
||||
|
||||
# enable distributed processing
|
||||
if self._distributed_enabled:
|
||||
from apex import parallel
|
||||
self._ddp = parallel
|
||||
# enable distributed processing
|
||||
if self.is_dist():
|
||||
from apex import parallel
|
||||
self._ddp = parallel
|
||||
|
||||
assert dist.is_available() # distributed module is available
|
||||
assert dist.is_nccl_available()
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend='nccl', init_method='env://')
|
||||
assert dist.is_initialized()
|
||||
assert dist.is_available() # distributed module is available
|
||||
assert dist.is_nccl_available()
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend='nccl', init_method='env://')
|
||||
assert dist.is_initialized()
|
||||
|
||||
self._set_ranks()
|
||||
assert dist.get_world_size() == self.world_size
|
||||
assert dist.get_rank() == self.global_rank
|
||||
else:
|
||||
assert self.world_size == 1
|
||||
assert self.local_rank == 0
|
||||
assert self.global_rank == 0
|
||||
self._set_ranks()
|
||||
assert dist.get_world_size() == self.world_size
|
||||
assert dist.get_rank() == self.global_rank
|
||||
else:
|
||||
assert self.world_size == 1
|
||||
assert self.local_rank == 0
|
||||
assert self.global_rank == 0
|
||||
|
||||
assert self.world_size >= 1
|
||||
assert not self._min_world_size or self.world_size >= self._min_world_size
|
||||
|
@ -94,9 +89,9 @@ class ApexUtils:
|
|||
self.device = torch.device('cuda', self._gpu)
|
||||
self._setup_gpus(seed, detect_anomaly)
|
||||
|
||||
logger.info({'amp_available': self._amp is not None,
|
||||
self._log_info({'amp_available': self._amp is not None,
|
||||
'distributed_available': self._ddp is not None})
|
||||
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
|
||||
self._log_info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
|
||||
'world_size': self.world_size,
|
||||
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
|
||||
'local_rank': self.local_rank})
|
||||
|
@ -106,10 +101,10 @@ class ApexUtils:
|
|||
utils.setup_cuda(seed, self.local_rank)
|
||||
|
||||
torch.autograd.set_detect_anomaly(detect_anomaly)
|
||||
self.logger.info({'set_detect_anomaly': detect_anomaly,
|
||||
self._log_info({'set_detect_anomaly': detect_anomaly,
|
||||
'is_anomaly_enabled': torch.is_anomaly_enabled()})
|
||||
|
||||
self.logger.info({'gpu_names': utils.cuda_device_names(),
|
||||
self._log_info({'gpu_names': utils.cuda_device_names(),
|
||||
'gpu_count': torch.cuda.device_count(),
|
||||
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
|
||||
|
@ -118,8 +113,8 @@ class ApexUtils:
|
|||
'cudnn.deterministic': cudnn.deterministic,
|
||||
'cudnn.version': cudnn.version()
|
||||
})
|
||||
self.logger.info({'memory': str(psutil.virtual_memory())})
|
||||
self.logger.info({'CPUs': str(psutil.cpu_count())})
|
||||
self._log_info({'memory': str(psutil.virtual_memory())})
|
||||
self._log_info({'CPUs': str(psutil.cpu_count())})
|
||||
|
||||
# gpu_usage = os.popen(
|
||||
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
|
||||
|
@ -127,7 +122,7 @@ class ApexUtils:
|
|||
# for i, line in enumerate(gpu_usage):
|
||||
# vals = line.split(',')
|
||||
# if len(vals) == 2:
|
||||
# logger.info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
|
||||
# _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
|
||||
|
||||
def _set_ranks(self)->None:
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
|
@ -153,18 +148,22 @@ class ApexUtils:
|
|||
self._gpu = self.gpu_ids[self.local_rank]
|
||||
|
||||
def is_mixed(self)->bool:
|
||||
return self._mixed_prec_enabled
|
||||
return self._enabled and self._mixed_prec_enabled
|
||||
def is_dist(self)->bool:
|
||||
return self._distributed_enabled
|
||||
return self._enabled and self._distributed_enabled
|
||||
def is_master(self)->bool:
|
||||
return self.global_rank == 0
|
||||
|
||||
def _log_info(self, d:dict)->None:
|
||||
if self.logger is not None:
|
||||
self.logger.info(d)
|
||||
|
||||
def sync_devices(self)->None:
|
||||
if self._distributed_enabled:
|
||||
if self.is_dist():
|
||||
torch.cuda.synchronize(self.device)
|
||||
|
||||
def reduce(self, val, op='mean'):
|
||||
if self._distributed_enabled:
|
||||
if self.is_dist():
|
||||
if not isinstance(val, Tensor):
|
||||
rt = torch.tensor(val).to(self.device)
|
||||
converted = True
|
||||
|
@ -184,7 +183,7 @@ class ApexUtils:
|
|||
return val
|
||||
|
||||
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
|
||||
if self._mixed_prec_enabled:
|
||||
if self.is_mixed():
|
||||
with self._amp.scale_loss(loss, optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
|
@ -193,26 +192,26 @@ class ApexUtils:
|
|||
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
|
||||
->Tuple[nn.Module, Optimizer]:
|
||||
# conver BNs to sync BNs in distributed mode
|
||||
if self._distributed_enabled and self._sync_bn:
|
||||
if self.is_dist() and self._sync_bn:
|
||||
model = self._ddp.convert_syncbn_model(model)
|
||||
self.logger.info({'BNs_converted': True})
|
||||
self._log_info({'BNs_converted': True})
|
||||
|
||||
model = model.to(self.device)
|
||||
|
||||
if self._mixed_prec_enabled:
|
||||
if self.is_mixed():
|
||||
# scale LR
|
||||
if self._scale_lr:
|
||||
lr = ml_utils.get_optim_lr(optim)
|
||||
scaled_lr = lr * self.world_size / float(batch_size)
|
||||
ml_utils.set_optim_lr(optim, scaled_lr)
|
||||
self.logger.info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
|
||||
self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
|
||||
|
||||
model, optim = self._amp.initialize(
|
||||
model, optim, opt_level=self._opt_level,
|
||||
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
|
||||
)
|
||||
|
||||
if self._distributed_enabled:
|
||||
if self.is_dist():
|
||||
# By default, apex.parallel.DistributedDataParallel overlaps communication with
|
||||
# computation in the backward pass.
|
||||
# delay_allreduce delays all communication to the end of the backward pass.
|
||||
|
@ -222,19 +221,19 @@ class ApexUtils:
|
|||
|
||||
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
|
||||
if clip > 0.0:
|
||||
if self._mixed_prec_enabled:
|
||||
if self.is_mixed():
|
||||
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
|
||||
else:
|
||||
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
||||
|
||||
def state_dict(self):
|
||||
if self._mixed_prec_enabled:
|
||||
if self.is_mixed():
|
||||
return self._amp.state_dict()
|
||||
else:
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if self._mixed_prec_enabled:
|
||||
if self.is_mixed():
|
||||
self._amp.load_state_dict()
|
||||
|
||||
|
||||
|
|
|
@ -30,42 +30,11 @@ class SummaryWriterDummy:
|
|||
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter]
|
||||
logger = OrderedDictLogger(None, None)
|
||||
_tb_writer: SummaryWriterAny = None
|
||||
_apex_utils = ApexUtils()
|
||||
_atexit_reg = False # is hook for atexit registered?
|
||||
|
||||
def get_conf()->Config:
|
||||
return Config.get()
|
||||
|
||||
def get_device():
|
||||
global _apex_utils
|
||||
return _apex_utils.device
|
||||
|
||||
def get_apex_utils()->ApexUtils:
|
||||
global _apex_utils
|
||||
assert _apex_utils
|
||||
return _apex_utils
|
||||
|
||||
def is_dist()->bool:
|
||||
global _apex_utils
|
||||
return _apex_utils.is_dist()
|
||||
|
||||
def reduce_min(val):
|
||||
global _apex_utils
|
||||
return _apex_utils.reduce(val, op='min')
|
||||
def reduce_max(val):
|
||||
global _apex_utils
|
||||
return _apex_utils.reduce(val, op='max')
|
||||
def reduce_sum(val):
|
||||
global _apex_utils
|
||||
return _apex_utils.reduce(val, op='sum')
|
||||
def reduce_mean(val):
|
||||
global _apex_utils
|
||||
return _apex_utils.reduce(val, op='mean')
|
||||
|
||||
def is_dist()->bool:
|
||||
global _apex_utils
|
||||
return _apex_utils.is_dist()
|
||||
|
||||
def get_conf_common()->Config:
|
||||
return get_conf()['common']
|
||||
|
||||
|
@ -138,14 +107,18 @@ def common_init(config_filepath: Optional[str]=None,
|
|||
logger.info({'expdir': expdir,
|
||||
'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})
|
||||
|
||||
# create a[ex to know distributed processing paramters
|
||||
conf_apex = get_conf_common()['apex']
|
||||
apex = ApexUtils(conf_apex, None)
|
||||
|
||||
# create global logger
|
||||
_setup_logger()
|
||||
_setup_logger(apex)
|
||||
# create info file for current system
|
||||
_create_sysinfo(conf)
|
||||
|
||||
# setup tensorboard
|
||||
global _tb_writer
|
||||
_tb_writer = _create_tb_writer(get_apex_utils().is_master())
|
||||
_tb_writer = _create_tb_writer(apex.is_master())
|
||||
|
||||
# create hooks to execute code when script exits
|
||||
global _atexit_reg
|
||||
|
@ -216,17 +189,18 @@ def _setup_dirs()->Optional[str]:
|
|||
os.environ['distdir'] = conf_common['distdir'] = distdir
|
||||
|
||||
|
||||
def _setup_logger():
|
||||
def _setup_logger(apex:ApexUtils):
|
||||
global logger
|
||||
logger.close() # close any previous instances
|
||||
|
||||
conf_common = get_conf_common()
|
||||
expdir = conf_common['expdir']
|
||||
distdir = conf_common['distdir']
|
||||
global_rank = get_apex_utils().global_rank
|
||||
|
||||
global_rank = apex.global_rank
|
||||
|
||||
# file where logger would log messages
|
||||
if get_apex_utils().is_master():
|
||||
if apex.is_master():
|
||||
sys_log_filepath = utils.full_path(os.path.join(expdir, 'logs.log'))
|
||||
logs_yaml_filepath = utils.full_path(os.path.join(expdir, 'logs.yaml'))
|
||||
experiment_name = get_experiment_name()
|
||||
|
|
|
@ -11,7 +11,8 @@ from torch import Tensor
|
|||
import yaml
|
||||
|
||||
from . import utils, ml_utils
|
||||
from .common import logger, get_tb_writer, is_dist, reduce_mean, reduce_sum, reduce_min, reduce_max
|
||||
from .common import logger, get_tb_writer
|
||||
from .apex_utils import ApexUtils
|
||||
|
||||
|
||||
class Metrics:
|
||||
|
@ -29,7 +30,7 @@ class Metrics:
|
|||
best we have seen for each epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, title:str, logger_freq:int=50) -> None:
|
||||
def __init__(self, title:str, apex:Optional[ApexUtils], logger_freq:int=50) -> None:
|
||||
"""Create the metrics object to maintain epoch stats
|
||||
|
||||
Arguments:
|
||||
|
@ -39,6 +40,7 @@ class Metrics:
|
|||
"""
|
||||
self.logger_freq = logger_freq
|
||||
self.title = title
|
||||
self._apex = apex
|
||||
self._reset_run()
|
||||
|
||||
def _reset_run(self)->None:
|
||||
|
@ -59,27 +61,27 @@ class Metrics:
|
|||
logger.info({'epoch':self.run_metrics.epoch_time_avg(),
|
||||
'step': self.run_metrics.step_time_avg(),
|
||||
'run': self.run_metrics.duration()})
|
||||
if is_dist():
|
||||
logger.info({'dist_epoch_sum': reduce_sum(self.run_metrics.epoch_time_avg()),
|
||||
'dist_step': reduce_mean(self.run_metrics.step_time_avg()),
|
||||
'dist_run_sum': reduce_sum(self.run_metrics.duration())})
|
||||
if self.is_dist():
|
||||
logger.info({'dist_epoch_sum': self.reduce_sum(self.run_metrics.epoch_time_avg()),
|
||||
'dist_step': self.reduce_mean(self.run_metrics.step_time_avg()),
|
||||
'dist_run_sum': self.reduce_sum(self.run_metrics.duration())})
|
||||
|
||||
|
||||
best_train, best_val = self.run_metrics.best_epoch()
|
||||
with logger.pushd('best_train'):
|
||||
logger.info({'epoch': best_train.index,
|
||||
'top1': best_train.top1.avg})
|
||||
if is_dist():
|
||||
logger.info({'dist_epoch': reduce_mean(best_train.index),
|
||||
'dist_top1': reduce_mean(best_train.top1.avg)})
|
||||
if self.is_dist():
|
||||
logger.info({'dist_epoch': self.reduce_mean(best_train.index),
|
||||
'dist_top1': self.reduce_mean(best_train.top1.avg)})
|
||||
|
||||
if best_val:
|
||||
with logger.pushd('best_val'):
|
||||
logger.info({'epoch': best_val.index,
|
||||
'top1': best_val.val_metrics.top1.avg})
|
||||
if is_dist():
|
||||
logger.info({'dist_epoch': reduce_mean(best_val.index),
|
||||
'dist_top1': reduce_mean(best_val.val_metrics.top1.avg)})
|
||||
if self.is_dist():
|
||||
logger.info({'dist_epoch': self.reduce_mean(best_val.index),
|
||||
'dist_top1': self.reduce_mean(best_val.val_metrics.top1.avg)})
|
||||
|
||||
def pre_step(self, x: Tensor, y: Tensor):
|
||||
self.run_metrics.cur_epoch().pre_step()
|
||||
|
@ -102,11 +104,11 @@ class Metrics:
|
|||
'loss': epoch.loss.avg,
|
||||
'step_time': epoch.step_time.last})
|
||||
|
||||
if is_dist():
|
||||
logger.info({'dist_top1': reduce_mean(epoch.top1.avg),
|
||||
'dist_top5': reduce_mean(epoch.top5.avg),
|
||||
'dist_loss': reduce_mean(epoch.loss.avg),
|
||||
'dist_step_time': reduce_mean(epoch.step_time.last)})
|
||||
if self.is_dist():
|
||||
logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg),
|
||||
'dist_top5': self.reduce_mean(epoch.top5.avg),
|
||||
'dist_loss': self.reduce_mean(epoch.loss.avg),
|
||||
'dist_step_time': self.reduce_mean(epoch.step_time.last)})
|
||||
|
||||
|
||||
# NOTE: Tensorboard step-level logging is removed as it becomes exponentially expensive on Azure blobs
|
||||
|
@ -143,24 +145,24 @@ class Metrics:
|
|||
'duration': epoch.duration(),
|
||||
'step_time': epoch.step_time.avg,
|
||||
'end_lr': lr})
|
||||
if is_dist():
|
||||
logger.info({'dist_top1': reduce_mean(epoch.top1.avg),
|
||||
'dist_top5': reduce_mean(epoch.top5.avg),
|
||||
'dist_loss': reduce_mean(epoch.loss.avg),
|
||||
'dist_duration': reduce_mean(epoch.duration()),
|
||||
'dist_step_time': reduce_mean(epoch.step_time.avg),
|
||||
'dist_end_lr': reduce_mean(lr)})
|
||||
if self.is_dist():
|
||||
logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg),
|
||||
'dist_top5': self.reduce_mean(epoch.top5.avg),
|
||||
'dist_loss': self.reduce_mean(epoch.loss.avg),
|
||||
'dist_duration': self.reduce_mean(epoch.duration()),
|
||||
'dist_step_time': self.reduce_mean(epoch.step_time.avg),
|
||||
'dist_end_lr': self.reduce_mean(lr)})
|
||||
if test_epoch:
|
||||
with logger.pushd('val'):
|
||||
logger.info({'top1': test_epoch.top1.avg,
|
||||
'top5': test_epoch.top5.avg,
|
||||
'loss': test_epoch.loss.avg,
|
||||
'duration': epoch.duration()})
|
||||
if is_dist():
|
||||
logger.info({'dist_top1': reduce_mean(test_epoch.top1.avg),
|
||||
'dist_top5': reduce_mean(test_epoch.top5.avg),
|
||||
'dist_loss': reduce_mean(test_epoch.loss.avg),
|
||||
'dist_duration': reduce_mean(test_epoch.duration())})
|
||||
if self.is_dist():
|
||||
logger.info({'dist_top1': self.reduce_mean(test_epoch.top1.avg),
|
||||
'dist_top5': self.reduce_mean(test_epoch.top5.avg),
|
||||
'dist_loss': self.reduce_mean(test_epoch.loss.avg),
|
||||
'dist_duration': self.reduce_mean(test_epoch.duration())})
|
||||
|
||||
# writer = get_tb_writer()
|
||||
# writer.add_scalar(f'{self._tb_path}/train_epochs/loss',
|
||||
|
@ -181,9 +183,14 @@ class Metrics:
|
|||
return utils.state_dict(self)
|
||||
|
||||
def load_state_dict(self, state_dict:dict)->None:
|
||||
# simply convert current object to dictionary
|
||||
utils.load_state_dict(self, state_dict)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
del state['_apex'] # cannot serialize this
|
||||
return state
|
||||
# no need to define __setstate__ because _apex should be set from constructor
|
||||
|
||||
def save(self, filepath:str)->Optional[str]:
|
||||
if filepath:
|
||||
filepath = utils.full_path(filepath)
|
||||
|
@ -197,6 +204,27 @@ class Metrics:
|
|||
def cur_epoch(self)->'EpochMetrics':
|
||||
return self.run_metrics.cur_epoch()
|
||||
|
||||
def reduce_min(self, val):
|
||||
if not self._apex:
|
||||
return val
|
||||
return self._apex.reduce(val, op='min')
|
||||
def reduce_max(self, val):
|
||||
if not self._apex:
|
||||
return val
|
||||
return self._apex.reduce(val, op='max')
|
||||
def reduce_sum(self, val):
|
||||
if not self._apex:
|
||||
return val
|
||||
return self._apex.reduce(val, op='sum')
|
||||
def reduce_mean(self, val):
|
||||
if not self._apex:
|
||||
return val
|
||||
return self._apex.reduce(val, op='mean')
|
||||
def is_dist(self)->bool:
|
||||
if not self._apex:
|
||||
return False
|
||||
return self._apex.is_dist()
|
||||
|
||||
|
||||
class Accumulator:
|
||||
# TODO: replace this with Metrics class
|
||||
|
|
|
@ -9,18 +9,18 @@ from overrides import EnforceOverrides
|
|||
from .metrics import Metrics
|
||||
from .config import Config
|
||||
from . import utils, ml_utils
|
||||
from .common import logger, get_device
|
||||
from archai.common.common import get_apex_utils
|
||||
from .common import logger
|
||||
from archai.common.apex_utils import ApexUtils
|
||||
|
||||
class Tester(EnforceOverrides):
|
||||
def __init__(self, conf_eval:Config, model:nn.Module)->None:
|
||||
# TODO: currently we expect that given model and dataloader will already be distributed
|
||||
def __init__(self, conf_eval:Config, model:nn.Module, apex:ApexUtils)->None:
|
||||
self._title = conf_eval['title']
|
||||
self._logger_freq = conf_eval['logger_freq']
|
||||
conf_lossfn = conf_eval['lossfn']
|
||||
|
||||
self._apex = apex
|
||||
self.model = model
|
||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(get_device())
|
||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(apex.device)
|
||||
self._metrics = None
|
||||
|
||||
def test(self, test_dl: DataLoader)->Metrics:
|
||||
|
@ -43,7 +43,7 @@ class Tester(EnforceOverrides):
|
|||
|
||||
with torch.no_grad(), logger.pushd('steps'):
|
||||
for step, (x, y) in enumerate(test_dl):
|
||||
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True)
|
||||
x, y = x.to(self._apex.device, non_blocking=True), y.to(self._apex.device, non_blocking=True)
|
||||
|
||||
assert not self.model.training # derived class might alter the mode
|
||||
logger.pushd(step)
|
||||
|
@ -57,7 +57,7 @@ class Tester(EnforceOverrides):
|
|||
self._post_step(x, y, logits, loss, steps, self._metrics)
|
||||
|
||||
# TODO: we possibly need to sync so all replicas are upto date
|
||||
get_apex_utils().sync_devices()
|
||||
self._apex.sync_devices()
|
||||
|
||||
logger.popd()
|
||||
self._metrics.post_epoch(None)
|
||||
|
@ -87,5 +87,5 @@ class Tester(EnforceOverrides):
|
|||
metrics.post_step(x, y, logits, loss, steps)
|
||||
|
||||
def _create_metrics(self)->Metrics:
|
||||
return Metrics(self._title, logger_freq=self._logger_freq)
|
||||
return Metrics(self._title, self._apex, logger_freq=self._logger_freq)
|
||||
|
||||
|
|
|
@ -11,8 +11,9 @@ from .metrics import Metrics
|
|||
from .tester import Tester
|
||||
from .config import Config
|
||||
from . import utils, ml_utils
|
||||
from ..common.common import logger, get_device, get_apex_utils
|
||||
from ..common.common import logger
|
||||
from ..common.checkpoint import CheckPoint
|
||||
from ..common.apex_utils import ApexUtils
|
||||
|
||||
|
||||
class Trainer(EnforceOverrides):
|
||||
|
@ -33,13 +34,13 @@ class Trainer(EnforceOverrides):
|
|||
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
|
||||
# endregion
|
||||
|
||||
get_apex_utils().reset(logger, conf_apex)
|
||||
self._apex = ApexUtils(conf_apex, logger)
|
||||
|
||||
self._checkpoint = checkpoint
|
||||
self.model = model
|
||||
|
||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn)
|
||||
self._tester = Tester(conf_validation, model) \
|
||||
self._tester = Tester(conf_validation, model, self._apex) \
|
||||
if conf_validation else None
|
||||
self._metrics:Optional[Metrics] = None
|
||||
|
||||
|
@ -52,7 +53,7 @@ class Trainer(EnforceOverrides):
|
|||
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
|
||||
logger.pushd(self._title)
|
||||
|
||||
self._metrics = Metrics(self._title, logger_freq=self._logger_freq)
|
||||
self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)
|
||||
|
||||
# optimizers, schedulers needs to be recreated for each fit call
|
||||
# as they have state specific to each run
|
||||
|
@ -60,10 +61,10 @@ class Trainer(EnforceOverrides):
|
|||
# create scheduler for optim before applying amp
|
||||
self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl))
|
||||
# before checkpoint restore, convert to amp
|
||||
self.model, self._optim = get_apex_utils().to_amp(self.model, optim,
|
||||
self.model, self._optim = self._apex.to_amp(self.model, optim,
|
||||
batch_size=train_dl.batch_size)
|
||||
|
||||
self._lossfn = self._lossfn.to(get_device())
|
||||
self._lossfn = self._lossfn.to(self.get_device())
|
||||
|
||||
self.pre_fit(train_dl, val_dl)
|
||||
|
||||
|
@ -158,7 +159,7 @@ class Trainer(EnforceOverrides):
|
|||
self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr'])
|
||||
|
||||
# checkpoint if enabled with given freq or if this is the last epoch
|
||||
if self._checkpoint is not None and get_apex_utils().is_master() and \
|
||||
if self._checkpoint is not None and self._apex.is_master() and \
|
||||
self._checkpoint.freq > 0 and (self._metrics.epochs() % self._checkpoint.freq == 0 or \
|
||||
self._metrics.epochs() >= self._epochs):
|
||||
self._checkpoint.new()
|
||||
|
@ -173,6 +174,9 @@ class Trainer(EnforceOverrides):
|
|||
self._metrics.post_step(x, y, logits, loss, steps)
|
||||
######################### hooks #########################
|
||||
|
||||
def get_device(self):
|
||||
return self._apex.device
|
||||
|
||||
def restore_checkpoint(self)->None:
|
||||
state = self._checkpoint['trainer']
|
||||
last_epoch = state['last_epoch']
|
||||
|
@ -180,7 +184,7 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
self._metrics.load_state_dict(state['metrics'])
|
||||
assert self._metrics.epochs() == last_epoch+1
|
||||
get_apex_utils().load_state_dict(state['amp'])
|
||||
self._apex.load_state_dict(state['amp'])
|
||||
self.model.load_state_dict(state['model'])
|
||||
self._optim.load_state_dict(state['optim'])
|
||||
if self._sched:
|
||||
|
@ -198,7 +202,7 @@ class Trainer(EnforceOverrides):
|
|||
'model': self.model.state_dict(),
|
||||
'optim': self._optim.state_dict(),
|
||||
'sched': self._sched.state_dict() if self._sched else None,
|
||||
'amp': get_apex_utils().state_dict()
|
||||
'amp': self._apex.state_dict()
|
||||
}
|
||||
self._checkpoint['trainer'] = state
|
||||
|
||||
|
@ -208,7 +212,7 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
logger.pushd('steps')
|
||||
for step, (x, y) in enumerate(train_dl):
|
||||
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True)
|
||||
x, y = x.to(self.get_device(), non_blocking=True), y.to(self.get_device(), non_blocking=True)
|
||||
|
||||
logger.pushd(step)
|
||||
assert self.model.training # derived class might alter the mode
|
||||
|
@ -226,15 +230,15 @@ class Trainer(EnforceOverrides):
|
|||
loss = self.compute_loss(self._lossfn, x, y, logits,
|
||||
self._aux_weight, aux_logits)
|
||||
|
||||
get_apex_utils().backward(loss, self._optim)
|
||||
self._apex.backward(loss, self._optim)
|
||||
|
||||
# TODO: original darts clips alphas as well but pt.darts doesn't
|
||||
get_apex_utils().clip_grad(self._grad_clip, self.model, self._optim)
|
||||
self._apex.clip_grad(self._grad_clip, self.model, self._optim)
|
||||
|
||||
self._optim.step()
|
||||
|
||||
# TODO: we possibly need to sync so all replicas are upto date
|
||||
get_apex_utils().sync_devices()
|
||||
self._apex.sync_devices()
|
||||
|
||||
if self._sched and not self._sched_on_epoch:
|
||||
self._sched.step()
|
||||
|
|
|
@ -55,7 +55,7 @@ def deep_update(d:MutableMapping, u:Mapping, map_type:Type[MutableMapping]=dict)
|
|||
return d
|
||||
|
||||
def state_dict(val)->Mapping:
|
||||
assert hasattr(val, '__dict__'), 'val must be object with __dict__'
|
||||
assert hasattr(val, '__dict__'), 'val must be object with __dict__ otherwise it cannot be loaded back in load_state_dict'
|
||||
|
||||
# Can't do below because val has state_dict() which calls utils.state_dict
|
||||
# if has_method(val, 'state_dict'):
|
||||
|
@ -76,8 +76,8 @@ def load_state_dict(val:Any, state_dict:Mapping)->None:
|
|||
assert s is not None, 'state_dict must contain yaml key'
|
||||
|
||||
obj = yaml.load(s, Loader=yaml.Loader)
|
||||
for k in val.__dict__.keys():
|
||||
setattr(val, k, getattr(obj, k))
|
||||
for k, v in obj.__dict__.items():
|
||||
setattr(val, k, v)
|
||||
|
||||
def deep_comp(o1:Any, o2:Any)->bool:
|
||||
# NOTE: dict don't have __dict__
|
||||
|
|
|
@ -19,8 +19,8 @@ common:
|
|||
redis: null
|
||||
apex: # this is overriden in search and eval individually
|
||||
enabled: False # global switch to disable anything apex
|
||||
distributed_enabled: False # enable/disable distributed mode
|
||||
mixed_prec_enabled: False # switch to disable amp mixed precision
|
||||
distributed_enabled: True # enable/disable distributed mode
|
||||
mixed_prec_enabled: True # switch to disable amp mixed precision
|
||||
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
|
||||
opt_level: 'O2' # optimization level for mixed precision
|
||||
bn_fp32: True # keep BN in fp32
|
||||
|
@ -88,6 +88,7 @@ nas:
|
|||
trainer:
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
enabled: True
|
||||
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
|
||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
|
|
|
@ -1,10 +1,27 @@
|
|||
from collections import UserDict
|
||||
import yaml
|
||||
from typing import Iterator
|
||||
|
||||
y = """
|
||||
a: .NaN
|
||||
class A(object):
|
||||
def __init__(self):
|
||||
self.hidden = 42
|
||||
self.visible = 5
|
||||
|
||||
"""
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
del state['hidden'] # cannot serialize this
|
||||
return state
|
||||
|
||||
d=yaml.load(y, Loader=yaml.Loader)
|
||||
a = A()
|
||||
d = yaml.dump(a)
|
||||
print(d)
|
||||
print(type( d['a']))
|
||||
|
||||
|
||||
# y = """
|
||||
# a: .NaN
|
||||
|
||||
# """
|
||||
|
||||
# d=yaml.load(y, Loader=yaml.Loader)
|
||||
# print(d)
|
||||
# print(type( d['a']))
|
||||
|
|
|
@ -5,7 +5,7 @@ from archai import cifar10_models
|
|||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger, common_init, get_device
|
||||
from archai.common.common import logger, common_init
|
||||
from archai.datasets import data
|
||||
|
||||
def train_test(conf_eval:Config):
|
||||
|
@ -24,7 +24,7 @@ def train_test(conf_eval:Config):
|
|||
conf_trainer['aux_weight'] = 0.0
|
||||
|
||||
Net = cifar10_models.resnet34
|
||||
model = Net().to(get_device())
|
||||
model = Net().to(torch.device('cuda', 0))
|
||||
|
||||
# get data
|
||||
train_dl, _, test_dl = data.get_data(conf_loader)
|
||||
|
|
Загрузка…
Ссылка в новой задаче