зеркало из https://github.com/microsoft/torchgeo.git
Adds the LandCoverAI100 dataset and datamodule for use in semantic segmentation notebooks (#2262)
* Add dataset and datamodule * Add docs * Tests * Ran ruff one time * Fixture needs a params kwarg * Make dataset work * Add versionadded to datamodule * Add conf file to test new datamodule * Test datamodule * Changing dataset URL * Update main hash
This commit is contained in:
Родитель
de315490b7
Коммит
94960bbcf6
|
@ -123,6 +123,7 @@ LandCover.ai
|
||||||
^^^^^^^^^^^^
|
^^^^^^^^^^^^
|
||||||
|
|
||||||
.. autoclass:: LandCoverAIDataModule
|
.. autoclass:: LandCoverAIDataModule
|
||||||
|
.. autoclass:: LandCoverAI100DataModule
|
||||||
|
|
||||||
LEVIR-CD
|
LEVIR-CD
|
||||||
^^^^^^^^
|
^^^^^^^^
|
||||||
|
|
|
@ -316,6 +316,7 @@ LandCover.ai
|
||||||
^^^^^^^^^^^^
|
^^^^^^^^^^^^
|
||||||
|
|
||||||
.. autoclass:: LandCoverAI
|
.. autoclass:: LandCoverAI
|
||||||
|
.. autoclass:: LandCoverAI100
|
||||||
|
|
||||||
LEVIR-CD
|
LEVIR-CD
|
||||||
^^^^^^^^
|
^^^^^^^^
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
model:
|
||||||
|
class_path: SemanticSegmentationTask
|
||||||
|
init_args:
|
||||||
|
loss: "ce"
|
||||||
|
model: "unet"
|
||||||
|
backbone: "resnet18"
|
||||||
|
in_channels: 3
|
||||||
|
num_classes: 5
|
||||||
|
num_filters: 1
|
||||||
|
ignore_index: null
|
||||||
|
data:
|
||||||
|
class_path: LandCoverAI100DataModule
|
||||||
|
init_args:
|
||||||
|
batch_size: 1
|
||||||
|
dict_kwargs:
|
||||||
|
root: "tests/data/landcoverai"
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from itertools import product
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -17,6 +18,7 @@ from torchgeo.datasets import (
|
||||||
BoundingBox,
|
BoundingBox,
|
||||||
DatasetNotFoundError,
|
DatasetNotFoundError,
|
||||||
LandCoverAI,
|
LandCoverAI,
|
||||||
|
LandCoverAI100,
|
||||||
LandCoverAIGeo,
|
LandCoverAIGeo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -72,20 +74,25 @@ class TestLandCoverAIGeo:
|
||||||
class TestLandCoverAI:
|
class TestLandCoverAI:
|
||||||
pytest.importorskip('cv2', minversion='4.5.4')
|
pytest.importorskip('cv2', minversion='4.5.4')
|
||||||
|
|
||||||
@pytest.fixture(params=['train', 'val', 'test'])
|
@pytest.fixture(
|
||||||
|
params=product([LandCoverAI100, LandCoverAI], ['train', 'val', 'test'])
|
||||||
|
)
|
||||||
def dataset(
|
def dataset(
|
||||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||||
) -> LandCoverAI:
|
) -> LandCoverAI:
|
||||||
|
base_class: type[LandCoverAI] = request.param[0]
|
||||||
|
split: str = request.param[1]
|
||||||
md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
|
md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
|
||||||
monkeypatch.setattr(LandCoverAI, 'md5', md5)
|
monkeypatch.setattr(base_class, 'md5', md5)
|
||||||
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
|
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
|
||||||
monkeypatch.setattr(LandCoverAI, 'url', url)
|
monkeypatch.setattr(base_class, 'url', url)
|
||||||
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
|
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
|
||||||
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
|
monkeypatch.setattr(base_class, 'sha256', sha256)
|
||||||
|
if base_class == LandCoverAI100:
|
||||||
|
monkeypatch.setattr(base_class, 'filename', 'landcover.ai.v1.zip')
|
||||||
root = tmp_path
|
root = tmp_path
|
||||||
split = request.param
|
|
||||||
transforms = nn.Identity()
|
transforms = nn.Identity()
|
||||||
return LandCoverAI(root, split, transforms, download=True, checksum=True)
|
return base_class(root, split, transforms, download=True, checksum=True)
|
||||||
|
|
||||||
def test_getitem(self, dataset: LandCoverAI) -> None:
|
def test_getitem(self, dataset: LandCoverAI) -> None:
|
||||||
x = dataset[0]
|
x = dataset[0]
|
||||||
|
|
|
@ -62,6 +62,7 @@ class TestSemanticSegmentationTask:
|
||||||
'l7irish',
|
'l7irish',
|
||||||
'l8biome',
|
'l8biome',
|
||||||
'landcoverai',
|
'landcoverai',
|
||||||
|
'landcoverai100',
|
||||||
'loveda',
|
'loveda',
|
||||||
'naipchesapeake',
|
'naipchesapeake',
|
||||||
'potsdam2d',
|
'potsdam2d',
|
||||||
|
|
|
@ -23,7 +23,7 @@ from .inria import InriaAerialImageLabelingDataModule
|
||||||
from .iobench import IOBenchDataModule
|
from .iobench import IOBenchDataModule
|
||||||
from .l7irish import L7IrishDataModule
|
from .l7irish import L7IrishDataModule
|
||||||
from .l8biome import L8BiomeDataModule
|
from .l8biome import L8BiomeDataModule
|
||||||
from .landcoverai import LandCoverAIDataModule
|
from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule
|
||||||
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
|
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
|
||||||
from .loveda import LoveDADataModule
|
from .loveda import LoveDADataModule
|
||||||
from .naip import NAIPChesapeakeDataModule
|
from .naip import NAIPChesapeakeDataModule
|
||||||
|
@ -82,6 +82,7 @@ __all__ = (
|
||||||
'GID15DataModule',
|
'GID15DataModule',
|
||||||
'InriaAerialImageLabelingDataModule',
|
'InriaAerialImageLabelingDataModule',
|
||||||
'LandCoverAIDataModule',
|
'LandCoverAIDataModule',
|
||||||
|
'LandCoverAI100DataModule',
|
||||||
'LEVIRCDDataModule',
|
'LEVIRCDDataModule',
|
||||||
'LEVIRCDPlusDataModule',
|
'LEVIRCDPlusDataModule',
|
||||||
'LoveDADataModule',
|
'LoveDADataModule',
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
"""LandCover.ai datamodule."""
|
"""LandCover.ai datamodules."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import kornia.augmentation as K
|
import kornia.augmentation as K
|
||||||
|
|
||||||
from ..datasets import LandCoverAI
|
from ..datasets import LandCoverAI, LandCoverAI100
|
||||||
from ..transforms import AugmentationSequential
|
from ..transforms import AugmentationSequential
|
||||||
from .geo import NonGeoDataModule
|
from .geo import NonGeoDataModule
|
||||||
|
|
||||||
|
@ -43,3 +43,29 @@ class LandCoverAIDataModule(NonGeoDataModule):
|
||||||
self.aug = AugmentationSequential(
|
self.aug = AugmentationSequential(
|
||||||
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
|
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LandCoverAI100DataModule(NonGeoDataModule):
|
||||||
|
"""LightningDataModule implementation for the LandCoverAI100 dataset.
|
||||||
|
|
||||||
|
Uses the train/val/test splits from the dataset.
|
||||||
|
|
||||||
|
.. versionadded:: 0.7
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a new LandCoverAI100DataModule instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Size of each mini-batch.
|
||||||
|
num_workers: Number of workers for parallel data loading.
|
||||||
|
**kwargs: Additional keyword arguments passed to
|
||||||
|
:class:`~torchgeo.datasets.LandCoverAI100`.
|
||||||
|
"""
|
||||||
|
super().__init__(LandCoverAI100, batch_size, num_workers, **kwargs)
|
||||||
|
|
||||||
|
self.aug = AugmentationSequential(
|
||||||
|
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
|
||||||
|
)
|
||||||
|
|
|
@ -65,7 +65,7 @@ from .inria import InriaAerialImageLabeling
|
||||||
from .iobench import IOBench
|
from .iobench import IOBench
|
||||||
from .l7irish import L7Irish
|
from .l7irish import L7Irish
|
||||||
from .l8biome import L8Biome
|
from .l8biome import L8Biome
|
||||||
from .landcoverai import LandCoverAI, LandCoverAIBase, LandCoverAIGeo
|
from .landcoverai import LandCoverAI, LandCoverAI100, LandCoverAIBase, LandCoverAIGeo
|
||||||
from .landsat import (
|
from .landsat import (
|
||||||
Landsat,
|
Landsat,
|
||||||
Landsat1,
|
Landsat1,
|
||||||
|
@ -224,6 +224,7 @@ __all__ = (
|
||||||
'IDTReeS',
|
'IDTReeS',
|
||||||
'InriaAerialImageLabeling',
|
'InriaAerialImageLabeling',
|
||||||
'LandCoverAI',
|
'LandCoverAI',
|
||||||
|
'LandCoverAI100',
|
||||||
'LEVIRCD',
|
'LEVIRCD',
|
||||||
'LEVIRCDBase',
|
'LEVIRCDBase',
|
||||||
'LEVIRCDPlus',
|
'LEVIRCDPlus',
|
||||||
|
|
|
@ -401,10 +401,26 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset):
|
||||||
super()._extract()
|
super()._extract()
|
||||||
|
|
||||||
# Generate train/val/test splits
|
# Generate train/val/test splits
|
||||||
# Always check the sha256 of this file before executing
|
# Always check the sha256 of this file before executing to avoid malicious code injection
|
||||||
# to avoid malicious code injection
|
# The LandCoverAI100 dataset doesn't contain split.py, so only run if split.py exists
|
||||||
with working_dir(self.root):
|
if os.path.exists(os.path.join(self.root, 'split.py')):
|
||||||
with open('split.py') as f:
|
with working_dir(self.root):
|
||||||
split = f.read().encode('utf-8')
|
with open('split.py') as f:
|
||||||
assert hashlib.sha256(split).hexdigest() == self.sha256
|
split = f.read().encode('utf-8')
|
||||||
exec(split)
|
assert hashlib.sha256(split).hexdigest() == self.sha256
|
||||||
|
exec(split)
|
||||||
|
|
||||||
|
|
||||||
|
class LandCoverAI100(LandCoverAI):
|
||||||
|
"""Subset of LandCoverAI containing only 100 images.
|
||||||
|
|
||||||
|
Intended for tutorials and demonstrations, not for benchmarking.
|
||||||
|
|
||||||
|
Maintains the same file structure, classes, and train-val-test split.
|
||||||
|
|
||||||
|
.. versionadded:: 0.7
|
||||||
|
"""
|
||||||
|
|
||||||
|
url = 'https://huggingface.co/datasets/torchgeo/landcoverai/resolve/5cdf9299bd6c1232506cf79373df01f6e6596b50/landcoverai100.zip'
|
||||||
|
filename = 'landcoverai100.zip'
|
||||||
|
md5 = '66eb33b5a0cabb631836ce0a4eafb7cd'
|
||||||
|
|
Загрузка…
Ссылка в новой задаче