diff --git a/archai/common/apex_utils.py b/archai/common/apex_utils.py index 4497ba66..ddba4233 100644 --- a/archai/common/apex_utils.py +++ b/archai/common/apex_utils.py @@ -53,32 +53,15 @@ class ApexUtils: # enable distributed processing if self._distributed: + from apex import parallel + self._ddp = parallel + assert dist.is_available() # distributed module is available assert dist.is_nccl_available() dist.init_process_group(backend='nccl', init_method='env://') assert dist.is_initialized() - if 'WORLD_SIZE' in os.environ: - self._world_size = int(os.environ['WORLD_SIZE']) - assert dist.get_world_size() == self._world_size - else: - raise RuntimeError('WORLD_SIZE must be set by distributed launcher when distributed mode is enabled') - - parser = argparse.ArgumentParser() - parser.add_argument('--local-rank', type=int, help='local-rank must be supplied by torch distributed launcher!') - args, extra_args = parser.parse_known_args() - - self.local_rank = args.local_rank - self.global_rank = dist.get_rank() - - from apex import parallel - self._ddp = parallel - assert self.local_rank < torch.cuda.device_count() - self._gpu = self.local_rank # reset to default assignment for rank - # remap if GPU IDs are specified - if len(self.gpu_ids): - assert len(self.gpu_ids) > self.local_rank - self._gpu = self.gpu_ids[self.local_rank] + self._set_ranks() assert self._world_size >= 1 assert not self._min_world_size or self._world_size >= self._min_world_size @@ -98,6 +81,32 @@ class ApexUtils: logger.close() + def _set_ranks(self)->None: + if 'WORLD_SIZE' in os.environ: + self._world_size = int(os.environ['WORLD_SIZE']) + assert dist.get_world_size() == self._world_size + else: + raise RuntimeError('WORLD_SIZE must be set by distributed launcher when distributed mode is enabled') + + if 'LOCAL_RANK' in os.environ: + self.local_rank = int(os.environ['LOCAL_RANK']) + else: + raise RuntimeError('LOCAL_RANK must be set by distributed launcher when distributed mode is enabled') + + self.global_rank = dist.get_rank() + # parser = argparse.ArgumentParser() + # parser.add_argument('--local-rank', type=int, help='local-rank must be supplied by torch distributed launcher!') + # args, extra_args = parser.parse_known_args() + # self.local_rank = args.local_rank + + assert self.local_rank < torch.cuda.device_count() + self._gpu = self.local_rank % torch.cuda.device_count() + # remap if GPU IDs are specified + if len(self.gpu_ids): + assert len(self.gpu_ids) > self.local_rank + self._gpu = self.gpu_ids[self.local_rank] + + def _create_init_logger(self, distdir:Optional[str])->OrderedDictLogger: # create PID specific logger to support many distributed processes init_log_filepath, yaml_log_filepath = None, None diff --git a/archai/common/common.py b/archai/common/common.py index 2c070c86..1a9eacfb 100644 --- a/archai/common/common.py +++ b/archai/common/common.py @@ -253,7 +253,7 @@ def _setup_apex(): def _setup_gpus(): conf_common = get_conf_common() - utils.setup_cuda(conf_common['seed']) + utils.setup_cuda(conf_common['seed'], get_apex_utils().local_rank) if conf_common['detect_anomaly']: logger.warn({'set_detect_anomaly':True}) diff --git a/archai/common/utils.py b/archai/common/utils.py index 8804848b..5582255d 100644 --- a/archai/common/utils.py +++ b/archai/common/utils.py @@ -1,4 +1,4 @@ -from typing import Iterable, Type, MutableMapping, Mapping, Any, Optional, Tuple, List +from typing import Iterable, Type, MutableMapping, Mapping, Any, Optional, Tuple, List, Union import numpy as np import logging import csv @@ -6,7 +6,7 @@ from collections import OrderedDict import sys import os import pathlib -import time +import random import torch import torch.backends.cudnn as cudnn @@ -207,12 +207,15 @@ def download_and_extract_tar(url, download_root, extract_root=None, filename=Non extract_tar(os.path.join(download_root, filename), extract_root, **kwargs) -def setup_cuda(seed): - seed = int(seed) +def setup_cuda(seed:Union[float, int], local_rank:int): + seed = int(seed) + local_rank # setup cuda cudnn.enabled = True - np.random.seed(seed) torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + #torch.cuda.manual_seed_all(seed) cudnn.benchmark = True # set to false if deterministic torch.set_printoptions(precision=10) diff --git a/scripts/perf/mixed_ops.py b/scripts/perf/mixed_ops.py index 466d8fff..70adfa35 100644 --- a/scripts/perf/mixed_ops.py +++ b/scripts/perf/mixed_ops.py @@ -6,7 +6,7 @@ import torch.backends.cudnn as cudnn import numpy as np -utils.setup_cuda(2) +utils.setup_cuda(2, 0) device = torch.device('cuda') diff --git a/scripts/perf/model_dl_test.py b/scripts/perf/model_dl_test.py index 23eec7c1..2e1af4b7 100644 --- a/scripts/perf/model_dl_test.py +++ b/scripts/perf/model_dl_test.py @@ -10,7 +10,7 @@ from torch_testbed.dataloader_dali import cifar10_dataloaders utils.create_logger() -utils.setup_cuda(42) +utils.setup_cuda(42, 0) batch_size = 512 half = True diff --git a/scripts/perf/model_test.py b/scripts/perf/model_test.py index effed2cd..23b25bbe 100644 --- a/scripts/perf/model_test.py +++ b/scripts/perf/model_test.py @@ -5,7 +5,7 @@ from torch_testbed import utils, cifar10_models from torch_testbed.timing import MeasureTime, print_all_timings, print_timing, get_timing utils.create_logger() -utils.setup_cuda(42) +utils.setup_cuda(42, 0) batch_size = 512 half = True