зеркало из https://github.com/microsoft/torchgeo.git
Add South America Soybean DataModule (#1959)
* Add South America Soybean DataModule * Add train_aug * Regenerate data --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
2849944a02
Коммит
5d253c55d8
|
@ -36,6 +36,7 @@ Sentinel
|
|||
|
||||
.. autoclass:: Sentinel2CDLDataModule
|
||||
.. autoclass:: Sentinel2NCCMDataModule
|
||||
.. autoclass:: Sentinel2SouthAmericaSoybeanDataModule
|
||||
|
||||
Non-geospatial DataModules
|
||||
--------------------------
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
model:
|
||||
class_path: SemanticSegmentationTask
|
||||
init_args:
|
||||
loss: "ce"
|
||||
model: "deeplabv3+"
|
||||
backbone: "resnet18"
|
||||
in_channels: 13
|
||||
num_classes: 2
|
||||
num_filters: 1
|
||||
data:
|
||||
class_path: Sentinel2SouthAmericaSoybeanDataModule
|
||||
init_args:
|
||||
batch_size: 2
|
||||
patch_size: 16
|
||||
dict_kwargs:
|
||||
south_america_soybean_paths: "tests/data/south_america_soybean"
|
||||
sentinel2_paths: "tests/data/sentinel2"
|
Двоичные данные
tests/data/south_america_soybean/SouthAmericaSoybean.zip
Двоичные данные
tests/data/south_america_soybean/SouthAmericaSoybean.zip
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -11,7 +11,7 @@ import rasterio
|
|||
from rasterio.crs import CRS
|
||||
from rasterio.transform import Affine
|
||||
|
||||
SIZE = 32
|
||||
SIZE = 128
|
||||
|
||||
|
||||
np.random.seed(0)
|
||||
|
@ -24,15 +24,8 @@ def create_file(path: str, dtype: str):
|
|||
"driver": "GTiff",
|
||||
"dtype": dtype,
|
||||
"count": 1,
|
||||
"crs": CRS.from_epsg(4326),
|
||||
"transform": Affine(
|
||||
0.0002499999999999943131,
|
||||
0.0,
|
||||
-82.0005000000000024,
|
||||
0.0,
|
||||
-0.0002499999999999943131,
|
||||
0.0005000000000000,
|
||||
),
|
||||
"crs": CRS.from_epsg(32616),
|
||||
"transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0),
|
||||
"height": SIZE,
|
||||
"width": SIZE,
|
||||
"compress": "lzw",
|
||||
|
|
|
@ -75,6 +75,7 @@ class TestSemanticSegmentationTask:
|
|||
"sen12ms_s2_reduced",
|
||||
"sentinel2_cdl",
|
||||
"sentinel2_nccm",
|
||||
"sentinel2_south_america_soybean",
|
||||
"spacenet1",
|
||||
"ssl4eo_l_benchmark_cdl",
|
||||
"ssl4eo_l_benchmark_nlcd",
|
||||
|
|
|
@ -31,6 +31,7 @@ from .seco import SeasonalContrastS2DataModule
|
|||
from .sen12ms import SEN12MSDataModule
|
||||
from .sentinel2_cdl import Sentinel2CDLDataModule
|
||||
from .sentinel2_nccm import Sentinel2NCCMDataModule
|
||||
from .sentinel2_south_america_soybean import Sentinel2SouthAmericaSoybeanDataModule
|
||||
from .skippd import SKIPPDDataModule
|
||||
from .so2sat import So2SatDataModule
|
||||
from .spacenet import SpaceNet1DataModule
|
||||
|
@ -53,6 +54,7 @@ __all__ = (
|
|||
"NAIPChesapeakeDataModule",
|
||||
"Sentinel2CDLDataModule",
|
||||
"Sentinel2NCCMDataModule",
|
||||
"Sentinel2SouthAmericaSoybeanDataModule",
|
||||
# NonGeoDataset
|
||||
"BigEarthNetDataModule",
|
||||
"ChaBuDDataModule",
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
"""South America Soybean datamodule."""
|
||||
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
from kornia.constants import DataKey, Resample
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from ..datasets import Sentinel2, SouthAmericaSoybean, random_grid_cell_assignment
|
||||
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
|
||||
from ..samplers.utils import _to_tuple
|
||||
from ..transforms import AugmentationSequential
|
||||
from .geo import GeoDataModule
|
||||
|
||||
|
||||
class Sentinel2SouthAmericaSoybeanDataModule(GeoDataModule):
|
||||
"""LightningDataModule for SouthAmericaSoybean and Sentinel2 datasets.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[int, tuple[int, int]] = 64,
|
||||
length: Optional[int] = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new Sentinel2SouthAmericaSoybeanDataModule instance.
|
||||
|
||||
Args:
|
||||
batch_size: Size of each mini-batch.
|
||||
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
|
||||
length: Length of each training epoch.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
**kwargs: Additional keyword arguments passed to
|
||||
:class:`~torchgeo.datasets.SouthAmericaSoybean`
|
||||
(prefix keys with ``south_america_soybean_``) and
|
||||
:class:`~torchgeo.datasets.Sentinel2`
|
||||
(prefix keys with ``sentinel2_``).
|
||||
"""
|
||||
self.south_america_soybean_kwargs = {}
|
||||
self.sentinel2_kwargs = {}
|
||||
for key, val in kwargs.items():
|
||||
if key.startswith("south_america_soybean_"):
|
||||
self.south_america_soybean_kwargs[key[22:]] = val
|
||||
elif key.startswith("sentinel2_"):
|
||||
self.sentinel2_kwargs[key[10:]] = val
|
||||
|
||||
super().__init__(
|
||||
SouthAmericaSoybean,
|
||||
batch_size=batch_size,
|
||||
patch_size=patch_size,
|
||||
length=length,
|
||||
num_workers=num_workers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.train_aug = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std),
|
||||
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
|
||||
K.RandomVerticalFlip(p=0.5),
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
data_keys=["image", "mask"],
|
||||
extra_args={
|
||||
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
|
||||
},
|
||||
)
|
||||
|
||||
self.aug = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"]
|
||||
)
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Set up datasets and samplers.
|
||||
|
||||
Args:
|
||||
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
||||
"""
|
||||
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
|
||||
self.south_america_soybean = SouthAmericaSoybean(
|
||||
**self.south_america_soybean_kwargs
|
||||
)
|
||||
self.dataset = self.sentinel2 & self.south_america_soybean
|
||||
|
||||
generator = torch.Generator().manual_seed(1)
|
||||
(self.train_dataset, self.val_dataset, self.test_dataset) = (
|
||||
random_grid_cell_assignment(
|
||||
self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator
|
||||
)
|
||||
)
|
||||
|
||||
if stage in ["fit"]:
|
||||
self.train_batch_sampler = RandomBatchGeoSampler(
|
||||
self.train_dataset, self.patch_size, self.batch_size, self.length
|
||||
)
|
||||
if stage in ["fit", "validate"]:
|
||||
self.val_sampler = GridGeoSampler(
|
||||
self.val_dataset, self.patch_size, self.patch_size
|
||||
)
|
||||
if stage in ["test"]:
|
||||
self.test_sampler = GridGeoSampler(
|
||||
self.test_dataset, self.patch_size, self.patch_size
|
||||
)
|
||||
|
||||
def plot(self, *args: Any, **kwargs: Any) -> Figure:
|
||||
"""Run SouthAmericaSoybean plot method.
|
||||
|
||||
Args:
|
||||
*args: Arguments passed to plot method.
|
||||
**kwargs: Keyword arguments passed to plot method.
|
||||
|
||||
Returns:
|
||||
A matplotlib Figure with the image, ground truth, and predictions.
|
||||
"""
|
||||
return self.south_america_soybean.plot(*args, **kwargs)
|
Загрузка…
Ссылка в новой задаче