set rank refactor, add local_rank to seed

This commit is contained in:
Shital Shah 2020-04-21 23:55:00 -07:00
Родитель ee99587b14
Коммит 1cdcb8d7ba
6 изменённых файлов: 42 добавлений и 30 удалений

Просмотреть файл

@ -53,32 +53,15 @@ class ApexUtils:
# enable distributed processing
if self._distributed:
from apex import parallel
self._ddp = parallel
assert dist.is_available() # distributed module is available
assert dist.is_nccl_available()
dist.init_process_group(backend='nccl', init_method='env://')
assert dist.is_initialized()
if 'WORLD_SIZE' in os.environ:
self._world_size = int(os.environ['WORLD_SIZE'])
assert dist.get_world_size() == self._world_size
else:
raise RuntimeError('WORLD_SIZE must be set by distributed launcher when distributed mode is enabled')
parser = argparse.ArgumentParser()
parser.add_argument('--local-rank', type=int, help='local-rank must be supplied by torch distributed launcher!')
args, extra_args = parser.parse_known_args()
self.local_rank = args.local_rank
self.global_rank = dist.get_rank()
from apex import parallel
self._ddp = parallel
assert self.local_rank < torch.cuda.device_count()
self._gpu = self.local_rank # reset to default assignment for rank
# remap if GPU IDs are specified
if len(self.gpu_ids):
assert len(self.gpu_ids) > self.local_rank
self._gpu = self.gpu_ids[self.local_rank]
self._set_ranks()
assert self._world_size >= 1
assert not self._min_world_size or self._world_size >= self._min_world_size
@ -98,6 +81,32 @@ class ApexUtils:
logger.close()
def _set_ranks(self)->None:
if 'WORLD_SIZE' in os.environ:
self._world_size = int(os.environ['WORLD_SIZE'])
assert dist.get_world_size() == self._world_size
else:
raise RuntimeError('WORLD_SIZE must be set by distributed launcher when distributed mode is enabled')
if 'LOCAL_RANK' in os.environ:
self.local_rank = int(os.environ['LOCAL_RANK'])
else:
raise RuntimeError('LOCAL_RANK must be set by distributed launcher when distributed mode is enabled')
self.global_rank = dist.get_rank()
# parser = argparse.ArgumentParser()
# parser.add_argument('--local-rank', type=int, help='local-rank must be supplied by torch distributed launcher!')
# args, extra_args = parser.parse_known_args()
# self.local_rank = args.local_rank
assert self.local_rank < torch.cuda.device_count()
self._gpu = self.local_rank % torch.cuda.device_count()
# remap if GPU IDs are specified
if len(self.gpu_ids):
assert len(self.gpu_ids) > self.local_rank
self._gpu = self.gpu_ids[self.local_rank]
def _create_init_logger(self, distdir:Optional[str])->OrderedDictLogger:
# create PID specific logger to support many distributed processes
init_log_filepath, yaml_log_filepath = None, None

Просмотреть файл

@ -253,7 +253,7 @@ def _setup_apex():
def _setup_gpus():
conf_common = get_conf_common()
utils.setup_cuda(conf_common['seed'])
utils.setup_cuda(conf_common['seed'], get_apex_utils().local_rank)
if conf_common['detect_anomaly']:
logger.warn({'set_detect_anomaly':True})

Просмотреть файл

@ -1,4 +1,4 @@
from typing import Iterable, Type, MutableMapping, Mapping, Any, Optional, Tuple, List
from typing import Iterable, Type, MutableMapping, Mapping, Any, Optional, Tuple, List, Union
import numpy as np
import logging
import csv
@ -6,7 +6,7 @@ from collections import OrderedDict
import sys
import os
import pathlib
import time
import random
import torch
import torch.backends.cudnn as cudnn
@ -207,12 +207,15 @@ def download_and_extract_tar(url, download_root, extract_root=None, filename=Non
extract_tar(os.path.join(download_root, filename), extract_root, **kwargs)
def setup_cuda(seed):
seed = int(seed)
def setup_cuda(seed:Union[float, int], local_rank:int):
seed = int(seed) + local_rank
# setup cuda
cudnn.enabled = True
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
#torch.cuda.manual_seed_all(seed)
cudnn.benchmark = True # set to false if deterministic
torch.set_printoptions(precision=10)

Просмотреть файл

@ -6,7 +6,7 @@ import torch.backends.cudnn as cudnn
import numpy as np
utils.setup_cuda(2)
utils.setup_cuda(2, 0)
device = torch.device('cuda')

Просмотреть файл

@ -10,7 +10,7 @@ from torch_testbed.dataloader_dali import cifar10_dataloaders
utils.create_logger()
utils.setup_cuda(42)
utils.setup_cuda(42, 0)
batch_size = 512
half = True

Просмотреть файл

@ -5,7 +5,7 @@ from torch_testbed import utils, cifar10_models
from torch_testbed.timing import MeasureTime, print_all_timings, print_timing, get_timing
utils.create_logger()
utils.setup_cuda(42)
utils.setup_cuda(42, 0)
batch_size = 512
half = True