Use pytest parametrize cross product (#1504)

* Use pytest parametrize cross product

* Bug fixes
This commit is contained in:
Adam J. Stewart 2023-08-06 18:51:11 -05:00 коммит произвёл GitHub
Родитель 5b189ecfd2
Коммит ef5c8de28b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 12 добавлений и 11 удалений

Просмотреть файл

@ -1,8 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import itertools
import pytest
import torch
import torch.nn as nn
@ -56,9 +54,10 @@ class TestChangeStar:
ChangeStarFarSeg(classes=4, backbone="anynet", backbone_pretrained=False)
@torch.no_grad()
@pytest.mark.parametrize(
"inc,innerc,nc,sf", list(itertools.product(IN_CHANNELS, INNNR_CHANNELS, NC, SF))
)
@pytest.mark.parametrize("inc", IN_CHANNELS)
@pytest.mark.parametrize("innerc", INNNR_CHANNELS)
@pytest.mark.parametrize("nc", NC)
@pytest.mark.parametrize("sf", SF)
def test_changemixin_output_size(
self, inc: int, innerc: int, nc: int, sf: int
) -> None:

Просмотреть файл

@ -1,8 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import itertools
import pytest
import torch
@ -15,7 +13,8 @@ CLASSES = [1, 2]
class TestFCSiamConc:
@torch.no_grad()
@pytest.mark.parametrize("b, c", list(itertools.product(BATCH_SIZE, CHANNELS)))
@pytest.mark.parametrize("b", BATCH_SIZE)
@pytest.mark.parametrize("c", CHANNELS)
def test_in_channels(self, b: int, c: int) -> None:
classes = 2
t, h, w = 2, 64, 64
@ -25,7 +24,8 @@ class TestFCSiamConc:
assert y.shape == (b, classes, h, w)
@torch.no_grad()
@pytest.mark.parametrize("b, classes", list(itertools.product(BATCH_SIZE, CLASSES)))
@pytest.mark.parametrize("b", BATCH_SIZE)
@pytest.mark.parametrize("classes", CLASSES)
def test_classes(self, b: int, classes: int) -> None:
t, c, h, w = 2, 3, 64, 64
model = FCSiamConc(in_channels=3, classes=classes, encoder_weights=None)
@ -36,7 +36,8 @@ class TestFCSiamConc:
class TestFCSiamDiff:
@torch.no_grad()
@pytest.mark.parametrize("b, c", list(itertools.product(BATCH_SIZE, CHANNELS)))
@pytest.mark.parametrize("b", BATCH_SIZE)
@pytest.mark.parametrize("c", CHANNELS)
def test_in_channels(self, b: int, c: int) -> None:
classes = 2
t, h, w = 2, 64, 64
@ -46,7 +47,8 @@ class TestFCSiamDiff:
assert y.shape == (b, classes, h, w)
@torch.no_grad()
@pytest.mark.parametrize("b, classes", list(itertools.product(BATCH_SIZE, CLASSES)))
@pytest.mark.parametrize("b", BATCH_SIZE)
@pytest.mark.parametrize("classes", CLASSES)
def test_classes(self, b: int, classes: int) -> None:
t, c, h, w = 2, 3, 64, 64
model = FCSiamDiff(in_channels=3, classes=classes, encoder_weights=None)