2023-01-24 01:08:17 +03:00
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
|
|
from typing import Any, Dict
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
import matplotlib.pyplot as plt
|
2023-01-24 01:08:17 +03:00
|
|
|
import pytest
|
|
|
|
import torch
|
2023-03-17 22:20:25 +03:00
|
|
|
from _pytest.fixtures import SubRequest
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from rasterio.crs import CRS
|
2023-01-24 01:08:17 +03:00
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
from torchgeo.datamodules import (
|
|
|
|
GeoDataModule,
|
|
|
|
MisconfigurationException,
|
|
|
|
NonGeoDataModule,
|
|
|
|
)
|
|
|
|
from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset
|
|
|
|
from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler
|
|
|
|
|
|
|
|
|
|
|
|
class CustomGeoDataset(GeoDataset):
|
|
|
|
def __init__(self, split: str = "train", download: bool = False) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.index.insert(0, (0, 1, 2, 3, 4, 5))
|
|
|
|
self.res = 1
|
|
|
|
|
|
|
|
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
2023-03-17 22:20:25 +03:00
|
|
|
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
|
|
|
|
return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query}
|
|
|
|
|
|
|
|
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
|
|
|
|
return plt.figure()
|
2023-01-24 01:08:17 +03:00
|
|
|
|
|
|
|
|
|
|
|
class CustomGeoDataModule(GeoDataModule):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__(CustomGeoDataset, 1, 1, 1, 0, download=True)
|
|
|
|
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
class SamplerGeoDataModule(CustomGeoDataModule):
|
2023-01-24 01:08:17 +03:00
|
|
|
def setup(self, stage: str) -> None:
|
|
|
|
self.dataset = CustomGeoDataset()
|
|
|
|
self.train_sampler = RandomGeoSampler(self.dataset, 1, 1)
|
|
|
|
self.val_sampler = RandomGeoSampler(self.dataset, 1, 1)
|
|
|
|
self.test_sampler = RandomGeoSampler(self.dataset, 1, 1)
|
|
|
|
self.predict_sampler = RandomGeoSampler(self.dataset, 1, 1)
|
|
|
|
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
class BatchSamplerGeoDataModule(CustomGeoDataModule):
|
2023-01-24 01:08:17 +03:00
|
|
|
def setup(self, stage: str) -> None:
|
|
|
|
self.dataset = CustomGeoDataset()
|
|
|
|
self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
|
|
|
|
self.val_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
|
|
|
|
self.test_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
|
|
|
|
self.predict_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
class CustomNonGeoDataset(NonGeoDataset):
|
|
|
|
def __init__(self, split: str = "train", download: bool = False) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
|
|
|
return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)}
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
return 1
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
|
|
|
|
return plt.figure()
|
|
|
|
|
2023-01-24 01:08:17 +03:00
|
|
|
|
|
|
|
class CustomNonGeoDataModule(NonGeoDataModule):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__(CustomNonGeoDataset, 1, 0, download=True)
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
def setup(self, stage: str) -> None:
|
|
|
|
super().setup(stage)
|
|
|
|
|
|
|
|
if stage in ["predict"]:
|
|
|
|
self.predict_dataset = CustomNonGeoDataset()
|
|
|
|
|
2023-01-24 01:08:17 +03:00
|
|
|
|
|
|
|
class TestGeoDataModule:
|
2023-03-17 22:20:25 +03:00
|
|
|
@pytest.fixture(params=[SamplerGeoDataModule, BatchSamplerGeoDataModule])
|
|
|
|
def datamodule(self, request: SubRequest) -> CustomGeoDataModule:
|
|
|
|
dm: CustomGeoDataModule = request.param()
|
|
|
|
dm.trainer = Trainer(max_epochs=1)
|
|
|
|
return dm
|
|
|
|
|
2023-01-24 01:08:17 +03:00
|
|
|
@pytest.mark.parametrize("stage", ["fit", "validate", "test"])
|
|
|
|
def test_setup(self, stage: str) -> None:
|
|
|
|
dm = CustomGeoDataModule()
|
|
|
|
dm.prepare_data()
|
|
|
|
dm.setup(stage)
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
def test_train(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup("fit")
|
|
|
|
datamodule.trainer.training = True # type: ignore[union-attr]
|
|
|
|
batch = next(iter(datamodule.train_dataloader()))
|
|
|
|
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_val(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup("validate")
|
|
|
|
datamodule.trainer.validating = True # type: ignore[union-attr]
|
|
|
|
batch = next(iter(datamodule.val_dataloader()))
|
|
|
|
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_test(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup("test")
|
|
|
|
datamodule.trainer.testing = True # type: ignore[union-attr]
|
|
|
|
batch = next(iter(datamodule.test_dataloader()))
|
|
|
|
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_predict(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup("predict")
|
|
|
|
datamodule.trainer.predicting = True # type: ignore[union-attr]
|
|
|
|
batch = next(iter(datamodule.predict_dataloader()))
|
|
|
|
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_plot(self, datamodule: CustomGeoDataModule) -> None:
|
|
|
|
datamodule.setup("validate")
|
|
|
|
datamodule.plot()
|
|
|
|
plt.close()
|
2023-01-24 01:08:17 +03:00
|
|
|
|
|
|
|
def test_no_datasets(self) -> None:
|
|
|
|
dm = CustomGeoDataModule()
|
|
|
|
msg = "CustomGeoDataModule.setup does not define a '{}_dataset'"
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format("train")):
|
|
|
|
dm.train_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format("val")):
|
|
|
|
dm.val_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format("test")):
|
|
|
|
dm.test_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format("predict")):
|
|
|
|
dm.predict_dataloader()
|
|
|
|
|
|
|
|
|
|
|
|
class TestNonGeoDataModule:
|
2023-03-17 22:20:25 +03:00
|
|
|
@pytest.fixture
|
|
|
|
def datamodule(self) -> CustomNonGeoDataModule:
|
|
|
|
dm = CustomNonGeoDataModule()
|
|
|
|
dm.trainer = Trainer(max_epochs=1)
|
|
|
|
return dm
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("stage", ["fit", "validate", "test", "predict"])
|
2023-01-24 01:08:17 +03:00
|
|
|
def test_setup(self, stage: str) -> None:
|
|
|
|
dm = CustomNonGeoDataModule()
|
|
|
|
dm.prepare_data()
|
|
|
|
dm.setup(stage)
|
|
|
|
|
2023-03-17 22:20:25 +03:00
|
|
|
def test_train(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup("fit")
|
|
|
|
datamodule.trainer.training = True # type: ignore[union-attr]
|
|
|
|
batch = next(iter(datamodule.train_dataloader()))
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_val(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup("validate")
|
|
|
|
datamodule.trainer.validating = True # type: ignore[union-attr]
|
|
|
|
batch = next(iter(datamodule.val_dataloader()))
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_test(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup("test")
|
|
|
|
datamodule.trainer.testing = True # type: ignore[union-attr]
|
|
|
|
batch = next(iter(datamodule.test_dataloader()))
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup("predict")
|
|
|
|
datamodule.trainer.predicting = True # type: ignore[union-attr]
|
|
|
|
batch = next(iter(datamodule.predict_dataloader()))
|
|
|
|
batch = datamodule.on_after_batch_transfer(batch, 0)
|
|
|
|
|
|
|
|
def test_plot(self, datamodule: CustomNonGeoDataModule) -> None:
|
|
|
|
datamodule.setup("validate")
|
|
|
|
datamodule.plot()
|
|
|
|
plt.close()
|
|
|
|
|
2023-01-24 01:08:17 +03:00
|
|
|
def test_no_datasets(self) -> None:
|
|
|
|
dm = CustomNonGeoDataModule()
|
|
|
|
msg = "CustomNonGeoDataModule.setup does not define a '{}_dataset'"
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format("train")):
|
|
|
|
dm.train_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format("val")):
|
|
|
|
dm.val_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format("test")):
|
|
|
|
dm.test_dataloader()
|
|
|
|
with pytest.raises(MisconfigurationException, match=msg.format("predict")):
|
|
|
|
dm.predict_dataloader()
|