зеркало из https://github.com/microsoft/archai.git
Separate apex config for search and eval
This commit is contained in:
Родитель
aa2f8fafca
Коммит
ee0540d0d0
|
@ -8,14 +8,29 @@ from torch import Tensor, nn
|
|||
from torch.backends import cudnn
|
||||
import torch.distributed as dist
|
||||
|
||||
from archai.common.config import Config
|
||||
import psutil
|
||||
|
||||
from archai.common.config import Config
|
||||
from archai.common import ml_utils, utils
|
||||
from archai.common.ordereddict_logger import OrderedDictLogger
|
||||
|
||||
class ApexUtils:
|
||||
def __init__(self, distdir:Optional[str], apex_config:Config)->None:
|
||||
logger = self._create_init_logger(distdir)
|
||||
def __init__(self)->None:
|
||||
self._amp = self._ddp = None
|
||||
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
|
||||
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
|
||||
|
||||
self._mixed_prec_enabled = False
|
||||
self._distributed_enabled = False
|
||||
self._world_size = 1 # total number of processes in distributed run
|
||||
self.local_rank = 0
|
||||
self.global_rank = 0
|
||||
|
||||
def reset(self, logger:OrderedDictLogger, apex_config:Config)->None:
|
||||
# reset allows to configure differently for search or eval modes
|
||||
|
||||
# to avoid circular references= with common, logger is passed from outside
|
||||
self.logger = logger
|
||||
|
||||
# region conf vars
|
||||
self._enabled = apex_config['enabled'] # global switch to disable anything apex
|
||||
|
@ -27,6 +42,8 @@ class ApexUtils:
|
|||
self._sync_bn = apex_config['sync_bn'] # should be replace BNs with sync BNs for distributed model
|
||||
self._scale_lr = apex_config['scale_lr'] # enable/disable distributed mode
|
||||
self._min_world_size = apex_config['min_world_size'] # allows to confirm we are indeed in distributed setting
|
||||
seed = apex_config['seed']
|
||||
detect_anomaly = apex_config['detect_anomaly']
|
||||
conf_gpu_ids = apex_config['gpus']
|
||||
# endregion
|
||||
|
||||
|
@ -58,11 +75,9 @@ class ApexUtils:
|
|||
|
||||
assert dist.is_available() # distributed module is available
|
||||
assert dist.is_nccl_available()
|
||||
dist.init_process_group(backend='nccl', init_method='env://')
|
||||
assert dist.is_initialized()
|
||||
|
||||
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
|
||||
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend='nccl', init_method='env://')
|
||||
assert dist.is_initialized()
|
||||
|
||||
self._set_ranks()
|
||||
|
||||
|
@ -74,6 +89,7 @@ class ApexUtils:
|
|||
assert self._gpu < torch.cuda.device_count()
|
||||
torch.cuda.set_device(self._gpu)
|
||||
self.device = torch.device('cuda', self._gpu)
|
||||
self._setup_gpus(seed, detect_anomaly)
|
||||
|
||||
logger.info({'amp_available': self._amp is not None,
|
||||
'distributed_available': self._ddp is not None})
|
||||
|
@ -82,10 +98,33 @@ class ApexUtils:
|
|||
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
|
||||
'local_rank': self.local_rank})
|
||||
|
||||
logger.info({})
|
||||
|
||||
logger.close()
|
||||
def _setup_gpus(self, seed:float, detect_anomaly:bool):
|
||||
utils.setup_cuda(seed, self.local_rank)
|
||||
|
||||
torch.autograd.set_detect_anomaly(detect_anomaly)
|
||||
self.logger.info({'set_detect_anomaly': detect_anomaly,
|
||||
'is_anomaly_enabled': torch.is_anomaly_enabled()})
|
||||
|
||||
self.logger.info({'gpu_names': utils.cuda_device_names(),
|
||||
'gpu_count': torch.cuda.device_count(),
|
||||
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
|
||||
'cudnn.enabled': cudnn.enabled,
|
||||
'cudnn.benchmark': cudnn.benchmark,
|
||||
'cudnn.deterministic': cudnn.deterministic,
|
||||
'cudnn.version': cudnn.version()
|
||||
})
|
||||
self.logger.info({'memory': str(psutil.virtual_memory())})
|
||||
self.logger.info({'CPUs': str(psutil.cpu_count())})
|
||||
|
||||
# gpu_usage = os.popen(
|
||||
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
|
||||
# ).read().split('\n')
|
||||
# for i, line in enumerate(gpu_usage):
|
||||
# vals = line.split(',')
|
||||
# if len(vals) == 2:
|
||||
# logger.info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
|
||||
|
||||
def _set_ranks(self)->None:
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
|
@ -112,35 +151,10 @@ class ApexUtils:
|
|||
assert len(self.gpu_ids) > self.local_rank
|
||||
self._gpu = self.gpu_ids[self.local_rank]
|
||||
|
||||
|
||||
def _create_init_logger(self, distdir:Optional[str])->OrderedDictLogger:
|
||||
# create PID specific logger to support many distributed processes
|
||||
init_log_filepath, yaml_log_filepath = None, None
|
||||
if distdir:
|
||||
init_log_filepath = os.path.join(utils.full_path(distdir),
|
||||
'apex_' + str(os.getpid()) + '.log')
|
||||
yaml_log_filepath = os.path.join(utils.full_path(distdir),
|
||||
'apex_' + str(os.getpid()) + '.yaml')
|
||||
|
||||
sys_logger = utils.create_logger(filepath=init_log_filepath)
|
||||
if not init_log_filepath:
|
||||
sys_logger.warn('logdir not specified, no logs will be created or any models saved')
|
||||
|
||||
logger = OrderedDictLogger(yaml_log_filepath, sys_logger)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def set_replica_logger(self, logger:OrderedDictLogger)->None:
|
||||
# To avoid circular dependency we don't reference logger in common.
|
||||
# Furthermore, each replica has its own logger but sharing same exp directory.
|
||||
# We can't create replica specific logger at time of init so this is set later.
|
||||
self.logger = logger
|
||||
|
||||
def is_mixed(self)->bool:
|
||||
return self._amp is not None
|
||||
return self._mixed_prec_enabled
|
||||
def is_dist(self)->bool:
|
||||
return self._ddp is not None
|
||||
return self._distributed_enabled
|
||||
def is_master(self)->bool:
|
||||
return self.global_rank == 0
|
||||
|
||||
|
@ -164,7 +178,7 @@ class ApexUtils:
|
|||
return val
|
||||
|
||||
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
|
||||
if self._amp:
|
||||
if self._mixed_prec_enabled:
|
||||
with self._amp.scale_loss(loss, optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
|
@ -173,13 +187,13 @@ class ApexUtils:
|
|||
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
|
||||
->Tuple[nn.Module, Optimizer]:
|
||||
# conver BNs to sync BNs in distributed mode
|
||||
if self._ddp and self._sync_bn:
|
||||
if self._distributed_enabled and self._sync_bn:
|
||||
model = self._ddp.convert_syncbn_model(model)
|
||||
self.logger.info({'BNs_converted': True})
|
||||
|
||||
model = model.to(self.device)
|
||||
|
||||
if self._amp:
|
||||
if self._mixed_prec_enabled:
|
||||
# scale LR
|
||||
if self._scale_lr:
|
||||
lr = ml_utils.get_optim_lr(optim)
|
||||
|
@ -192,7 +206,7 @@ class ApexUtils:
|
|||
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
|
||||
)
|
||||
|
||||
if self._ddp:
|
||||
if self._distributed_enabled:
|
||||
# By default, apex.parallel.DistributedDataParallel overlaps communication with
|
||||
# computation in the backward pass.
|
||||
# delay_allreduce delays all communication to the end of the backward pass.
|
||||
|
@ -202,19 +216,19 @@ class ApexUtils:
|
|||
|
||||
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
|
||||
if clip > 0.0:
|
||||
if self._amp:
|
||||
if self._mixed_prec_enabled:
|
||||
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
|
||||
else:
|
||||
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
||||
|
||||
def state_dict(self):
|
||||
if self._amp:
|
||||
if self._mixed_prec_enabled:
|
||||
return self._amp.state_dict()
|
||||
else:
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if self._amp:
|
||||
if self._mixed_prec_enabled:
|
||||
self._amp.load_state_dict()
|
||||
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class SummaryWriterDummy:
|
|||
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter]
|
||||
logger = OrderedDictLogger(None, None)
|
||||
_tb_writer: SummaryWriterAny = None
|
||||
_apex_utils = None
|
||||
_apex_utils = ApexUtils()
|
||||
_atexit_reg = False # is hook for atexit registered?
|
||||
|
||||
def get_conf()->Config:
|
||||
|
@ -138,12 +138,8 @@ def common_init(config_filepath: Optional[str]=None,
|
|||
logger.info({'expdir': expdir,
|
||||
'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})
|
||||
|
||||
# set up amp, apex, mixed-prec, distributed training stubs
|
||||
_setup_apex()
|
||||
# create global logger
|
||||
_setup_logger()
|
||||
# init GPU settings
|
||||
_setup_gpus()
|
||||
# create info file for current system
|
||||
_create_sysinfo(conf)
|
||||
|
||||
|
@ -249,10 +245,6 @@ def _setup_logger():
|
|||
sys_logger.warn(
|
||||
'logdir not specified, no logs will be created or any models saved')
|
||||
|
||||
# We need to create ApexUtils before we have logger. Now that we have logger
|
||||
# lets give it to ApexUtils
|
||||
get_apex_utils().set_replica_logger(logger)
|
||||
|
||||
# reset to new file path
|
||||
logger.reset(logs_yaml_filepath, sys_logger)
|
||||
logger.info({
|
||||
|
@ -263,40 +255,7 @@ def _setup_logger():
|
|||
'sys_log_filepath': sys_log_filepath
|
||||
})
|
||||
|
||||
def _setup_apex():
|
||||
conf_common = get_conf_common()
|
||||
distdir = conf_common['distdir']
|
||||
|
||||
global _apex_utils
|
||||
_apex_utils = ApexUtils(distdir, conf_common['apex'])
|
||||
|
||||
def _setup_gpus():
|
||||
conf_common = get_conf_common()
|
||||
|
||||
utils.setup_cuda(conf_common['seed'], get_apex_utils().local_rank)
|
||||
|
||||
if conf_common['detect_anomaly']:
|
||||
logger.warn({'set_detect_anomaly':True})
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
logger.info({'gpu_names': utils.cuda_device_names(),
|
||||
'gpu_count': torch.cuda.device_count(),
|
||||
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
|
||||
'cudnn.enabled': cudnn.enabled,
|
||||
'cudnn.benchmark': cudnn.benchmark,
|
||||
'cudnn.deterministic': cudnn.deterministic,
|
||||
'cudnn.version': cudnn.version()
|
||||
})
|
||||
logger.info({'memory': str(psutil.virtual_memory())})
|
||||
logger.info({'CPUs': str(psutil.cpu_count())})
|
||||
|
||||
|
||||
# gpu_usage = os.popen(
|
||||
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
|
||||
# ).read().split('\n')
|
||||
# for i, line in enumerate(gpu_usage):
|
||||
# vals = line.split(',')
|
||||
# if len(vals) == 2:
|
||||
# logger.info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch import nn
|
|||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger
|
||||
from archai.common.common import logger, get_apex_utils
|
||||
from archai.datasets import data
|
||||
from archai.nas.model_desc import ModelDesc
|
||||
from archai.nas.cell_builder import CellBuilder
|
||||
|
@ -24,8 +24,11 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
|
|||
conf_checkpoint = conf_eval['checkpoint']
|
||||
resume = conf_eval['resume']
|
||||
conf_train = conf_eval['trainer']
|
||||
conf_apex = conf_eval['apex']
|
||||
# endregion
|
||||
|
||||
get_apex_utils().reset(logger, conf_apex)
|
||||
|
||||
if cell_builder:
|
||||
cell_builder.register_ops()
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import tensorwatch as tw
|
|||
from torch.utils.data.dataloader import DataLoader
|
||||
import yaml
|
||||
|
||||
from archai.common.common import logger
|
||||
from archai.common.common import logger, get_apex_utils
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.config import Config
|
||||
from .cell_builder import CellBuilder
|
||||
|
@ -108,6 +108,7 @@ class Search:
|
|||
self.search_iters = conf_search['search_iters']
|
||||
self.pareto_enabled = conf_pareto['enabled']
|
||||
pareto_summary_filename = conf_pareto['summary_filename']
|
||||
conf_apex = conf_search['apex']
|
||||
# endregion
|
||||
|
||||
self.cell_builder = cell_builder
|
||||
|
@ -116,6 +117,8 @@ class Search:
|
|||
self._parito_filepath = utils.full_path(pareto_summary_filename)
|
||||
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
|
||||
|
||||
get_apex_utils().reset(logger, conf_apex)
|
||||
|
||||
logger.info({'pareto_enabled': self.pareto_enabled,
|
||||
'base_reductions': self.base_reductions,
|
||||
'base_cells': self.base_cells,
|
||||
|
|
|
@ -10,15 +10,15 @@ common:
|
|||
checkpoint:
|
||||
filename: '$expdir/checkpoint.pth'
|
||||
freq: 10
|
||||
detect_anomaly: False # if True, PyTorch code will run 6X slower
|
||||
|
||||
# TODO: workers setting
|
||||
|
||||
# reddis address of Ray cluster. Use None for single node run
|
||||
# otherwise it should something like host:6379. Make sure to run on head node:
|
||||
# "ray start --head --redis-port=6379"
|
||||
redis: null
|
||||
apex:
|
||||
enabled: True # global switch to disable anything apex
|
||||
apex: # this is overriden in search and eval individually
|
||||
enabled: False # global switch to disable anything apex
|
||||
distributed_enabled: False # enable/disable distributed mode
|
||||
mixed_prec_enabled: False # switch to disable amp mixed precision
|
||||
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
|
||||
|
@ -28,6 +28,8 @@ common:
|
|||
sync_bn: False # should be replace BNs with sync BNs for distributed model
|
||||
scale_lr: True # enable/disable distributed mode
|
||||
min_world_size: 0 # allows to confirm we are indeed in distributed setting
|
||||
detect_anomaly: False # if True, PyTorch code will run 6X slower
|
||||
seed: '_copy: common/seed'
|
||||
|
||||
smoke_test: False
|
||||
only_eval: False
|
||||
|
@ -48,6 +50,8 @@ nas:
|
|||
metric_filename: '$expdir/eval_train_metrics.yaml'
|
||||
model_filename: '$expdir/model.pt' # file to which trained model will be saved
|
||||
data_parallel: False
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
checkpoint:
|
||||
_copy: 'common/checkpoint'
|
||||
resume: '_copy: common/resume'
|
||||
|
@ -118,6 +122,8 @@ nas:
|
|||
data_parallel: False
|
||||
checkpoint:
|
||||
_copy: 'common/checkpoint'
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
resume: '_copy: common/resume'
|
||||
search_iters: 1
|
||||
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
|
||||
|
|
Загрузка…
Ссылка в новой задаче