From f8e7d141da053a34f0dbd9a76079674bf8c43804 Mon Sep 17 00:00:00 2001 From: Shital Shah Date: Thu, 23 Apr 2020 22:08:53 -0700 Subject: [PATCH] propertly setup ranks in init --- archai/common/apex_utils.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/archai/common/apex_utils.py b/archai/common/apex_utils.py index 2aed0b5d..11815f3e 100644 --- a/archai/common/apex_utils.py +++ b/archai/common/apex_utils.py @@ -36,13 +36,7 @@ class ApexUtils: # defaults for non-distributed mode self._amp, self._ddp = None, None - self.world_size = 1 - self.local_rank = self.global_rank = 0 - - self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i] - - # which GPU to use, we will use only 1 GPU per process to avoid complications with apex - self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 + self._set_ranks(conf_gpu_ids) #_log_info({'apex_config': apex_config.to_dict()}) self._log_info({'torch.distributed.is_available': dist.is_available()}) @@ -70,8 +64,6 @@ class ApexUtils: if not dist.is_initialized(): dist.init_process_group(backend='nccl', init_method='env://') assert dist.is_initialized() - - self._set_ranks() assert dist.get_world_size() == self.world_size assert dist.get_rank() == self.global_rank else: @@ -125,7 +117,10 @@ class ApexUtils: # if len(vals) == 2: # _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1])) - def _set_ranks(self)->None: + def _set_ranks(self, conf_gpu_ids:str)->None: + + # this function needs to work even when torch.distributed is not available + if 'WORLD_SIZE' in os.environ: self.world_size = int(os.environ['WORLD_SIZE']) else: @@ -142,11 +137,16 @@ class ApexUtils: self.global_rank = 0 assert self.local_rank < torch.cuda.device_count() - self._gpu = self.local_rank % torch.cuda.device_count() + + self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i] + # which GPU to use, we will use only 1 GPU per process to avoid complications with apex # remap if GPU IDs are specified if len(self.gpu_ids): assert len(self.gpu_ids) > self.local_rank self._gpu = self.gpu_ids[self.local_rank] + else: + self._gpu = self.local_rank % torch.cuda.device_count() + def is_mixed(self)->bool: return self._enabled and self._mixed_prec_enabled