Change local_rank to 0 to make test pass on 1 GPU machine.
This commit is contained in:
Родитель
fc1517efb7
Коммит
b5a9f502f3
|
@ -46,10 +46,10 @@ def test_get_device_all_gpus():
|
|||
|
||||
@pytest.mark.gpu
|
||||
def test_get_device_local_rank():
|
||||
device, gpus = get_device(local_rank=1)
|
||||
device, gpus = get_device(local_rank=0)
|
||||
assert isinstance(device, torch.device)
|
||||
assert device.type == "cuda"
|
||||
assert device.index == 1
|
||||
assert device.index == 0
|
||||
assert gpus == 1
|
||||
|
||||
|
||||
|
@ -121,4 +121,3 @@ def test_move_to_device_gpu(model):
|
|||
assert isinstance(model_cuda_same_gpu, DataParallel)
|
||||
else:
|
||||
assert isinstance(model_cuda_same_gpu, Sequential)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче