Adds the LandCoverAI100 dataset and datamodule for use in semantic segmentation notebooks (#2262)

* Add dataset and datamodule

* Add docs

* Tests

* Ran ruff one time

* Fixture needs a params kwarg

* Make dataset work

* Add versionadded to datamodule

* Add conf file to test new datamodule

* Test datamodule

* Changing dataset URL

* Update main hash
This commit is contained in:
Caleb Robinson 2024-09-11 13:48:24 -07:00 коммит произвёл GitHub
Родитель de315490b7
Коммит 94960bbcf6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 87 добавлений и 17 удалений

Просмотреть файл

@ -123,6 +123,7 @@ LandCover.ai
^^^^^^^^^^^^ ^^^^^^^^^^^^
.. autoclass:: LandCoverAIDataModule .. autoclass:: LandCoverAIDataModule
.. autoclass:: LandCoverAI100DataModule
LEVIR-CD LEVIR-CD
^^^^^^^^ ^^^^^^^^

Просмотреть файл

@ -316,6 +316,7 @@ LandCover.ai
^^^^^^^^^^^^ ^^^^^^^^^^^^
.. autoclass:: LandCoverAI .. autoclass:: LandCoverAI
.. autoclass:: LandCoverAI100
LEVIR-CD LEVIR-CD
^^^^^^^^ ^^^^^^^^

Просмотреть файл

@ -0,0 +1,16 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 3
num_classes: 5
num_filters: 1
ignore_index: null
data:
class_path: LandCoverAI100DataModule
init_args:
batch_size: 1
dict_kwargs:
root: "tests/data/landcoverai"

Просмотреть файл

@ -3,6 +3,7 @@
import os import os
import shutil import shutil
from itertools import product
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -17,6 +18,7 @@ from torchgeo.datasets import (
BoundingBox, BoundingBox,
DatasetNotFoundError, DatasetNotFoundError,
LandCoverAI, LandCoverAI,
LandCoverAI100,
LandCoverAIGeo, LandCoverAIGeo,
) )
@ -72,20 +74,25 @@ class TestLandCoverAIGeo:
class TestLandCoverAI: class TestLandCoverAI:
pytest.importorskip('cv2', minversion='4.5.4') pytest.importorskip('cv2', minversion='4.5.4')
@pytest.fixture(params=['train', 'val', 'test']) @pytest.fixture(
params=product([LandCoverAI100, LandCoverAI], ['train', 'val', 'test'])
)
def dataset( def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> LandCoverAI: ) -> LandCoverAI:
base_class: type[LandCoverAI] = request.param[0]
split: str = request.param[1]
md5 = 'ff8998857cc8511f644d3f7d0f3688d0' md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
monkeypatch.setattr(LandCoverAI, 'md5', md5) monkeypatch.setattr(base_class, 'md5', md5)
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
monkeypatch.setattr(LandCoverAI, 'url', url) monkeypatch.setattr(base_class, 'url', url)
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
monkeypatch.setattr(LandCoverAI, 'sha256', sha256) monkeypatch.setattr(base_class, 'sha256', sha256)
if base_class == LandCoverAI100:
monkeypatch.setattr(base_class, 'filename', 'landcover.ai.v1.zip')
root = tmp_path root = tmp_path
split = request.param
transforms = nn.Identity() transforms = nn.Identity()
return LandCoverAI(root, split, transforms, download=True, checksum=True) return base_class(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: LandCoverAI) -> None: def test_getitem(self, dataset: LandCoverAI) -> None:
x = dataset[0] x = dataset[0]

Просмотреть файл

@ -62,6 +62,7 @@ class TestSemanticSegmentationTask:
'l7irish', 'l7irish',
'l8biome', 'l8biome',
'landcoverai', 'landcoverai',
'landcoverai100',
'loveda', 'loveda',
'naipchesapeake', 'naipchesapeake',
'potsdam2d', 'potsdam2d',

Просмотреть файл

