sentinel2nccm datamodule on new branch (#1950)

* sentinel2nccm datamodule

* Fixed style errors

* added 2019 to sentinel2, removed 2022 from nccm

* fixed error

* Use matching split size

---------

Co-authored-by: shreya28 <“shreya28@illinois.edu”>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
shreyakannan1205 2024-03-22 17:05:01 -05:00 коммит произвёл GitHub
Родитель 5a7b9e58bc
Коммит bd48efe988
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
98 изменённых файлов: 198 добавлений и 12 удалений

Просмотреть файл

@ -30,6 +30,7 @@ Sentinel
^^^^^^^^
.. autoclass:: Sentinel2CDLDataModule
.. autoclass:: Sentinel2NCCMDataModule
Non-geospatial DataModules
--------------------------

Просмотреть файл

@ -0,0 +1,18 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 13
num_classes: 5
num_filters: 1
ignore_index: 4
data:
class_path: Sentinel2NCCMDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
nccm_paths: "tests/data/nccm"
sentinel2_paths: "tests/data/sentinel2"

Двоичные данные
tests/data/nccm/CDL2017_clip.tif

Двоичный файл не отображается.

Двоичные данные
tests/data/nccm/CDL2018_clip1.tif

Двоичный файл не отображается.

Двоичные данные
tests/data/nccm/CDL2019_clip.tif

Двоичный файл не отображается.

Просмотреть файл

@ -11,7 +11,7 @@ import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine
SIZE = 32
SIZE = 128
np.random.seed(0)
files = ["CDL2017_clip.tif", "CDL2018_clip1.tif", "CDL2019_clip.tif"]
@ -23,15 +23,8 @@ def create_file(path: str, dtype: str):
"driver": "GTiff",
"dtype": dtype,
"count": 1,
"crs": CRS.from_epsg(4326),
"transform": Affine(
8.983152841195208e-05,
0.0,
115.483402043364,
0.0,
-8.983152841195208e-05,
53.531397320113605,
),
"crs": CRS.from_epsg(32616),
"transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0),
"height": SIZE,
"width": SIZE,
"compress": "lzw",

Просмотреть файл

@ -37,6 +37,20 @@ filenames: FILENAME_HIERARCHY = {
"T16TFM_20220412T162841_B12.jp2",
"T16TFM_20220412T162841_B8A.jp2",
"T16TFM_20220412T162841_TCI.jp2",
"T16TFM_20190412T162841_B01.jp2",
"T16TFM_20190412T162841_B02.jp2",
"T16TFM_20190412T162841_B03.jp2",
"T16TFM_20190412T162841_B04.jp2",
"T16TFM_20190412T162841_B05.jp2",
"T16TFM_20190412T162841_B06.jp2",
"T16TFM_20190412T162841_B07.jp2",
"T16TFM_20190412T162841_B08.jp2",
"T16TFM_20190412T162841_B09.jp2",
"T16TFM_20190412T162841_B10.jp2",
"T16TFM_20190412T162841_B11.jp2",
"T16TFM_20190412T162841_B12.jp2",
"T16TFM_20190412T162841_B8A.jp2",
"T16TFM_20190412T162841_TCI.jp2",
]
}
}
@ -54,6 +68,13 @@ filenames: FILENAME_HIERARCHY = {
"T26EMU_20220414T110751_B08_10m.jp2",
"T26EMU_20220414T110751_TCI_10m.jp2",
"T26EMU_20220414T110751_WVP_10m.jp2",
"T26EMU_20190414T110751_AOT_10m.jp2",
"T26EMU_20190414T110751_B02_10m.jp2",
"T26EMU_20190414T110751_B03_10m.jp2",
"T26EMU_20190414T110751_B04_10m.jp2",
"T26EMU_20190414T110751_B08_10m.jp2",
"T26EMU_20190414T110751_TCI_10m.jp2",
"T26EMU_20190414T110751_WVP_10m.jp2",
],
"R20m": [
"T26EMU_20220414T110751_AOT_20m.jp2",
@ -70,6 +91,20 @@ filenames: FILENAME_HIERARCHY = {
"T26EMU_20220414T110751_SCL_20m.jp2",
"T26EMU_20220414T110751_TCI_20m.jp2",
"T26EMU_20220414T110751_WVP_20m.jp2",
"T26EMU_20190414T110751_AOT_20m.jp2",
"T26EMU_20190414T110751_B01_20m.jp2",
"T26EMU_20190414T110751_B02_20m.jp2",
"T26EMU_20190414T110751_B03_20m.jp2",
"T26EMU_20190414T110751_B04_20m.jp2",
"T26EMU_20190414T110751_B05_20m.jp2",
"T26EMU_20190414T110751_B06_20m.jp2",
"T26EMU_20190414T110751_B07_20m.jp2",
"T26EMU_20190414T110751_B11_20m.jp2",
"T26EMU_20190414T110751_B12_20m.jp2",
"T26EMU_20190414T110751_B8A_20m.jp2",
"T26EMU_20190414T110751_SCL_20m.jp2",
"T26EMU_20190414T110751_TCI_20m.jp2",
"T26EMU_20190414T110751_WVP_20m.jp2",
],
"R60m": [
"T26EMU_20220414T110751_AOT_60m.jp2",
@ -87,6 +122,21 @@ filenames: FILENAME_HIERARCHY = {
"T26EMU_20220414T110751_SCL_60m.jp2",
"T26EMU_20220414T110751_TCI_60m.jp2",
"T26EMU_20220414T110751_WVP_60m.jp2",
"T26EMU_20190414T110751_AOT_60m.jp2",
"T26EMU_20190414T110751_B01_60m.jp2",
"T26EMU_20190414T110751_B02_60m.jp2",
"T26EMU_20190414T110751_B03_60m.jp2",
"T26EMU_20190414T110751_B04_60m.jp2",
"T26EMU_20190414T110751_B05_60m.jp2",
"T26EMU_20190414T110751_B06_60m.jp2",
"T26EMU_20190414T110751_B07_60m.jp2",
"T26EMU_20190414T110751_B09_60m.jp2",
"T26EMU_20190414T110751_B11_60m.jp2",
"T26EMU_20190414T110751_B12_60m.jp2",
"T26EMU_20190414T110751_B8A_60m.jp2",
"T26EMU_20190414T110751_SCL_60m.jp2",
"T26EMU_20190414T110751_TCI_60m.jp2",
"T26EMU_20190414T110751_WVP_60m.jp2",
],
}
}

