torchgeo/tests/datamodules/test_geo.py

197 строки
7.7 KiB
Python
Исходник Обычный вид История

DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from typing import Any, Dict
import matplotlib.pyplot as plt
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
import pytest
import torch
from _pytest.fixtures import SubRequest
from pytorch_lightning import Trainer
from rasterio.crs import CRS
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
from torch import Tensor
from torchgeo.datamodules import (
GeoDataModule,
MisconfigurationException,
NonGeoDataModule,
)
from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset
from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler
class CustomGeoDataset(GeoDataset):
def __init__(self, split: str = "train", download: bool = False) -> None:
super().__init__()
self.index.insert(0, (0, 1, 2, 3, 4, 5))
self.res = 1
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query}
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
return plt.figure()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
class CustomGeoDataModule(GeoDataModule):
def __init__(self) -> None:
super().__init__(CustomGeoDataset, 1, 1, 1, 0, download=True)
class SamplerGeoDataModule(CustomGeoDataModule):
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def setup(self, stage: str) -> None:
self.dataset = CustomGeoDataset()
self.train_sampler = RandomGeoSampler(self.dataset, 1, 1)
self.val_sampler = RandomGeoSampler(self.dataset, 1, 1)
self.test_sampler = RandomGeoSampler(self.dataset, 1, 1)
self.predict_sampler = RandomGeoSampler(self.dataset, 1, 1)
class BatchSamplerGeoDataModule(CustomGeoDataModule):
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def setup(self, stage: str) -> None:
self.dataset = CustomGeoDataset()
self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
self.val_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
self.test_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
self.predict_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
class CustomNonGeoDataset(NonGeoDataset):
def __init__(self, split: str = "train", download: bool = False) -> None:
pass
def __getitem__(self, index: int) -> Dict[str, Tensor]:
return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)}
def __len__(self) -> int:
return 1
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
return plt.figure()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
class CustomNonGeoDataModule(NonGeoDataModule):
def __init__(self) -> None:
super().__init__(CustomNonGeoDataset, 1, 0, download=True)
def setup(self, stage: str) -> None:
super().setup(stage)
if stage in ["predict"]:
self.predict_dataset = CustomNonGeoDataset()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
class TestGeoDataModule:
@pytest.fixture(params=[SamplerGeoDataModule, BatchSamplerGeoDataModule])
def datamodule(self, request: SubRequest) -> CustomGeoDataModule:
dm: CustomGeoDataModule = request.param()
dm.trainer = Trainer(max_epochs=1)
return dm
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
@pytest.mark.parametrize("stage", ["fit", "validate", "test"])
def test_setup(self, stage: str) -> None:
dm = CustomGeoDataModule()
dm.prepare_data()
dm.setup(stage)
def test_train(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("fit")
datamodule.trainer.training = True # type: ignore[union-attr]
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_val(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.trainer.validating = True # type: ignore[union-attr]
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_test(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("test")
datamodule.trainer.testing = True # type: ignore[union-attr]
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_predict(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("predict")
datamodule.trainer.predicting = True # type: ignore[union-attr]
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_plot(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.plot()
plt.close()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def test_no_datasets(self) -> None:
dm = CustomGeoDataModule()
msg = "CustomGeoDataModule.setup does not define a '{}_dataset'"
with pytest.raises(MisconfigurationException, match=msg.format("train")):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("val")):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("test")):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("predict")):
dm.predict_dataloader()
class TestNonGeoDataModule:
@pytest.fixture
def datamodule(self) -> CustomNonGeoDataModule:
dm = CustomNonGeoDataModule()
dm.trainer = Trainer(max_epochs=1)
return dm
@pytest.mark.parametrize("stage", ["fit", "validate", "test", "predict"])
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def test_setup(self, stage: str) -> None:
dm = CustomNonGeoDataModule()
dm.prepare_data()
dm.setup(stage)
def test_train(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("fit")
datamodule.trainer.training = True # type: ignore[union-attr]
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_val(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.trainer.validating = True # type: ignore[union-attr]
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_test(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("test")
datamodule.trainer.testing = True # type: ignore[union-attr]
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("predict")
datamodule.trainer.predicting = True # type: ignore[union-attr]
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_plot(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup("validate")
datamodule.plot()
plt.close()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def test_no_datasets(self) -> None:
dm = CustomNonGeoDataModule()
msg = "CustomNonGeoDataModule.setup does not define a '{}_dataset'"
with pytest.raises(MisconfigurationException, match=msg.format("train")):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("val")):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("test")):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format("predict")):
dm.predict_dataloader()