@ -23,7 +23,7 @@ from .inria import InriaAerialImageLabelingDataModule
from .iobench import IOBenchDataModule from .iobench import IOBenchDataModule
from .l7irish import L7IrishDataModule from .l7irish import L7IrishDataModule
from .l8biome import L8BiomeDataModule from .l8biome import L8BiomeDataModule
from .landcoverai import LandCoverAIDataModule from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
from .loveda import LoveDADataModule from .loveda import LoveDADataModule
from .naip import NAIPChesapeakeDataModule from .naip import NAIPChesapeakeDataModule
@ -82,6 +82,7 @@ __all__ = (
'GID15DataModule', 'GID15DataModule',
'InriaAerialImageLabelingDataModule', 'InriaAerialImageLabelingDataModule',
'LandCoverAIDataModule', 'LandCoverAIDataModule',
'LandCoverAI100DataModule',
'LEVIRCDDataModule', 'LEVIRCDDataModule',
'LEVIRCDPlusDataModule', 'LEVIRCDPlusDataModule',
'LoveDADataModule', 'LoveDADataModule',

Просмотреть файл

@ -1,13 +1,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved. # Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. # Licensed under the MIT License.
"""LandCover.ai datamodule.""" """LandCover.ai datamodules."""
from typing import Any from typing import Any
import kornia.augmentation as K import kornia.augmentation as K
from ..datasets import LandCoverAI from ..datasets import LandCoverAI, LandCoverAI100
from ..transforms import AugmentationSequential from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule from .geo import NonGeoDataModule
@ -43,3 +43,29 @@ class LandCoverAIDataModule(NonGeoDataModule):
self.aug = AugmentationSequential( self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
) )
class LandCoverAI100DataModule(NonGeoDataModule):
"""LightningDataModule implementation for the LandCoverAI100 dataset.
Uses the train/val/test splits from the dataset.
.. versionadded:: 0.7
"""
def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new LandCoverAI100DataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.LandCoverAI100`.
"""
super().__init__(LandCoverAI100, batch_size, num_workers, **kwargs)
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
)

Просмотреть файл

@ -65,7 +65,7 @@ from .inria import InriaAerialImageLabeling
from .iobench import IOBench from .iobench import IOBench
from .l7irish import L7Irish from .l7irish import L7Irish
from .l8biome import L8Biome from .l8biome import L8Biome
from .landcoverai import LandCoverAI, LandCoverAIBase, LandCoverAIGeo from .landcoverai import LandCoverAI, LandCoverAI100, LandCoverAIBase, LandCoverAIGeo
from .landsat import ( from .landsat import (
Landsat, Landsat,
Landsat1, Landsat1,
@ -224,6 +224,7 @@ __all__ = (
'IDTReeS', 'IDTReeS',
'InriaAerialImageLabeling', 'InriaAerialImageLabeling',
'LandCoverAI', 'LandCoverAI',
'LandCoverAI100',
'LEVIRCD', 'LEVIRCD',
'LEVIRCDBase', 'LEVIRCDBase',
'LEVIRCDPlus', 'LEVIRCDPlus',

Просмотреть файл

@ -401,10 +401,26 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset):
super()._extract() super()._extract()
# Generate train/val/test splits # Generate train/val/test splits
# Always check the sha256 of this file before executing # Always check the sha256 of this file before executing to avoid malicious code injection
# to avoid malicious code injection # The LandCoverAI100 dataset doesn't contain split.py, so only run if split.py exists
with working_dir(self.root): if os.path.exists(os.path.join(self.root, 'split.py')):
with open('split.py') as f: with working_dir(self.root):
split = f.read().encode('utf-8') with open('split.py') as f:
assert hashlib.sha256(split).hexdigest() == self.sha256 split = f.read().encode('utf-8')
exec(split) assert hashlib.sha256(split).hexdigest() == self.sha256
exec(split)
class LandCoverAI100(LandCoverAI):
"""Subset of LandCoverAI containing only 100 images.
Intended for tutorials and demonstrations, not for benchmarking.
Maintains the same file structure, classes, and train-val-test split.
.. versionadded:: 0.7
"""
url = 'https://huggingface.co/datasets/torchgeo/landcoverai/resolve/5cdf9299bd6c1232506cf79373df01f6e6596b50/landcoverai100.zip'
filename = 'landcoverai100.zip'
md5 = '66eb33b5a0cabb631836ce0a4eafb7cd'