зеркало из https://github.com/microsoft/archai.git
propertly setup ranks in init
This commit is contained in:
Родитель
d1c77267a6
Коммит
f8e7d141da
|
@ -36,13 +36,7 @@ class ApexUtils:
|
|||
|
||||
# defaults for non-distributed mode
|
||||
self._amp, self._ddp = None, None
|
||||
self.world_size = 1
|
||||
self.local_rank = self.global_rank = 0
|
||||
|
||||
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
|
||||
|
||||
# which GPU to use, we will use only 1 GPU per process to avoid complications with apex
|
||||
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0
|
||||
self._set_ranks(conf_gpu_ids)
|
||||
|
||||
#_log_info({'apex_config': apex_config.to_dict()})
|
||||
self._log_info({'torch.distributed.is_available': dist.is_available()})
|
||||
|
@ -70,8 +64,6 @@ class ApexUtils:
|
|||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend='nccl', init_method='env://')
|
||||
assert dist.is_initialized()
|
||||
|
||||
self._set_ranks()
|
||||
assert dist.get_world_size() == self.world_size
|
||||
assert dist.get_rank() == self.global_rank
|
||||
else:
|
||||
|
@ -125,7 +117,10 @@ class ApexUtils:
|
|||
# if len(vals) == 2:
|
||||
# _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
|
||||
|
||||
def _set_ranks(self)->None:
|
||||
def _set_ranks(self, conf_gpu_ids:str)->None:
|
||||
|
||||
# this function needs to work even when torch.distributed is not available
|
||||
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
self.world_size = int(os.environ['WORLD_SIZE'])
|
||||
else:
|
||||
|
@ -142,11 +137,16 @@ class ApexUtils:
|
|||
self.global_rank = 0
|
||||
|
||||
assert self.local_rank < torch.cuda.device_count()
|
||||
self._gpu = self.local_rank % torch.cuda.device_count()
|
||||
|
||||
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
|
||||
# which GPU to use, we will use only 1 GPU per process to avoid complications with apex
|
||||
# remap if GPU IDs are specified
|
||||
if len(self.gpu_ids):
|
||||
assert len(self.gpu_ids) > self.local_rank
|
||||
self._gpu = self.gpu_ids[self.local_rank]
|
||||
else:
|
||||
self._gpu = self.local_rank % torch.cuda.device_count()
|
||||
|
||||
|
||||
def is_mixed(self)->bool:
|
||||
return self._enabled and self._mixed_prec_enabled
|
||||
|
|
Загрузка…
Ссылка в новой задаче