diff --git a/tests/unit/test_common_pytorch_utils.py b/tests/unit/test_common_pytorch_utils.py index 1f5b6b2..c05223b 100644 --- a/tests/unit/test_common_pytorch_utils.py +++ b/tests/unit/test_common_pytorch_utils.py @@ -4,8 +4,18 @@ import pytest import torch import torch.nn as nn +from torch.nn.parallel.data_parallel import DataParallel +from torch.nn.modules.container import Sequential + from utils_nlp.common.pytorch_utils import get_device, move_to_device + +@pytest.fixture +def model(): + return nn.Sequential( + nn.Linear(24, 8), nn.ReLU(), nn.Linear(8, 2), nn.Sigmoid() + ) + def test_get_device_cpu(): device = get_device("cpu") @@ -19,7 +29,7 @@ def test_get_device_exception(): @pytest.mark.gpu -def test_gpu_machine(): +def test_machine_is_gpu_machine(): assert torch.cuda.is_available() is True @@ -29,12 +39,14 @@ def test_get_device_gpu(): assert isinstance(device, torch.device) assert device.type == "cuda" - -def test_move_to_device(): - model = nn.Sequential( - nn.Linear(24, 8), nn.ReLU(), nn.Linear(8, 2), nn.Sigmoid() - ) +def test_move_to_device_cpu(model): + # test when device.type="cpu" + model_cpu = move_to_device(model, torch.device("cpu")) + assert isinstance(model_cpu, nn.modules.container.Sequential) + + +def test_move_to_device_cpu_parallelized(model): # test when input model is parallelized model_parallelized = nn.DataParallel(model) model_parallelized_output = move_to_device( @@ -43,63 +55,62 @@ def test_move_to_device(): assert isinstance( model_parallelized_output, nn.modules.container.Sequential ) - + + +def test_move_to_device_exception_not_torch_device(model): # test when device is not torch.device with pytest.raises(ValueError): move_to_device(model, "abc") - - if torch.cuda.is_available(): - # test when device.type="cuda" - model_cuda = move_to_device(model, torch.device("cuda")) - num_cuda_devices = torch.cuda.device_count() - - if num_cuda_devices > 1: - assert isinstance( - model_cuda, nn.parallel.data_parallel.DataParallel - ) - else: - assert isinstance(model_cuda, nn.modules.container.Sequential) - - model_cuda_1_gpu = move_to_device( - model, torch.device("cuda"), num_gpus=1 - ) - assert isinstance(model_cuda_1_gpu, nn.modules.container.Sequential) - - model_cuda_1_more_gpu = move_to_device( - model, torch.device("cuda"), num_gpus=num_cuda_devices + 1 - ) - if num_cuda_devices > 1: - assert isinstance( - model_cuda_1_more_gpu, nn.parallel.data_parallel.DataParallel - ) - else: - assert isinstance( - model_cuda_1_more_gpu, nn.modules.container.Sequential - ) - - model_cuda_same_gpu = move_to_device( - model, torch.device("cuda"), num_gpus=num_cuda_devices - ) - if num_cuda_devices > 1: - assert isinstance( - model_cuda_same_gpu, nn.parallel.data_parallel.DataParallel - ) - else: - assert isinstance( - model_cuda_same_gpu, nn.modules.container.Sequential - ) - - # test when device.type is cuda, but num_gpus is 0 - with pytest.raises(ValueError): - move_to_device(model, torch.device("cuda"), num_gpus=0) - else: - with pytest.raises(Exception): - move_to_device(model, torch.device("cuda")) - - # test when device.type="cpu" - model_cpu = move_to_device(model, torch.device("cpu")) - assert isinstance(model_cpu, nn.modules.container.Sequential) - + + +def test_move_to_device_exception_wrong_type(model): # test when device.type is not "cuda" or "cpu" with pytest.raises(Exception): move_to_device(model, torch.device("opengl")) + + +def test_move_to_device_exception_gpu_model_on_cpu_machine(model): + # test when the model is moved to a gpu but it is a cpu machine + with pytest.raises(Exception): + move_to_device(model, torch.device("cuda")) + + +@pytest.mark.gpu +def test_move_to_device_exception_cuda_zero_gpus(model): + # test when device.type is cuda, but num_gpus is 0 + with pytest.raises(ValueError): + move_to_device(model, torch.device("cuda"), num_gpus=0) + + +@pytest.mark.gpu +def test_move_to_device_gpu(model): + # test when device.type="cuda" + model_cuda = move_to_device(model, torch.device("cuda")) + num_cuda_devices = torch.cuda.device_count() + + if num_cuda_devices > 1: + assert isinstance(model_cuda, DataParallel) + else: + assert isinstance(model_cuda, Sequential) + + model_cuda_1_gpu = move_to_device( + model, torch.device("cuda"), num_gpus=1 + ) + assert isinstance(model_cuda_1_gpu, Sequential) + + model_cuda_1_more_gpu = move_to_device( + model, torch.device("cuda"), num_gpus=num_cuda_devices + 1 + ) + if num_cuda_devices > 1: + assert isinstance(model_cuda_1_more_gpu, DataParallel) + else: + assert isinstance(model_cuda_1_more_gpu, Sequential) + + model_cuda_same_gpu = move_to_device( + model, torch.device("cuda"), num_gpus=num_cuda_devices + ) + if num_cuda_devices > 1: + assert isinstance(model_cuda_same_gpu, DataParallel) + else: + assert isinstance(model_cuda_same_gpu, Sequential) +