зеркало из https://github.com/microsoft/torchgeo.git
Adds the LandCoverAI100 dataset and datamodule for use in semantic segmentation notebooks (#2262)
* Add dataset and datamodule * Add docs * Tests * Ran ruff one time * Fixture needs a params kwarg * Make dataset work * Add versionadded to datamodule * Add conf file to test new datamodule * Test datamodule * Changing dataset URL * Update main hash
This commit is contained in:
Родитель
de315490b7
Коммит
94960bbcf6
|
@ -123,6 +123,7 @@ LandCover.ai
|
|||
^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: LandCoverAIDataModule
|
||||
.. autoclass:: LandCoverAI100DataModule
|
||||
|
||||
LEVIR-CD
|
||||
^^^^^^^^
|
||||
|
|
|
@ -316,6 +316,7 @@ LandCover.ai
|
|||
^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: LandCoverAI
|
||||
.. autoclass:: LandCoverAI100
|
||||
|
||||
LEVIR-CD
|
||||
^^^^^^^^
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
model:
|
||||
class_path: SemanticSegmentationTask
|
||||
init_args:
|
||||
loss: "ce"
|
||||
model: "unet"
|
||||
backbone: "resnet18"
|
||||
in_channels: 3
|
||||
num_classes: 5
|
||||
num_filters: 1
|
||||
ignore_index: null
|
||||
data:
|
||||
class_path: LandCoverAI100DataModule
|
||||
init_args:
|
||||
batch_size: 1
|
||||
dict_kwargs:
|
||||
root: "tests/data/landcoverai"
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import os
|
||||
import shutil
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -17,6 +18,7 @@ from torchgeo.datasets import (
|
|||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
LandCoverAI,
|
||||
LandCoverAI100,
|
||||
LandCoverAIGeo,
|
||||
)
|
||||
|
||||
|
@ -72,20 +74,25 @@ class TestLandCoverAIGeo:
|
|||
class TestLandCoverAI:
|
||||
pytest.importorskip('cv2', minversion='4.5.4')
|
||||
|
||||
@pytest.fixture(params=['train', 'val', 'test'])
|
||||
@pytest.fixture(
|
||||
params=product([LandCoverAI100, LandCoverAI], ['train', 'val', 'test'])
|
||||
)
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> LandCoverAI:
|
||||
base_class: type[LandCoverAI] = request.param[0]
|
||||
split: str = request.param[1]
|
||||
md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
|
||||
monkeypatch.setattr(LandCoverAI, 'md5', md5)
|
||||
monkeypatch.setattr(base_class, 'md5', md5)
|
||||
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
|
||||
monkeypatch.setattr(LandCoverAI, 'url', url)
|
||||
monkeypatch.setattr(base_class, 'url', url)
|
||||
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
|
||||
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
|
||||
monkeypatch.setattr(base_class, 'sha256', sha256)
|
||||
if base_class == LandCoverAI100:
|
||||
monkeypatch.setattr(base_class, 'filename', 'landcover.ai.v1.zip')
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return LandCoverAI(root, split, transforms, download=True, checksum=True)
|
||||
return base_class(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: LandCoverAI) -> None:
|
||||
x = dataset[0]
|
||||
|
|
|
@ -62,6 +62,7 @@ class TestSemanticSegmentationTask:
|
|||
'l7irish',
|
||||
'l8biome',
|
||||
'landcoverai',
|
||||
'landcoverai100',
|
||||
'loveda',
|
||||
'naipchesapeake',
|
||||
'potsdam2d',
|
||||
|
|
|
@ -23,7 +23,7 @@ from .inria import InriaAerialImageLabelingDataModule
|
|||
from .iobench import IOBenchDataModule
|
||||
from .l7irish import L7IrishDataModule
|
||||
from .l8biome import L8BiomeDataModule
|
||||
from .landcoverai import LandCoverAIDataModule
|
||||
from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule
|
||||
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
|
||||
from .loveda import LoveDADataModule
|
||||
from .naip import NAIPChesapeakeDataModule
|
||||
|
@ -82,6 +82,7 @@ __all__ = (
|
|||
'GID15DataModule',
|
||||
'InriaAerialImageLabelingDataModule',
|
||||
'LandCoverAIDataModule',
|
||||
'LandCoverAI100DataModule',
|
||||
'LEVIRCDDataModule',
|
||||
'LEVIRCDPlusDataModule',
|
||||
'LoveDADataModule',
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""LandCover.ai datamodule."""
|
||||
"""LandCover.ai datamodules."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
||||
from ..datasets import LandCoverAI
|
||||
from ..datasets import LandCoverAI, LandCoverAI100
|
||||
from ..transforms import AugmentationSequential
|
||||
from .geo import NonGeoDataModule
|
||||
|
||||
|
@ -43,3 +43,29 @@ class LandCoverAIDataModule(NonGeoDataModule):
|
|||
self.aug = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
|
||||
)
|
||||
|
||||
|
||||
class LandCoverAI100DataModule(NonGeoDataModule):
|
||||
"""LightningDataModule implementation for the LandCoverAI100 dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
|
||||
.. versionadded:: 0.7
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a new LandCoverAI100DataModule instance.
|
||||
|
||||
Args:
|
||||
batch_size: Size of each mini-batch.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
**kwargs: Additional keyword arguments passed to
|
||||
:class:`~torchgeo.datasets.LandCoverAI100`.
|
||||
"""
|
||||
super().__init__(LandCoverAI100, batch_size, num_workers, **kwargs)
|
||||
|
||||
self.aug = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
|
||||
)
|
||||
|
|
|
@ -65,7 +65,7 @@ from .inria import InriaAerialImageLabeling
|
|||
from .iobench import IOBench
|
||||
from .l7irish import L7Irish
|
||||
from .l8biome import L8Biome
|
||||
from .landcoverai import LandCoverAI, LandCoverAIBase, LandCoverAIGeo
|
||||
from .landcoverai import LandCoverAI, LandCoverAI100, LandCoverAIBase, LandCoverAIGeo
|
||||
from .landsat import (
|
||||
Landsat,
|
||||
Landsat1,
|
||||
|
@ -224,6 +224,7 @@ __all__ = (
|
|||
'IDTReeS',
|
||||
'InriaAerialImageLabeling',
|
||||
'LandCoverAI',
|
||||
'LandCoverAI100',
|
||||
'LEVIRCD',
|
||||
'LEVIRCDBase',
|
||||
'LEVIRCDPlus',
|
||||
|
|
|
@ -401,10 +401,26 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset):
|
|||
super()._extract()
|
||||
|
||||
# Generate train/val/test splits
|
||||
# Always check the sha256 of this file before executing
|
||||
# to avoid malicious code injection
|
||||
with working_dir(self.root):
|
||||
with open('split.py') as f:
|
||||
split = f.read().encode('utf-8')
|
||||
assert hashlib.sha256(split).hexdigest() == self.sha256
|
||||
exec(split)
|
||||
# Always check the sha256 of this file before executing to avoid malicious code injection
|
||||
# The LandCoverAI100 dataset doesn't contain split.py, so only run if split.py exists
|
||||
if os.path.exists(os.path.join(self.root, 'split.py')):
|
||||
with working_dir(self.root):
|
||||
with open('split.py') as f:
|
||||
split = f.read().encode('utf-8')
|
||||
assert hashlib.sha256(split).hexdigest() == self.sha256
|
||||
exec(split)
|
||||
|
||||
|
||||
class LandCoverAI100(LandCoverAI):
|
||||
"""Subset of LandCoverAI containing only 100 images.
|
||||
|
||||
Intended for tutorials and demonstrations, not for benchmarking.
|
||||
|
||||
Maintains the same file structure, classes, and train-val-test split.
|
||||
|
||||
.. versionadded:: 0.7
|
||||
"""
|
||||
|
||||
url = 'https://huggingface.co/datasets/torchgeo/landcoverai/resolve/5cdf9299bd6c1232506cf79373df01f6e6596b50/landcoverai100.zip'
|
||||
filename = 'landcoverai100.zip'
|
||||
md5 = '66eb33b5a0cabb631836ce0a4eafb7cd'
|
||||
|
|
Загрузка…
Ссылка в новой задаче