зеркало из https://github.com/microsoft/archai.git
init ranks in constructor
This commit is contained in:
Родитель
3545fb852e
Коммит
a81c871435
|
@ -19,12 +19,12 @@ class ApexUtils:
|
|||
self._amp = self._ddp = None
|
||||
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
|
||||
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
|
||||
self.gpu_ids = [] # use all gpus
|
||||
|
||||
self._mixed_prec_enabled = False
|
||||
self._distributed_enabled = False
|
||||
self._world_size = 1 # total number of processes in distributed run
|
||||
self.local_rank = 0
|
||||
self.global_rank = 0
|
||||
|
||||
self._set_ranks()
|
||||
|
||||
def reset(self, logger:OrderedDictLogger, apex_config:Config)->None:
|
||||
# reset allows to configure differently for search or eval modes
|
||||
|
@ -50,9 +50,6 @@ class ApexUtils:
|
|||
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
|
||||
self._amp, self._ddp = None, None
|
||||
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 # which GPU to use, we will use only 1 GPU
|
||||
self._world_size = 1 # total number of processes in distributed run
|
||||
self.local_rank = 0
|
||||
self.global_rank = 0
|
||||
|
||||
#logger.info({'apex_config': apex_config.to_dict()})
|
||||
logger.info({'torch.distributed.is_available': dist.is_available()})
|
||||
|
@ -80,11 +77,17 @@ class ApexUtils:
|
|||
assert dist.is_initialized()
|
||||
|
||||
self._set_ranks()
|
||||
assert dist.get_world_size() == self.world_size
|
||||
assert dist.get_rank() == self.global_rank
|
||||
else:
|
||||
assert self.world_size == 1
|
||||
assert self.local_rank == 0
|
||||
assert self.global_rank == 0
|
||||
|
||||
assert self._world_size >= 1
|
||||
assert not self._min_world_size or self._world_size >= self._min_world_size
|
||||
assert self.local_rank >= 0 and self.local_rank < self._world_size
|
||||
assert self.global_rank >= 0 and self.global_rank < self._world_size
|
||||
assert self.world_size >= 1
|
||||
assert not self._min_world_size or self.world_size >= self._min_world_size
|
||||
assert self.local_rank >= 0 and self.local_rank < self.world_size
|
||||
assert self.global_rank >= 0 and self.global_rank < self.world_size
|
||||
|
||||
assert self._gpu < torch.cuda.device_count()
|
||||
torch.cuda.set_device(self._gpu)
|
||||
|
@ -94,7 +97,7 @@ class ApexUtils:
|
|||
logger.info({'amp_available': self._amp is not None,
|
||||
'distributed_available': self._ddp is not None})
|
||||
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
|
||||
'world_size': self._world_size,
|
||||
'world_size': self.world_size,
|
||||
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
|
||||
'local_rank': self.local_rank})
|
||||
|
||||
|
@ -128,21 +131,19 @@ class ApexUtils:
|
|||
|
||||
def _set_ranks(self)->None:
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
self._world_size = int(os.environ['WORLD_SIZE'])
|
||||
assert dist.get_world_size() == self._world_size
|
||||
self.world_size = int(os.environ['WORLD_SIZE'])
|
||||
else:
|
||||
raise RuntimeError('WORLD_SIZE must be set by distributed launcher when distributed mode is enabled')
|
||||
self.world_size = 1
|
||||
|
||||
if 'LOCAL_RANK' in os.environ:
|
||||
self.local_rank = int(os.environ['LOCAL_RANK'])
|
||||
else:
|
||||
raise RuntimeError('LOCAL_RANK must be set by distributed launcher when distributed mode is enabled')
|
||||
self.local_rank = 0
|
||||
|
||||
self.global_rank = dist.get_rank()
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('--local-rank', type=int, help='local-rank must be supplied by torch distributed launcher!')
|
||||
# args, extra_args = parser.parse_known_args()
|
||||
# self.local_rank = args.local_rank
|
||||
if 'RANK' in os.environ:
|
||||
self.global_rank = int(os.environ['RANK'])
|
||||
else:
|
||||
self.global_rank = 0
|
||||
|
||||
assert self.local_rank < torch.cuda.device_count()
|
||||
self._gpu = self.local_rank % torch.cuda.device_count()
|
||||
|
@ -174,7 +175,7 @@ class ApexUtils:
|
|||
r_op = self._op_map[op]
|
||||
dist.all_reduce(rt, op=r_op)
|
||||
if op=='mean':
|
||||
rt /= self._world_size
|
||||
rt /= self.world_size
|
||||
|
||||
if converted and len(rt.shape)==0:
|
||||
return rt.item()
|
||||
|
@ -202,7 +203,7 @@ class ApexUtils:
|
|||
# scale LR
|
||||
if self._scale_lr:
|
||||
lr = ml_utils.get_optim_lr(optim)
|
||||
scaled_lr = lr * self._world_size / float(batch_size)
|
||||
scaled_lr = lr * self.world_size / float(batch_size)
|
||||
ml_utils.set_optim_lr(optim, scaled_lr)
|
||||
self.logger.info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
|
||||
|
||||
|
|
|
@ -236,7 +236,7 @@ def _setup_logger():
|
|||
logs_yaml_filepath = utils.full_path(os.path.join(distdir, f'logs_{global_rank}.yaml'))
|
||||
experiment_name = get_experiment_name() + '_' + str(global_rank)
|
||||
enable_stdout = False
|
||||
print('No stdout logging for replica {global_rank}')
|
||||
print(f'No stdout logging for replica {global_rank}')
|
||||
|
||||
sys_logger = utils.create_logger(filepath=sys_log_filepath,
|
||||
name=experiment_name,
|
||||
|
|
Загрузка…
Ссылка в новой задаче