This commit is contained in:
Shital Shah 2020-04-22 12:17:50 -07:00
Родитель 00a5d5f16b
Коммит 8969661426
1 изменённых файлов: 6 добавлений и 6 удалений

Просмотреть файл

@ -40,15 +40,15 @@ class DistributedStratifiedSampler(Sampler):
assert hasattr(dataset, 'targets') and dataset.targets is not None, 'dataset needs to have targets attribute to work with this sampler'
if num_replicas is None:
if not dist.is_available():
num_replicas = 1
else:
if dist.is_available() and dist.is_initialized():
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
rank = 0
else:
num_replicas = 1
if rank is None:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
assert num_replicas >= 1
assert rank >= 0 and rank < num_replicas