зеркало из https://github.com/microsoft/torchgeo.git
Use pytest parametrize cross product (#1504)
* Use pytest parametrize cross product * Bug fixes
This commit is contained in:
Родитель
5b189ecfd2
Коммит
ef5c8de28b
|
@ -1,8 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -56,9 +54,10 @@ class TestChangeStar:
|
|||
ChangeStarFarSeg(classes=4, backbone="anynet", backbone_pretrained=False)
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(
|
||||
"inc,innerc,nc,sf", list(itertools.product(IN_CHANNELS, INNNR_CHANNELS, NC, SF))
|
||||
)
|
||||
@pytest.mark.parametrize("inc", IN_CHANNELS)
|
||||
@pytest.mark.parametrize("innerc", INNNR_CHANNELS)
|
||||
@pytest.mark.parametrize("nc", NC)
|
||||
@pytest.mark.parametrize("sf", SF)
|
||||
def test_changemixin_output_size(
|
||||
self, inc: int, innerc: int, nc: int, sf: int
|
||||
) -> None:
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -15,7 +13,8 @@ CLASSES = [1, 2]
|
|||
|
||||
class TestFCSiamConc:
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("b, c", list(itertools.product(BATCH_SIZE, CHANNELS)))
|
||||
@pytest.mark.parametrize("b", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("c", CHANNELS)
|
||||
def test_in_channels(self, b: int, c: int) -> None:
|
||||
classes = 2
|
||||
t, h, w = 2, 64, 64
|
||||
|
@ -25,7 +24,8 @@ class TestFCSiamConc:
|
|||
assert y.shape == (b, classes, h, w)
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("b, classes", list(itertools.product(BATCH_SIZE, CLASSES)))
|
||||
@pytest.mark.parametrize("b", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("classes", CLASSES)
|
||||
def test_classes(self, b: int, classes: int) -> None:
|
||||
t, c, h, w = 2, 3, 64, 64
|
||||
model = FCSiamConc(in_channels=3, classes=classes, encoder_weights=None)
|
||||
|
@ -36,7 +36,8 @@ class TestFCSiamConc:
|
|||
|
||||
class TestFCSiamDiff:
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("b, c", list(itertools.product(BATCH_SIZE, CHANNELS)))
|
||||
@pytest.mark.parametrize("b", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("c", CHANNELS)
|
||||
def test_in_channels(self, b: int, c: int) -> None:
|
||||
classes = 2
|
||||
t, h, w = 2, 64, 64
|
||||
|
@ -46,7 +47,8 @@ class TestFCSiamDiff:
|
|||
assert y.shape == (b, classes, h, w)
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("b, classes", list(itertools.product(BATCH_SIZE, CLASSES)))
|
||||
@pytest.mark.parametrize("b", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("classes", CLASSES)
|
||||
def test_classes(self, b: int, classes: int) -> None:
|
||||
t, c, h, w = 2, 3, 64, 64
|
||||
model = FCSiamDiff(in_channels=3, classes=classes, encoder_weights=None)
|
||||
|
|
Загрузка…
Ссылка в новой задаче