зеркало из https://github.com/microsoft/archai.git
Move apex to trainer
This commit is contained in:
Родитель
0235072c8e
Коммит
d861b5f2a6
|
@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer
|
||||||
from archai.common import utils, ml_utils
|
from archai.common import utils, ml_utils
|
||||||
from archai.nas.model import Model
|
from archai.nas.model import Model
|
||||||
from archai.common.checkpoint import CheckPoint
|
from archai.common.checkpoint import CheckPoint
|
||||||
from archai.common.common import logger, get_device
|
from archai.common.common import logger
|
||||||
from .bilevel_optimizer import BilevelOptimizer
|
from .bilevel_optimizer import BilevelOptimizer
|
||||||
|
|
||||||
class BilevelArchTrainer(ArchTrainer):
|
class BilevelArchTrainer(ArchTrainer):
|
||||||
|
@ -36,10 +36,11 @@ class BilevelArchTrainer(ArchTrainer):
|
||||||
assert val_dl is not None
|
assert val_dl is not None
|
||||||
w_momentum = self._conf_w_optim['momentum']
|
w_momentum = self._conf_w_optim['momentum']
|
||||||
w_decay = self._conf_w_optim['decay']
|
w_decay = self._conf_w_optim['decay']
|
||||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device())
|
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
|
||||||
|
|
||||||
self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum,
|
self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum,
|
||||||
w_decay, self.model, lossfn)
|
w_decay, self.model, lossfn,
|
||||||
|
self.get_device())
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
|
def post_fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->None:
|
||||||
|
@ -71,7 +72,7 @@ class BilevelArchTrainer(ArchTrainer):
|
||||||
self._valid_iter = iter(self._val_dl)
|
self._valid_iter = iter(self._val_dl)
|
||||||
x_val, y_val = next(self._valid_iter)
|
x_val, y_val = next(self._valid_iter)
|
||||||
|
|
||||||
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True)
|
x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True)
|
||||||
|
|
||||||
# update alphas
|
# update alphas
|
||||||
self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer())
|
self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer())
|
||||||
|
|
|
@ -9,11 +9,11 @@ from torch.optim.optimizer import Optimizer
|
||||||
from archai.common.config import Config
|
from archai.common.config import Config
|
||||||
from archai.common import utils, ml_utils
|
from archai.common import utils, ml_utils
|
||||||
from archai.nas.model import Model
|
from archai.nas.model import Model
|
||||||
from archai.common.common import logger, get_device
|
from archai.common.common import logger
|
||||||
|
|
||||||
class BilevelOptimizer:
|
class BilevelOptimizer:
|
||||||
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
|
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
|
||||||
model: Model, lossfn: _Loss) -> None:
|
model: Model, lossfn: _Loss, device) -> None:
|
||||||
self._w_momentum = w_momentum # momentum for w
|
self._w_momentum = w_momentum # momentum for w
|
||||||
self._w_weight_decay = w_decay # weight decay for w
|
self._w_weight_decay = w_decay # weight decay for w
|
||||||
self._lossfn = lossfn
|
self._lossfn = lossfn
|
||||||
|
@ -22,7 +22,7 @@ class BilevelOptimizer:
|
||||||
# create a copy of model which we will use
|
# create a copy of model which we will use
|
||||||
# to compute grads for alphas without disturbing
|
# to compute grads for alphas without disturbing
|
||||||
# original weights
|
# original weights
|
||||||
self._vmodel = copy.deepcopy(model).to(get_device())
|
self._vmodel = copy.deepcopy(model).to(device)
|
||||||
# this is the optimizer to optimize alphas parameter
|
# this is the optimizer to optimize alphas parameter
|
||||||
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())
|
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from archai.nas.arch_trainer import ArchTrainer
|
||||||
from archai.common import utils, ml_utils
|
from archai.common import utils, ml_utils
|
||||||
from archai.nas.model import Model
|
from archai.nas.model import Model
|
||||||
from archai.common.checkpoint import CheckPoint
|
from archai.common.checkpoint import CheckPoint
|
||||||
from archai.common.common import logger, get_device
|
from archai.common.common import logger
|
||||||
|
|
||||||
|
|
||||||
class XnasArchTrainer(ArchTrainer):
|
class XnasArchTrainer(ArchTrainer):
|
||||||
|
@ -39,7 +39,7 @@ class XnasArchTrainer(ArchTrainer):
|
||||||
# optimizers, schedulers needs to be recreated for each fit call
|
# optimizers, schedulers needs to be recreated for each fit call
|
||||||
# as they have state
|
# as they have state
|
||||||
assert val_dl is not None
|
assert val_dl is not None
|
||||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device())
|
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())
|
||||||
|
|
||||||
self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn)
|
self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn)
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class XnasArchTrainer(ArchTrainer):
|
||||||
self._valid_iter = iter(self._val_dl)
|
self._valid_iter = iter(self._val_dl)
|
||||||
x_val, y_val = next(self._valid_iter)
|
x_val, y_val = next(self._valid_iter)
|
||||||
|
|
||||||
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True)
|
x_val, y_val = x_val.to(self.get_device()), y_val.to(self.get_device(), non_blocking=True)
|
||||||
|
|
||||||
# update alphas
|
# update alphas
|
||||||
self._xnas_optim.step(x, y, x_val, y_val)
|
self._xnas_optim.step(x, y, x_val, y_val)
|
||||||
|
|
|
@ -15,23 +15,7 @@ from archai.common import ml_utils, utils
|
||||||
from archai.common.ordereddict_logger import OrderedDictLogger
|
from archai.common.ordereddict_logger import OrderedDictLogger
|
||||||
|
|
||||||
class ApexUtils:
|
class ApexUtils:
|
||||||
def __init__(self)->None:
|
def __init__(self, apex_config:Config, logger:Optional[OrderedDictLogger])->None:
|
||||||
self._amp = self._ddp = None
|
|
||||||
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
|
|
||||||
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
|
|
||||||
self.gpu_ids = [] # use all gpus
|
|
||||||
|
|
||||||
self._mixed_prec_enabled = False
|
|
||||||
self._distributed_enabled = False
|
|
||||||
|
|
||||||
self._set_ranks()
|
|
||||||
|
|
||||||
def reset(self, logger:OrderedDictLogger, apex_config:Config)->None:
|
|
||||||
# reset allows to configure differently for search or eval modes
|
|
||||||
|
|
||||||
# to avoid circular references= with common, logger is passed from outside
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
# region conf vars
|
# region conf vars
|
||||||
self._enabled = apex_config['enabled'] # global switch to disable anything apex
|
self._enabled = apex_config['enabled'] # global switch to disable anything apex
|
||||||
self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode
|
self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode
|
||||||
|
@ -47,42 +31,53 @@ class ApexUtils:
|
||||||
conf_gpu_ids = apex_config['gpus']
|
conf_gpu_ids = apex_config['gpus']
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
|
# to avoid circular references= with common, logger is passed from outside
|
||||||
self._amp, self._ddp = None, None
|
self.logger = logger
|
||||||
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 # which GPU to use, we will use only 1 GPU
|
|
||||||
|
|
||||||
#logger.info({'apex_config': apex_config.to_dict()})
|
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
|
||||||
logger.info({'torch.distributed.is_available': dist.is_available()})
|
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
|
||||||
|
|
||||||
|
# defaults for non-distributed mode
|
||||||
|
self._amp, self._ddp = None, None
|
||||||
|
self.world_size = 1
|
||||||
|
self.local_rank = self.global_rank = 0
|
||||||
|
|
||||||
|
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
|
||||||
|
|
||||||
|
# which GPU to use, we will use only 1 GPU per process to avoid complications with apex
|
||||||
|
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0
|
||||||
|
|
||||||
|
#_log_info({'apex_config': apex_config.to_dict()})
|
||||||
|
self._log_info({'torch.distributed.is_available': dist.is_available()})
|
||||||
if dist.is_available():
|
if dist.is_available():
|
||||||
logger.info({'gloo_available': dist.is_gloo_available(),
|
self._log_info({'gloo_available': dist.is_gloo_available(),
|
||||||
'mpi_available': dist.is_mpi_available(),
|
'mpi_available': dist.is_mpi_available(),
|
||||||
'nccl_available': dist.is_nccl_available()})
|
'nccl_available': dist.is_nccl_available()})
|
||||||
|
|
||||||
if self._enabled:
|
if self.is_mixed():
|
||||||
if self._mixed_prec_enabled:
|
# init enable mixed precision
|
||||||
# init enable mixed precision
|
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
|
||||||
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
|
from apex import amp
|
||||||
from apex import amp
|
self._amp = amp
|
||||||
self._amp = amp
|
|
||||||
|
|
||||||
# enable distributed processing
|
# enable distributed processing
|
||||||
if self._distributed_enabled:
|
if self.is_dist():
|
||||||
from apex import parallel
|
from apex import parallel
|
||||||
self._ddp = parallel
|
self._ddp = parallel
|
||||||
|
|
||||||
assert dist.is_available() # distributed module is available
|
assert dist.is_available() # distributed module is available
|
||||||
assert dist.is_nccl_available()
|
assert dist.is_nccl_available()
|
||||||
if not dist.is_initialized():
|
if not dist.is_initialized():
|
||||||
dist.init_process_group(backend='nccl', init_method='env://')
|
dist.init_process_group(backend='nccl', init_method='env://')
|
||||||
assert dist.is_initialized()
|
assert dist.is_initialized()
|
||||||
|
|
||||||
self._set_ranks()
|
self._set_ranks()
|
||||||
assert dist.get_world_size() == self.world_size
|
assert dist.get_world_size() == self.world_size
|
||||||
assert dist.get_rank() == self.global_rank
|
assert dist.get_rank() == self.global_rank
|
||||||
else:
|
else:
|
||||||
assert self.world_size == 1
|
assert self.world_size == 1
|
||||||
assert self.local_rank == 0
|
assert self.local_rank == 0
|
||||||
assert self.global_rank == 0
|
assert self.global_rank == 0
|
||||||
|
|
||||||
assert self.world_size >= 1
|
assert self.world_size >= 1
|
||||||
assert not self._min_world_size or self.world_size >= self._min_world_size
|
assert not self._min_world_size or self.world_size >= self._min_world_size
|
||||||
|
@ -94,9 +89,9 @@ class ApexUtils:
|
||||||
self.device = torch.device('cuda', self._gpu)
|
self.device = torch.device('cuda', self._gpu)
|
||||||
self._setup_gpus(seed, detect_anomaly)
|
self._setup_gpus(seed, detect_anomaly)
|
||||||
|
|
||||||
logger.info({'amp_available': self._amp is not None,
|
self._log_info({'amp_available': self._amp is not None,
|
||||||
'distributed_available': self._ddp is not None})
|
'distributed_available': self._ddp is not None})
|
||||||
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
|
self._log_info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
|
||||||
'world_size': self.world_size,
|
'world_size': self.world_size,
|
||||||
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
|
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
|
||||||
'local_rank': self.local_rank})
|
'local_rank': self.local_rank})
|
||||||
|
@ -106,10 +101,10 @@ class ApexUtils:
|
||||||
utils.setup_cuda(seed, self.local_rank)
|
utils.setup_cuda(seed, self.local_rank)
|
||||||
|
|
||||||
torch.autograd.set_detect_anomaly(detect_anomaly)
|
torch.autograd.set_detect_anomaly(detect_anomaly)
|
||||||
self.logger.info({'set_detect_anomaly': detect_anomaly,
|
self._log_info({'set_detect_anomaly': detect_anomaly,
|
||||||
'is_anomaly_enabled': torch.is_anomaly_enabled()})
|
'is_anomaly_enabled': torch.is_anomaly_enabled()})
|
||||||
|
|
||||||
self.logger.info({'gpu_names': utils.cuda_device_names(),
|
self._log_info({'gpu_names': utils.cuda_device_names(),
|
||||||
'gpu_count': torch.cuda.device_count(),
|
'gpu_count': torch.cuda.device_count(),
|
||||||
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
|
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
|
||||||
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
|
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
|
||||||
|
@ -118,8 +113,8 @@ class ApexUtils:
|
||||||
'cudnn.deterministic': cudnn.deterministic,
|
'cudnn.deterministic': cudnn.deterministic,
|
||||||
'cudnn.version': cudnn.version()
|
'cudnn.version': cudnn.version()
|
||||||
})
|
})
|
||||||
self.logger.info({'memory': str(psutil.virtual_memory())})
|
self._log_info({'memory': str(psutil.virtual_memory())})
|
||||||
self.logger.info({'CPUs': str(psutil.cpu_count())})
|
self._log_info({'CPUs': str(psutil.cpu_count())})
|
||||||
|
|
||||||
# gpu_usage = os.popen(
|
# gpu_usage = os.popen(
|
||||||
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
|
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
|
||||||
|
@ -127,7 +122,7 @@ class ApexUtils:
|
||||||
# for i, line in enumerate(gpu_usage):
|
# for i, line in enumerate(gpu_usage):
|
||||||
# vals = line.split(',')
|
# vals = line.split(',')
|
||||||
# if len(vals) == 2:
|
# if len(vals) == 2:
|
||||||
# logger.info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
|
# _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
|
||||||
|
|
||||||
def _set_ranks(self)->None:
|
def _set_ranks(self)->None:
|
||||||
if 'WORLD_SIZE' in os.environ:
|
if 'WORLD_SIZE' in os.environ:
|
||||||
|
@ -153,18 +148,22 @@ class ApexUtils:
|
||||||
self._gpu = self.gpu_ids[self.local_rank]
|
self._gpu = self.gpu_ids[self.local_rank]
|
||||||
|
|
||||||
def is_mixed(self)->bool:
|
def is_mixed(self)->bool:
|
||||||
return self._mixed_prec_enabled
|
return self._enabled and self._mixed_prec_enabled
|
||||||
def is_dist(self)->bool:
|
def is_dist(self)->bool:
|
||||||
return self._distributed_enabled
|
return self._enabled and self._distributed_enabled
|
||||||
def is_master(self)->bool:
|
def is_master(self)->bool:
|
||||||
return self.global_rank == 0
|
return self.global_rank == 0
|
||||||
|
|
||||||
|
def _log_info(self, d:dict)->None:
|
||||||
|
if self.logger is not None:
|
||||||
|
self.logger.info(d)
|
||||||
|
|
||||||
def sync_devices(self)->None:
|
def sync_devices(self)->None:
|
||||||
if self._distributed_enabled:
|
if self.is_dist():
|
||||||
torch.cuda.synchronize(self.device)
|
torch.cuda.synchronize(self.device)
|
||||||
|
|
||||||
def reduce(self, val, op='mean'):
|
def reduce(self, val, op='mean'):
|
||||||
if self._distributed_enabled:
|
if self.is_dist():
|
||||||
if not isinstance(val, Tensor):
|
if not isinstance(val, Tensor):
|
||||||
rt = torch.tensor(val).to(self.device)
|
rt = torch.tensor(val).to(self.device)
|
||||||
converted = True
|
converted = True
|
||||||
|
@ -184,7 +183,7 @@ class ApexUtils:
|
||||||
return val
|
return val
|
||||||
|
|
||||||
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
|
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
|
||||||
if self._mixed_prec_enabled:
|
if self.is_mixed():
|
||||||
with self._amp.scale_loss(loss, optim) as scaled_loss:
|
with self._amp.scale_loss(loss, optim) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
|
@ -193,26 +192,26 @@ class ApexUtils:
|
||||||
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
|
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
|
||||||
->Tuple[nn.Module, Optimizer]:
|
->Tuple[nn.Module, Optimizer]:
|
||||||
# conver BNs to sync BNs in distributed mode
|
# conver BNs to sync BNs in distributed mode
|
||||||
if self._distributed_enabled and self._sync_bn:
|
if self.is_dist() and self._sync_bn:
|
||||||
model = self._ddp.convert_syncbn_model(model)
|
model = self._ddp.convert_syncbn_model(model)
|
||||||
self.logger.info({'BNs_converted': True})
|
self._log_info({'BNs_converted': True})
|
||||||
|
|
||||||
model = model.to(self.device)
|
model = model.to(self.device)
|
||||||
|
|
||||||
if self._mixed_prec_enabled:
|
if self.is_mixed():
|
||||||
# scale LR
|
# scale LR
|
||||||
if self._scale_lr:
|
if self._scale_lr:
|
||||||
lr = ml_utils.get_optim_lr(optim)
|
lr = ml_utils.get_optim_lr(optim)
|
||||||
scaled_lr = lr * self.world_size / float(batch_size)
|
scaled_lr = lr * self.world_size / float(batch_size)
|
||||||
ml_utils.set_optim_lr(optim, scaled_lr)
|
ml_utils.set_optim_lr(optim, scaled_lr)
|
||||||
self.logger.info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
|
self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
|
||||||
|
|
||||||
model, optim = self._amp.initialize(
|
model, optim = self._amp.initialize(
|
||||||
model, optim, opt_level=self._opt_level,
|
model, optim, opt_level=self._opt_level,
|
||||||
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
|
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._distributed_enabled:
|
if self.is_dist():
|
||||||
# By default, apex.parallel.DistributedDataParallel overlaps communication with
|
# By default, apex.parallel.DistributedDataParallel overlaps communication with
|
||||||
# computation in the backward pass.
|
# computation in the backward pass.
|
||||||
# delay_allreduce delays all communication to the end of the backward pass.
|
# delay_allreduce delays all communication to the end of the backward pass.
|
||||||
|
@ -222,19 +221,19 @@ class ApexUtils:
|
||||||
|
|
||||||
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
|
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
|
||||||
if clip > 0.0:
|
if clip > 0.0:
|
||||||
if self._mixed_prec_enabled:
|
if self.is_mixed():
|
||||||
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
|
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
|
||||||
else:
|
else:
|
||||||
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
if self._mixed_prec_enabled:
|
if self.is_mixed():
|
||||||
return self._amp.state_dict()
|
return self._amp.state_dict()
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
if self._mixed_prec_enabled:
|
if self.is_mixed():
|
||||||
self._amp.load_state_dict()
|
self._amp.load_state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,42 +30,11 @@ class SummaryWriterDummy:
|
||||||
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter]
|
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter]
|
||||||
logger = OrderedDictLogger(None, None)
|
logger = OrderedDictLogger(None, None)
|
||||||
_tb_writer: SummaryWriterAny = None
|
_tb_writer: SummaryWriterAny = None
|
||||||
_apex_utils = ApexUtils()
|
|
||||||
_atexit_reg = False # is hook for atexit registered?
|
_atexit_reg = False # is hook for atexit registered?
|
||||||
|
|
||||||
def get_conf()->Config:
|
def get_conf()->Config:
|
||||||
return Config.get()
|
return Config.get()
|
||||||
|
|
||||||
def get_device():
|
|
||||||
global _apex_utils
|
|
||||||
return _apex_utils.device
|
|
||||||
|
|
||||||
def get_apex_utils()->ApexUtils:
|
|
||||||
global _apex_utils
|
|
||||||
assert _apex_utils
|
|
||||||
return _apex_utils
|
|
||||||
|
|
||||||
def is_dist()->bool:
|
|
||||||
global _apex_utils
|
|
||||||
return _apex_utils.is_dist()
|
|
||||||
|
|
||||||
def reduce_min(val):
|
|
||||||
global _apex_utils
|
|
||||||
return _apex_utils.reduce(val, op='min')
|
|
||||||
def reduce_max(val):
|
|
||||||
global _apex_utils
|
|
||||||
return _apex_utils.reduce(val, op='max')
|
|
||||||
def reduce_sum(val):
|
|
||||||
global _apex_utils
|
|
||||||
return _apex_utils.reduce(val, op='sum')
|
|
||||||
def reduce_mean(val):
|
|
||||||
global _apex_utils
|
|
||||||
return _apex_utils.reduce(val, op='mean')
|
|
||||||
|
|
||||||
def is_dist()->bool:
|
|
||||||
global _apex_utils
|
|
||||||
return _apex_utils.is_dist()
|
|
||||||
|
|
||||||
def get_conf_common()->Config:
|
def get_conf_common()->Config:
|
||||||
return get_conf()['common']
|
return get_conf()['common']
|
||||||
|
|
||||||
|
@ -138,14 +107,18 @@ def common_init(config_filepath: Optional[str]=None,
|
||||||
logger.info({'expdir': expdir,
|
logger.info({'expdir': expdir,
|
||||||
'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})
|
'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})
|
||||||
|
|
||||||
|
# create a[ex to know distributed processing paramters
|
||||||
|
conf_apex = get_conf_common()['apex']
|
||||||
|
apex = ApexUtils(conf_apex, None)
|
||||||
|
|
||||||
# create global logger
|
# create global logger
|
||||||
_setup_logger()
|
_setup_logger(apex)
|
||||||
# create info file for current system
|
# create info file for current system
|
||||||
_create_sysinfo(conf)
|
_create_sysinfo(conf)
|
||||||
|
|
||||||
# setup tensorboard
|
# setup tensorboard
|
||||||
global _tb_writer
|
global _tb_writer
|
||||||
_tb_writer = _create_tb_writer(get_apex_utils().is_master())
|
_tb_writer = _create_tb_writer(apex.is_master())
|
||||||
|
|
||||||
# create hooks to execute code when script exits
|
# create hooks to execute code when script exits
|
||||||
global _atexit_reg
|
global _atexit_reg
|
||||||
|
@ -216,17 +189,18 @@ def _setup_dirs()->Optional[str]:
|
||||||
os.environ['distdir'] = conf_common['distdir'] = distdir
|
os.environ['distdir'] = conf_common['distdir'] = distdir
|
||||||
|
|
||||||
|
|
||||||
def _setup_logger():
|
def _setup_logger(apex:ApexUtils):
|
||||||
global logger
|
global logger
|
||||||
logger.close() # close any previous instances
|
logger.close() # close any previous instances
|
||||||
|
|
||||||
conf_common = get_conf_common()
|
conf_common = get_conf_common()
|
||||||
expdir = conf_common['expdir']
|
expdir = conf_common['expdir']
|
||||||
distdir = conf_common['distdir']
|
distdir = conf_common['distdir']
|
||||||
global_rank = get_apex_utils().global_rank
|
|
||||||
|
global_rank = apex.global_rank
|
||||||
|
|
||||||
# file where logger would log messages
|
# file where logger would log messages
|
||||||
if get_apex_utils().is_master():
|
if apex.is_master():
|
||||||
sys_log_filepath = utils.full_path(os.path.join(expdir, 'logs.log'))
|
sys_log_filepath = utils.full_path(os.path.join(expdir, 'logs.log'))
|
||||||
logs_yaml_filepath = utils.full_path(os.path.join(expdir, 'logs.yaml'))
|
logs_yaml_filepath = utils.full_path(os.path.join(expdir, 'logs.yaml'))
|
||||||
experiment_name = get_experiment_name()
|
experiment_name = get_experiment_name()
|
||||||
|
|
|
@ -11,7 +11,8 @@ from torch import Tensor
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from . import utils, ml_utils
|
from . import utils, ml_utils
|
||||||
from .common import logger, get_tb_writer, is_dist, reduce_mean, reduce_sum, reduce_min, reduce_max
|
from .common import logger, get_tb_writer
|
||||||
|
from .apex_utils import ApexUtils
|
||||||
|
|
||||||
|
|
||||||
class Metrics:
|
class Metrics:
|
||||||
|
@ -29,7 +30,7 @@ class Metrics:
|
||||||
best we have seen for each epoch.
|
best we have seen for each epoch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, title:str, logger_freq:int=50) -> None:
|
def __init__(self, title:str, apex:Optional[ApexUtils], logger_freq:int=50) -> None:
|
||||||
"""Create the metrics object to maintain epoch stats
|
"""Create the metrics object to maintain epoch stats
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -39,6 +40,7 @@ class Metrics:
|
||||||
"""
|
"""
|
||||||
self.logger_freq = logger_freq
|
self.logger_freq = logger_freq
|
||||||
self.title = title
|
self.title = title
|
||||||
|
self._apex = apex
|
||||||
self._reset_run()
|
self._reset_run()
|
||||||
|
|
||||||
def _reset_run(self)->None:
|
def _reset_run(self)->None:
|
||||||
|
@ -59,27 +61,27 @@ class Metrics:
|
||||||
logger.info({'epoch':self.run_metrics.epoch_time_avg(),
|
logger.info({'epoch':self.run_metrics.epoch_time_avg(),
|
||||||
'step': self.run_metrics.step_time_avg(),
|
'step': self.run_metrics.step_time_avg(),
|
||||||
'run': self.run_metrics.duration()})
|
'run': self.run_metrics.duration()})
|
||||||
if is_dist():
|
if self.is_dist():
|
||||||
logger.info({'dist_epoch_sum': reduce_sum(self.run_metrics.epoch_time_avg()),
|
logger.info({'dist_epoch_sum': self.reduce_sum(self.run_metrics.epoch_time_avg()),
|
||||||
'dist_step': reduce_mean(self.run_metrics.step_time_avg()),
|
'dist_step': self.reduce_mean(self.run_metrics.step_time_avg()),
|
||||||
'dist_run_sum': reduce_sum(self.run_metrics.duration())})
|
'dist_run_sum': self.reduce_sum(self.run_metrics.duration())})
|
||||||
|
|
||||||
|
|
||||||
best_train, best_val = self.run_metrics.best_epoch()
|
best_train, best_val = self.run_metrics.best_epoch()
|
||||||
with logger.pushd('best_train'):
|
with logger.pushd('best_train'):
|
||||||
logger.info({'epoch': best_train.index,
|
logger.info({'epoch': best_train.index,
|
||||||
'top1': best_train.top1.avg})
|
'top1': best_train.top1.avg})
|
||||||
if is_dist():
|
if self.is_dist():
|
||||||
logger.info({'dist_epoch': reduce_mean(best_train.index),
|
logger.info({'dist_epoch': self.reduce_mean(best_train.index),
|
||||||
'dist_top1': reduce_mean(best_train.top1.avg)})
|
'dist_top1': self.reduce_mean(best_train.top1.avg)})
|
||||||
|
|
||||||
if best_val:
|
if best_val:
|
||||||
with logger.pushd('best_val'):
|
with logger.pushd('best_val'):
|
||||||
logger.info({'epoch': best_val.index,
|
logger.info({'epoch': best_val.index,
|
||||||
'top1': best_val.val_metrics.top1.avg})
|
'top1': best_val.val_metrics.top1.avg})
|
||||||
if is_dist():
|
if self.is_dist():
|
||||||
logger.info({'dist_epoch': reduce_mean(best_val.index),
|
logger.info({'dist_epoch': self.reduce_mean(best_val.index),
|
||||||
'dist_top1': reduce_mean(best_val.val_metrics.top1.avg)})
|
'dist_top1': self.reduce_mean(best_val.val_metrics.top1.avg)})
|
||||||
|
|
||||||
def pre_step(self, x: Tensor, y: Tensor):
|
def pre_step(self, x: Tensor, y: Tensor):
|
||||||
self.run_metrics.cur_epoch().pre_step()
|
self.run_metrics.cur_epoch().pre_step()
|
||||||
|
@ -102,11 +104,11 @@ class Metrics:
|
||||||
'loss': epoch.loss.avg,
|
'loss': epoch.loss.avg,
|
||||||
'step_time': epoch.step_time.last})
|
'step_time': epoch.step_time.last})
|
||||||
|
|
||||||
if is_dist():
|
if self.is_dist():
|
||||||
logger.info({'dist_top1': reduce_mean(epoch.top1.avg),
|
logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg),
|
||||||
'dist_top5': reduce_mean(epoch.top5.avg),
|
'dist_top5': self.reduce_mean(epoch.top5.avg),
|
||||||
'dist_loss': reduce_mean(epoch.loss.avg),
|
'dist_loss': self.reduce_mean(epoch.loss.avg),
|
||||||
'dist_step_time': reduce_mean(epoch.step_time.last)})
|
'dist_step_time': self.reduce_mean(epoch.step_time.last)})
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Tensorboard step-level logging is removed as it becomes exponentially expensive on Azure blobs
|
# NOTE: Tensorboard step-level logging is removed as it becomes exponentially expensive on Azure blobs
|
||||||
|
@ -143,24 +145,24 @@ class Metrics:
|
||||||
'duration': epoch.duration(),
|
'duration': epoch.duration(),
|
||||||
'step_time': epoch.step_time.avg,
|
'step_time': epoch.step_time.avg,
|
||||||
'end_lr': lr})
|
'end_lr': lr})
|
||||||
if is_dist():
|
if self.is_dist():
|
||||||
logger.info({'dist_top1': reduce_mean(epoch.top1.avg),
|
logger.info({'dist_top1': self.reduce_mean(epoch.top1.avg),
|
||||||
'dist_top5': reduce_mean(epoch.top5.avg),
|
'dist_top5': self.reduce_mean(epoch.top5.avg),
|
||||||
'dist_loss': reduce_mean(epoch.loss.avg),
|
'dist_loss': self.reduce_mean(epoch.loss.avg),
|
||||||
'dist_duration': reduce_mean(epoch.duration()),
|
'dist_duration': self.reduce_mean(epoch.duration()),
|
||||||
'dist_step_time': reduce_mean(epoch.step_time.avg),
|
'dist_step_time': self.reduce_mean(epoch.step_time.avg),
|
||||||
'dist_end_lr': reduce_mean(lr)})
|
'dist_end_lr': self.reduce_mean(lr)})
|
||||||
if test_epoch:
|
if test_epoch:
|
||||||
with logger.pushd('val'):
|
with logger.pushd('val'):
|
||||||
logger.info({'top1': test_epoch.top1.avg,
|
logger.info({'top1': test_epoch.top1.avg,
|
||||||
'top5': test_epoch.top5.avg,
|
'top5': test_epoch.top5.avg,
|
||||||
'loss': test_epoch.loss.avg,
|
'loss': test_epoch.loss.avg,
|
||||||
'duration': epoch.duration()})
|
'duration': epoch.duration()})
|
||||||
if is_dist():
|
if self.is_dist():
|
||||||
logger.info({'dist_top1': reduce_mean(test_epoch.top1.avg),
|
logger.info({'dist_top1': self.reduce_mean(test_epoch.top1.avg),
|
||||||
'dist_top5': reduce_mean(test_epoch.top5.avg),
|
'dist_top5': self.reduce_mean(test_epoch.top5.avg),
|
||||||
'dist_loss': reduce_mean(test_epoch.loss.avg),
|
'dist_loss': self.reduce_mean(test_epoch.loss.avg),
|
||||||
'dist_duration': reduce_mean(test_epoch.duration())})
|
'dist_duration': self.reduce_mean(test_epoch.duration())})
|
||||||
|
|
||||||
# writer = get_tb_writer()
|
# writer = get_tb_writer()
|
||||||
# writer.add_scalar(f'{self._tb_path}/train_epochs/loss',
|
# writer.add_scalar(f'{self._tb_path}/train_epochs/loss',
|
||||||
|
@ -181,9 +183,14 @@ class Metrics:
|
||||||
return utils.state_dict(self)
|
return utils.state_dict(self)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict:dict)->None:
|
def load_state_dict(self, state_dict:dict)->None:
|
||||||
# simply convert current object to dictionary
|
|
||||||
utils.load_state_dict(self, state_dict)
|
utils.load_state_dict(self, state_dict)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
del state['_apex'] # cannot serialize this
|
||||||
|
return state
|
||||||
|
# no need to define __setstate__ because _apex should be set from constructor
|
||||||
|
|
||||||
def save(self, filepath:str)->Optional[str]:
|
def save(self, filepath:str)->Optional[str]:
|
||||||
if filepath:
|
if filepath:
|
||||||
filepath = utils.full_path(filepath)
|
filepath = utils.full_path(filepath)
|
||||||
|
@ -197,6 +204,27 @@ class Metrics:
|
||||||
def cur_epoch(self)->'EpochMetrics':
|
def cur_epoch(self)->'EpochMetrics':
|
||||||
return self.run_metrics.cur_epoch()
|
return self.run_metrics.cur_epoch()
|
||||||
|
|
||||||
|
def reduce_min(self, val):
|
||||||
|
if not self._apex:
|
||||||
|
return val
|
||||||
|
return self._apex.reduce(val, op='min')
|
||||||
|
def reduce_max(self, val):
|
||||||
|
if not self._apex:
|
||||||
|
return val
|
||||||
|
return self._apex.reduce(val, op='max')
|
||||||
|
def reduce_sum(self, val):
|
||||||
|
if not self._apex:
|
||||||
|
return val
|
||||||
|
return self._apex.reduce(val, op='sum')
|
||||||
|
def reduce_mean(self, val):
|
||||||
|
if not self._apex:
|
||||||
|
return val
|
||||||
|
return self._apex.reduce(val, op='mean')
|
||||||
|
def is_dist(self)->bool:
|
||||||
|
if not self._apex:
|
||||||
|
return False
|
||||||
|
return self._apex.is_dist()
|
||||||
|
|
||||||
|
|
||||||
class Accumulator:
|
class Accumulator:
|
||||||
# TODO: replace this with Metrics class
|
# TODO: replace this with Metrics class
|
||||||
|
|
|
@ -9,18 +9,18 @@ from overrides import EnforceOverrides
|
||||||
from .metrics import Metrics
|
from .metrics import Metrics
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from . import utils, ml_utils
|
from . import utils, ml_utils
|
||||||
from .common import logger, get_device
|
from .common import logger
|
||||||
from archai.common.common import get_apex_utils
|
from archai.common.apex_utils import ApexUtils
|
||||||
|
|
||||||
class Tester(EnforceOverrides):
|
class Tester(EnforceOverrides):
|
||||||
def __init__(self, conf_eval:Config, model:nn.Module)->None:
|
def __init__(self, conf_eval:Config, model:nn.Module, apex:ApexUtils)->None:
|
||||||
# TODO: currently we expect that given model and dataloader will already be distributed
|
|
||||||
self._title = conf_eval['title']
|
self._title = conf_eval['title']
|
||||||
self._logger_freq = conf_eval['logger_freq']
|
self._logger_freq = conf_eval['logger_freq']
|
||||||
conf_lossfn = conf_eval['lossfn']
|
conf_lossfn = conf_eval['lossfn']
|
||||||
|
|
||||||
|
self._apex = apex
|
||||||
self.model = model
|
self.model = model
|
||||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(get_device())
|
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(apex.device)
|
||||||
self._metrics = None
|
self._metrics = None
|
||||||
|
|
||||||
def test(self, test_dl: DataLoader)->Metrics:
|
def test(self, test_dl: DataLoader)->Metrics:
|
||||||
|
@ -43,7 +43,7 @@ class Tester(EnforceOverrides):
|
||||||
|
|
||||||
with torch.no_grad(), logger.pushd('steps'):
|
with torch.no_grad(), logger.pushd('steps'):
|
||||||
for step, (x, y) in enumerate(test_dl):
|
for step, (x, y) in enumerate(test_dl):
|
||||||
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True)
|
x, y = x.to(self._apex.device, non_blocking=True), y.to(self._apex.device, non_blocking=True)
|
||||||
|
|
||||||
assert not self.model.training # derived class might alter the mode
|
assert not self.model.training # derived class might alter the mode
|
||||||
logger.pushd(step)
|
logger.pushd(step)
|
||||||
|
@ -57,7 +57,7 @@ class Tester(EnforceOverrides):
|
||||||
self._post_step(x, y, logits, loss, steps, self._metrics)
|
self._post_step(x, y, logits, loss, steps, self._metrics)
|
||||||
|
|
||||||
# TODO: we possibly need to sync so all replicas are upto date
|
# TODO: we possibly need to sync so all replicas are upto date
|
||||||
get_apex_utils().sync_devices()
|
self._apex.sync_devices()
|
||||||
|
|
||||||
logger.popd()
|
logger.popd()
|
||||||
self._metrics.post_epoch(None)
|
self._metrics.post_epoch(None)
|
||||||
|
@ -87,5 +87,5 @@ class Tester(EnforceOverrides):
|
||||||
metrics.post_step(x, y, logits, loss, steps)
|
metrics.post_step(x, y, logits, loss, steps)
|
||||||
|
|
||||||
def _create_metrics(self)->Metrics:
|
def _create_metrics(self)->Metrics:
|
||||||
return Metrics(self._title, logger_freq=self._logger_freq)
|
return Metrics(self._title, self._apex, logger_freq=self._logger_freq)
|
||||||
|
|
||||||
|
|
|
@ -11,8 +11,9 @@ from .metrics import Metrics
|
||||||
from .tester import Tester
|
from .tester import Tester
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from . import utils, ml_utils
|
from . import utils, ml_utils
|
||||||
from ..common.common import logger, get_device, get_apex_utils
|
from ..common.common import logger
|
||||||
from ..common.checkpoint import CheckPoint
|
from ..common.checkpoint import CheckPoint
|
||||||
|
from ..common.apex_utils import ApexUtils
|
||||||
|
|
||||||
|
|
||||||
class Trainer(EnforceOverrides):
|
class Trainer(EnforceOverrides):
|
||||||
|
@ -33,13 +34,13 @@ class Trainer(EnforceOverrides):
|
||||||
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
|
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
get_apex_utils().reset(logger, conf_apex)
|
self._apex = ApexUtils(conf_apex, logger)
|
||||||
|
|
||||||
self._checkpoint = checkpoint
|
self._checkpoint = checkpoint
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn)
|
self._lossfn = ml_utils.get_lossfn(conf_lossfn)
|
||||||
self._tester = Tester(conf_validation, model) \
|
self._tester = Tester(conf_validation, model, self._apex) \
|
||||||
if conf_validation else None
|
if conf_validation else None
|
||||||
self._metrics:Optional[Metrics] = None
|
self._metrics:Optional[Metrics] = None
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ class Trainer(EnforceOverrides):
|
||||||
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
|
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
|
||||||
logger.pushd(self._title)
|
logger.pushd(self._title)
|
||||||
|
|
||||||
self._metrics = Metrics(self._title, logger_freq=self._logger_freq)
|
self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)
|
||||||
|
|
||||||
# optimizers, schedulers needs to be recreated for each fit call
|
# optimizers, schedulers needs to be recreated for each fit call
|
||||||
# as they have state specific to each run
|
# as they have state specific to each run
|
||||||
|
@ -60,10 +61,10 @@ class Trainer(EnforceOverrides):
|
||||||
# create scheduler for optim before applying amp
|
# create scheduler for optim before applying amp
|
||||||
self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl))
|
self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl))
|
||||||
# before checkpoint restore, convert to amp
|
# before checkpoint restore, convert to amp
|
||||||
self.model, self._optim = get_apex_utils().to_amp(self.model, optim,
|
self.model, self._optim = self._apex.to_amp(self.model, optim,
|
||||||
batch_size=train_dl.batch_size)
|
batch_size=train_dl.batch_size)
|
||||||
|
|
||||||
self._lossfn = self._lossfn.to(get_device())
|
self._lossfn = self._lossfn.to(self.get_device())
|
||||||
|
|
||||||
self.pre_fit(train_dl, val_dl)
|
self.pre_fit(train_dl, val_dl)
|
||||||
|
|
||||||
|
@ -158,7 +159,7 @@ class Trainer(EnforceOverrides):
|
||||||
self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr'])
|
self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr'])
|
||||||
|
|
||||||
# checkpoint if enabled with given freq or if this is the last epoch
|
# checkpoint if enabled with given freq or if this is the last epoch
|
||||||
if self._checkpoint is not None and get_apex_utils().is_master() and \
|
if self._checkpoint is not None and self._apex.is_master() and \
|
||||||
self._checkpoint.freq > 0 and (self._metrics.epochs() % self._checkpoint.freq == 0 or \
|
self._checkpoint.freq > 0 and (self._metrics.epochs() % self._checkpoint.freq == 0 or \
|
||||||
self._metrics.epochs() >= self._epochs):
|
self._metrics.epochs() >= self._epochs):
|
||||||
self._checkpoint.new()
|
self._checkpoint.new()
|
||||||
|
@ -173,6 +174,9 @@ class Trainer(EnforceOverrides):
|
||||||
self._metrics.post_step(x, y, logits, loss, steps)
|
self._metrics.post_step(x, y, logits, loss, steps)
|
||||||
######################### hooks #########################
|
######################### hooks #########################
|
||||||
|
|
||||||
|
def get_device(self):
|
||||||
|
return self._apex.device
|
||||||
|
|
||||||
def restore_checkpoint(self)->None:
|
def restore_checkpoint(self)->None:
|
||||||
state = self._checkpoint['trainer']
|
state = self._checkpoint['trainer']
|
||||||
last_epoch = state['last_epoch']
|
last_epoch = state['last_epoch']
|
||||||
|
@ -180,7 +184,7 @@ class Trainer(EnforceOverrides):
|
||||||
|
|
||||||
self._metrics.load_state_dict(state['metrics'])
|
self._metrics.load_state_dict(state['metrics'])
|
||||||
assert self._metrics.epochs() == last_epoch+1
|
assert self._metrics.epochs() == last_epoch+1
|
||||||
get_apex_utils().load_state_dict(state['amp'])
|
self._apex.load_state_dict(state['amp'])
|
||||||
self.model.load_state_dict(state['model'])
|
self.model.load_state_dict(state['model'])
|
||||||
self._optim.load_state_dict(state['optim'])
|
self._optim.load_state_dict(state['optim'])
|
||||||
if self._sched:
|
if self._sched:
|
||||||
|
@ -198,7 +202,7 @@ class Trainer(EnforceOverrides):
|
||||||
'model': self.model.state_dict(),
|
'model': self.model.state_dict(),
|
||||||
'optim': self._optim.state_dict(),
|
'optim': self._optim.state_dict(),
|
||||||
'sched': self._sched.state_dict() if self._sched else None,
|
'sched': self._sched.state_dict() if self._sched else None,
|
||||||
'amp': get_apex_utils().state_dict()
|
'amp': self._apex.state_dict()
|
||||||
}
|
}
|
||||||
self._checkpoint['trainer'] = state
|
self._checkpoint['trainer'] = state
|
||||||
|
|
||||||
|
@ -208,7 +212,7 @@ class Trainer(EnforceOverrides):
|
||||||
|
|
||||||
logger.pushd('steps')
|
logger.pushd('steps')
|
||||||
for step, (x, y) in enumerate(train_dl):
|
for step, (x, y) in enumerate(train_dl):
|
||||||
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True)
|
x, y = x.to(self.get_device(), non_blocking=True), y.to(self.get_device(), non_blocking=True)
|
||||||
|
|
||||||
logger.pushd(step)
|
logger.pushd(step)
|
||||||
assert self.model.training # derived class might alter the mode
|
assert self.model.training # derived class might alter the mode
|
||||||
|
@ -226,15 +230,15 @@ class Trainer(EnforceOverrides):
|
||||||
loss = self.compute_loss(self._lossfn, x, y, logits,
|
loss = self.compute_loss(self._lossfn, x, y, logits,
|
||||||
self._aux_weight, aux_logits)
|
self._aux_weight, aux_logits)
|
||||||
|
|
||||||
get_apex_utils().backward(loss, self._optim)
|
self._apex.backward(loss, self._optim)
|
||||||
|
|
||||||
# TODO: original darts clips alphas as well but pt.darts doesn't
|
# TODO: original darts clips alphas as well but pt.darts doesn't
|
||||||
get_apex_utils().clip_grad(self._grad_clip, self.model, self._optim)
|
self._apex.clip_grad(self._grad_clip, self.model, self._optim)
|
||||||
|
|
||||||
self._optim.step()
|
self._optim.step()
|
||||||
|
|
||||||
# TODO: we possibly need to sync so all replicas are upto date
|
# TODO: we possibly need to sync so all replicas are upto date
|
||||||
get_apex_utils().sync_devices()
|
self._apex.sync_devices()
|
||||||
|
|
||||||
if self._sched and not self._sched_on_epoch:
|
if self._sched and not self._sched_on_epoch:
|
||||||
self._sched.step()
|
self._sched.step()
|
||||||
|
|
|
@ -55,7 +55,7 @@ def deep_update(d:MutableMapping, u:Mapping, map_type:Type[MutableMapping]=dict)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def state_dict(val)->Mapping:
|
def state_dict(val)->Mapping:
|
||||||
assert hasattr(val, '__dict__'), 'val must be object with __dict__'
|
assert hasattr(val, '__dict__'), 'val must be object with __dict__ otherwise it cannot be loaded back in load_state_dict'
|
||||||
|
|
||||||
# Can't do below because val has state_dict() which calls utils.state_dict
|
# Can't do below because val has state_dict() which calls utils.state_dict
|
||||||
# if has_method(val, 'state_dict'):
|
# if has_method(val, 'state_dict'):
|
||||||
|
@ -76,8 +76,8 @@ def load_state_dict(val:Any, state_dict:Mapping)->None:
|
||||||
assert s is not None, 'state_dict must contain yaml key'
|
assert s is not None, 'state_dict must contain yaml key'
|
||||||
|
|
||||||
obj = yaml.load(s, Loader=yaml.Loader)
|
obj = yaml.load(s, Loader=yaml.Loader)
|
||||||
for k in val.__dict__.keys():
|
for k, v in obj.__dict__.items():
|
||||||
setattr(val, k, getattr(obj, k))
|
setattr(val, k, v)
|
||||||
|
|
||||||
def deep_comp(o1:Any, o2:Any)->bool:
|
def deep_comp(o1:Any, o2:Any)->bool:
|
||||||
# NOTE: dict don't have __dict__
|
# NOTE: dict don't have __dict__
|
||||||
|
|
|
@ -19,8 +19,8 @@ common:
|
||||||
redis: null
|
redis: null
|
||||||
apex: # this is overriden in search and eval individually
|
apex: # this is overriden in search and eval individually
|
||||||
enabled: False # global switch to disable anything apex
|
enabled: False # global switch to disable anything apex
|
||||||
distributed_enabled: False # enable/disable distributed mode
|
distributed_enabled: True # enable/disable distributed mode
|
||||||
mixed_prec_enabled: False # switch to disable amp mixed precision
|
mixed_prec_enabled: True # switch to disable amp mixed precision
|
||||||
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
|
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
|
||||||
opt_level: 'O2' # optimization level for mixed precision
|
opt_level: 'O2' # optimization level for mixed precision
|
||||||
bn_fp32: True # keep BN in fp32
|
bn_fp32: True # keep BN in fp32
|
||||||
|
@ -88,6 +88,7 @@ nas:
|
||||||
trainer:
|
trainer:
|
||||||
apex:
|
apex:
|
||||||
_copy: 'common/apex'
|
_copy: 'common/apex'
|
||||||
|
enabled: True
|
||||||
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
|
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
|
||||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||||
grad_clip: 5.0 # grads above this value is clipped
|
grad_clip: 5.0 # grads above this value is clipped
|
||||||
|
|
|
@ -1,10 +1,27 @@
|
||||||
|
from collections import UserDict
|
||||||
import yaml
|
import yaml
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
y = """
|
class A(object):
|
||||||
a: .NaN
|
def __init__(self):
|
||||||
|
self.hidden = 42
|
||||||
|
self.visible = 5
|
||||||
|
|
||||||
"""
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
del state['hidden'] # cannot serialize this
|
||||||
|
return state
|
||||||
|
|
||||||
d=yaml.load(y, Loader=yaml.Loader)
|
a = A()
|
||||||
|
d = yaml.dump(a)
|
||||||
print(d)
|
print(d)
|
||||||
print(type( d['a']))
|
|
||||||
|
|
||||||
|
# y = """
|
||||||
|
# a: .NaN
|
||||||
|
|
||||||
|
# """
|
||||||
|
|
||||||
|
# d=yaml.load(y, Loader=yaml.Loader)
|
||||||
|
# print(d)
|
||||||
|
# print(type( d['a']))
|
||||||
|
|
|
@ -5,7 +5,7 @@ from archai import cifar10_models
|
||||||
|
|
||||||
from archai.common.trainer import Trainer
|
from archai.common.trainer import Trainer
|
||||||
from archai.common.config import Config
|
from archai.common.config import Config
|
||||||
from archai.common.common import logger, common_init, get_device
|
from archai.common.common import logger, common_init
|
||||||
from archai.datasets import data
|
from archai.datasets import data
|
||||||
|
|
||||||
def train_test(conf_eval:Config):
|
def train_test(conf_eval:Config):
|
||||||
|
@ -24,7 +24,7 @@ def train_test(conf_eval:Config):
|
||||||
conf_trainer['aux_weight'] = 0.0
|
conf_trainer['aux_weight'] = 0.0
|
||||||
|
|
||||||
Net = cifar10_models.resnet34
|
Net = cifar10_models.resnet34
|
||||||
model = Net().to(get_device())
|
model = Net().to(torch.device('cuda', 0))
|
||||||
|
|
||||||
# get data
|
# get data
|
||||||
train_dl, _, test_dl = data.get_data(conf_loader)
|
train_dl, _, test_dl = data.get_data(conf_loader)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче