This commit is contained in:
Shital Shah 2020-04-23 22:08:53 -07:00
Родитель d1c77267a6
Коммит f8e7d141da
1 изменённых файлов: 11 добавлений и 11 удалений

Просмотреть файл

@ -36,13 +36,7 @@ class ApexUtils:
# defaults for non-distributed mode
self._amp, self._ddp = None, None
self.world_size = 1
self.local_rank = self.global_rank = 0
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
# which GPU to use, we will use only 1 GPU per process to avoid complications with apex
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0
self._set_ranks(conf_gpu_ids)
#_log_info({'apex_config': apex_config.to_dict()})
self._log_info({'torch.distributed.is_available': dist.is_available()})
@ -70,8 +64,6 @@ class ApexUtils:
if not dist.is_initialized():
dist.init_process_group(backend='nccl', init_method='env://')
assert dist.is_initialized()
self._set_ranks()
assert dist.get_world_size() == self.world_size
assert dist.get_rank() == self.global_rank
else:
@ -125,7 +117,10 @@ class ApexUtils:
# if len(vals) == 2:
# _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
def _set_ranks(self)->None:
def _set_ranks(self, conf_gpu_ids:str)->None:
# this function needs to work even when torch.distributed is not available
if 'WORLD_SIZE' in os.environ:
self.world_size = int(os.environ['WORLD_SIZE'])
else:
@ -142,11 +137,16 @@ class ApexUtils:
self.global_rank = 0
assert self.local_rank < torch.cuda.device_count()
self._gpu = self.local_rank % torch.cuda.device_count()
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
# which GPU to use, we will use only 1 GPU per process to avoid complications with apex
# remap if GPU IDs are specified
if len(self.gpu_ids):
assert len(self.gpu_ids) > self.local_rank
self._gpu = self.gpu_ids[self.local_rank]
else:
self._gpu = self.local_rank % torch.cuda.device_count()
def is_mixed(self)->bool:
return self._enabled and self._mixed_prec_enabled