2023-01-24 01:08:17 +03:00
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
2023-04-16 04:27:51 +03:00
|
|
|
from typing import Any
|
2023-01-24 01:08:17 +03:00
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
import matplotlib.pyplot as plt
|
2023-01-24 01:08:17 +03:00
|
|
|
import pytest
|
|
|
|
import torch
|
2023-03-17 22:20:25 +03:00
|
|
|
from _pytest.fixtures import SubRequest
|
2023-03-18 07:37:16 +03:00
|
|
|
from lightning.pytorch import Trainer
|
2023-09-21 14:35:57 +03:00
|
|
|
from matplotlib.figure import Figure
|
2023-03-17 22:20:25 +03:00
|
|
|
from rasterio.crs import CRS
|
2023-01-24 01:08:17 +03:00
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
from torchgeo.datamodules import (
|
|
|
|
GeoDataModule,
|
|
|
|
MisconfigurationException,
|
|
|
|
NonGeoDataModule,
|
|
|
|
)
|
|
|
|
from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset
|
|
|
|
from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler
|
|
|
|
|
|
|
|
|
|
|
|
class CustomGeoDataset(GeoDataset):
|
2023-07-08 01:31:13 +03:00
|
|
|
def __init__(
|
|
|
|
self, split: str = 'train', length: int = 1, download: bool = False
|
|
|
|
) -> None:
|
2023-01-24 01:08:17 +03:00
|
|
|
super().__init__()
|
2023-07-08 01:31:13 +03:00
|
|
|
for i in range(length):
|
|
|
|
self.index.insert(i, (0, 1, 2, 3, 4, 5))
|
2023-01-24 01:08:17 +03:00
|
|
|
self.res = 1
|
|
|
|
|
2023-04-16 04:27:51 +03:00
|
|
|
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
|
2023-03-17 22:20:25 +03:00
|
|
|
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
|
2024-08-06 14:03:53 +03:00
|
|
|
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}
|
2023-03-17 22:20:25 +03:00
|
|
|
|
2023-09-21 14:35:57 +03:00
|
|
|
def plot(self, *args: Any, **kwargs: Any) -> Figure:
|
2023-03-17 22:20:25 +03:00
|
|
|
return plt.figure()
|
2023-01-24 01:08:17 +03:00
|
|
|
|
|
|
|
|
|
|
|
class CustomGeoDataModule(GeoDataModule):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__(CustomGeoDataset, 1, 1, 1, 0, download=True)
|
|
|
|
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
class SamplerGeoDataModule(CustomGeoDataModule):
|
2023-01-24 01:08:17 +03:00
|
|
|
def setup(self, stage: str) -> None:
|
|
|
|
self.dataset = CustomGeoDataset()
|
|
|
|
self.train_sampler = RandomGeoSampler(self.dataset, 1, 1)
|
|
|
|
self.val_sampler = RandomGeoSampler(self.dataset, 1, 1)
|
|
|
|
self.test_sampler = RandomGeoSampler(self.dataset, 1, 1)
|
|
|
|
self.predict_sampler = RandomGeoSampler(self.dataset, 1, 1)
|
|
|
|
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
class BatchSamplerGeoDataModule(CustomGeoDataModule):
|
2023-01-24 01:08:17 +03:00
|
|
|
def setup(self, stage: str) -> None:
|
|
|
|
self.dataset = CustomGeoDataset()
|
|
|
|
self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
|
|
|
|
self.val_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
|
|
|
|
self.test_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
|
|
|
|
self.predict_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
class CustomNonGeoDataset(NonGeoDataset):
|
2023-07-08 01:31:13 +03:00
|
|
|
def __init__(
|
|
|
|
self, split: str = 'train', length: int = 1, download: bool = False
|
|
|
|
) -> None:
|
|
|
|
self.length = length
|
2023-01-24 01:08:17 +03:00
|
|
|
|
2023-04-16 04:27:51 +03:00
|
|
|
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
2023-01-24 01:08:17 +03:00
|
|
|
return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)}
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
2023-07-08 01:31:13 +03:00
|
|
|
return self.length
|
2023-01-24 01:08:17 +03:00
|
|
|
|
2023-09-21 14:35:57 +03:00
|
|
|
def plot(self, *args: Any, **kwargs: Any) -> Figure:
|
2023-03-17 22:20:25 +03:00
|
|
|
return plt.figure()
|
|
|
|
|
2023-01-24 01:08:17 +03:00
|
|
|
|
|
|
|
class CustomNonGeoDataModule(NonGeoDataModule):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__(CustomNonGeoDataset, 1, 0, download=True)
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
def setup(self, stage: str) -> None:
|
|
|
|
super().setup(stage)
|
|
|
|
|
|
|
|
if stage in ['predict']:
|
|
|
|
self.predict_dataset = CustomNonGeoDataset()
|
|
|
|
|
2023-01-24 01:08:17 +03:00
|
|
|
|
|
|
|
class TestGeoDataModule:
|
2023-03-17 22:20:25 +03:00
|
|
|
@pytest.fixture(params=[SamplerGeoDataModule, BatchSamplerGeoDataModule])
|
|
|
|
def datamodule(self, request: SubRequest) -> CustomGeoDataModule:
|
|
|
|
dm: CustomGeoDataModule = request.param()
|
2023-03-22 15:26:45 +03:00
|
|
|
dm.trainer = Trainer(accelerator='cpu', max_epochs=1)
|
2023-03-17 22:20:25 +03:00
|
|
|
return dm
|
|
|
|
|
2023-01-24 01:08:17 +03:00
|
|
|
@pytest.mark.parametrize('stage', ['fit', 'validate', 'test'])
|
|
|
|
def test_setup(self, stage: str) -> None:
|
|
|
|
dm = CustomGeoDataModule()
|
|
|
|
dm.prepare_data()
|
|
|
|
dm.setup(stage)
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
def test_train(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup('fit')
|
2023-06-08 07:11:19 +03:00
|
|
|
if datamodule.trainer:
|
|
|
|
datamodule.trainer.training = True
|
2023-03-17 22:20:25 +03:00
|
|
|
batch = next(iter(datamodule.train_dataloader()))
|
|
|
|
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_val(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup('validate')
|
2023-06-08 07:11:19 +03:00
|
|
|
if datamodule.trainer:
|
|
|
|
datamodule.trainer.validating = True
|
2023-03-17 22:20:25 +03:00
|
|
|
batch = next(iter(datamodule.val_dataloader()))
|
|
|
|
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_test(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup('test')
|
2023-06-08 07:11:19 +03:00
|
|
|
if datamodule.trainer:
|
|
|
|
datamodule.trainer.testing = True
|
2023-03-17 22:20:25 +03:00
|
|
|
batch = next(iter(datamodule.test_dataloader()))
|
|
|
|
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_predict(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup('predict')
|
2023-06-08 07:11:19 +03:00
|
|
|
if datamodule.trainer:
|
|
|
|
datamodule.trainer.predicting = True
|
2023-03-17 22:20:25 +03:00
|
|
|
batch = next(iter(datamodule.predict_dataloader()))
|
|
|
|
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_plot(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup('validate')
|
|
|
|
datamodule.plot()
|
|
|
|
plt.close()
|
2023-01-24 01:08:17 +03:00
|
|
|
|
|
|
|
def test_no_datasets(self) -> None:
|
|
|
|
dm = CustomGeoDataModule()
|
2023-07-08 01:31:13 +03:00
|
|
|
msg = r'CustomGeoDataModule\.setup must define one of '
|
|
|
|
msg += r"\('{0}_dataset', 'dataset'\)\."
|
2023-01-24 01:08:17 +03:00
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('train')):
|
|
|
|
dm.train_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('val')):
|
|
|
|
dm.val_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('test')):
|
|
|
|
dm.test_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('predict')):
|
|
|
|
dm.predict_dataloader()
|
|
|
|
|
2023-07-08 01:31:13 +03:00
|
|
|
def test_no_samplers(self) -> None:
|
|
|
|
dm = CustomGeoDataModule()
|
|
|
|
dm.dataset = CustomGeoDataset()
|
|
|
|
msg = r'CustomGeoDataModule\.setup must define one of '
|
|
|
|
msg += r"\('{0}_batch_sampler', '{0}_sampler', 'batch_sampler', 'sampler'\)\."
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('train')):
|
|
|
|
dm.train_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('val')):
|
|
|
|
dm.val_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('test')):
|
|
|
|
dm.test_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('predict')):
|
|
|
|
dm.predict_dataloader()
|
|
|
|
|
|
|
|
def test_zero_length_dataset(self) -> None:
|
|
|
|
dm = CustomGeoDataModule()
|
|
|
|
dm.dataset = CustomGeoDataset(length=0)
|
|
|
|
msg = r'CustomGeoDataModule\.dataset has length 0.'
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.train_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.val_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.test_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.predict_dataloader()
|
|
|
|
|
|
|
|
def test_zero_length_sampler(self) -> None:
|
|
|
|
dm = CustomGeoDataModule()
|
|
|
|
dm.dataset = CustomGeoDataset()
|
|
|
|
dm.sampler = RandomGeoSampler(dm.dataset, 1, 1)
|
|
|
|
dm.sampler.length = 0
|
|
|
|
msg = r'CustomGeoDataModule\.sampler has length 0.'
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.train_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.val_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.test_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.predict_dataloader()
|
|
|
|
|
2023-01-24 01:08:17 +03:00
|
|
|
|
|
|
|
class TestNonGeoDataModule:
|
2023-03-17 22:20:25 +03:00
|
|
|
@pytest.fixture
|
|
|
|
def datamodule(self) -> CustomNonGeoDataModule:
|
|
|
|
dm = CustomNonGeoDataModule()
|
2023-03-22 15:26:45 +03:00
|
|
|
dm.trainer = Trainer(accelerator='cpu', max_epochs=1)
|
2023-03-17 22:20:25 +03:00
|
|
|
return dm
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('stage', ['fit', 'validate', 'test', 'predict'])
|
2023-01-24 01:08:17 +03:00
|
|
|
def test_setup(self, stage: str) -> None:
|
|
|
|
dm = CustomNonGeoDataModule()
|
|
|
|
dm.prepare_data()
|
|
|
|
dm.setup(stage)
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
def test_train(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup('fit')
|
2023-06-08 07:11:19 +03:00
|
|
|
if datamodule.trainer:
|
|
|
|
datamodule.trainer.training = True
|
2023-03-17 22:20:25 +03:00
|
|
|
batch = next(iter(datamodule.train_dataloader()))
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_val(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup('validate')
|
2023-06-08 07:11:19 +03:00
|
|
|
if datamodule.trainer:
|
|
|
|
datamodule.trainer.validating = True
|
2023-03-17 22:20:25 +03:00
|
|
|
batch = next(iter(datamodule.val_dataloader()))
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_test(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup('test')
|
2023-06-08 07:11:19 +03:00
|
|
|
if datamodule.trainer:
|
|
|
|
datamodule.trainer.testing = True
|
2023-03-17 22:20:25 +03:00
|
|
|
batch = next(iter(datamodule.test_dataloader()))
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup('predict')
|
2023-06-08 07:11:19 +03:00
|
|
|
if datamodule.trainer:
|
|
|
|
datamodule.trainer.predicting = True
|
2023-03-17 22:20:25 +03:00
|
|
|
batch = next(iter(datamodule.predict_dataloader()))
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_plot(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup('validate')
|
|
|
|
datamodule.plot()
|
|
|
|
plt.close()
|
|
|
|
|
2023-01-24 01:08:17 +03:00
|
|
|
def test_no_datasets(self) -> None:
|
|
|
|
dm = CustomNonGeoDataModule()
|
2023-07-08 01:31:13 +03:00
|
|
|
msg = r'CustomNonGeoDataModule\.setup must define one of '
|
|
|
|
msg += r"\('{0}_dataset', 'dataset'\)\."
|
2023-01-24 01:08:17 +03:00
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('train')):
|
|
|
|
dm.train_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('val')):
|
|
|
|
dm.val_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('test')):
|
|
|
|
dm.test_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format('predict')):
|
|
|
|
dm.predict_dataloader()
|
2023-07-08 01:31:13 +03:00
|
|
|
|
|
|
|
def test_zero_length_dataset(self) -> None:
|
|
|
|
dm = CustomNonGeoDataModule()
|
|
|
|
dm.dataset = CustomNonGeoDataset(length=0)
|
|
|
|
msg = r'CustomNonGeoDataModule\.dataset has length 0.'
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.train_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.val_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.test_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg):
|
|
|
|
dm.predict_dataloader()
|