зеркало из https://github.com/microsoft/torchgeo.git
51 строка
1.7 KiB
Python
51 строка
1.7 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import os
|
|
|
|
import matplotlib.pyplot as plt
|
|
import pytest
|
|
from _pytest.fixtures import SubRequest
|
|
|
|
from torchgeo.datamodules import USAVarsDataModule
|
|
from torchgeo.datasets import unbind_samples
|
|
|
|
|
|
class TestUSAVarsDataModule:
|
|
@pytest.fixture
|
|
def datamodule(self, request: SubRequest) -> USAVarsDataModule:
|
|
root = os.path.join('tests', 'data', 'usavars')
|
|
batch_size = 1
|
|
num_workers = 0
|
|
|
|
dm = USAVarsDataModule(
|
|
root=root, batch_size=batch_size, num_workers=num_workers, download=True
|
|
)
|
|
dm.prepare_data()
|
|
return dm
|
|
|
|
def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None:
|
|
datamodule.setup('fit')
|
|
assert len(datamodule.train_dataloader()) == 3
|
|
batch = next(iter(datamodule.train_dataloader()))
|
|
assert batch['image'].shape[0] == datamodule.batch_size
|
|
|
|
def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None:
|
|
datamodule.setup('validate')
|
|
assert len(datamodule.val_dataloader()) == 2
|
|
batch = next(iter(datamodule.val_dataloader()))
|
|
assert batch['image'].shape[0] == datamodule.batch_size
|
|
|
|
def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None:
|
|
datamodule.setup('test')
|
|
assert len(datamodule.test_dataloader()) == 1
|
|
batch = next(iter(datamodule.test_dataloader()))
|
|
assert batch['image'].shape[0] == datamodule.batch_size
|
|
|
|
def test_plot(self, datamodule: USAVarsDataModule) -> None:
|
|
datamodule.setup('validate')
|
|
batch = next(iter(datamodule.val_dataloader()))
|
|
sample = unbind_samples(batch)[0]
|
|
datamodule.plot(sample)
|
|
plt.close()
|