зеркало из https://github.com/microsoft/archai.git
added mixed prec enabled flag
This commit is contained in:
Родитель
8969661426
Коммит
aa2f8fafca
|
@ -19,11 +19,12 @@ class ApexUtils:
|
|||
|
||||
# region conf vars
|
||||
self._enabled = apex_config['enabled'] # global switch to disable anything apex
|
||||
self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode
|
||||
self._mixed_prec_enabled = apex_config['mixed_prec_enabled'] # enable/disable distributed mode
|
||||
self._opt_level = apex_config['opt_level'] # optimization level for mixed precision
|
||||
self._bn_fp32 = apex_config['bn_fp32'] # keep BN in fp32
|
||||
self._loss_scale = apex_config['loss_scale'] # loss scaling mode for mixed prec
|
||||
self._sync_bn = apex_config['sync_bn'] # should be replace BNs with sync BNs for distributed model
|
||||
self._distributed = apex_config['distributed'] # enable/disable distributed mode
|
||||
self._scale_lr = apex_config['scale_lr'] # enable/disable distributed mode
|
||||
self._min_world_size = apex_config['min_world_size'] # allows to confirm we are indeed in distributed setting
|
||||
conf_gpu_ids = apex_config['gpus']
|
||||
|
@ -37,20 +38,21 @@ class ApexUtils:
|
|||
self.global_rank = 0
|
||||
|
||||
logger.info({'apex_config': apex_config.to_dict()})
|
||||
logger.info({'torch.distributed_is_available': dist.is_available()})
|
||||
logger.info({'torch.distributed.is_available': dist.is_available()})
|
||||
if dist.is_available():
|
||||
logger.info({'gloo_available': dist.is_gloo_available(),
|
||||
'mpi_available': dist.is_mpi_available(),
|
||||
'nccl_available': dist.is_nccl_available()})
|
||||
|
||||
if self._enabled:
|
||||
# init enable mixed precision
|
||||
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
|
||||
from apex import amp
|
||||
self._amp = amp
|
||||
if self._mixed_prec_enabled:
|
||||
# init enable mixed precision
|
||||
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
|
||||
from apex import amp
|
||||
self._amp = amp
|
||||
|
||||
# enable distributed processing
|
||||
if self._distributed:
|
||||
if self._distributed_enabled:
|
||||
from apex import parallel
|
||||
self._ddp = parallel
|
||||
|
||||
|
@ -73,11 +75,14 @@ class ApexUtils:
|
|||
torch.cuda.set_device(self._gpu)
|
||||
self.device = torch.device('cuda', self._gpu)
|
||||
|
||||
logger.info({'amp_available': self._amp is not None, 'distributed_available': self._distributed is not None})
|
||||
logger.info({'distributed': self._distributed, 'world_size': self._world_size,
|
||||
'gpu': self._gpu, 'gpu_ids':self.gpu_ids, 'local_rank': self.local_rank})
|
||||
logger.info({'amp_available': self._amp is not None,
|
||||
'distributed_available': self._ddp is not None})
|
||||
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
|
||||
'world_size': self._world_size,
|
||||
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
|
||||
'local_rank': self.local_rank})
|
||||
|
||||
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False})
|
||||
logger.info({})
|
||||
|
||||
logger.close()
|
||||
|
||||
|
@ -140,11 +145,11 @@ class ApexUtils:
|
|||
return self.global_rank == 0
|
||||
|
||||
def sync_devices(self)->None:
|
||||
if self._distributed:
|
||||
torch.cuda.synchronize()
|
||||
if self._distributed_enabled:
|
||||
torch.cuda.synchronize(self.device)
|
||||
|
||||
def reduce(self, val, op='mean'):
|
||||
if self._distributed:
|
||||
if self._distributed_enabled:
|
||||
if not isinstance(val, Tensor):
|
||||
rt = torch.tensor(val)
|
||||
else:
|
||||
|
@ -211,9 +216,5 @@ class ApexUtils:
|
|||
def load_state_dict(self, state_dict):
|
||||
if self._amp:
|
||||
self._amp.load_state_dict()
|
||||
else:
|
||||
if state_dict is not None:
|
||||
raise RuntimeError('checkpoint state_dict is not None but Nvidia apex (amp) is not ')
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -18,13 +18,14 @@ common:
|
|||
# "ray start --head --redis-port=6379"
|
||||
redis: null
|
||||
apex:
|
||||
enabled: True # global switch to disable anything apex
|
||||
distributed_enabled: False # enable/disable distributed mode
|
||||
mixed_prec_enabled: False # switch to disable amp mixed precision
|
||||
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
|
||||
enabled: False # global switch to disable anything apex
|
||||
opt_level: 'O2' # optimization level for mixed precision
|
||||
bn_fp32: True # keep BN in fp32
|
||||
loss_scale: "dynamic" # loss scaling mode for mixed prec, must be string reprenting floar ot "dynamic"
|
||||
sync_bn: False # should be replace BNs with sync BNs for distributed model
|
||||
distributed: False # enable/disable distributed mode
|
||||
scale_lr: True # enable/disable distributed mode
|
||||
min_world_size: 0 # allows to confirm we are indeed in distributed setting
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче