Adds the LandCoverAI100 dataset and datamodule for use in semantic segmentation notebooks (#2262)

* Add dataset and datamodule

* Add docs

* Tests

* Ran ruff one time

* Fixture needs a params kwarg

* Make dataset work

* Add versionadded to datamodule

* Add conf file to test new datamodule

* Test datamodule

* Changing dataset URL

* Update main hash
This commit is contained in:
Caleb Robinson 2024-09-11 13:48:24 -07:00 коммит произвёл GitHub
Родитель de315490b7
Коммит 94960bbcf6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 87 добавлений и 17 удалений

Просмотреть файл

@ -123,6 +123,7 @@ LandCover.ai
^^^^^^^^^^^^
.. autoclass:: LandCoverAIDataModule
.. autoclass:: LandCoverAI100DataModule
LEVIR-CD
^^^^^^^^

Просмотреть файл

@ -316,6 +316,7 @@ LandCover.ai
^^^^^^^^^^^^
.. autoclass:: LandCoverAI
.. autoclass:: LandCoverAI100
LEVIR-CD
^^^^^^^^

Просмотреть файл

@ -0,0 +1,16 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 3
num_classes: 5
num_filters: 1
ignore_index: null
data:
class_path: LandCoverAI100DataModule
init_args:
batch_size: 1
dict_kwargs:
root: "tests/data/landcoverai"

Просмотреть файл

@ -3,6 +3,7 @@
import os
import shutil
from itertools import product
from pathlib import Path
import matplotlib.pyplot as plt
@ -17,6 +18,7 @@ from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
LandCoverAI,
LandCoverAI100,
LandCoverAIGeo,
)
@ -72,20 +74,25 @@ class TestLandCoverAIGeo:
class TestLandCoverAI:
pytest.importorskip('cv2', minversion='4.5.4')
@pytest.fixture(params=['train', 'val', 'test'])
@pytest.fixture(
params=product([LandCoverAI100, LandCoverAI], ['train', 'val', 'test'])
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> LandCoverAI:
base_class: type[LandCoverAI] = request.param[0]
split: str = request.param[1]
md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
monkeypatch.setattr(LandCoverAI, 'md5', md5)
monkeypatch.setattr(base_class, 'md5', md5)
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
monkeypatch.setattr(LandCoverAI, 'url', url)
monkeypatch.setattr(base_class, 'url', url)
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
monkeypatch.setattr(base_class, 'sha256', sha256)
if base_class == LandCoverAI100:
monkeypatch.setattr(base_class, 'filename', 'landcover.ai.v1.zip')
root = tmp_path
split = request.param
transforms = nn.Identity()
return LandCoverAI(root, split, transforms, download=True, checksum=True)
return base_class(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: LandCoverAI) -> None:
x = dataset[0]

Просмотреть файл

@ -62,6 +62,7 @@ class TestSemanticSegmentationTask:
'l7irish',
'l8biome',
'landcoverai',
'landcoverai100',
'loveda',
'naipchesapeake',
'potsdam2d',

Просмотреть файл

@ -23,7 +23,7 @@ from .inria import InriaAerialImageLabelingDataModule
from .iobench import IOBenchDataModule
from .l7irish import L7IrishDataModule
from .l8biome import L8BiomeDataModule
from .landcoverai import LandCoverAIDataModule
from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
from .loveda import LoveDADataModule
from .naip import NAIPChesapeakeDataModule
@ -82,6 +82,7 @@ __all__ = (
'GID15DataModule',
'InriaAerialImageLabelingDataModule',
'LandCoverAIDataModule',
'LandCoverAI100DataModule',
'LEVIRCDDataModule',
'LEVIRCDPlusDataModule',
'LoveDADataModule',

Просмотреть файл

@ -1,13 +1,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""LandCover.ai datamodule."""
"""LandCover.ai datamodules."""
from typing import Any
import kornia.augmentation as K
from ..datasets import LandCoverAI
from ..datasets import LandCoverAI, LandCoverAI100
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
@ -43,3 +43,29 @@ class LandCoverAIDataModule(NonGeoDataModule):
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
)
class LandCoverAI100DataModule(NonGeoDataModule):
"""LightningDataModule implementation for the LandCoverAI100 dataset.
Uses the train/val/test splits from the dataset.
.. versionadded:: 0.7
"""
def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new LandCoverAI100DataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.LandCoverAI100`.
"""
super().__init__(LandCoverAI100, batch_size, num_workers, **kwargs)
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
)

Просмотреть файл

@ -65,7 +65,7 @@ from .inria import InriaAerialImageLabeling
from .iobench import IOBench
from .l7irish import L7Irish
from .l8biome import L8Biome
from .landcoverai import LandCoverAI, LandCoverAIBase, LandCoverAIGeo
from .landcoverai import LandCoverAI, LandCoverAI100, LandCoverAIBase, LandCoverAIGeo
from .landsat import (
Landsat,
Landsat1,
@ -224,6 +224,7 @@ __all__ = (
'IDTReeS',
'InriaAerialImageLabeling',
'LandCoverAI',
'LandCoverAI100',
'LEVIRCD',
'LEVIRCDBase',
'LEVIRCDPlus',

Просмотреть файл

@ -401,10 +401,26 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset):
super()._extract()
# Generate train/val/test splits
# Always check the sha256 of this file before executing
# to avoid malicious code injection
with working_dir(self.root):
with open('split.py') as f:
split = f.read().encode('utf-8')
assert hashlib.sha256(split).hexdigest() == self.sha256
exec(split)
# Always check the sha256 of this file before executing to avoid malicious code injection
# The LandCoverAI100 dataset doesn't contain split.py, so only run if split.py exists
if os.path.exists(os.path.join(self.root, 'split.py')):
with working_dir(self.root):
with open('split.py') as f:
split = f.read().encode('utf-8')
assert hashlib.sha256(split).hexdigest() == self.sha256
exec(split)
class LandCoverAI100(LandCoverAI):
"""Subset of LandCoverAI containing only 100 images.
Intended for tutorials and demonstrations, not for benchmarking.
Maintains the same file structure, classes, and train-val-test split.
.. versionadded:: 0.7
"""
url = 'https://huggingface.co/datasets/torchgeo/landcoverai/resolve/5cdf9299bd6c1232506cf79373df01f6e6596b50/landcoverai100.zip'
filename = 'landcoverai100.zip'
md5 = '66eb33b5a0cabb631836ce0a4eafb7cd'