зеркало из https://github.com/microsoft/torchgeo.git
Add AgriFieldNet datamodule (#1873)
* add agrifieldnet datamodule
* fix codecov
* extra_args not needed
* Bigger default batch size
* Revert "extra_args not needed"
This reverts commit f690d8b1f8
.
* Same split as everyone else
---------
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
bd48efe988
Коммит
d0300449c6
|
@ -6,6 +6,11 @@ torchgeo.datamodules
|
|||
Geospatial DataModules
|
||||
----------------------
|
||||
|
||||
AgriFieldNet
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: AgriFieldNetDataModule
|
||||
|
||||
Chesapeake Land Cover
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
model:
|
||||
class_path: SemanticSegmentationTask
|
||||
init_args:
|
||||
loss: "ce"
|
||||
model: "unet"
|
||||
backbone: "resnet18"
|
||||
in_channels: 12
|
||||
num_classes: 14
|
||||
num_filters: 1
|
||||
ignore_index: 0
|
||||
data:
|
||||
class_path: AgriFieldNetDataModule
|
||||
init_args:
|
||||
batch_size: 2
|
||||
patch_size: 16
|
||||
dict_kwargs:
|
||||
paths: "tests/data/agrifieldnet"
|
|
@ -55,6 +55,7 @@ class TestSemanticSegmentationTask:
|
|||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"agrifieldnet",
|
||||
"chabud",
|
||||
"chesapeake_cvpr_5",
|
||||
"chesapeake_cvpr_7",
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
"""TorchGeo datamodules."""
|
||||
|
||||
from .agrifieldnet import AgriFieldNetDataModule
|
||||
from .bigearthnet import BigEarthNetDataModule
|
||||
from .chabud import ChaBuDDataModule
|
||||
from .chesapeake import ChesapeakeCVPRDataModule
|
||||
|
@ -45,6 +46,7 @@ from .xview import XView2DataModule
|
|||
|
||||
__all__ = (
|
||||
# GeoDataset
|
||||
"AgriFieldNetDataModule",
|
||||
"ChesapeakeCVPRDataModule",
|
||||
"L7IrishDataModule",
|
||||
"L8BiomeDataModule",
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""AgriFieldNet datamodule."""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
from kornia.constants import DataKey, Resample
|
||||
|
||||
from ..datasets import AgriFieldNet, random_bbox_assignment
|
||||
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
|
||||
from ..samplers.utils import _to_tuple
|
||||
from ..transforms import AugmentationSequential
|
||||
from .geo import GeoDataModule
|
||||
|
||||
|
||||
class AgriFieldNetDataModule(GeoDataModule):
|
||||
"""LightningDataModule implementation for the AgriFieldNet dataset.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[int, tuple[int, int]] = 256,
|
||||
length: Optional[int] = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new AgriFieldNetDataModule instance.
|
||||
|
||||
Args:
|
||||
batch_size: Size of each mini-batch.
|
||||
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
|
||||
length: Length of each training epoch.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
**kwargs: Additional keyword arguments passed to
|
||||
:class:`~torchgeo.datasets.AgriFieldNet`.
|
||||
"""
|
||||
super().__init__(
|
||||
AgriFieldNet,
|
||||
batch_size=batch_size,
|
||||
patch_size=patch_size,
|
||||
length=length,
|
||||
num_workers=num_workers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.train_aug = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std),
|
||||
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
|
||||
K.RandomVerticalFlip(p=0.5),
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
data_keys=["image", "mask"],
|
||||
extra_args={
|
||||
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
|
||||
},
|
||||
)
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Set up datasets.
|
||||
|
||||
Args:
|
||||
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
||||
"""
|
||||
dataset = AgriFieldNet(**self.kwargs)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
(self.train_dataset, self.val_dataset, self.test_dataset) = (
|
||||
random_bbox_assignment(dataset, [0.8, 0.1, 0.1], generator)
|
||||
)
|
||||
|
||||
if stage in ["fit"]:
|
||||
self.train_batch_sampler = RandomBatchGeoSampler(
|
||||
self.train_dataset, self.patch_size, self.batch_size, self.length
|
||||
)
|
||||
if stage in ["fit", "validate"]:
|
||||
self.val_sampler = GridGeoSampler(
|
||||
self.val_dataset, self.patch_size, self.patch_size
|
||||
)
|
||||
if stage in ["test"]:
|
||||
self.test_sampler = GridGeoSampler(
|
||||
self.test_dataset, self.patch_size, self.patch_size
|
||||
)
|
Загрузка…
Ссылка в новой задаче