refactor
This commit is contained in:
Родитель
7c3b82fda0
Коммит
6b41ede682
|
@ -4,8 +4,18 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel.data_parallel import DataParallel
|
||||||
|
from torch.nn.modules.container import Sequential
|
||||||
|
|
||||||
from utils_nlp.common.pytorch_utils import get_device, move_to_device
|
from utils_nlp.common.pytorch_utils import get_device, move_to_device
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model():
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Linear(24, 8), nn.ReLU(), nn.Linear(8, 2), nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_device_cpu():
|
def test_get_device_cpu():
|
||||||
device = get_device("cpu")
|
device = get_device("cpu")
|
||||||
|
@ -19,7 +29,7 @@ def test_get_device_exception():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.gpu
|
@pytest.mark.gpu
|
||||||
def test_gpu_machine():
|
def test_machine_is_gpu_machine():
|
||||||
assert torch.cuda.is_available() is True
|
assert torch.cuda.is_available() is True
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,12 +39,14 @@ def test_get_device_gpu():
|
||||||
assert isinstance(device, torch.device)
|
assert isinstance(device, torch.device)
|
||||||
assert device.type == "cuda"
|
assert device.type == "cuda"
|
||||||
|
|
||||||
|
|
||||||
def test_move_to_device():
|
|
||||||
model = nn.Sequential(
|
|
||||||
nn.Linear(24, 8), nn.ReLU(), nn.Linear(8, 2), nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def test_move_to_device_cpu(model):
|
||||||
|
# test when device.type="cpu"
|
||||||
|
model_cpu = move_to_device(model, torch.device("cpu"))
|
||||||
|
assert isinstance(model_cpu, nn.modules.container.Sequential)
|
||||||
|
|
||||||
|
|
||||||
|
def test_move_to_device_cpu_parallelized(model):
|
||||||
# test when input model is parallelized
|
# test when input model is parallelized
|
||||||
model_parallelized = nn.DataParallel(model)
|
model_parallelized = nn.DataParallel(model)
|
||||||
model_parallelized_output = move_to_device(
|
model_parallelized_output = move_to_device(
|
||||||
|
@ -43,63 +55,62 @@ def test_move_to_device():
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
model_parallelized_output, nn.modules.container.Sequential
|
model_parallelized_output, nn.modules.container.Sequential
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_move_to_device_exception_not_torch_device(model):
|
||||||
# test when device is not torch.device
|
# test when device is not torch.device
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
move_to_device(model, "abc")
|
move_to_device(model, "abc")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
# test when device.type="cuda"
|
def test_move_to_device_exception_wrong_type(model):
|
||||||
model_cuda = move_to_device(model, torch.device("cuda"))
|
|
||||||
num_cuda_devices = torch.cuda.device_count()
|
|
||||||
|
|
||||||
if num_cuda_devices > 1:
|
|
||||||
assert isinstance(
|
|
||||||
model_cuda, nn.parallel.data_parallel.DataParallel
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert isinstance(model_cuda, nn.modules.container.Sequential)
|
|
||||||
|
|
||||||
model_cuda_1_gpu = move_to_device(
|
|
||||||
model, torch.device("cuda"), num_gpus=1
|
|
||||||
)
|
|
||||||
assert isinstance(model_cuda_1_gpu, nn.modules.container.Sequential)
|
|
||||||
|
|
||||||
model_cuda_1_more_gpu = move_to_device(
|
|
||||||
model, torch.device("cuda"), num_gpus=num_cuda_devices + 1
|
|
||||||
)
|
|
||||||
if num_cuda_devices > 1:
|
|
||||||
assert isinstance(
|
|
||||||
model_cuda_1_more_gpu, nn.parallel.data_parallel.DataParallel
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert isinstance(
|
|
||||||
model_cuda_1_more_gpu, nn.modules.container.Sequential
|
|
||||||
)
|
|
||||||
|
|
||||||
model_cuda_same_gpu = move_to_device(
|
|
||||||
model, torch.device("cuda"), num_gpus=num_cuda_devices
|
|
||||||
)
|
|
||||||
if num_cuda_devices > 1:
|
|
||||||
assert isinstance(
|
|
||||||
model_cuda_same_gpu, nn.parallel.data_parallel.DataParallel
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert isinstance(
|
|
||||||
model_cuda_same_gpu, nn.modules.container.Sequential
|
|
||||||
)
|
|
||||||
|
|
||||||
# test when device.type is cuda, but num_gpus is 0
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
move_to_device(model, torch.device("cuda"), num_gpus=0)
|
|
||||||
else:
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
move_to_device(model, torch.device("cuda"))
|
|
||||||
|
|
||||||
# test when device.type="cpu"
|
|
||||||
model_cpu = move_to_device(model, torch.device("cpu"))
|
|
||||||
assert isinstance(model_cpu, nn.modules.container.Sequential)
|
|
||||||
|
|
||||||
# test when device.type is not "cuda" or "cpu"
|
# test when device.type is not "cuda" or "cpu"
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
move_to_device(model, torch.device("opengl"))
|
move_to_device(model, torch.device("opengl"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_move_to_device_exception_gpu_model_on_cpu_machine(model):
|
||||||
|
# test when the model is moved to a gpu but it is a cpu machine
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
move_to_device(model, torch.device("cuda"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.gpu
|
||||||
|
def test_move_to_device_exception_cuda_zero_gpus(model):
|
||||||
|
# test when device.type is cuda, but num_gpus is 0
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
move_to_device(model, torch.device("cuda"), num_gpus=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.gpu
|
||||||
|
def test_move_to_device_gpu(model):
|
||||||
|
# test when device.type="cuda"
|
||||||
|
model_cuda = move_to_device(model, torch.device("cuda"))
|
||||||
|
num_cuda_devices = torch.cuda.device_count()
|
||||||
|
|
||||||
|
if num_cuda_devices > 1:
|
||||||
|
assert isinstance(model_cuda, DataParallel)
|
||||||
|
else:
|
||||||
|
assert isinstance(model_cuda, Sequential)
|
||||||
|
|
||||||
|
model_cuda_1_gpu = move_to_device(
|
||||||
|
model, torch.device("cuda"), num_gpus=1
|
||||||
|
)
|
||||||
|
assert isinstance(model_cuda_1_gpu, Sequential)
|
||||||
|
|
||||||
|
model_cuda_1_more_gpu = move_to_device(
|
||||||
|
model, torch.device("cuda"), num_gpus=num_cuda_devices + 1
|
||||||
|
)
|
||||||
|
if num_cuda_devices > 1:
|
||||||
|
assert isinstance(model_cuda_1_more_gpu, DataParallel)
|
||||||
|
else:
|
||||||
|
assert isinstance(model_cuda_1_more_gpu, Sequential)
|
||||||
|
|
||||||
|
model_cuda_same_gpu = move_to_device(
|
||||||
|
model, torch.device("cuda"), num_gpus=num_cuda_devices
|
||||||
|
)
|
||||||
|
if num_cuda_devices > 1:
|
||||||
|
assert isinstance(model_cuda_same_gpu, DataParallel)
|
||||||
|
else:
|
||||||
|
assert isinstance(model_cuda_same_gpu, Sequential)
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче