Change local_rank to 0 to make test pass on 1 GPU machine.

This commit is contained in:
hlums 2020-01-22 16:24:04 +00:00
Родитель fc1517efb7
Коммит b5a9f502f3
1 изменённых файлов: 2 добавлений и 3 удалений

Просмотреть файл

@ -46,10 +46,10 @@ def test_get_device_all_gpus():
@pytest.mark.gpu
def test_get_device_local_rank():
device, gpus = get_device(local_rank=1)
device, gpus = get_device(local_rank=0)
assert isinstance(device, torch.device)
assert device.type == "cuda"
assert device.index == 1
assert device.index == 0
assert gpus == 1
@ -121,4 +121,3 @@ def test_move_to_device_gpu(model):
assert isinstance(model_cuda_same_gpu, DataParallel)
else:
assert isinstance(model_cuda_same_gpu, Sequential)