Separate apex config for search and eval

This commit is contained in:
Shital Shah 2020-04-22 14:05:23 -07:00
Родитель aa2f8fafca
Коммит ee0540d0d0
5 изменённых файлов: 77 добавлений и 92 удалений

Просмотреть файл

@ -8,14 +8,29 @@ from torch import Tensor, nn
from torch.backends import cudnn
import torch.distributed as dist
from archai.common.config import Config
import psutil
from archai.common.config import Config
from archai.common import ml_utils, utils
from archai.common.ordereddict_logger import OrderedDictLogger
class ApexUtils:
def __init__(self, distdir:Optional[str], apex_config:Config)->None:
logger = self._create_init_logger(distdir)
def __init__(self)->None:
self._amp = self._ddp = None
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
self._mixed_prec_enabled = False
self._distributed_enabled = False
self._world_size = 1 # total number of processes in distributed run
self.local_rank = 0
self.global_rank = 0
def reset(self, logger:OrderedDictLogger, apex_config:Config)->None:
# reset allows to configure differently for search or eval modes
# to avoid circular references= with common, logger is passed from outside
self.logger = logger
# region conf vars
self._enabled = apex_config['enabled'] # global switch to disable anything apex
@ -27,6 +42,8 @@ class ApexUtils:
self._sync_bn = apex_config['sync_bn'] # should be replace BNs with sync BNs for distributed model
self._scale_lr = apex_config['scale_lr'] # enable/disable distributed mode
self._min_world_size = apex_config['min_world_size'] # allows to confirm we are indeed in distributed setting
seed = apex_config['seed']
detect_anomaly = apex_config['detect_anomaly']
conf_gpu_ids = apex_config['gpus']
# endregion
@ -58,11 +75,9 @@ class ApexUtils:
assert dist.is_available() # distributed module is available
assert dist.is_nccl_available()
dist.init_process_group(backend='nccl', init_method='env://')
assert dist.is_initialized()
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
if not dist.is_initialized():
dist.init_process_group(backend='nccl', init_method='env://')
assert dist.is_initialized()
self._set_ranks()
@ -74,6 +89,7 @@ class ApexUtils:
assert self._gpu < torch.cuda.device_count()
torch.cuda.set_device(self._gpu)
self.device = torch.device('cuda', self._gpu)
self._setup_gpus(seed, detect_anomaly)
logger.info({'amp_available': self._amp is not None,
'distributed_available': self._ddp is not None})
@ -82,10 +98,33 @@ class ApexUtils:
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
'local_rank': self.local_rank})
logger.info({})
logger.close()
def _setup_gpus(self, seed:float, detect_anomaly:bool):
utils.setup_cuda(seed, self.local_rank)
torch.autograd.set_detect_anomaly(detect_anomaly)
self.logger.info({'set_detect_anomaly': detect_anomaly,
'is_anomaly_enabled': torch.is_anomaly_enabled()})
self.logger.info({'gpu_names': utils.cuda_device_names(),
'gpu_count': torch.cuda.device_count(),
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
'cudnn.enabled': cudnn.enabled,
'cudnn.benchmark': cudnn.benchmark,
'cudnn.deterministic': cudnn.deterministic,
'cudnn.version': cudnn.version()
})
self.logger.info({'memory': str(psutil.virtual_memory())})
self.logger.info({'CPUs': str(psutil.cpu_count())})
# gpu_usage = os.popen(
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
# ).read().split('\n')
# for i, line in enumerate(gpu_usage):
# vals = line.split(',')
# if len(vals) == 2:
# logger.info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
def _set_ranks(self)->None:
if 'WORLD_SIZE' in os.environ:
@ -112,35 +151,10 @@ class ApexUtils:
assert len(self.gpu_ids) > self.local_rank
self._gpu = self.gpu_ids[self.local_rank]
def _create_init_logger(self, distdir:Optional[str])->OrderedDictLogger:
# create PID specific logger to support many distributed processes
init_log_filepath, yaml_log_filepath = None, None
if distdir:
init_log_filepath = os.path.join(utils.full_path(distdir),
'apex_' + str(os.getpid()) + '.log')
yaml_log_filepath = os.path.join(utils.full_path(distdir),
'apex_' + str(os.getpid()) + '.yaml')
sys_logger = utils.create_logger(filepath=init_log_filepath)
if not init_log_filepath:
sys_logger.warn('logdir not specified, no logs will be created or any models saved')
logger = OrderedDictLogger(yaml_log_filepath, sys_logger)
return logger
def set_replica_logger(self, logger:OrderedDictLogger)->None:
# To avoid circular dependency we don't reference logger in common.
# Furthermore, each replica has its own logger but sharing same exp directory.
# We can't create replica specific logger at time of init so this is set later.
self.logger = logger
def is_mixed(self)->bool:
return self._amp is not None
return self._mixed_prec_enabled
def is_dist(self)->bool:
return self._ddp is not None
return self._distributed_enabled
def is_master(self)->bool:
return self.global_rank == 0
@ -164,7 +178,7 @@ class ApexUtils:
return val
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
if self._amp:
if self._mixed_prec_enabled:
with self._amp.scale_loss(loss, optim) as scaled_loss:
scaled_loss.backward()
else:
@ -173,13 +187,13 @@ class ApexUtils:
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
->Tuple[nn.Module, Optimizer]:
# conver BNs to sync BNs in distributed mode
if self._ddp and self._sync_bn:
if self._distributed_enabled and self._sync_bn:
model = self._ddp.convert_syncbn_model(model)
self.logger.info({'BNs_converted': True})
model = model.to(self.device)
if self._amp:
if self._mixed_prec_enabled:
# scale LR
if self._scale_lr:
lr = ml_utils.get_optim_lr(optim)
@ -192,7 +206,7 @@ class ApexUtils:
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
)
if self._ddp:
if self._distributed_enabled:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# delay_allreduce delays all communication to the end of the backward pass.
@ -202,19 +216,19 @@ class ApexUtils:
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
if clip > 0.0:
if self._amp:
if self._mixed_prec_enabled:
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
else:
nn.utils.clip_grad_norm_(model.parameters(), clip)
def state_dict(self):
if self._amp:
if self._mixed_prec_enabled:
return self._amp.state_dict()
else:
return None
def load_state_dict(self, state_dict):
if self._amp:
if self._mixed_prec_enabled:
self._amp.load_state_dict()