Просмотреть файл

@ -108,7 +108,7 @@ class TestSentinel2:
return Sentinel2(root, res=res, bands=bands, transforms=transforms)
def test_separate_files(self, dataset: Sentinel2) -> None:
assert dataset.index.count(dataset.index.bounds) == 2
assert dataset.index.count(dataset.index.bounds) == 4
def test_getitem(self, dataset: Sentinel2) -> None:
x = dataset[dataset.bounds]

Просмотреть файл

@ -73,6 +73,7 @@ class TestSemanticSegmentationTask:
"sen12ms_s2_all",
"sen12ms_s2_reduced",
"sentinel2_cdl",
"sentinel2_nccm",
"spacenet1",
"ssl4eo_l_benchmark_cdl",
"ssl4eo_l_benchmark_nlcd",

Просмотреть файл

@ -29,6 +29,7 @@ from .resisc45 import RESISC45DataModule
from .seco import SeasonalContrastS2DataModule
from .sen12ms import SEN12MSDataModule
from .sentinel2_cdl import Sentinel2CDLDataModule
from .sentinel2_nccm import Sentinel2NCCMDataModule
from .skippd import SKIPPDDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
@ -49,6 +50,7 @@ __all__ = (
"L8BiomeDataModule",
"NAIPChesapeakeDataModule",
"Sentinel2CDLDataModule",
"Sentinel2NCCMDataModule",
# NonGeoDataset
"BigEarthNetDataModule",
"ChaBuDDataModule",

Просмотреть файл

@ -92,7 +92,7 @@ class Sentinel2CDLDataModule(GeoDataModule):
(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(
self.dataset, [0.8, 0.10, 0.10], grid_size=8, generator=generator
self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator
)
)
if stage in ["fit"]:

Просмотреть файл

@ -0,0 +1,121 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Sentinel-2 and NCCM datamodule."""
from typing import Any, Optional, Union
import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample
from matplotlib.figure import Figure
from ..datasets import NCCM, Sentinel2, random_grid_cell_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule
class Sentinel2NCCMDataModule(GeoDataModule):
"""LightningDataModule implementation for the Sentinel-2 and NCCM dataset.
.. versionadded:: 0.6
"""
def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 64,
length: Optional[int] = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new Sentinel2NCCMDataModule instance.
Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.NCCM` (prefix keys with ``nccm_``) and
:class:`~torchgeo.datasets.Sentinel2`
(prefix keys with ``sentinel2_``).
"""
# Define prefix for NCCM and Sentinel-2 arguments
nccm_signature = "nccm_"
sentinel2_signature = "sentinel2_"
self.nccm_kwargs = {}
self.sentinel2_kwargs = {}
for key, val in kwargs.items():
# Check if the current key starts with the NCCM prefix
if key.startswith(nccm_signature):
# If so, extract the key-value pair to the NCCM dictionary
self.nccm_kwargs[key[len(nccm_signature) :]] = val
# Check if the current key starts with the Sentinel-2 prefix
elif key.startswith(sentinel2_signature):
# If so, extract the key-value pair to the Sentinel-2 dictionary
self.sentinel2_kwargs[key[len(sentinel2_signature) :]] = val
super().__init__(
NCCM, batch_size, patch_size, length, num_workers, **self.nccm_kwargs
)
self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=["image", "mask"],
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
},
)
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"]
)
def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
self.nccm = NCCM(**self.nccm_kwargs)
self.dataset = self.sentinel2 & self.nccm
generator = torch.Generator().manual_seed(0)
(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(
self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator
)
)
if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)
def plot(self, *args: Any, **kwargs: Any) -> Figure:
"""Run NCCM plot method.
Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
return self.nccm.plot(*args, **kwargs)