зеркало из https://github.com/microsoft/torchgeo.git
Tests: downloaded weights have different number of classes (#1229)
This commit is contained in:
Родитель
80d78e5387
Коммит
674bc92dc9
|
@ -172,7 +172,8 @@ class TestClassificationTask:
|
|||
model_kwargs["model"] = weights.meta["model"]
|
||||
model_kwargs["in_channels"] = weights.meta["in_chans"]
|
||||
model_kwargs["weights"] = weights
|
||||
ClassificationTask(**model_kwargs)
|
||||
with pytest.warns(UserWarning):
|
||||
ClassificationTask(**model_kwargs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_weight_str_download(
|
||||
|
@ -181,7 +182,8 @@ class TestClassificationTask:
|
|||
model_kwargs["model"] = weights.meta["model"]
|
||||
model_kwargs["in_channels"] = weights.meta["in_chans"]
|
||||
model_kwargs["weights"] = str(weights)
|
||||
ClassificationTask(**model_kwargs)
|
||||
with pytest.warns(UserWarning):
|
||||
ClassificationTask(**model_kwargs)
|
||||
|
||||
def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
|
||||
model_kwargs["loss"] = "invalid_loss"
|
||||
|
|
|
@ -146,7 +146,8 @@ class TestRegressionTask:
|
|||
model_kwargs["model"] = weights.meta["model"]
|
||||
model_kwargs["in_channels"] = weights.meta["in_chans"]
|
||||
model_kwargs["weights"] = weights
|
||||
RegressionTask(**model_kwargs)
|
||||
with pytest.warns(UserWarning):
|
||||
RegressionTask(**model_kwargs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_weight_str_download(
|
||||
|
@ -155,7 +156,8 @@ class TestRegressionTask:
|
|||
model_kwargs["model"] = weights.meta["model"]
|
||||
model_kwargs["in_channels"] = weights.meta["in_chans"]
|
||||
model_kwargs["weights"] = str(weights)
|
||||
RegressionTask(**model_kwargs)
|
||||
with pytest.warns(UserWarning):
|
||||
RegressionTask(**model_kwargs)
|
||||
|
||||
def test_no_rgb(
|
||||
self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool
|
||||
|
|
Загрузка…
Ссылка в новой задаче