refactor
This commit is contained in:
Родитель
7c3b82fda0
Коммит
6b41ede682
|
@ -4,9 +4,19 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel.data_parallel import DataParallel
|
||||
from torch.nn.modules.container import Sequential
|
||||
|
||||
from utils_nlp.common.pytorch_utils import get_device, move_to_device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model():
|
||||
return nn.Sequential(
|
||||
nn.Linear(24, 8), nn.ReLU(), nn.Linear(8, 2), nn.Sigmoid()
|
||||
)
|
||||
|
||||
|
||||
def test_get_device_cpu():
|
||||
device = get_device("cpu")
|
||||
assert isinstance(device, torch.device)
|
||||
|
@ -19,7 +29,7 @@ def test_get_device_exception():
|
|||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_gpu_machine():
|
||||
def test_machine_is_gpu_machine():
|
||||
assert torch.cuda.is_available() is True
|
||||
|
||||
|
||||
|
@ -30,11 +40,13 @@ def test_get_device_gpu():
|
|||
assert device.type == "cuda"
|
||||
|
||||
|
||||
def test_move_to_device():
|
||||
model = nn.Sequential(
|
||||
nn.Linear(24, 8), nn.ReLU(), nn.Linear(8, 2), nn.Sigmoid()
|
||||
)
|
||||
def test_move_to_device_cpu(model):
|
||||
# test when device.type="cpu"
|
||||
model_cpu = move_to_device(model, torch.device("cpu"))
|
||||
assert isinstance(model_cpu, nn.modules.container.Sequential)
|
||||
|
||||
|
||||
def test_move_to_device_cpu_parallelized(model):
|
||||
# test when input model is parallelized
|
||||
model_parallelized = nn.DataParallel(model)
|
||||
model_parallelized_output = move_to_device(
|
||||
|
@ -44,62 +56,61 @@ def test_move_to_device():
|
|||
model_parallelized_output, nn.modules.container.Sequential
|
||||
)
|
||||
|
||||
|
||||
def test_move_to_device_exception_not_torch_device(model):
|
||||
# test when device is not torch.device
|
||||
with pytest.raises(ValueError):
|
||||
move_to_device(model, "abc")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
||||
def test_move_to_device_exception_wrong_type(model):
|
||||
# test when device.type is not "cuda" or "cpu"
|
||||
with pytest.raises(Exception):
|
||||
move_to_device(model, torch.device("opengl"))
|
||||
|
||||
|
||||
def test_move_to_device_exception_gpu_model_on_cpu_machine(model):
|
||||
# test when the model is moved to a gpu but it is a cpu machine
|
||||
with pytest.raises(Exception):
|
||||
move_to_device(model, torch.device("cuda"))
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_move_to_device_exception_cuda_zero_gpus(model):
|
||||
# test when device.type is cuda, but num_gpus is 0
|
||||
with pytest.raises(ValueError):
|
||||
move_to_device(model, torch.device("cuda"), num_gpus=0)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_move_to_device_gpu(model):
|
||||
# test when device.type="cuda"
|
||||
model_cuda = move_to_device(model, torch.device("cuda"))
|
||||
num_cuda_devices = torch.cuda.device_count()
|
||||
|
||||
if num_cuda_devices > 1:
|
||||
assert isinstance(
|
||||
model_cuda, nn.parallel.data_parallel.DataParallel
|
||||
)
|
||||
assert isinstance(model_cuda, DataParallel)
|
||||
else:
|
||||
assert isinstance(model_cuda, nn.modules.container.Sequential)
|
||||
assert isinstance(model_cuda, Sequential)
|
||||
|
||||
model_cuda_1_gpu = move_to_device(
|
||||
model, torch.device("cuda"), num_gpus=1
|
||||
)
|
||||
assert isinstance(model_cuda_1_gpu, nn.modules.container.Sequential)
|
||||
assert isinstance(model_cuda_1_gpu, Sequential)
|
||||
|
||||
model_cuda_1_more_gpu = move_to_device(
|
||||
model, torch.device("cuda"), num_gpus=num_cuda_devices + 1
|
||||
)
|
||||
if num_cuda_devices > 1:
|
||||
assert isinstance(
|
||||
model_cuda_1_more_gpu, nn.parallel.data_parallel.DataParallel
|
||||
)
|
||||
assert isinstance(model_cuda_1_more_gpu, DataParallel)
|
||||
else:
|
||||
assert isinstance(
|
||||
model_cuda_1_more_gpu, nn.modules.container.Sequential
|
||||
)
|
||||
assert isinstance(model_cuda_1_more_gpu, Sequential)
|
||||
|
||||
model_cuda_same_gpu = move_to_device(
|
||||
model, torch.device("cuda"), num_gpus=num_cuda_devices
|
||||
)
|
||||
if num_cuda_devices > 1:
|
||||
assert isinstance(
|
||||
model_cuda_same_gpu, nn.parallel.data_parallel.DataParallel
|
||||
)
|
||||
assert isinstance(model_cuda_same_gpu, DataParallel)
|
||||
else:
|
||||
assert isinstance(
|
||||
model_cuda_same_gpu, nn.modules.container.Sequential
|
||||
)
|
||||
assert isinstance(model_cuda_same_gpu, Sequential)
|
||||
|
||||
# test when device.type is cuda, but num_gpus is 0
|
||||
with pytest.raises(ValueError):
|
||||
move_to_device(model, torch.device("cuda"), num_gpus=0)
|
||||
else:
|
||||
with pytest.raises(Exception):
|
||||
move_to_device(model, torch.device("cuda"))
|
||||
|
||||
# test when device.type="cpu"
|
||||
model_cpu = move_to_device(model, torch.device("cpu"))
|
||||
assert isinstance(model_cpu, nn.modules.container.Sequential)
|
||||
|
||||
# test when device.type is not "cuda" or "cpu"
|
||||
with pytest.raises(Exception):
|
||||
move_to_device(model, torch.device("opengl"))
|
||||
|
|
Загрузка…
Ссылка в новой задаче