зеркало из https://github.com/microsoft/archai.git
fix getting dist rank
This commit is contained in:
Родитель
00a5d5f16b
Коммит
8969661426
|
@ -40,15 +40,15 @@ class DistributedStratifiedSampler(Sampler):
|
|||
assert hasattr(dataset, 'targets') and dataset.targets is not None, 'dataset needs to have targets attribute to work with this sampler'
|
||||
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
num_replicas = 1
|
||||
else:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
rank = 0
|
||||
else:
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
assert num_replicas >= 1
|
||||
assert rank >= 0 and rank < num_replicas
|
||||
|
|
Загрузка…
Ссылка в новой задаче