This commit is contained in:
Shital Shah 2020-04-22 12:35:45 -07:00
Родитель 8969661426
Коммит aa2f8fafca
2 изменённых файлов: 23 добавлений и 21 удалений

Просмотреть файл

@ -19,11 +19,12 @@ class ApexUtils:
# region conf vars
self._enabled = apex_config['enabled'] # global switch to disable anything apex
self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode
self._mixed_prec_enabled = apex_config['mixed_prec_enabled'] # enable/disable distributed mode
self._opt_level = apex_config['opt_level'] # optimization level for mixed precision
self._bn_fp32 = apex_config['bn_fp32'] # keep BN in fp32
self._loss_scale = apex_config['loss_scale'] # loss scaling mode for mixed prec
self._sync_bn = apex_config['sync_bn'] # should be replace BNs with sync BNs for distributed model
self._distributed = apex_config['distributed'] # enable/disable distributed mode
self._scale_lr = apex_config['scale_lr'] # enable/disable distributed mode
self._min_world_size = apex_config['min_world_size'] # allows to confirm we are indeed in distributed setting
conf_gpu_ids = apex_config['gpus']
@ -37,20 +38,21 @@ class ApexUtils:
self.global_rank = 0
logger.info({'apex_config': apex_config.to_dict()})
logger.info({'torch.distributed_is_available': dist.is_available()})
logger.info({'torch.distributed.is_available': dist.is_available()})
if dist.is_available():
logger.info({'gloo_available': dist.is_gloo_available(),
'mpi_available': dist.is_mpi_available(),
'nccl_available': dist.is_nccl_available()})
if self._enabled:
# init enable mixed precision
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
from apex import amp
self._amp = amp
if self._mixed_prec_enabled:
# init enable mixed precision
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
from apex import amp
self._amp = amp
# enable distributed processing
if self._distributed:
if self._distributed_enabled:
from apex import parallel
self._ddp = parallel
@ -73,11 +75,14 @@ class ApexUtils:
torch.cuda.set_device(self._gpu)
self.device = torch.device('cuda', self._gpu)
logger.info({'amp_available': self._amp is not None, 'distributed_available': self._distributed is not None})
logger.info({'distributed': self._distributed, 'world_size': self._world_size,
'gpu': self._gpu, 'gpu_ids':self.gpu_ids, 'local_rank': self.local_rank})
logger.info({'amp_available': self._amp is not None,
'distributed_available': self._ddp is not None})
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
'world_size': self._world_size,
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
'local_rank': self.local_rank})
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False})
logger.info({})
logger.close()
@ -140,11 +145,11 @@ class ApexUtils:
return self.global_rank == 0
def sync_devices(self)->None:
if self._distributed:
torch.cuda.synchronize()
if self._distributed_enabled:
torch.cuda.synchronize(self.device)
def reduce(self, val, op='mean'):
if self._distributed:
if self._distributed_enabled:
if not isinstance(val, Tensor):
rt = torch.tensor(val)
else:
@ -211,9 +216,5 @@ class ApexUtils:
def load_state_dict(self, state_dict):
if self._amp:
self._amp.load_state_dict()
else:
if state_dict is not None:
raise RuntimeError('checkpoint state_dict is not None but Nvidia apex (amp) is not ')
else:
pass

Просмотреть файл

@ -18,13 +18,14 @@ common:
# "ray start --head --redis-port=6379"
redis: null
apex:
enabled: True # global switch to disable anything apex
distributed_enabled: False # enable/disable distributed mode
mixed_prec_enabled: False # switch to disable amp mixed precision
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
enabled: False # global switch to disable anything apex
opt_level: 'O2' # optimization level for mixed precision
bn_fp32: True # keep BN in fp32
loss_scale: "dynamic" # loss scaling mode for mixed prec, must be string reprenting floar ot "dynamic"
sync_bn: False # should be replace BNs with sync BNs for distributed model
distributed: False # enable/disable distributed mode
scale_lr: True # enable/disable distributed mode
min_world_size: 0 # allows to confirm we are indeed in distributed setting