Просмотреть файл

@ -30,7 +30,7 @@ class SummaryWriterDummy:
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter]
logger = OrderedDictLogger(None, None)
_tb_writer: SummaryWriterAny = None
_apex_utils = None
_apex_utils = ApexUtils()
_atexit_reg = False # is hook for atexit registered?
def get_conf()->Config:
@ -138,12 +138,8 @@ def common_init(config_filepath: Optional[str]=None,
logger.info({'expdir': expdir,
'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})
# set up amp, apex, mixed-prec, distributed training stubs
_setup_apex()
# create global logger
_setup_logger()
# init GPU settings
_setup_gpus()
# create info file for current system
_create_sysinfo(conf)
@ -249,10 +245,6 @@ def _setup_logger():
sys_logger.warn(
'logdir not specified, no logs will be created or any models saved')
# We need to create ApexUtils before we have logger. Now that we have logger
# lets give it to ApexUtils
get_apex_utils().set_replica_logger(logger)
# reset to new file path
logger.reset(logs_yaml_filepath, sys_logger)
logger.info({
@ -263,40 +255,7 @@ def _setup_logger():
'sys_log_filepath': sys_log_filepath
})
def _setup_apex():
conf_common = get_conf_common()
distdir = conf_common['distdir']
global _apex_utils
_apex_utils = ApexUtils(distdir, conf_common['apex'])
def _setup_gpus():
conf_common = get_conf_common()
utils.setup_cuda(conf_common['seed'], get_apex_utils().local_rank)
if conf_common['detect_anomaly']:
logger.warn({'set_detect_anomaly':True})
torch.autograd.set_detect_anomaly(True)
logger.info({'gpu_names': utils.cuda_device_names(),
'gpu_count': torch.cuda.device_count(),
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
'cudnn.enabled': cudnn.enabled,
'cudnn.benchmark': cudnn.benchmark,
'cudnn.deterministic': cudnn.deterministic,
'cudnn.version': cudnn.version()
})
logger.info({'memory': str(psutil.virtual_memory())})
logger.info({'CPUs': str(psutil.cpu_count())})
# gpu_usage = os.popen(
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
# ).read().split('\n')
# for i, line in enumerate(gpu_usage):
# vals = line.split(',')
# if len(vals) == 2:
# logger.info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))

Просмотреть файл

@ -7,7 +7,7 @@ from torch import nn
from archai.common.trainer import Trainer
from archai.common.config import Config
from archai.common.common import logger
from archai.common.common import logger, get_apex_utils
from archai.datasets import data
from archai.nas.model_desc import ModelDesc
from archai.nas.cell_builder import CellBuilder
@ -24,8 +24,11 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
conf_checkpoint = conf_eval['checkpoint']
resume = conf_eval['resume']
conf_train = conf_eval['trainer']
conf_apex = conf_eval['apex']
# endregion
get_apex_utils().reset(logger, conf_apex)
if cell_builder:
cell_builder.register_ops()

Просмотреть файл

@ -8,7 +8,7 @@ import tensorwatch as tw
from torch.utils.data.dataloader import DataLoader
import yaml
from archai.common.common import logger
from archai.common.common import logger, get_apex_utils
from archai.common.checkpoint import CheckPoint
from archai.common.config import Config
from .cell_builder import CellBuilder
@ -108,6 +108,7 @@ class Search:
self.search_iters = conf_search['search_iters']
self.pareto_enabled = conf_pareto['enabled']
pareto_summary_filename = conf_pareto['summary_filename']
conf_apex = conf_search['apex']
# endregion
self.cell_builder = cell_builder
@ -116,6 +117,8 @@ class Search:
self._parito_filepath = utils.full_path(pareto_summary_filename)
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
get_apex_utils().reset(logger, conf_apex)
logger.info({'pareto_enabled': self.pareto_enabled,
'base_reductions': self.base_reductions,
'base_cells': self.base_cells,

Просмотреть файл

@ -10,15 +10,15 @@ common:
checkpoint:
filename: '$expdir/checkpoint.pth'
freq: 10
detect_anomaly: False # if True, PyTorch code will run 6X slower
# TODO: workers setting
# reddis address of Ray cluster. Use None for single node run
# otherwise it should something like host:6379. Make sure to run on head node:
# "ray start --head --redis-port=6379"
redis: null
apex:
enabled: True # global switch to disable anything apex
apex: # this is overriden in search and eval individually
enabled: False # global switch to disable anything apex
distributed_enabled: False # enable/disable distributed mode
mixed_prec_enabled: False # switch to disable amp mixed precision
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
@ -28,6 +28,8 @@ common:
sync_bn: False # should be replace BNs with sync BNs for distributed model
scale_lr: True # enable/disable distributed mode
min_world_size: 0 # allows to confirm we are indeed in distributed setting
detect_anomaly: False # if True, PyTorch code will run 6X slower
seed: '_copy: common/seed'
smoke_test: False
only_eval: False
@ -48,6 +50,8 @@ nas:
metric_filename: '$expdir/eval_train_metrics.yaml'
model_filename: '$expdir/model.pt' # file to which trained model will be saved
data_parallel: False
apex:
_copy: 'common/apex'
checkpoint:
_copy: 'common/checkpoint'
resume: '_copy: common/resume'
@ -118,6 +122,8 @@ nas:
data_parallel: False
checkpoint:
_copy: 'common/checkpoint'
apex:
_copy: 'common/apex'
resume: '_copy: common/resume'
search_iters: 1
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized