This commit is contained in:
Shital Shah 2020-04-22 19:57:41 -07:00
Родитель 3545fb852e
Коммит a81c871435
2 изменённых файлов: 24 добавлений и 23 удалений

Просмотреть файл

@ -19,12 +19,12 @@ class ApexUtils:
self._amp = self._ddp = None
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
self.gpu_ids = [] # use all gpus
self._mixed_prec_enabled = False
self._distributed_enabled = False
self._world_size = 1 # total number of processes in distributed run
self.local_rank = 0
self.global_rank = 0
self._set_ranks()
def reset(self, logger:OrderedDictLogger, apex_config:Config)->None:
# reset allows to configure differently for search or eval modes
@ -50,9 +50,6 @@ class ApexUtils:
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
self._amp, self._ddp = None, None
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 # which GPU to use, we will use only 1 GPU
self._world_size = 1 # total number of processes in distributed run
self.local_rank = 0
self.global_rank = 0
#logger.info({'apex_config': apex_config.to_dict()})
logger.info({'torch.distributed.is_available': dist.is_available()})
@ -80,11 +77,17 @@ class ApexUtils:
assert dist.is_initialized()
self._set_ranks()
assert dist.get_world_size() == self.world_size
assert dist.get_rank() == self.global_rank
else:
assert self.world_size == 1
assert self.local_rank == 0
assert self.global_rank == 0
assert self._world_size >= 1
assert not self._min_world_size or self._world_size >= self._min_world_size
assert self.local_rank >= 0 and self.local_rank < self._world_size
assert self.global_rank >= 0 and self.global_rank < self._world_size
assert self.world_size >= 1
assert not self._min_world_size or self.world_size >= self._min_world_size
assert self.local_rank >= 0 and self.local_rank < self.world_size
assert self.global_rank >= 0 and self.global_rank < self.world_size
assert self._gpu < torch.cuda.device_count()
torch.cuda.set_device(self._gpu)
@ -94,7 +97,7 @@ class ApexUtils:
logger.info({'amp_available': self._amp is not None,
'distributed_available': self._ddp is not None})
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
'world_size': self._world_size,
'world_size': self.world_size,
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
'local_rank': self.local_rank})
@ -128,21 +131,19 @@ class ApexUtils:
def _set_ranks(self)->None:
if 'WORLD_SIZE' in os.environ:
self._world_size = int(os.environ['WORLD_SIZE'])
assert dist.get_world_size() == self._world_size
self.world_size = int(os.environ['WORLD_SIZE'])
else:
raise RuntimeError('WORLD_SIZE must be set by distributed launcher when distributed mode is enabled')
self.world_size = 1
if 'LOCAL_RANK' in os.environ:
self.local_rank = int(os.environ['LOCAL_RANK'])
else:
raise RuntimeError('LOCAL_RANK must be set by distributed launcher when distributed mode is enabled')
self.local_rank = 0
self.global_rank = dist.get_rank()
# parser = argparse.ArgumentParser()
# parser.add_argument('--local-rank', type=int, help='local-rank must be supplied by torch distributed launcher!')
# args, extra_args = parser.parse_known_args()
# self.local_rank = args.local_rank
if 'RANK' in os.environ:
self.global_rank = int(os.environ['RANK'])
else:
self.global_rank = 0
assert self.local_rank < torch.cuda.device_count()
self._gpu = self.local_rank % torch.cuda.device_count()
@ -174,7 +175,7 @@ class ApexUtils:
r_op = self._op_map[op]
dist.all_reduce(rt, op=r_op)
if op=='mean':
rt /= self._world_size
rt /= self.world_size
if converted and len(rt.shape)==0:
return rt.item()
@ -202,7 +203,7 @@ class ApexUtils:
# scale LR
if self._scale_lr:
lr = ml_utils.get_optim_lr(optim)
scaled_lr = lr * self._world_size / float(batch_size)
scaled_lr = lr * self.world_size / float(batch_size)
ml_utils.set_optim_lr(optim, scaled_lr)
self.logger.info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})

Просмотреть файл

@ -236,7 +236,7 @@ def _setup_logger():
logs_yaml_filepath = utils.full_path(os.path.join(distdir, f'logs_{global_rank}.yaml'))
experiment_name = get_experiment_name() + '_' + str(global_rank)
enable_stdout = False
print('No stdout logging for replica {global_rank}')
print(f'No stdout logging for replica {global_rank}')
sys_logger = utils.create_logger(filepath=sys_log_filepath,
name=experiment_name,