torchgeo/tests/datamodules/test_geo.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

268 строки
10 KiB
Python
Исходник Постоянная ссылка Обычный вид История

DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
2023-04-16 04:27:51 +03:00
from typing import Any
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
import matplotlib.pyplot as plt
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
import pytest
import torch
from _pytest.fixtures import SubRequest
from lightning.pytorch import Trainer
from matplotlib.figure import Figure
from rasterio.crs import CRS
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
from torch import Tensor
from torchgeo.datamodules import (
GeoDataModule,
MisconfigurationException,
NonGeoDataModule,
)
from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset
from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler
class CustomGeoDataset(GeoDataset):
def __init__(
self, split: str = 'train', length: int = 1, download: bool = False
) -> None:
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
super().__init__()
for i in range(length):
self.index.insert(i, (0, 1, 2, 3, 4, 5))
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
self.res = 1
2023-04-16 04:27:51 +03:00
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}
def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
class CustomGeoDataModule(GeoDataModule):
def __init__(self) -> None:
super().__init__(CustomGeoDataset, 1, 1, 1, 0, download=True)
class SamplerGeoDataModule(CustomGeoDataModule):
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def setup(self, stage: str) -> None:
self.dataset = CustomGeoDataset()
self.train_sampler = RandomGeoSampler(self.dataset, 1, 1)
self.val_sampler = RandomGeoSampler(self.dataset, 1, 1)
self.test_sampler = RandomGeoSampler(self.dataset, 1, 1)
self.predict_sampler = RandomGeoSampler(self.dataset, 1, 1)
class BatchSamplerGeoDataModule(CustomGeoDataModule):
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def setup(self, stage: str) -> None:
self.dataset = CustomGeoDataset()
self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
self.val_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
self.test_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
self.predict_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1)
class CustomNonGeoDataset(NonGeoDataset):
def __init__(
self, split: str = 'train', length: int = 1, download: bool = False
) -> None:
self.length = length
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
2023-04-16 04:27:51 +03:00
def __getitem__(self, index: int) -> dict[str, Tensor]:
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)}
def __len__(self) -> int:
return self.length
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
class CustomNonGeoDataModule(NonGeoDataModule):
def __init__(self) -> None:
super().__init__(CustomNonGeoDataset, 1, 0, download=True)
def setup(self, stage: str) -> None:
super().setup(stage)
if stage in ['predict']:
self.predict_dataset = CustomNonGeoDataset()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
class TestGeoDataModule:
@pytest.fixture(params=[SamplerGeoDataModule, BatchSamplerGeoDataModule])
def datamodule(self, request: SubRequest) -> CustomGeoDataModule:
dm: CustomGeoDataModule = request.param()
dm.trainer = Trainer(accelerator='cpu', max_epochs=1)
return dm
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
@pytest.mark.parametrize('stage', ['fit', 'validate', 'test'])
def test_setup(self, stage: str) -> None:
dm = CustomGeoDataModule()
dm.prepare_data()
dm.setup(stage)
def test_train(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup('fit')
if datamodule.trainer:
datamodule.trainer.training = True
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_val(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup('validate')
if datamodule.trainer:
datamodule.trainer.validating = True
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_test(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup('test')
if datamodule.trainer:
datamodule.trainer.testing = True
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_predict(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup('predict')
if datamodule.trainer:
datamodule.trainer.predicting = True
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_plot(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup('validate')
datamodule.plot()
plt.close()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def test_no_datasets(self) -> None:
dm = CustomGeoDataModule()
msg = r'CustomGeoDataModule\.setup must define one of '
msg += r"\('{0}_dataset', 'dataset'\)\."
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
with pytest.raises(MisconfigurationException, match=msg.format('train')):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format('val')):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format('test')):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format('predict')):
dm.predict_dataloader()
def test_no_samplers(self) -> None:
dm = CustomGeoDataModule()
dm.dataset = CustomGeoDataset()
msg = r'CustomGeoDataModule\.setup must define one of '
msg += r"\('{0}_batch_sampler', '{0}_sampler', 'batch_sampler', 'sampler'\)\."
with pytest.raises(MisconfigurationException, match=msg.format('train')):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format('val')):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format('test')):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format('predict')):
dm.predict_dataloader()
def test_zero_length_dataset(self) -> None:
dm = CustomGeoDataModule()
dm.dataset = CustomGeoDataset(length=0)
msg = r'CustomGeoDataModule\.dataset has length 0.'
with pytest.raises(MisconfigurationException, match=msg):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.predict_dataloader()
def test_zero_length_sampler(self) -> None:
dm = CustomGeoDataModule()
dm.dataset = CustomGeoDataset()
dm.sampler = RandomGeoSampler(dm.dataset, 1, 1)
dm.sampler.length = 0
msg = r'CustomGeoDataModule\.sampler has length 0.'
with pytest.raises(MisconfigurationException, match=msg):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.predict_dataloader()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
class TestNonGeoDataModule:
@pytest.fixture
def datamodule(self) -> CustomNonGeoDataModule:
dm = CustomNonGeoDataModule()
dm.trainer = Trainer(accelerator='cpu', max_epochs=1)
return dm
@pytest.mark.parametrize('stage', ['fit', 'validate', 'test', 'predict'])
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def test_setup(self, stage: str) -> None:
dm = CustomNonGeoDataModule()
dm.prepare_data()
dm.setup(stage)
def test_train(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup('fit')
if datamodule.trainer:
datamodule.trainer.training = True
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_val(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup('validate')
if datamodule.trainer:
datamodule.trainer.validating = True
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_test(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup('test')
if datamodule.trainer:
datamodule.trainer.testing = True
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup('predict')
if datamodule.trainer:
datamodule.trainer.predicting = True
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
def test_plot(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup('validate')
datamodule.plot()
plt.close()
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
def test_no_datasets(self) -> None:
dm = CustomNonGeoDataModule()
msg = r'CustomNonGeoDataModule\.setup must define one of '
msg += r"\('{0}_dataset', 'dataset'\)\."
DataModules: run all data augmentation on the GPU (#992) * DataModules: run all data augmentation on the GPU * Passing tests * Update BigEarthNet * Break ChesapeakeCVPR * Update COWC * Update Cyclone * Update ETCI2021 * mypy fixes * Update FAIR1M * Update Inria * Update LandCoverAI * Update LoveDA * Update NAIP * Update NASA * Update OSCD * Update RESISC45 * Update SEN12MS * Update So2Sat * Update SpaceNet * Update UCMerced * Update USAVars * Update xview * Remove seed * mypy fixes * OSCD hacks * Add NonGeoDataModule base class * Fixes * Add base class to docs * mypy fixes * Fix several tests * Fix Normalize * Syntax error * Fix bigearthnet * Fix dtype * Consistent kornia import * Get regression datasets working * Fix detection tests * Fix some chesapeake bugs * Fix several segmentation issues * isort fixes * Undo breaking change * Remove more code duplication, standardize docstrings * mypy fixes * Add default augmentation * Augmentations can be any callable * Fix datasets tests * Fix datamodule tests * Fix more datamodules * Typo fix * Set up val_dataset even when fit * Fix classification tests * Fix ETCI2021 * Fix SEN12MS * Add GeoDataModule base class * Fix several chesapeake bugs * Fix dtype and shape * Fix crs/bbox issue * Fix test dtype * Fix unequal size stacking error * flake8 fix * Better checks on sampler * Fix bug introduced in NAIP dm * Fix chesapeake dimensions * Add one to mask * Fix missing imports * Fix batch size * Simplify augmentations * Don't run test or predict without datasets * Fix tests * Allow shared dataset * One more try * Fix typo * Fix another typo * Fix Chesapeake dimensions * Apply augmentations during sanity check too * Don't reuse fixtures * Increase coverage * Fix ETCI tests * Test predict_step * Test all loss methods * Simplify validation plotting * Document new classes * Fix plotting * Plotting should be robust in case dataset does not contain RGB bands * Fix flake8 * 100% coverage of trainers * Add lightning-lite dependency * Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. * Define our own MisconfigurationException * Properly test new data module base classes * Fix mistake in setup call * ExtractTensorPatches runs into OOM errors * Test both fast_dev_run True and False * Fix plot methods * Fix OSCD tests * Fix bug with inconsistent train/val/test splits between stages * Fix issues with images of different sizes * Fix OSCD tests * Fix OSCD tests * Bad rebase * No trainer for OSCD so no need for config * Bad rebase * plot: only works during validation * Fix collation of NASA Marine Debris dataset * flake8 fix * Quick test * Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. * 56 workers is a bit excessive Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
2023-01-24 01:08:17 +03:00
with pytest.raises(MisconfigurationException, match=msg.format('train')):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format('val')):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format('test')):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg.format('predict')):
dm.predict_dataloader()
def test_zero_length_dataset(self) -> None:
dm = CustomNonGeoDataModule()
dm.dataset = CustomNonGeoDataset(length=0)
msg = r'CustomNonGeoDataModule\.dataset has length 0.'
with pytest.raises(MisconfigurationException, match=msg):
dm.train_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.val_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.test_dataloader()
with pytest.raises(MisconfigurationException, match=msg):
dm.predict_dataloader()