зеркало из https://github.com/microsoft/archai.git
set rank refactor, add local_rank to seed
This commit is contained in:
Родитель
ee99587b14
Коммит
1cdcb8d7ba
|
@ -53,32 +53,15 @@ class ApexUtils:
|
|||
|
||||
# enable distributed processing
|
||||
if self._distributed:
|
||||
from apex import parallel
|
||||
self._ddp = parallel
|
||||
|
||||
assert dist.is_available() # distributed module is available
|
||||
assert dist.is_nccl_available()
|
||||
dist.init_process_group(backend='nccl', init_method='env://')
|
||||
assert dist.is_initialized()
|
||||
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
self._world_size = int(os.environ['WORLD_SIZE'])
|
||||
assert dist.get_world_size() == self._world_size
|
||||
else:
|
||||
raise RuntimeError('WORLD_SIZE must be set by distributed launcher when distributed mode is enabled')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--local-rank', type=int, help='local-rank must be supplied by torch distributed launcher!')
|
||||
args, extra_args = parser.parse_known_args()
|
||||
|
||||
self.local_rank = args.local_rank
|
||||
self.global_rank = dist.get_rank()
|
||||
|
||||
from apex import parallel
|
||||
self._ddp = parallel
|
||||
assert self.local_rank < torch.cuda.device_count()
|
||||
self._gpu = self.local_rank # reset to default assignment for rank
|
||||
# remap if GPU IDs are specified
|
||||
if len(self.gpu_ids):
|
||||
assert len(self.gpu_ids) > self.local_rank
|
||||
self._gpu = self.gpu_ids[self.local_rank]
|
||||
self._set_ranks()
|
||||
|
||||
assert self._world_size >= 1
|
||||
assert not self._min_world_size or self._world_size >= self._min_world_size
|
||||
|
@ -98,6 +81,32 @@ class ApexUtils:
|
|||
logger.close()
|
||||
|
||||
|
||||
def _set_ranks(self)->None:
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
self._world_size = int(os.environ['WORLD_SIZE'])
|
||||
assert dist.get_world_size() == self._world_size
|
||||
else:
|
||||
raise RuntimeError('WORLD_SIZE must be set by distributed launcher when distributed mode is enabled')
|
||||
|
||||
if 'LOCAL_RANK' in os.environ:
|
||||
self.local_rank = int(os.environ['LOCAL_RANK'])
|
||||
else:
|
||||
raise RuntimeError('LOCAL_RANK must be set by distributed launcher when distributed mode is enabled')
|
||||
|
||||
self.global_rank = dist.get_rank()
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('--local-rank', type=int, help='local-rank must be supplied by torch distributed launcher!')
|
||||
# args, extra_args = parser.parse_known_args()
|
||||
# self.local_rank = args.local_rank
|
||||
|
||||
assert self.local_rank < torch.cuda.device_count()
|
||||
self._gpu = self.local_rank % torch.cuda.device_count()
|
||||
# remap if GPU IDs are specified
|
||||
if len(self.gpu_ids):
|
||||
assert len(self.gpu_ids) > self.local_rank
|
||||
self._gpu = self.gpu_ids[self.local_rank]
|
||||
|
||||
|
||||
def _create_init_logger(self, distdir:Optional[str])->OrderedDictLogger:
|
||||
# create PID specific logger to support many distributed processes
|
||||
init_log_filepath, yaml_log_filepath = None, None
|
||||
|
|
|
@ -253,7 +253,7 @@ def _setup_apex():
|
|||
def _setup_gpus():
|
||||
conf_common = get_conf_common()
|
||||
|
||||
utils.setup_cuda(conf_common['seed'])
|
||||
utils.setup_cuda(conf_common['seed'], get_apex_utils().local_rank)
|
||||
|
||||
if conf_common['detect_anomaly']:
|
||||
logger.warn({'set_detect_anomaly':True})
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterable, Type, MutableMapping, Mapping, Any, Optional, Tuple, List
|
||||
from typing import Iterable, Type, MutableMapping, Mapping, Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import logging
|
||||
import csv
|
||||
|
@ -6,7 +6,7 @@ from collections import OrderedDict
|
|||
import sys
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
@ -207,12 +207,15 @@ def download_and_extract_tar(url, download_root, extract_root=None, filename=Non
|
|||
|
||||
extract_tar(os.path.join(download_root, filename), extract_root, **kwargs)
|
||||
|
||||
def setup_cuda(seed):
|
||||
seed = int(seed)
|
||||
def setup_cuda(seed:Union[float, int], local_rank:int):
|
||||
seed = int(seed) + local_rank
|
||||
# setup cuda
|
||||
cudnn.enabled = True
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
#torch.cuda.manual_seed_all(seed)
|
||||
cudnn.benchmark = True # set to false if deterministic
|
||||
torch.set_printoptions(precision=10)
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.backends.cudnn as cudnn
|
|||
import numpy as np
|
||||
|
||||
|
||||
utils.setup_cuda(2)
|
||||
utils.setup_cuda(2, 0)
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from torch_testbed.dataloader_dali import cifar10_dataloaders
|
|||
|
||||
|
||||
utils.create_logger()
|
||||
utils.setup_cuda(42)
|
||||
utils.setup_cuda(42, 0)
|
||||
|
||||
batch_size = 512
|
||||
half = True
|
||||
|
|
|
@ -5,7 +5,7 @@ from torch_testbed import utils, cifar10_models
|
|||
from torch_testbed.timing import MeasureTime, print_all_timings, print_timing, get_timing
|
||||
|
||||
utils.create_logger()
|
||||
utils.setup_cuda(42)
|
||||
utils.setup_cuda(42, 0)
|
||||
|
||||
batch_size = 512
|
||||
half = True
|
||||
|
|
Загрузка…
Ссылка в новой задаче