зеркало из https://github.com/microsoft/archai.git
chore(tests): Adds nlp.objectives tests.
This commit is contained in:
Родитель
725b857776
Коммит
96708cf379
|
@ -7,7 +7,8 @@ addopts=-vv --durations=10
|
|||
# Do not run tests in the build folder
|
||||
norecursedirs=build
|
||||
|
||||
# Deprecated warnings to be ignored
|
||||
# Warnings to be ignored
|
||||
filterwarnings=
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::torch.jit.TracerWarning
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import pytest
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from archai.discrete_search import ArchaiModel
|
||||
from archai.nlp.objectives.parameters import NonEmbeddingParamsProxy, TotalParamsProxy
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model():
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embd = nn.Embedding(10, 10)
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embd(x)
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2(x), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
return ArchaiModel(Model(), archid="test")
|
||||
|
||||
|
||||
def test_total_params_proxy(model):
|
||||
# Assert that the number of trainable parameters is correct
|
||||
proxy = TotalParamsProxy(trainable_only=True)
|
||||
num_params = proxy.evaluate(model, None)
|
||||
assert num_params == sum(param.numel() for param in model.arch.parameters() if param.requires_grad)
|
||||
|
||||
# Assert that the number of all parameters is correct
|
||||
proxy = TotalParamsProxy(trainable_only=False)
|
||||
num_params = proxy.evaluate(model, None)
|
||||
assert num_params == sum(param.numel() for param in model.arch.parameters())
|
||||
|
||||
|
||||
def test_non_embedding_params_proxy(model):
|
||||
# Assert that the number of non-embedding trainable parameters is correct
|
||||
proxy = NonEmbeddingParamsProxy(trainable_only=True)
|
||||
non_embedding_params = proxy.evaluate(model, None)
|
||||
embedding_params = sum(param.numel() for param in model.arch.embd.parameters() if param.requires_grad)
|
||||
assert non_embedding_params + embedding_params == sum(
|
||||
param.numel() for param in model.arch.parameters() if param.requires_grad
|
||||
)
|
||||
|
||||
# Assert that the number of non-embedding parameters is correct
|
||||
proxy = NonEmbeddingParamsProxy(trainable_only=False)
|
||||
non_embedding_params = proxy.evaluate(model, None)
|
||||
embedding_params = sum(param.numel() for param in model.arch.embd.parameters())
|
||||
assert non_embedding_params + embedding_params == sum(param.numel() for param in model.arch.parameters())
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import pytest
|
||||
|
||||
from archai.nlp.objectives.transformer_flex_latency import TransformerFlexOnnxLatency
|
||||
from archai.nlp.search_spaces.transformer_flex.search_space import (
|
||||
TransformerFlexSearchSpace,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_space():
|
||||
return TransformerFlexSearchSpace("gpt2")
|
||||
|
||||
|
||||
def test_transformer_flex_onnx_latency(search_space):
|
||||
arch = search_space.random_sample()
|
||||
objective = TransformerFlexOnnxLatency(search_space)
|
||||
|
||||
# Assert that the returned latency is valid
|
||||
latency = objective.evaluate(arch, None)
|
||||
assert latency > 0.0
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import pytest
|
||||
|
||||
from archai.nlp.objectives.transformer_flex_memory import TransformerFlexOnnxMemory
|
||||
from archai.nlp.search_spaces.transformer_flex.search_space import (
|
||||
TransformerFlexSearchSpace,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_space():
|
||||
return TransformerFlexSearchSpace("gpt2")
|
||||
|
||||
|
||||
def test_transformer_flex_onnx_memory(search_space):
|
||||
arch = search_space.random_sample()
|
||||
objective = TransformerFlexOnnxMemory(search_space)
|
||||
|
||||
# Assert that the returned memory is valid
|
||||
memory = objective.evaluate(arch, None)
|
||||
assert memory > 0.0
|
Загрузка…
Ссылка в новой задаче