Move DataModules to torchgeo.datamodules (#321)

* Move DataModules to torchgeo.datamodules

* Clean up local imports
This commit is contained in:
Adam J. Stewart 2021-12-23 20:10:50 -06:00 коммит произвёл GitHub
Родитель 5a57d6c9a3
Коммит cbebc1e0db
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
100 изменённых файлов: 3978 добавлений и 3500 удалений

105
docs/api/datamodules.rst Normal file
Просмотреть файл

@ -0,0 +1,105 @@
torchgeo.datamodules
====================
.. module:: torchgeo.datamodules
Geospatial DataModules
----------------------
Chesapeake Bay High-Resolution Land Cover Project
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: ChesapeakeCVPRDataModule
National Agriculture Imagery Program (NAIP)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: NAIPChesapeakeDataModule
Non-geospatial DataModules
--------------------------
BigEarthNet
^^^^^^^^^^^
.. autoclass:: BigEarthNetDataModule
Cars Overhead With Context (COWC)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: COWCCountingDataModule
ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: ETCI2021DataModule
EuroSAT
^^^^^^^
.. autoclass:: EuroSATDataModule
FAIR1M (Fine-grAined object recognItion in high-Resolution imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: FAIR1MDataModule
LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: LandCoverAIDataModule
LoveDA (Land-cOVEr Domain Adaptive semantic segmentation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: LoveDADataModule
NASA Marine Debris
^^^^^^^^^^^^^^^^^^
.. autoclass:: NASAMarineDebrisDataModule
OSCD (Onera Satellite Change Detection)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: OSCDDataModule
Potsdam
^^^^^^^
.. autoclass:: Potsdam2DDataModule
RESISC45 (Remote Sensing Image Scene Classification)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: RESISC45DataModule
SEN12MS
^^^^^^^
.. autoclass:: SEN12MSDataModule
So2Sat
^^^^^^
.. autoclass:: So2SatDataModule
Tropical Cyclone Wind Estimation Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: CycloneDataModule
UC Merced
^^^^^^^^^
.. autoclass:: UCMercedDataModule
Vaihingen
^^^^^^^^^
.. autoclass:: Vaihingen2DDataModule
xView2
^^^^^^
.. autoclass:: XView2DataModule

Просмотреть файл

@ -31,7 +31,6 @@ Chesapeake Bay High-Resolution Land Cover Project
.. autoclass:: ChesapeakeVA
.. autoclass:: ChesapeakeWV
.. autoclass:: ChesapeakeCVPR
.. autoclass:: ChesapeakeCVPRDataModule
Cropland Data Layer (CDL)
^^^^^^^^^^^^^^^^^^^^^^^^^
@ -57,7 +56,6 @@ National Agriculture Imagery Program (NAIP)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: NAIP
.. autoclass:: NAIPChesapeakeDataModule
Sentinel
^^^^^^^^
@ -86,7 +84,6 @@ BigEarthNet
^^^^^^^^^^^
.. autoclass:: BigEarthNet
.. autoclass:: BigEarthNetDataModule
Cars Overhead With Context (COWC)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -94,7 +91,6 @@ Cars Overhead With Context (COWC)
.. autoclass:: COWC
.. autoclass:: COWCCounting
.. autoclass:: COWCDetection
.. autoclass:: COWCCountingDataModule
CV4A Kenya Crop Type Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -105,19 +101,16 @@ ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: ETCI2021
.. autoclass:: ETCI2021DataModule
EuroSAT
^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^
.. autoclass:: EuroSAT
.. autoclass:: EuroSATDataModule
FAIR1M (Fine-grAined object recognItion in high-Resolution imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: FAIR1M
.. autoclass:: FAIR1MDataModule
GID-15 (Gaofen Image Dataset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -133,7 +126,6 @@ LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: LandCoverAI
.. autoclass:: LandCoverAIDataModule
LEVIR-CD+ (LEVIR Change Detection +)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -144,19 +136,16 @@ LoveDA (Land-cOVEr Domain Adaptive semantic segmentation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: LoveDA
.. autoclass:: LoveDADataModule
NASA Marine Debris
^^^^^^^^^^^^^^^^^^
.. autoclass:: NASAMarineDebris
.. autoclass:: NASAMarineDebrisDataModule
OSCD (Onera Satellite Change Detection)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: OSCD
.. autoclass:: OSCDDataModule
PatternNet
^^^^^^^^^^
@ -167,13 +156,11 @@ Potsdam
^^^^^^^
.. autoclass:: Potsdam2D
.. autoclass:: Potsdam2DDataModule
RESISC45 (Remote Sensing Image Scene Classification)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: RESISC45
.. autoclass:: RESISC45DataModule
Seasonal Contrast
^^^^^^^^^^^^^^^^^
@ -184,13 +171,11 @@ SEN12MS
^^^^^^^
.. autoclass:: SEN12MS
.. autoclass:: SEN12MSDataModule
So2Sat
^^^^^^
.. autoclass:: So2Sat
.. autoclass:: So2SatDataModule
SpaceNet
^^^^^^^^
@ -206,30 +191,26 @@ Tropical Cyclone Wind Estimation Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: TropicalCycloneWindEstimation
.. autoclass:: CycloneDataModule
UC Merced
^^^^^^^^^
.. autoclass:: UCMerced
Vaihingen
^^^^^^^^^
.. autoclass:: Vaihingen2D
.. autoclass:: Vaihingen2DDataModule
NWPU VHR-10
^^^^^^^^^^^
.. autoclass:: VHR10
UC Merced
^^^^^^^^^
.. autoclass:: UCMerced
.. autoclass:: UCMercedDataModule
xView2
^^^^^^
.. autoclass:: XView2
.. autoclass:: XView2DataModule
ZueriCrop
^^^^^^^^^

Просмотреть файл

@ -15,6 +15,7 @@ torchgeo
:maxdepth: 2
:caption: Package Reference
api/datamodules
api/datasets
api/models
api/samplers

Просмотреть файл

@ -11,7 +11,7 @@ import os
import pytorch_lightning as pl
import torch
from torchgeo.datasets import ChesapeakeCVPRDataModule
from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask
ALL_TEST_SPLITS = [["de-val"], ["pa-test"], ["ny-test"], ["pa-test", "ny-test"]]

Просмотреть файл

@ -73,7 +73,7 @@ strict_equality = true
[tool.pydocstyle]
convention = "google"
match_dir = "(datasets|models|samplers|torchgeo|trainers|transforms)"
match_dir = "(datamodules|datasets|models|samplers|torchgeo|trainers|transforms)"
[tool.pytest.ini_options]
# Skip slow tests by default

Просмотреть файл

@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.datamodules import BigEarthNetDataModule
class TestBigEarthNetDataModule:
@pytest.fixture(scope="class", params=["s1", "s2", "all"])
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
bands = request.param
root = os.path.join("tests", "data", "bigearthnet")
num_classes = 19
batch_size = 1
num_workers = 0
dm = BigEarthNetDataModule(root, bands, num_classes, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
import torch
from _pytest.fixtures import SubRequest
from torchgeo.datamodules import ChesapeakeCVPRDataModule
class TestChesapeakeCVPRDataModule:
@pytest.fixture(scope="class", params=[5, 7])
def datamodule(self, request: SubRequest) -> ChesapeakeCVPRDataModule:
dm = ChesapeakeCVPRDataModule(
os.path.join("tests", "data", "chesapeake", "cvpr"),
["de-test"],
["de-test"],
["de-test"],
patch_size=32,
patches_per_tile=2,
batch_size=2,
num_workers=0,
class_set=request.param,
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
next(iter(datamodule.test_dataloader()))
def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None:
nodata_check = datamodule.nodata_check(4)
sample = {
"image": torch.ones(1, 2, 2), # type: ignore[attr-defined]
"mask": torch.ones(2, 2), # type: ignore[attr-defined]
}
out = nodata_check(sample)
assert torch.equal( # type: ignore[attr-defined]
out["image"], torch.zeros(1, 4, 4) # type: ignore[attr-defined]
)
assert torch.equal( # type: ignore[attr-defined]
out["mask"], torch.zeros(4, 4) # type: ignore[attr-defined]
)

Просмотреть файл

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import COWCCountingDataModule
class TestCOWCCountingDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> COWCCountingDataModule:
root = os.path.join("tests", "data", "cowc_counting")
seed = 0
batch_size = 1
num_workers = 0
dm = COWCCountingDataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: COWCCountingDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: COWCCountingDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: COWCCountingDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import CycloneDataModule
class TestCycloneDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> CycloneDataModule:
root = os.path.join("tests", "data", "cyclone")
seed = 0
batch_size = 1
num_workers = 0
dm = CycloneDataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: CycloneDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: CycloneDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: CycloneDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import ETCI2021DataModule
class TestETCI2021DataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> ETCI2021DataModule:
root = os.path.join("tests", "data", "etci2021")
seed = 0
batch_size = 2
num_workers = 0
dm = ETCI2021DataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: ETCI2021DataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: ETCI2021DataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: ETCI2021DataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import EuroSATDataModule
class TestEuroSATDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> EuroSATDataModule:
root = os.path.join("tests", "data", "eurosat")
batch_size = 1
num_workers = 0
dm = EuroSATDataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: EuroSATDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: EuroSATDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: EuroSATDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import FAIR1MDataModule
class TestFAIR1MDataModule:
@pytest.fixture(scope="class", params=[True, False])
def datamodule(self) -> FAIR1MDataModule:
root = os.path.join("tests", "data", "fair1m")
batch_size = 2
num_workers = 0
dm = FAIR1MDataModule(
root, batch_size, num_workers, val_split_pct=0.33, test_split_pct=0.33
)
dm.setup()
return dm
def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import LandCoverAIDataModule
class TestLandCoverAIDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> LandCoverAIDataModule:
root = os.path.join("tests", "data", "landcoverai")
batch_size = 2
num_workers = 0
dm = LandCoverAIDataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import LoveDADataModule
class TestLoveDADataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> LoveDADataModule:
root = os.path.join("tests", "data", "loveda")
batch_size = 2
num_workers = 0
scene = ["rural", "urban"]
dm = LoveDADataModule(
root_dir=root, scene=scene, batch_size=batch_size, num_workers=num_workers
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: LoveDADataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: LoveDADataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: LoveDADataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import NAIPChesapeakeDataModule
class TestNAIPChesapeakeDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> NAIPChesapeakeDataModule:
dm = NAIPChesapeakeDataModule(
os.path.join("tests", "data", "naip"),
os.path.join("tests", "data", "chesapeake", "BAYWIDE"),
batch_size=2,
num_workers=0,
)
dm.patch_size = 32
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import NASAMarineDebrisDataModule
class TestNASAMarineDebrisDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> NASAMarineDebrisDataModule:
root = os.path.join("tests", "data", "nasa_marine_debris")
batch_size = 2
num_workers = 0
val_split_pct = 0.3
test_split_pct = 0.3
dm = NASAMarineDebrisDataModule(
root, batch_size, num_workers, val_split_pct, test_split_pct
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.datamodules import OSCDDataModule
class TestOSCDDataModule:
@pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5]))
def datamodule(self, request: SubRequest) -> OSCDDataModule:
bands, val_split_pct = request.param
patch_size = (2, 2)
num_patches_per_tile = 2
root = os.path.join("tests", "data", "oscd")
batch_size = 1
num_workers = 0
dm = OSCDDataModule(
root,
bands,
batch_size,
num_workers,
val_split_pct,
patch_size,
num_patches_per_tile,
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: OSCDDataModule) -> None:
sample = next(iter(datamodule.train_dataloader()))
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2
if datamodule.bands == "all":
assert sample["image"].shape[1] == 26
else:
assert sample["image"].shape[1] == 6
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
sample = next(iter(datamodule.val_dataloader()))
if datamodule.val_split_pct > 0.0:
assert (
sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
if datamodule.bands == "all":
assert sample["image"].shape[1] == 26
else:
assert sample["image"].shape[1] == 6
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
sample = next(iter(datamodule.test_dataloader()))
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
if datamodule.bands == "all":
assert sample["image"].shape[1] == 26
else:
assert sample["image"].shape[1] == 6

Просмотреть файл

@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.datamodules import Potsdam2DDataModule
class TestPotsdam2DDataModule:
@pytest.fixture(scope="class", params=[0.0, 0.5])
def datamodule(self, request: SubRequest) -> Potsdam2DDataModule:
root = os.path.join("tests", "data", "potsdam")
batch_size = 1
num_workers = 0
val_split_size = request.param
dm = Potsdam2DDataModule(
root, batch_size, num_workers, val_split_pct=val_split_size
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import RESISC45DataModule
class TestRESISC45DataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> RESISC45DataModule:
root = os.path.join("tests", "data", "resisc45")
batch_size = 2
num_workers = 0
dm = RESISC45DataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: RESISC45DataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: RESISC45DataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: RESISC45DataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.datamodules import SEN12MSDataModule
class TestSEN12MSDataModule:
@pytest.fixture(scope="class", params=["all", "s1", "s2-all", "s2-reduced"])
def datamodule(self, request: SubRequest) -> SEN12MSDataModule:
root = os.path.join("tests", "data", "sen12ms")
seed = 0
bands = request.param
batch_size = 1
num_workers = 0
dm = SEN12MSDataModule(root, seed, bands, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: SEN12MSDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: SEN12MSDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: SEN12MSDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.datamodules import So2SatDataModule
pytest.importorskip("h5py")
class TestSo2SatDataModule:
@pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"]))
def datamodule(self, request: SubRequest) -> So2SatDataModule:
unsupervised_mode, bands = request.param
root = os.path.join("tests", "data", "so2sat")
batch_size = 2
num_workers = 0
dm = So2SatDataModule(root, batch_size, num_workers, bands, unsupervised_mode)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: So2SatDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: So2SatDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: So2SatDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from torchgeo.datamodules import UCMercedDataModule
class TestUCMercedDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> UCMercedDataModule:
root = os.path.join("tests", "data", "ucmerced")
batch_size = 2
num_workers = 0
dm = UCMercedDataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: UCMercedDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: UCMercedDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: UCMercedDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import torch
from torch.utils.data import TensorDataset
from torchgeo.datamodules.utils import dataset_split
def test_dataset_split() -> None:
num_samples = 24
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
ds = TensorDataset(x, y)
# Test only train/val set split
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
assert len(train_ds) == num_samples // 2
assert len(val_ds) == num_samples // 2
# Test train/val/test set split
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
assert len(train_ds) == num_samples // 3
assert len(val_ds) == num_samples // 3
assert len(test_ds) == num_samples // 3

Просмотреть файл

@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.datamodules import Vaihingen2DDataModule
class TestVaihingen2DDataModule:
@pytest.fixture(scope="class", params=[0.0, 0.5])
def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule:
root = os.path.join("tests", "data", "vaihingen")
batch_size = 1
num_workers = 0
val_split_size = request.param
dm = Vaihingen2DDataModule(
root, batch_size, num_workers, val_split_pct=val_split_size
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.datamodules import XView2DataModule
class TestXView2DataModule:
@pytest.fixture(scope="class", params=[0.0, 0.5])
def datamodule(self, request: SubRequest) -> XView2DataModule:
root = os.path.join("tests", "data", "xview2")
batch_size = 1
num_workers = 0
val_split_size = request.param
dm = XView2DataModule(
root, batch_size, num_workers, val_split_pct=val_split_size
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import BigEarthNet, BigEarthNetDataModule
from torchgeo.datasets import BigEarthNet
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -148,26 +148,3 @@ class TestBigEarthNet:
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
BigEarthNet(str(tmp_path))
class TestBigEarthNetDataModule:
@pytest.fixture(scope="class", params=["s1", "s2", "all"])
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
bands = request.param
root = os.path.join("tests", "data", "bigearthnet")
num_classes = 19
batch_size = 1
num_workers = 0
dm = BigEarthNetDataModule(root, bands, num_classes, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -19,7 +19,6 @@ from torchgeo.datasets import (
BoundingBox,
Chesapeake13,
ChesapeakeCVPR,
ChesapeakeCVPRDataModule,
IntersectionDataset,
UnionDataset,
)
@ -179,45 +178,3 @@ class TestChesapeakeCVPR:
IndexError, match="query: .* spans multiple tiles which is not valid"
):
ds[dataset.bounds]
class TestChesapeakeCVPRDataModule:
@pytest.fixture(scope="class", params=[5, 7])
def datamodule(self, request: SubRequest) -> ChesapeakeCVPRDataModule:
dm = ChesapeakeCVPRDataModule(
os.path.join("tests", "data", "chesapeake", "cvpr"),
["de-test"],
["de-test"],
["de-test"],
patch_size=32,
patches_per_tile=2,
batch_size=2,
num_workers=0,
class_set=request.param,
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
next(iter(datamodule.test_dataloader()))
def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None:
nodata_check = datamodule.nodata_check(4)
sample = {
"image": torch.ones(1, 2, 2), # type: ignore[attr-defined]
"mask": torch.ones(2, 2), # type: ignore[attr-defined]
}
out = nodata_check(sample)
assert torch.equal( # type: ignore[attr-defined]
out["image"], torch.zeros(1, 4, 4) # type: ignore[attr-defined]
)
assert torch.equal( # type: ignore[attr-defined]
out["mask"], torch.zeros(4, 4) # type: ignore[attr-defined]
)

Просмотреть файл

@ -14,7 +14,7 @@ from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import COWCCounting, COWCCountingDataModule, COWCDetection
from torchgeo.datasets import COWCCounting, COWCDetection
from torchgeo.datasets.cowc import COWC
@ -148,25 +148,3 @@ class TestCOWCDetection:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
COWCDetection(str(tmp_path))
class TestCOWCCountingDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> COWCCountingDataModule:
root = os.path.join("tests", "data", "cowc_counting")
seed = 0
batch_size = 1
num_workers = 0
dm = COWCCountingDataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: COWCCountingDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: COWCCountingDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: COWCCountingDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -15,7 +15,7 @@ from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import CycloneDataModule, TropicalCycloneWindEstimation
from torchgeo.datasets import TropicalCycloneWindEstimation
class Dataset:
@ -103,25 +103,3 @@ class TestTropicalCycloneWindEstimation:
)
dataset.plot(sample)
plt.close()
class TestCycloneDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> CycloneDataModule:
root = os.path.join("tests", "data", "cyclone")
seed = 0
batch_size = 1
num_workers = 0
dm = CycloneDataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: CycloneDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: CycloneDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: CycloneDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -14,7 +14,7 @@ from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ETCI2021, ETCI2021DataModule
from torchgeo.datasets import ETCI2021
def download_url(url: str, root: str, *args: str) -> None:
@ -95,25 +95,3 @@ class TestETCI2021:
x["prediction"] = x["mask"][0].clone()
dataset.plot(x)
plt.close()
class TestETCI2021DataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> ETCI2021DataModule:
root = os.path.join("tests", "data", "etci2021")
seed = 0
batch_size = 2
num_workers = 0
dm = ETCI2021DataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: ETCI2021DataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: ETCI2021DataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: ETCI2021DataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import EuroSAT, EuroSATDataModule
from torchgeo.datasets import EuroSAT
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -100,24 +100,3 @@ class TestEuroSAT:
x["prediction"] = x["label"].clone()
dataset.plot(x)
plt.close()
class TestEuroSATDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> EuroSATDataModule:
root = os.path.join("tests", "data", "eurosat")
batch_size = 1
num_workers = 0
dm = EuroSATDataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: EuroSATDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: EuroSATDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: EuroSATDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -12,7 +12,7 @@ import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import FAIR1M, FAIR1MDataModule
from torchgeo.datasets import FAIR1M
class TestFAIR1M:
@ -73,25 +73,3 @@ class TestFAIR1M:
x["prediction_boxes"] = x["boxes"].clone()
dataset.plot(x)
plt.close()
class TestFAIR1MDataModule:
@pytest.fixture(scope="class", params=[True, False])
def datamodule(self) -> FAIR1MDataModule:
root = os.path.join("tests", "data", "fair1m")
batch_size = 2
num_workers = 0
dm = FAIR1MDataModule(
root, batch_size, num_workers, val_split_pct=0.33, test_split_pct=0.33
)
dm.setup()
return dm
def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import LandCoverAI, LandCoverAIDataModule
from torchgeo.datasets import LandCoverAI
def download_url(url: str, root: str, *args: str) -> None:
@ -78,24 +78,3 @@ class TestLandCoverAI:
x["prediction"] = x["mask"].clone()
dataset.plot(x)
plt.close()
class TestLandCoverAIDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> LandCoverAIDataModule:
root = os.path.join("tests", "data", "landcoverai")
batch_size = 2
num_workers = 0
dm = LandCoverAIDataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -14,7 +14,7 @@ from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import LoveDA, LoveDADataModule
from torchgeo.datasets import LoveDA
def download_url(url: str, root: str, *args: str) -> None:
@ -99,29 +99,3 @@ class TestLoveDA:
def test_plot(self, dataset: LoveDA) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()
class TestLoveDADataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> LoveDADataModule:
root = os.path.join("tests", "data", "loveda")
batch_size = 2
num_workers = 0
scene = ["rural", "urban"]
dm = LoveDADataModule(
root_dir=root, scene=scene, batch_size=batch_size, num_workers=num_workers
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: LoveDADataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: LoveDADataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: LoveDADataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -12,13 +12,7 @@ import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import (
NAIP,
BoundingBox,
IntersectionDataset,
NAIPChesapeakeDataModule,
UnionDataset,
)
from torchgeo.datasets import NAIP, BoundingBox, IntersectionDataset, UnionDataset
class TestNAIP:
@ -60,27 +54,3 @@ class TestNAIP:
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
class TestNAIPChesapeakeDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> NAIPChesapeakeDataModule:
dm = NAIPChesapeakeDataModule(
os.path.join("tests", "data", "naip"),
os.path.join("tests", "data", "chesapeake", "BAYWIDE"),
batch_size=2,
num_workers=0,
)
dm.patch_size = 32
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -13,7 +13,7 @@ import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import NASAMarineDebris, NASAMarineDebrisDataModule
from torchgeo.datasets import NASAMarineDebris
class Dataset:
@ -85,28 +85,3 @@ class TestNASAMarineDebris:
x["prediction_boxes"] = x["boxes"].clone()
dataset.plot(x)
plt.close()
class TestNASAMarineDebrisDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> NASAMarineDebrisDataModule:
root = os.path.join("tests", "data", "nasa_marine_debris")
batch_size = 2
num_workers = 0
val_split_pct = 0.3
test_split_pct = 0.3
dm = NASAMarineDebrisDataModule(
root, batch_size, num_workers, val_split_pct, test_split_pct
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -16,7 +16,7 @@ from matplotlib import pyplot as plt
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import OSCD, OSCDDataModule
from torchgeo.datasets import OSCD
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -105,56 +105,3 @@ class TestOSCD:
def test_plot(self, dataset: OSCD) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()
class TestOSCDDataModule:
@pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5]))
def datamodule(self, request: SubRequest) -> OSCDDataModule:
bands, val_split_pct = request.param
patch_size = (2, 2)
num_patches_per_tile = 2
root = os.path.join("tests", "data", "oscd")
batch_size = 1
num_workers = 0
dm = OSCDDataModule(
root,
bands,
batch_size,
num_workers,
val_split_pct,
patch_size,
num_patches_per_tile,
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: OSCDDataModule) -> None:
sample = next(iter(datamodule.train_dataloader()))
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2
if datamodule.bands == "all":
assert sample["image"].shape[1] == 26
else:
assert sample["image"].shape[1] == 6
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
sample = next(iter(datamodule.val_dataloader()))
if datamodule.val_split_pct > 0.0:
assert (
sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
if datamodule.bands == "all":
assert sample["image"].shape[1] == 26
else:
assert sample["image"].shape[1] == 6
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
sample = next(iter(datamodule.test_dataloader()))
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
if datamodule.bands == "all":
assert sample["image"].shape[1] == 26
else:
assert sample["image"].shape[1] == 6

Просмотреть файл

@ -13,7 +13,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import Potsdam2D, Potsdam2DDataModule
from torchgeo.datasets import Potsdam2D
class TestPotsdam2D:
@ -75,27 +75,3 @@ class TestPotsdam2D:
x["prediction"] = x["mask"].clone()
dataset.plot(x)
plt.close()
class TestPotsdam2DDataModule:
@pytest.fixture(scope="class", params=[0.0, 0.5])
def datamodule(self, request: SubRequest) -> Potsdam2DDataModule:
root = os.path.join("tests", "data", "potsdam")
batch_size = 1
num_workers = 0
val_split_size = request.param
dm = Potsdam2DDataModule(
root, batch_size, num_workers, val_split_pct=val_split_size
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -15,7 +15,7 @@ from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import RESISC45, RESISC45DataModule
from torchgeo.datasets import RESISC45
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -101,24 +101,3 @@ class TestRESISC45:
x["prediction"] = x["label"].clone()
dataset.plot(x)
plt.close()
class TestRESISC45DataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> RESISC45DataModule:
root = os.path.join("tests", "data", "resisc45")
batch_size = 2
num_workers = 0
dm = RESISC45DataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: RESISC45DataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: RESISC45DataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: RESISC45DataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import SEN12MS, SEN12MSDataModule
from torchgeo.datasets import SEN12MS
class TestSEN12MS:
@ -82,26 +82,3 @@ class TestSEN12MS:
ds = SEN12MS(root, bands=bands, checksum=False)
x = ds[0]["image"]
assert x.shape[0] == len(bands)
class TestSEN12MSDataModule:
@pytest.fixture(scope="class", params=["all", "s1", "s2-all", "s2-reduced"])
def datamodule(self, request: SubRequest) -> SEN12MSDataModule:
root = os.path.join("tests", "data", "sen12ms")
seed = 0
bands = request.param
batch_size = 1
num_workers = 0
dm = SEN12MSDataModule(root, seed, bands, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: SEN12MSDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: SEN12MSDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: SEN12MSDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -13,7 +13,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import So2Sat, So2SatDataModule
from torchgeo.datasets import So2Sat
pytest.importorskip("h5py")
@ -91,25 +91,3 @@ class TestSo2Sat:
match="h5py is not installed and is required to use this dataset",
):
So2Sat(dataset.root)
class TestSo2SatDataModule:
@pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"]))
def datamodule(self, request: SubRequest) -> So2SatDataModule:
unsupervised_mode, bands = request.param
root = os.path.join("tests", "data", "so2sat")
batch_size = 2
num_workers = 0
dm = So2SatDataModule(root, batch_size, num_workers, bands, unsupervised_mode)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: So2SatDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: So2SatDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: So2SatDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import UCMerced, UCMercedDataModule
from torchgeo.datasets import UCMerced
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -102,24 +102,3 @@ class TestUCMerced:
x["prediction"] = x["label"].clone()
dataset.plot(x)
plt.close()
class TestUCMercedDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> UCMercedDataModule:
root = os.path.join("tests", "data", "ucmerced")
batch_size = 2
num_workers = 0
dm = UCMercedDataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: UCMercedDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: UCMercedDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: UCMercedDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -18,13 +18,11 @@ import pytest
import torch
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torch.utils.data import TensorDataset
import torchgeo.datasets.utils
from torchgeo.datasets.utils import (
BoundingBox,
concat_samples,
dataset_split,
disambiguate_timestamp,
download_and_extract_archive,
download_radiant_mlhub_collection,
@ -563,24 +561,6 @@ def test_nonexisting_directory(tmp_path: Path) -> None:
assert subdir.cwd() == subdir
def test_dataset_split() -> None:
num_samples = 24
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
ds = TensorDataset(x, y)
# Test only train/val set split
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
assert len(train_ds) == num_samples // 2
assert len(val_ds) == num_samples // 2
# Test train/val/test set split
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
assert len(train_ds) == num_samples // 3
assert len(val_ds) == num_samples // 3
assert len(test_ds) == num_samples // 3
def test_percentile_normalization() -> None:
img = np.array([[1, 2], [98, 100]])

Просмотреть файл

@ -13,7 +13,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import Vaihingen2D, Vaihingen2DDataModule
from torchgeo.datasets import Vaihingen2D
class TestVaihingen2D:
@ -84,27 +84,3 @@ class TestVaihingen2D:
x["prediction"] = x["mask"].clone()
dataset.plot(x)
plt.close()
class TestVaihingen2DDataModule:
@pytest.fixture(scope="class", params=[0.0, 0.5])
def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule:
root = os.path.join("tests", "data", "vaihingen")
batch_size = 1
num_workers = 0
val_split_size = request.param
dm = Vaihingen2DDataModule(
root, batch_size, num_workers, val_split_pct=val_split_size
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -13,7 +13,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import XView2, XView2DataModule
from torchgeo.datasets import XView2
class TestXView2:
@ -95,27 +95,3 @@ class TestXView2:
x["prediction"] = x["mask"][0].clone()
dataset.plot(x)
plt.close()
class TestXView2DataModule:
@pytest.fixture(scope="class", params=[0.0, 0.5])
def datamodule(self, request: SubRequest) -> XView2DataModule:
root = os.path.join("tests", "data", "xview2")
batch_size = 1
num_workers = 0
val_split_size = request.param
dm = XView2DataModule(
root, batch_size, num_workers, val_split_pct=val_split_size
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -12,7 +12,7 @@ from omegaconf import OmegaConf
from pytorch_lightning.core.lightning import LightningModule
from torchvision.models import resnet18
from torchgeo.datasets import ChesapeakeCVPRDataModule
from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation

Просмотреть файл

@ -9,7 +9,7 @@ from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from torchgeo.datasets import ChesapeakeCVPRDataModule
from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask
from .test_utils import FakeTrainer, mocked_log

Просмотреть файл

@ -8,7 +8,7 @@ import pytest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from torchgeo.datasets import LandCoverAIDataModule
from torchgeo.datamodules import LandCoverAIDataModule
from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask
from .test_utils import FakeTrainer, mocked_log

Просмотреть файл

@ -8,7 +8,7 @@ import pytest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from torchgeo.datasets import NAIPChesapeakeDataModule
from torchgeo.datamodules import NAIPChesapeakeDataModule
from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask
from .test_utils import FakeTrainer, mocked_log

Просмотреть файл

@ -8,7 +8,7 @@ import pytest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from torchgeo.datasets import CycloneDataModule
from torchgeo.datamodules import CycloneDataModule
from torchgeo.trainers import RegressionTask
from .test_utils import mocked_log

Просмотреть файл

@ -7,7 +7,7 @@ from typing import Any, Dict, Generator
import pytest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import RESISC45DataModule
from torchgeo.datamodules import RESISC45DataModule
from torchgeo.trainers.resisc45 import RESISC45ClassificationTask
from .test_utils import FakeTrainer, mocked_log

Просмотреть файл

@ -9,7 +9,7 @@ from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from torchgeo.datasets import ChesapeakeCVPRDataModule
from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.trainers import SemanticSegmentationTask
from .test_utils import FakeTrainer, mocked_log

Просмотреть файл

@ -13,7 +13,7 @@ from typing import Dict, Tuple, Type
import pytorch_lightning as pl
from .datasets import (
from .datamodules import (
BigEarthNetDataModule,
ChesapeakeCVPRDataModule,
COWCCountingDataModule,

Просмотреть файл

@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""TorchGeo datamodules."""
from .bigearthnet import BigEarthNetDataModule
from .chesapeake import ChesapeakeCVPRDataModule
from .cowc import COWCCountingDataModule
from .cyclone import CycloneDataModule
from .etci2021 import ETCI2021DataModule
from .eurosat import EuroSATDataModule
from .fair1m import FAIR1MDataModule
from .landcoverai import LandCoverAIDataModule
from .loveda import LoveDADataModule
from .naip import NAIPChesapeakeDataModule
from .nasa_marine_debris import NASAMarineDebrisDataModule
from .oscd import OSCDDataModule
from .potsdam import Potsdam2DDataModule
from .resisc45 import RESISC45DataModule
from .sen12ms import SEN12MSDataModule
from .so2sat import So2SatDataModule
from .ucmerced import UCMercedDataModule
from .vaihingen import Vaihingen2DDataModule
from .xview import XView2DataModule
__all__ = (
# GeoDataset
"ChesapeakeCVPRDataModule",
"NAIPChesapeakeDataModule",
# VisionDataset
"BigEarthNetDataModule",
"COWCCountingDataModule",
"ETCI2021DataModule",
"EuroSATDataModule",
"FAIR1MDataModule",
"LandCoverAIDataModule",
"LoveDADataModule",
"NASAMarineDebrisDataModule",
"OSCDDataModule",
"Potsdam2DDataModule",
"RESISC45DataModule",
"SEN12MSDataModule",
"So2SatDataModule",
"CycloneDataModule",
"UCMercedDataModule",
"Vaihingen2DDataModule",
"XView2DataModule",
)
# https://stackoverflow.com/questions/40018681
for module in __all__:
globals()[module].__module__ = "torchgeo.datamodules"

Просмотреть файл

@ -0,0 +1,178 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""BigEarthNet datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets import BigEarthNet
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class BigEarthNetDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the BigEarthNet dataset.
Uses the train/val/test splits from the dataset.
"""
# (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
# min/max band statistics computed on 100k random samples
band_mins_raw = torch.tensor( # type: ignore[attr-defined]
[-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
)
band_maxs_raw = torch.tensor( # type: ignore[attr-defined]
[
31.0,
35.0,
18556.0,
20528.0,
18976.0,
17874.0,
16611.0,
16512.0,
16394.0,
16672.0,
16141.0,
16097.0,
15336.0,
15203.0,
]
)
# min/max band statistics computed by percentile clipping the
# above to samples to [2, 98]
band_mins = torch.tensor( # type: ignore[attr-defined]
[-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
)
band_maxs = torch.tensor( # type: ignore[attr-defined]
[
6.0,
16.0,
9859.0,
12872.0,
13163.0,
14445.0,
12477.0,
12563.0,
12289.0,
15596.0,
12183.0,
9458.0,
5897.0,
5544.0,
]
)
def __init__(
self,
root_dir: str,
bands: str = "all",
num_classes: int = 19,
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes
bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all}
num_classes: number of classes to load in target. one of {19, 43}
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.bands = bands
self.num_classes = num_classes
self.batch_size = batch_size
self.num_workers = num_workers
if bands == "all":
self.mins = self.band_mins[:, None, None]
self.maxs = self.band_maxs[:, None, None]
elif bands == "s1":
self.mins = self.band_mins[:2, None, None]
self.maxs = self.band_maxs[:2, None, None]
else:
self.mins = self.band_mins[2:, None, None]
self.maxs = self.band_maxs[2:, None, None]
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins)
sample["image"] = torch.clip( # type: ignore[attr-defined]
sample["image"], min=0.0, max=1.0
)
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
BigEarthNet(self.root_dir, split="train", bands=self.bands, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
"""
transforms = Compose([self.preprocess])
self.train_dataset = BigEarthNet(
self.root_dir,
split="train",
bands=self.bands,
num_classes=self.num_classes,
transforms=transforms,
)
self.val_dataset = BigEarthNet(
self.root_dir,
split="val",
bands=self.bands,
num_classes=self.num_classes,
transforms=transforms,
)
self.test_dataset = BigEarthNet(
self.root_dir,
split="test",
bands=self.bands,
num_classes=self.num_classes,
transforms=transforms,
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,312 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Chesapeake Bay High-Resolution Land Cover Project datamodule."""
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.nn.functional as F
from pytorch_lightning.core.datamodule import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets import ChesapeakeCVPR, stack_samples
from ..samplers.batch import RandomBatchGeoSampler
from ..samplers.single import GridGeoSampler
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class ChesapeakeCVPRDataModule(LightningDataModule):
"""LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.
Uses the random splits defined per state to partition tiles into train, val,
and test sets.
"""
def __init__(
self,
root_dir: str,
train_splits: List[str],
val_splits: List[str],
test_splits: List[str],
patches_per_tile: int = 200,
patch_size: int = 256,
batch_size: int = 64,
num_workers: int = 0,
class_set: int = 7,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the ChesapeakeCVPR Dataset
classes
train_splits: The splits used to train the model, e.g. ["ny-train"]
val_splits: The splits used to validate the model, e.g. ["ny-val"]
test_splits: The splits used to test the model, e.g. ["ny-test"]
patches_per_tile: The number of patches per tile to sample
patch_size: The size of each patch in pixels (test patches will be 1.5 times
this size)
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
class_set: The high-resolution land cover class set to use - 5 or 7
"""
super().__init__() # type: ignore[no-untyped-call]
for state in train_splits + val_splits + test_splits:
assert state in ChesapeakeCVPR.splits
assert class_set in [5, 7]
self.root_dir = root_dir
self.train_splits = train_splits
self.val_splits = val_splits
self.test_splits = test_splits
self.layers = ["naip-new", "lc"]
self.patches_per_tile = patches_per_tile
self.patch_size = patch_size
# This is a rough estimate of how large of a patch we will need to sample in
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
self.original_patch_size = int(patch_size * 2.0)
self.batch_size = batch_size
self.num_workers = num_workers
self.class_set = class_set
def pad_to(
self, size: int = 512, image_value: int = 0, mask_value: int = 0
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
"""Returns a function to perform a padding transform on a single sample.
Args:
size: output image size
image_value: value to pad image with
mask_value: value to pad mask with
Returns:
function to perform padding
"""
def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
_, height, width = sample["image"].shape
assert height <= size and width <= size
height_pad = size - height
width_pad = size - width
# See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
# for a description of the format of the padding tuple
sample["image"] = F.pad(
sample["image"],
(0, width_pad, 0, height_pad),
mode="constant",
value=image_value,
)
sample["mask"] = F.pad(
sample["mask"],
(0, width_pad, 0, height_pad),
mode="constant",
value=mask_value,
)
return sample
return pad_inner
def center_crop(
self, size: int = 512
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
"""Returns a function to perform a center crop transform on a single sample.
Args:
size: output image size
Returns:
function to perform center crop
"""
def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
_, height, width = sample["image"].shape
y1 = (height - size) // 2
x1 = (width - size) // 2
sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size]
sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size]
return sample
return center_crop_inner
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Preprocesses a single sample.
Args:
sample: sample dictionary containing image and mask
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
sample["mask"] = sample["mask"]
sample["mask"] = sample["mask"].squeeze()
if self.class_set == 5:
sample["mask"][sample["mask"] == 5] = 4
sample["mask"][sample["mask"] == 6] = 4
sample["image"] = sample["image"].float()
sample["mask"] = sample["mask"].long()
return sample
def nodata_check(
self, size: int = 512
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
"""Returns a function to check for nodata or mis-sized input.
Args:
size: output image size
Returns:
function to check for nodata values
"""
def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
num_channels, height, width = sample["image"].shape
if height < size or width < size:
sample["image"] = torch.zeros( # type: ignore[attr-defined]
(num_channels, size, size)
)
sample["mask"] = torch.zeros((size, size)) # type: ignore[attr-defined]
return sample
return nodata_check_inner
def prepare_data(self) -> None:
"""Confirms that the dataset is downloaded on the local node.
This method is called once per node, while :func:`setup` is called once per GPU.
"""
ChesapeakeCVPR(
self.root_dir,
splits=self.train_splits,
layers=self.layers,
transforms=None,
download=False,
checksum=False,
)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
Args:
stage: stage to set up
"""
train_transforms = Compose(
[
self.center_crop(self.patch_size),
self.nodata_check(self.patch_size),
self.preprocess,
]
)
val_transforms = Compose(
[
self.center_crop(self.patch_size),
self.nodata_check(self.patch_size),
self.preprocess,
]
)
test_transforms = Compose(
[
self.pad_to(self.original_patch_size, image_value=0, mask_value=0),
self.preprocess,
]
)
self.train_dataset = ChesapeakeCVPR(
self.root_dir,
splits=self.train_splits,
layers=self.layers,
transforms=train_transforms,
download=False,
checksum=False,
)
self.val_dataset = ChesapeakeCVPR(
self.root_dir,
splits=self.val_splits,
layers=self.layers,
transforms=val_transforms,
download=False,
checksum=False,
)
self.test_dataset = ChesapeakeCVPR(
self.root_dir,
splits=self.test_splits,
layers=self.layers,
transforms=test_transforms,
download=False,
checksum=False,
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
sampler = RandomBatchGeoSampler(
self.train_dataset,
size=self.original_patch_size,
batch_size=self.batch_size,
length=self.patches_per_tile * len(self.train_dataset),
)
return DataLoader(
self.train_dataset,
batch_sampler=sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
sampler = GridGeoSampler(
self.val_dataset,
size=self.original_patch_size,
stride=self.original_patch_size,
)
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
sampler = GridGeoSampler(
self.test_dataset,
size=self.original_patch_size,
stride=self.original_patch_size,
)
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)

Просмотреть файл

@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""COWC datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
from torch import Generator # type: ignore[attr-defined]
from torch.utils.data import DataLoader, random_split
from ..datasets import COWCCounting
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class COWCCountingDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the COWC Counting dataset."""
def __init__(
self,
root_dir: str,
seed: int,
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for COWC Counting based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the COWCCounting Dataset class
seed: The seed value to use when doing the dataset random_split
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and target
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0 # scale to [0, 1]
sample["label"] = sample["label"].float()
return sample
def prepare_data(self) -> None:
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
COWCCounting(self.root_dir, download=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
Args:
stage: stage to set up
"""
train_val_dataset = COWCCounting(
self.root_dir, split="train", transforms=self.custom_transform
)
self.test_dataset = COWCCounting(
self.root_dir, split="test", transforms=self.custom_transform
)
self.train_dataset, self.val_dataset = random_split(
train_val_dataset,
[len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)],
generator=Generator().manual_seed(self.seed),
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Tropical Cyclone Wind Estimation Competition datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import DataLoader, Subset
from ..datasets import TropicalCycloneWindEstimation
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class CycloneDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NASA Cyclone dataset.
Implements 80/20 train/val splits based on hurricane storm ids.
See :func:`setup` for more details.
"""
def __init__(
self,
root_dir: str,
seed: int,
batch_size: int = 64,
num_workers: int = 0,
api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the
TropicalCycloneWindEstimation Datasets classes
seed: The seed value to use when doing the sklearn based GroupShuffleSplit
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
api_key: The RadiantEarth MLHub API key to use if the dataset needs to be
downloaded
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
self.api_key = api_key
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and target
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0 # scale to [0,1]
sample["image"] = (
sample["image"].unsqueeze(0).repeat(3, 1, 1)
) # convert to 3 channel
sample["label"] = torch.as_tensor( # type: ignore[attr-defined]
sample["label"]
).float()
return sample
def prepare_data(self) -> None:
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
TropicalCycloneWindEstimation(
self.root_dir,
split="train",
transforms=self.custom_transform,
download=self.api_key is not None,
api_key=self.api_key,
)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
We split samples between train/val by the ``storm_id`` property. I.e. all
samples with the same ``storm_id`` value will be either in the train or the val
split. This is important to test one type of generalizability -- given a new
storm, can we predict its windspeed. The test set, however, contains *some*
storms from the training set (specifically, the latter parts of the storms) as
well as some novel storms.
Args:
stage: stage to set up
"""
self.all_train_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="train",
transforms=self.custom_transform,
download=False,
)
self.all_test_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="test",
transforms=self.custom_transform,
download=False,
)
storm_ids = []
for item in self.all_train_dataset.collection:
storm_id = item["href"].split("/")[0].split("_")[-2]
storm_ids.append(storm_id)
train_indices, val_indices = next(
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
storm_ids, groups=storm_ids
)
)
self.train_dataset = Subset(self.all_train_dataset, train_indices)
self.val_dataset = Subset(self.all_train_dataset, val_indices)
self.test_dataset = Subset(
self.all_test_dataset, range(len(self.all_test_dataset))
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,151 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""ETCI 2021 datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
from torch import Generator # type: ignore[attr-defined]
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Normalize
from ..datasets import ETCI2021
class ETCI2021DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the ETCI2021 dataset.
Splits the existing train split from the dataset into train/val with 80/20
proportions, then uses the existing val dataset as the test data.
.. versionadded:: 0.2
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1]
)
def __init__(
self,
root_dir: str,
seed: int = 0,
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for ETCI2021 based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the ETCI2021 Dataset classes
seed: The seed value to use when doing the dataset random_split
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
self.norm = Normalize(self.band_means, self.band_stds)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Notably, moves the given water mask to act as an input layer.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
image = sample["image"]
water_mask = sample["mask"][0].unsqueeze(0)
flood_mask = sample["mask"][1]
flood_mask = (flood_mask > 0).long()
sample["image"] = torch.cat( # type: ignore[attr-defined]
[image, water_mask], dim=0
).float()
sample["image"] /= 255.0
sample["image"] = self.norm(sample["image"])
sample["mask"] = flood_mask
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
ETCI2021(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
train_val_dataset = ETCI2021(
self.root_dir, split="train", transforms=self.preprocess
)
self.test_dataset = ETCI2021(
self.root_dir, split="val", transforms=self.preprocess
)
size_train_val = len(train_val_dataset)
size_train = int(0.8 * size_train_val)
size_val = size_train_val - size_train
self.train_dataset, self.val_dataset = random_split(
train_val_dataset,
[size_train, size_val],
generator=Generator().manual_seed(self.seed),
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,148 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""EuroSAT datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize
from ..datasets import EuroSAT
class EuroSATDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the EuroSAT dataset.
Uses the train/val/test splits from the dataset.
.. versionadded:: 0.2
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798,
]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042,
]
)
def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for EuroSAT based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.norm = Normalize(self.band_means, self.band_stds)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] = self.norm(sample["image"])
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
EuroSAT(self.root_dir)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms)
self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms)
self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,132 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""FAIR1M datamodule."""
from typing import Any, Dict, List, Optional
import pytorch_lightning as pl
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets import FAIR1M
from .utils import dataset_split
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
"""Custom object detection collate fn to handle variable number of boxes.
Args:
batch: list of sample dicts return by dataset
Returns:
batch dict output
"""
output: Dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output["boxes"] = [sample["boxes"] for sample in batch]
return output
class FAIR1MDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the FAIR1M dataset."""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for FAIR1M based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the FAIR1M Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
test_split_pct: What percentage of the dataset to use as a test set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = FAIR1M(self.root_dir, transforms=transforms)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
collate_fn=collate_fn,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)

Просмотреть файл

@ -0,0 +1,122 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""LandCover.ai datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from ..datasets import LandCoverAI
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class LandCoverAIDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the LandCover.ai dataset.
Uses the train/val/test splits from the dataset.
"""
def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for LandCover.ai based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and mask
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
sample["image"] = sample["image"].float()
sample["mask"] = sample["mask"].float().unsqueeze(0) + 1
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
_ = LandCoverAI(self.root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
train_transforms = self.preprocess
val_test_transforms = self.preprocess
self.train_dataset = LandCoverAI(
self.root_dir, split="train", transforms=train_transforms
)
self.val_dataset = LandCoverAI(
self.root_dir, split="val", transforms=val_test_transforms
)
self.test_dataset = LandCoverAI(
self.root_dir, split="test", transforms=val_test_transforms
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,129 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""LoveDA datamodule."""
from typing import Any, Dict, List, Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from ..datasets import LoveDA
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class LoveDADataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the LoveDA dataset.
Uses the train/val/test splits from the dataset.
"""
def __init__(
self,
root_dir: str,
scene: List[str],
batch_size: int = 32,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for LoveDA based DataLoaders.
Args:
root_dir: The ``root`` argument to pass to LoveDA Dataset classes
scene: specify whether to load only 'urban', only 'rural' or both
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.scene = scene
self.batch_size = batch_size
self.num_workers = num_workers
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and mask
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
_ = LoveDA(self.root_dir, scene=self.scene, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
train_transforms = self.preprocess
val_test_transforms = self.preprocess
self.train_dataset = LoveDA(
self.root_dir, split="train", scene=self.scene, transforms=train_transforms
)
self.val_dataset = LoveDA(
self.root_dir, split="val", scene=self.scene, transforms=val_test_transforms
)
self.test_dataset = LoveDA(
self.root_dir,
split="test",
scene=self.scene,
transforms=val_test_transforms,
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,161 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""National Agriculture Imagery Program (NAIP) datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples
from ..samplers.batch import RandomBatchGeoSampler
from ..samplers.single import GridGeoSampler
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class NAIPChesapeakeDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NAIP and Chesapeake datasets.
Uses the train/val/test splits from the dataset.
"""
# TODO: tune these hyperparams
length = 1000
stride = 128
def __init__(
self,
naip_root_dir: str,
chesapeake_root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
patch_size: int = 256,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders.
Args:
naip_root_dir: directory containing NAIP data
chesapeake_root_dir: directory containing Chesapeake data
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
patch_size: size of patches to sample
"""
super().__init__() # type: ignore[no-untyped-call]
self.naip_root_dir = naip_root_dir
self.chesapeake_root_dir = chesapeake_root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.patch_size = patch_size
def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the NAIP Dataset.
Args:
sample: NAIP image dictionary
Returns:
preprocessed NAIP data
"""
sample["image"] = sample["image"] / 255.0
sample["image"] = sample["image"].float()
return sample
def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Chesapeake Dataset.
Args:
sample: Chesapeake mask dictionary
Returns:
preprocessed Chesapeake data
"""
sample["mask"] = sample["mask"].long()[0]
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: state to set up
"""
# TODO: these transforms will be applied independently, this won't work if we
# add things like random horizontal flip
chesapeake = Chesapeake13(
self.chesapeake_root_dir, transforms=self.chesapeake_transform
)
naip = NAIP(
self.naip_root_dir,
chesapeake.crs,
chesapeake.res,
transforms=self.naip_transform,
)
self.dataset = chesapeake & naip
# TODO: figure out better train/val/test split
roi = self.dataset.bounds
midx = roi.minx + (roi.maxx - roi.minx) / 2
midy = roi.miny + (roi.maxy - roi.miny) / 2
train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt)
val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt)
test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt)
self.train_sampler = RandomBatchGeoSampler(
naip, self.patch_size, self.batch_size, self.length, train_roi
)
self.val_sampler = GridGeoSampler(naip, self.patch_size, self.stride, val_roi)
self.test_sampler = GridGeoSampler(naip, self.patch_size, self.stride, test_roi)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.dataset,
batch_sampler=self.train_sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.dataset,
batch_size=self.batch_size,
sampler=self.val_sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.dataset,
batch_size=self.batch_size,
sampler=self.test_sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)

Просмотреть файл

@ -0,0 +1,140 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""NASA Marine Debris datamodule."""
from typing import Any, Dict, List, Optional
import pytorch_lightning as pl
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets import NASAMarineDebris
from .utils import dataset_split
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
"""Custom object detection collate fn to handle variable boxes.
Args:
batch: list of sample dicts return by dataset
Returns:
batch dict output
"""
output: Dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output["boxes"] = [sample["boxes"] for sample in batch]
return output
class NASAMarineDebrisDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NASA Marine Debris dataset."""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NASA Marine Debris based DataLoaders.
Args:
root_dir: The ``root`` argument to pass to the Dataset class
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
test_split_pct: What percentage of the dataset to use as a test set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
NASAMarineDebris(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = NASAMarineDebris(self.root_dir, transforms=transforms)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
collate_fn=collate_fn,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)

Просмотреть файл

@ -0,0 +1,214 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""OSCD datamodule."""
from typing import Any, Dict, List, Optional, Tuple
import kornia.augmentation as K
import pytorch_lightning as pl
import torch
from einops import repeat
from torch.utils.data import DataLoader, Dataset
from torch.utils.data._utils.collate import default_collate
from torchvision.transforms import Compose, Normalize
from ..datasets import OSCD
from .utils import dataset_split
class OSCDDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the OSCD dataset.
Uses the train/test splits from the dataset and further splits
the train split into train/val splits.
.. versionadded: 0.2
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[
1583.0741,
1374.3202,
1294.1616,
1325.6158,
1478.7408,
1933.0822,
2166.0608,
2076.4868,
2306.0652,
690.9814,
16.2360,
2080.3347,
1524.6930,
]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[
52.1937,
83.4168,
105.6966,
151.1401,
147.4615,
115.9289,
123.1974,
114.6483,
141.4530,
73.2758,
4.8368,
213.4821,
179.4793,
]
)
def __init__(
self,
root_dir: str,
bands: str = "all",
train_batch_size: int = 32,
num_workers: int = 0,
val_split_pct: float = 0.2,
patch_size: Tuple[int, int] = (64, 64),
num_patches_per_tile: int = 32,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for OSCD based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the OSCD Dataset classes
bands: "rgb" or "all"
train_batch_size: The batch size used in the train DataLoader
(val_batch_size == test_batch_size == 1)
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
patch_size: Size of random patch from image and mask (height, width)
num_patches_per_tile: number of random patches per sample
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.bands = bands
self.train_batch_size = train_batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.patch_size = patch_size
self.num_patches_per_tile = num_patches_per_tile
if bands == "rgb":
self.band_means = self.band_means[[3, 2, 1], None, None]
self.band_stds = self.band_stds[[3, 2, 1], None, None]
else:
self.band_means = self.band_means[:, None, None]
self.band_stds = self.band_stds[:, None, None]
self.norm = Normalize(self.band_means, self.band_stds)
self.rcrop = K.AugmentationSequential(
K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True
)
self.padto = K.PadTo((1280, 1280))
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
sample["mask"] = sample["mask"]
sample["image"] = self.norm(sample["image"])
sample["image"] = torch.flatten( # type: ignore[attr-defined]
sample["image"], 0, 1
)
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
OSCD(self.root_dir, split="train", bands=self.bands, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
"""
def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]:
images, masks = [], []
for i in range(self.num_patches_per_tile):
mask = repeat(sample["mask"], "h w -> t h w", t=2).float()
image, mask = self.rcrop(sample["image"], mask)
mask = mask.squeeze()[0]
images.append(image.squeeze())
masks.append(mask.long())
sample["image"] = torch.stack(images)
sample["mask"] = torch.stack(masks)
return sample
def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["image"] = self.padto(sample["image"])[0]
sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0]
return sample
train_transforms = Compose([self.preprocess, n_random_crop])
# for testing and validation we pad all inputs to a fixed size to avoid issues
# with the upsampling paths in encoder-decoder architectures
test_transforms = Compose([self.preprocess, pad_to])
train_dataset = OSCD(
self.root_dir, split="train", bands=self.bands, transforms=train_transforms
)
self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]
if self.val_split_pct > 0.0:
val_dataset = OSCD(
self.root_dir,
split="train",
bands=self.bands,
transforms=test_transforms,
)
self.train_dataset, self.val_dataset, _ = dataset_split(
train_dataset, val_pct=self.val_split_pct, test_pct=0.0
)
self.val_dataset.dataset = val_dataset
else:
self.train_dataset = train_dataset
self.val_dataset = train_dataset
self.test_dataset = OSCD(
self.root_dir, split="test", bands=self.bands, transforms=test_transforms
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call]
batch
)
r_batch["image"] = torch.flatten( # type: ignore[attr-defined]
r_batch["image"], 0, 1
)
r_batch["mask"] = torch.flatten( # type: ignore[attr-defined]
r_batch["mask"], 0, 1
)
return r_batch
return DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
collate_fn=collate_wrapper,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
return DataLoader(
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
return DataLoader(
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

Просмотреть файл

@ -0,0 +1,121 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Potsdam datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from ..datasets import Potsdam2D
from .utils import dataset_split
class Potsdam2DDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the Potsdam2D dataset.
Uses the train/test splits from the dataset.
.. versionadded: 0.2
"""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Potsdam2D based DataLoaders.
Args:
root_dir: The ``root`` argument to pass to the Potsdam2D Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = Potsdam2D(self.root_dir, "train", transforms=transforms)
self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]
if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset
self.val_dataset = dataset
self.test_dataset = Potsdam2D(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""RESISC45 datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize
from ..datasets import RESISC45
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class RESISC45DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the RESISC45 dataset.
Uses the train/val/test splits from the dataset.
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[0.36801773, 0.38097873, 0.343583]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[0.14540215, 0.13558227, 0.13203649]
)
def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.norm = Normalize(self.band_means, self.band_stds)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = self.norm(sample["image"])
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
RESISC45(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
self.train_dataset = RESISC45(self.root_dir, "train", transforms=transforms)
self.val_dataset = RESISC45(self.root_dir, "val", transforms=transforms)
self.test_dataset = RESISC45(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,202 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""SEN12MS datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import DataLoader, Subset
from ..datasets import SEN12MS
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class SEN12MSDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the SEN12MS dataset.
Implements 80/20 geographic train/val splits and uses the test split from the
classification dataset definitions. See :func:`setup` for more details.
Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See
https://arxiv.org/abs/2002.08254.
"""
#: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader
#: here https://github.com/lukasliebel/dfc2020_baseline.
DFC2020_CLASS_MAPPING = torch.tensor( # type: ignore[attr-defined]
[
0, # maps 0s to 0
1, # maps 1s to 1
1, # maps 2s to 1
1, # ...
1,
1,
2,
2,
3,
3,
4,
5,
6,
7,
6,
8,
9,
10,
]
)
def __init__(
self,
root_dir: str,
seed: int,
band_set: str = "all",
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes
seed: The seed value to use when doing the sklearn based ShuffleSplit
band_set: The subset of S1/S2 bands to use. Options are: "all",
"s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes:
B2, B3, B4, B8, B11, and B12.
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
assert band_set in SEN12MS.BAND_SETS.keys()
self.root_dir = root_dir
self.seed = seed
self.band_set = band_set
self.band_indices = SEN12MS.BAND_SETS[band_set]
self.batch_size = batch_size
self.num_workers = num_workers
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and mask
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
if self.band_set == "all":
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000
elif self.band_set == "s1":
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
else:
sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000
sample["mask"] = sample["mask"][0, :, :].long()
sample["mask"] = torch.take( # type: ignore[attr-defined]
self.DFC2020_CLASS_MAPPING, sample["mask"]
)
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
We split samples between train and val geographically with proportions of 80/20.
This mimics the geographic test set split.
Args:
stage: stage to set up
"""
season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000}
self.all_train_dataset = SEN12MS(
self.root_dir,
split="train",
bands=self.band_indices,
transforms=self.custom_transform,
checksum=False,
)
self.all_test_dataset = SEN12MS(
self.root_dir,
split="test",
bands=self.band_indices,
transforms=self.custom_transform,
checksum=False,
)
# A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif"
# This patch will belong to the scene that is uniquelly identified by its
# (season, scene_id) tuple. Because the largest scene_id is 149, we can simply
# give each season a large number and representing a `unique_scene_id` as
# `season_id + scene_id`.
scenes = []
for scene_fn in self.all_train_dataset.ids:
parts = scene_fn.split("_")
season_id = season_to_int[parts[1]]
scene_id = int(parts[3])
scenes.append(season_id + scene_id)
train_indices, val_indices = next(
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
scenes, groups=scenes
)
)
self.train_dataset = Subset(self.all_train_dataset, train_indices)
self.val_dataset = Subset(self.all_train_dataset, val_indices)
self.test_dataset = Subset(
self.all_test_dataset, range(len(self.all_test_dataset))
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,225 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""So2Sat datamodule."""
from typing import Any, Dict, Optional, cast
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets import So2Sat
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class So2SatDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the So2Sat dataset.
Uses the train/val/test splits from the dataset.
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[
-3.591224256609313e-05,
-7.658561276843396e-06,
5.9373857475971184e-05,
2.5166231537121083e-05,
0.04420110659759328,
0.25761027084996196,
0.0007556743372573258,
0.0013503466830024448,
0.12375696117681859,
0.1092774636368323,
0.1010855203267882,
0.1142398616114001,
0.1592656692023089,
0.18147236008771792,
0.1745740312291377,
0.19501607349635292,
0.15428468872076637,
0.10905050699570007,
]
).reshape(18, 1, 1)
band_stds = torch.tensor( # type: ignore[attr-defined]
[
0.17555201137417686,
0.17556463274968204,
0.45998793417834255,
0.455988755730148,
2.8559909213125763,
8.324800606439833,
2.4498757382563103,
1.4647352984509094,
0.03958795985905458,
0.047778262752410296,
0.06636616706371974,
0.06358874912497474,
0.07744387147984592,
0.09101635085921553,
0.09218466562387101,
0.10164581233948201,
0.09991773043519253,
0.08780632509122865,
]
).reshape(18, 1, 1)
# this reorders the bands to put S2 RGB first, then remainder of S2, then S1
reindex_to_rgb_first = [
10,
9,
8,
11,
12,
13,
14,
15,
16,
17,
# 0,
# 1,
# 2,
# 3,
# 4,
# 5,
# 6,
# 7,
]
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
bands: str = "rgb",
unsupervised_mode: bool = False,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for So2Sat based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
bands: Either "rgb" or "s2"
unsupervised_mode: Makes the train dataloader return imagery from the train,
val, and test sets
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.bands = bands
self.unsupervised_mode = unsupervised_mode
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image
Returns:
preprocessed sample
"""
# sample["image"] = (sample["image"] - self.band_means) / self.band_stds
sample["image"] = sample["image"].float()
sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :]
if self.bands == "rgb":
sample["image"] = sample["image"][:3, :, :]
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
So2Sat(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
train_transforms = Compose([self.preprocess])
val_test_transforms = self.preprocess
if not self.unsupervised_mode:
self.train_dataset = So2Sat(
self.root_dir, split="train", transforms=train_transforms
)
self.val_dataset = So2Sat(
self.root_dir, split="validation", transforms=val_test_transforms
)
self.test_dataset = So2Sat(
self.root_dir, split="test", transforms=val_test_transforms
)
else:
temp_train = So2Sat(
self.root_dir, split="train", transforms=train_transforms
)
self.val_dataset = So2Sat(
self.root_dir, split="validation", transforms=train_transforms
)
self.test_dataset = So2Sat(
self.root_dir, split="test", transforms=train_transforms
)
self.train_dataset = cast(
So2Sat, temp_train + self.val_dataset + self.test_dataset
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,125 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""UC Merced datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize
from ..datasets import UCMerced
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class UCMercedDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the UC Merced dataset.
Uses random train/val/test splits.
"""
band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined]
band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined]
def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for UCMerced based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.norm = Normalize(self.band_means, self.band_stds)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
c, h, w = sample["image"].shape
if h != 256 or w != 256:
sample["image"] = torchvision.transforms.functional.resize(
sample["image"], size=(256, 256)
)
sample["image"] = self.norm(sample["image"])
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
UCMerced(self.root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms)
self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms)
self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Common datamodule utilities."""
from typing import Any, List, Optional
from torch.utils.data import Dataset, Subset, random_split
def dataset_split(
dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None
) -> List[Subset[Any]]:
"""Split a torch Dataset into train/val/test sets.
If ``test_pct`` is not set then only train and validation splits are returned.
Args:
dataset: dataset to be split into train/val or train/val/test subsets
val_pct: percentage of samples to be in validation set
test_pct: (Optional) percentage of samples to be in test set
Returns:
a list of the subset datasets. Either [train, val] or [train, val, test]
"""
if test_pct is None:
val_length = int(len(dataset) * val_pct)
train_length = len(dataset) - val_length
return random_split(dataset, [train_length, val_length])
else:
val_length = int(len(dataset) * val_pct)
test_length = int(len(dataset) * test_pct)
train_length = len(dataset) - (val_length + test_length)
return random_split(dataset, [train_length, val_length, test_length])

Просмотреть файл

@ -0,0 +1,121 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Vaihingen datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from ..datasets import Vaihingen2D
from .utils import dataset_split
class Vaihingen2DDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the Vaihingen2D dataset.
Uses the train/test splits from the dataset.
.. versionadded: 0.2
"""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Vaihingen2D based DataLoaders.
Args:
root_dir: The ``root`` argument to pass to the Vaihingen Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = Vaihingen2D(self.root_dir, "train", transforms=transforms)
self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]
if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset
self.val_dataset = dataset
self.test_dataset = Vaihingen2D(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -0,0 +1,121 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""xView2 datamodule."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from ..datasets import XView2
from .utils import dataset_split
class XView2DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the xView2 dataset.
Uses the train/val/test splits from the dataset.
.. versionadded: 0.2
"""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for xView2 based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the xView2 Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = XView2(self.root_dir, "train", transforms=transforms)
self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]
if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset
self.val_dataset = dataset
self.test_dataset = XView2(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -5,7 +5,7 @@
from .advance import ADVANCE
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet, BigEarthNetDataModule
from .bigearthnet import BigEarthNet
from .cbf import CanadianBuildingFootprints
from .cdl import CDL
from .chesapeake import (
@ -13,7 +13,6 @@ from .chesapeake import (
Chesapeake7,
Chesapeake13,
ChesapeakeCVPR,
ChesapeakeCVPRDataModule,
ChesapeakeDC,
ChesapeakeDE,
ChesapeakeMD,
@ -22,12 +21,12 @@ from .chesapeake import (
ChesapeakeVA,
ChesapeakeWV,
)
from .cowc import COWC, COWCCounting, COWCCountingDataModule, COWCDetection
from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import CycloneDataModule, TropicalCycloneWindEstimation
from .etci2021 import ETCI2021, ETCI2021DataModule
from .eurosat import EuroSAT, EuroSATDataModule
from .fair1m import FAIR1M, FAIR1MDataModule
from .cyclone import TropicalCycloneWindEstimation
from .etci2021 import ETCI2021
from .eurosat import EuroSAT
from .fair1m import FAIR1M
from .geo import (
GeoDataset,
IntersectionDataset,
@ -39,7 +38,7 @@ from .geo import (
)
from .gid15 import GID15
from .idtrees import IDTReeS
from .landcoverai import LandCoverAI, LandCoverAIDataModule
from .landcoverai import LandCoverAI
from .landsat import (
Landsat,
Landsat1,
@ -54,23 +53,23 @@ from .landsat import (
Landsat9,
)
from .levircd import LEVIRCDPlus
from .loveda import LoveDA, LoveDADataModule
from .naip import NAIP, NAIPChesapeakeDataModule
from .nasa_marine_debris import NASAMarineDebris, NASAMarineDebrisDataModule
from .loveda import LoveDA
from .naip import NAIP
from .nasa_marine_debris import NASAMarineDebris
from .nwpu import VHR10
from .oscd import OSCD, OSCDDataModule
from .oscd import OSCD
from .patternnet import PatternNet
from .potsdam import Potsdam2D, Potsdam2DDataModule
from .resisc45 import RESISC45, RESISC45DataModule
from .potsdam import Potsdam2D
from .resisc45 import RESISC45
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS, SEN12MSDataModule
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat, So2SatDataModule
from .so2sat import So2Sat
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
from .ucmerced import UCMerced, UCMercedDataModule
from .ucmerced import UCMerced
from .utils import BoundingBox, concat_samples, merge_samples, stack_samples
from .vaihingen import Vaihingen2D, Vaihingen2DDataModule
from .xview import XView2, XView2DataModule
from .vaihingen import Vaihingen2D
from .xview import XView2
from .zuericrop import ZueriCrop
__all__ = (
@ -88,7 +87,6 @@ __all__ = (
"ChesapeakeVA",
"ChesapeakeWV",
"ChesapeakeCVPR",
"ChesapeakeCVPRDataModule",
"Landsat",
"Landsat1",
"Landsat2",
@ -101,46 +99,32 @@ __all__ = (
"Landsat8",
"Landsat9",
"NAIP",
"NAIPChesapeakeDataModule",
"Sentinel",
"Sentinel2",
# VisionDataset
"ADVANCE",
"BeninSmallHolderCashews",
"BigEarthNet",
"BigEarthNetDataModule",
"COWC",
"COWCCounting",
"COWCDetection",
"COWCCountingDataModule",
"CV4AKenyaCropType",
"ETCI2021",
"ETCI2021DataModule",
"EuroSAT",
"EuroSATDataModule",
"FAIR1M",
"FAIR1MDataModule",
"GID15",
"IDTReeS",
"LandCoverAI",
"LandCoverAIDataModule",
"LEVIRCDPlus",
"LoveDA",
"LoveDADataModule",
"NASAMarineDebris",
"NASAMarineDebrisDataModule",
"OSCD",
"OSCDDataModule",
"PatternNet",
"Potsdam2D",
"Potsdam2DDataModule",
"RESISC45",
"RESISC45DataModule",
"SeasonalContrastS2",
"SEN12MS",
"SEN12MSDataModule",
"So2Sat",
"So2SatDataModule",
"SpaceNet",
"SpaceNet1",
"SpaceNet2",
@ -148,14 +132,10 @@ __all__ = (
"SpaceNet5",
"SpaceNet7",
"TropicalCycloneWindEstimation",
"CycloneDataModule",
"UCMerced",
"UCMercedDataModule",
"Vaihingen2D",
"Vaihingen2DDataModule",
"VHR10",
"XView2",
"XView2DataModule",
"ZueriCrop",
# Base classes
"GeoDataset",

Просмотреть файл

@ -6,24 +6,17 @@
import glob
import json
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional
import numpy as np
import pytorch_lightning as pl
import rasterio
import torch
from rasterio.enums import Resampling
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from .geo import VisionDataset
from .utils import download_url, extract_archive, sort_sentinel2_bands
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class BigEarthNet(VisionDataset):
"""BigEarthNet dataset.
@ -511,164 +504,3 @@ class BigEarthNet(VisionDataset):
"""
if not filepath.endswith(".csv"):
extract_archive(filepath)
class BigEarthNetDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the BigEarthNet dataset.
Uses the train/val/test splits from the dataset.
"""
# (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
# min/max band statistics computed on 100k random samples
band_mins_raw = torch.tensor( # type: ignore[attr-defined]
[-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
)
band_maxs_raw = torch.tensor( # type: ignore[attr-defined]
[
31.0,
35.0,
18556.0,
20528.0,
18976.0,
17874.0,
16611.0,
16512.0,
16394.0,
16672.0,
16141.0,
16097.0,
15336.0,
15203.0,
]
)
# min/max band statistics computed by percentile clipping the
# above to samples to [2, 98]
band_mins = torch.tensor( # type: ignore[attr-defined]
[-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
)
band_maxs = torch.tensor( # type: ignore[attr-defined]
[
6.0,
16.0,
9859.0,
12872.0,
13163.0,
14445.0,
12477.0,
12563.0,
12289.0,
15596.0,
12183.0,
9458.0,
5897.0,
5544.0,
]
)
def __init__(
self,
root_dir: str,
bands: str = "all",
num_classes: int = 19,
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes
bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all}
num_classes: number of classes to load in target. one of {19, 43}
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.bands = bands
self.num_classes = num_classes
self.batch_size = batch_size
self.num_workers = num_workers
if bands == "all":
self.mins = self.band_mins[:, None, None]
self.maxs = self.band_maxs[:, None, None]
elif bands == "s1":
self.mins = self.band_mins[:2, None, None]
self.maxs = self.band_maxs[:2, None, None]
else:
self.mins = self.band_mins[2:, None, None]
self.maxs = self.band_maxs[2:, None, None]
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins)
sample["image"] = torch.clip( # type: ignore[attr-defined]
sample["image"], min=0.0, max=1.0
)
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
BigEarthNet(self.root_dir, split="train", bands=self.bands, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
"""
transforms = Compose([self.preprocess])
self.train_dataset = BigEarthNet(
self.root_dir,
split="train",
bands=self.bands,
num_classes=self.num_classes,
transforms=transforms,
)
self.val_dataset = BigEarthNet(
self.root_dir,
split="val",
bands=self.bands,
num_classes=self.num_classes,
transforms=transforms,
)
self.test_dataset = BigEarthNet(
self.root_dir,
split="test",
bands=self.bands,
num_classes=self.num_classes,
transforms=transforms,
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -16,21 +16,10 @@ import rasterio.mask
import shapely.geometry
import shapely.ops
import torch
import torch.nn.functional as F
from pytorch_lightning.core.datamodule import LightningDataModule
from rasterio.crs import CRS
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..samplers.batch import RandomBatchGeoSampler
from ..samplers.single import GridGeoSampler
from .geo import GeoDataset, RasterDataset
from .utils import BoundingBox, download_url, extract_archive, stack_samples
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
from .utils import BoundingBox, download_url, extract_archive
class Chesapeake(RasterDataset, abc.ABC):
@ -537,294 +526,3 @@ class ChesapeakeCVPR(GeoDataset):
def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.filename))
class ChesapeakeCVPRDataModule(LightningDataModule):
"""LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.
Uses the random splits defined per state to partition tiles into train, val,
and test sets.
"""
def __init__(
self,
root_dir: str,
train_splits: List[str],
val_splits: List[str],
test_splits: List[str],
patches_per_tile: int = 200,
patch_size: int = 256,
batch_size: int = 64,
num_workers: int = 0,
class_set: int = 7,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the ChesapeakeCVPR Dataset
classes
train_splits: The splits used to train the model, e.g. ["ny-train"]
val_splits: The splits used to validate the model, e.g. ["ny-val"]
test_splits: The splits used to test the model, e.g. ["ny-test"]
patches_per_tile: The number of patches per tile to sample
patch_size: The size of each patch in pixels (test patches will be 1.5 times
this size)
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
class_set: The high-resolution land cover class set to use - 5 or 7
"""
super().__init__() # type: ignore[no-untyped-call]
for state in train_splits + val_splits + test_splits:
assert state in ChesapeakeCVPR.splits
assert class_set in [5, 7]
self.root_dir = root_dir
self.train_splits = train_splits
self.val_splits = val_splits
self.test_splits = test_splits
self.layers = ["naip-new", "lc"]
self.patches_per_tile = patches_per_tile
self.patch_size = patch_size
# This is a rough estimate of how large of a patch we will need to sample in
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
self.original_patch_size = int(patch_size * 2.0)
self.batch_size = batch_size
self.num_workers = num_workers
self.class_set = class_set
def pad_to(
self, size: int = 512, image_value: int = 0, mask_value: int = 0
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
"""Returns a function to perform a padding transform on a single sample.
Args:
size: output image size
image_value: value to pad image with
mask_value: value to pad mask with
Returns:
function to perform padding
"""
def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
_, height, width = sample["image"].shape
assert height <= size and width <= size
height_pad = size - height
width_pad = size - width
# See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
# for a description of the format of the padding tuple
sample["image"] = F.pad(
sample["image"],
(0, width_pad, 0, height_pad),
mode="constant",
value=image_value,
)
sample["mask"] = F.pad(
sample["mask"],
(0, width_pad, 0, height_pad),
mode="constant",
value=mask_value,
)
return sample
return pad_inner
def center_crop(
self, size: int = 512
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
"""Returns a function to perform a center crop transform on a single sample.
Args:
size: output image size
Returns:
function to perform center crop
"""
def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
_, height, width = sample["image"].shape
y1 = (height - size) // 2
x1 = (width - size) // 2
sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size]
sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size]
return sample
return center_crop_inner
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Preprocesses a single sample.
Args:
sample: sample dictionary containing image and mask
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
sample["mask"] = sample["mask"]
sample["mask"] = sample["mask"].squeeze()
if self.class_set == 5:
sample["mask"][sample["mask"] == 5] = 4
sample["mask"][sample["mask"] == 6] = 4
sample["image"] = sample["image"].float()
sample["mask"] = sample["mask"].long()
return sample
def nodata_check(
self, size: int = 512
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
"""Returns a function to check for nodata or mis-sized input.
Args:
size: output image size
Returns:
function to check for nodata values
"""
def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
num_channels, height, width = sample["image"].shape
if height < size or width < size:
sample["image"] = torch.zeros( # type: ignore[attr-defined]
(num_channels, size, size)
)
sample["mask"] = torch.zeros((size, size)) # type: ignore[attr-defined]
return sample
return nodata_check_inner
def prepare_data(self) -> None:
"""Confirms that the dataset is downloaded on the local node.
This method is called once per node, while :func:`setup` is called once per GPU.
"""
ChesapeakeCVPR(
self.root_dir,
splits=self.train_splits,
layers=self.layers,
transforms=None,
download=False,
checksum=False,
)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
Args:
stage: stage to set up
"""
train_transforms = Compose(
[
self.center_crop(self.patch_size),
self.nodata_check(self.patch_size),
self.preprocess,
]
)
val_transforms = Compose(
[
self.center_crop(self.patch_size),
self.nodata_check(self.patch_size),
self.preprocess,
]
)
test_transforms = Compose(
[
self.pad_to(self.original_patch_size, image_value=0, mask_value=0),
self.preprocess,
]
)
self.train_dataset = ChesapeakeCVPR(
self.root_dir,
splits=self.train_splits,
layers=self.layers,
transforms=train_transforms,
download=False,
checksum=False,
)
self.val_dataset = ChesapeakeCVPR(
self.root_dir,
splits=self.val_splits,
layers=self.layers,
transforms=val_transforms,
download=False,
checksum=False,
)
self.test_dataset = ChesapeakeCVPR(
self.root_dir,
splits=self.test_splits,
layers=self.layers,
transforms=test_transforms,
download=False,
checksum=False,
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
sampler = RandomBatchGeoSampler(
self.train_dataset,
size=self.original_patch_size,
batch_size=self.batch_size,
length=self.patches_per_tile * len(self.train_dataset),
)
return DataLoader(
self.train_dataset,
batch_sampler=sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
sampler = GridGeoSampler(
self.val_dataset,
size=self.original_patch_size,
stride=self.original_patch_size,
)
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
sampler = GridGeoSampler(
self.test_dataset,
size=self.original_patch_size,
stride=self.original_patch_size,
)
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)

Просмотреть файл

@ -6,22 +6,16 @@
import abc
import csv
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional
import numpy as np
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import Generator, Tensor # type: ignore[attr-defined]
from torch.utils.data import DataLoader, random_split
from torch import Tensor
from .geo import VisionDataset
from .utils import check_integrity, download_and_extract_archive
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class COWC(VisionDataset, abc.ABC):
"""Abstract base class for the COWC dataset.
@ -268,110 +262,3 @@ class COWCDetection(COWC):
# 4. Unknown
#
# May need new abstract base class. Will need subclasses for different patch sizes.
class COWCCountingDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the COWC Counting dataset."""
def __init__(
self,
root_dir: str,
seed: int,
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for COWC Counting based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the COWCCounting Dataset class
seed: The seed value to use when doing the dataset random_split
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and target
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0 # scale to [0, 1]
sample["label"] = sample["label"].float()
return sample
def prepare_data(self) -> None:
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
COWCCounting(self.root_dir, download=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
Args:
stage: stage to set up
"""
train_val_dataset = COWCCounting(
self.root_dir, split="train", transforms=self.custom_transform
)
self.test_dataset = COWCCounting(
self.root_dir, split="test", transforms=self.custom_transform
)
self.train_dataset, self.val_dataset = random_split(
train_val_dataset,
[len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)],
generator=Generator().manual_seed(self.seed),
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -10,20 +10,13 @@ from typing import Any, Callable, Dict, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from PIL import Image
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import DataLoader, Subset
from .geo import VisionDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class TropicalCycloneWindEstimation(VisionDataset):
"""Tropical Cyclone Wind Estimation Competition dataset.
@ -254,157 +247,3 @@ class TropicalCycloneWindEstimation(VisionDataset):
plt.suptitle(suptitle)
return fig
class CycloneDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NASA Cyclone dataset.
Implements 80/20 train/val splits based on hurricane storm ids.
See :func:`setup` for more details.
"""
def __init__(
self,
root_dir: str,
seed: int,
batch_size: int = 64,
num_workers: int = 0,
api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the
TropicalCycloneWindEstimation Datasets classes
seed: The seed value to use when doing the sklearn based GroupShuffleSplit
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
api_key: The RadiantEarth MLHub API key to use if the dataset needs to be
downloaded
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
self.api_key = api_key
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and target
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0 # scale to [0,1]
sample["image"] = (
sample["image"].unsqueeze(0).repeat(3, 1, 1)
) # convert to 3 channel
sample["label"] = torch.as_tensor( # type: ignore[attr-defined]
sample["label"]
).float()
return sample
def prepare_data(self) -> None:
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
TropicalCycloneWindEstimation(
self.root_dir,
split="train",
transforms=self.custom_transform,
download=self.api_key is not None,
api_key=self.api_key,
)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
We split samples between train/val by the ``storm_id`` property. I.e. all
samples with the same ``storm_id`` value will be either in the train or the val
split. This is important to test one type of generalizability -- given a new
storm, can we predict its windspeed. The test set, however, contains *some*
storms from the training set (specifically, the latter parts of the storms) as
well as some novel storms.
Args:
stage: stage to set up
"""
self.all_train_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="train",
transforms=self.custom_transform,
download=False,
)
self.all_test_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="test",
transforms=self.custom_transform,
download=False,
)
storm_ids = []
for item in self.all_train_dataset.collection:
storm_id = item["href"].split("/")[0].split("_")[-2]
storm_ids.append(storm_id)
train_indices, val_indices = next(
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
storm_ids, groups=storm_ids
)
)
self.train_dataset = Subset(self.all_train_dataset, train_indices)
self.val_dataset = Subset(self.all_train_dataset, val_indices)
self.test_dataset = Subset(
self.all_test_dataset, range(len(self.all_test_dataset))
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -5,16 +5,13 @@
import glob
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import Generator, Tensor # type: ignore[attr-defined]
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Normalize
from torch import Tensor
from .geo import VisionDataset
from .utils import download_and_extract_archive
@ -320,140 +317,3 @@ class ETCI2021(VisionDataset):
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class ETCI2021DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the ETCI2021 dataset.
Splits the existing train split from the dataset into train/val with 80/20
proportions, then uses the existing val dataset as the test data.
.. versionadded:: 0.2
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1]
)
def __init__(
self,
root_dir: str,
seed: int = 0,
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for ETCI2021 based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the ETCI2021 Dataset classes
seed: The seed value to use when doing the dataset random_split
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
self.norm = Normalize(self.band_means, self.band_stds)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Notably, moves the given water mask to act as an input layer.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
image = sample["image"]
water_mask = sample["mask"][0].unsqueeze(0)
flood_mask = sample["mask"][1]
flood_mask = (flood_mask > 0).long()
sample["image"] = torch.cat( # type: ignore[attr-defined]
[image, water_mask], dim=0
).float()
sample["image"] /= 255.0
sample["image"] = self.norm(sample["image"])
sample["mask"] = flood_mask
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
ETCI2021(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
train_val_dataset = ETCI2021(
self.root_dir, split="train", transforms=self.preprocess
)
self.test_dataset = ETCI2021(
self.root_dir, split="val", transforms=self.preprocess
)
size_train_val = len(train_val_dataset)
size_train = int(0.8 * size_train_val)
size_val = size_train_val - size_train
self.train_dataset, self.val_dataset = random_split(
train_val_dataset,
[size_train, size_val],
generator=Generator().manual_seed(self.seed),
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -4,15 +4,11 @@
"""EuroSAT dataset."""
import os
from typing import Any, Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize
from .geo import VisionClassificationDataset
from .utils import check_integrity, download_url, extract_archive, rasterio_loader
@ -229,138 +225,3 @@ class EuroSAT(VisionClassificationDataset):
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class EuroSATDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the EuroSAT dataset.
Uses the train/val/test splits from the dataset.
.. versionadded:: 0.2
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798,
]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042,
]
)
def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for EuroSAT based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.norm = Normalize(self.band_means, self.band_stds)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] = self.norm(sample["image"])
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
EuroSAT(self.root_dir)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms)
self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms)
self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -11,33 +11,12 @@ from xml.etree import ElementTree
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets.utils import check_integrity, dataset_split, extract_archive
from .geo import VisionDataset
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
"""Custom object detection collate fn to handle variable number of boxes.
Args:
batch: list of sample dicts return by dataset
Returns:
batch dict output
"""
output: Dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output["boxes"] = [sample["boxes"] for sample in batch]
return output
from .utils import check_integrity, extract_archive
def parse_pascal_voc(path: str) -> Dict[str, Any]:
@ -350,102 +329,3 @@ class FAIR1M(VisionDataset):
plt.suptitle(suptitle)
return fig
class FAIR1MDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the FAIR1M dataset."""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for FAIR1M based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the FAIR1M Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
test_split_pct: What percentage of the dataset to use as a test set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = FAIR1M(self.root_dir, transforms=transforms)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
collate_fn=collate_fn,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)

Просмотреть файл

@ -6,24 +6,18 @@
import hashlib
import os
from functools import lru_cache
from typing import Any, Callable, Dict, Optional
from typing import Callable, Dict, Optional
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib.colors import ListedColormap
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from .geo import VisionDataset
from .utils import check_integrity, download_and_extract_archive, working_dir
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class LandCoverAI(VisionDataset):
r"""LandCover.ai dataset.
@ -266,110 +260,3 @@ class LandCoverAI(VisionDataset):
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class LandCoverAIDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the LandCover.ai dataset.
Uses the train/val/test splits from the dataset.
"""
def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for LandCover.ai based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and mask
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
sample["image"] = sample["image"].float()
sample["mask"] = sample["mask"].float().unsqueeze(0) + 1
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
_ = LandCoverAI(self.root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
train_transforms = self.preprocess
val_test_transforms = self.preprocess
self.train_dataset = LandCoverAI(
self.root_dir, split="train", transforms=train_transforms
)
self.val_dataset = LandCoverAI(
self.root_dir, split="val", transforms=val_test_transforms
)
self.test_dataset = LandCoverAI(
self.root_dir, split="test", transforms=val_test_transforms
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -5,23 +5,17 @@
import glob
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from .geo import VisionDataset
from .utils import download_and_extract_archive
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class LoveDA(VisionDataset):
"""LoveDA dataset.
@ -305,117 +299,3 @@ class LoveDA(VisionDataset):
plt.suptitle(suptitle)
return fig
class LoveDADataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the LoveDA dataset.
Uses the train/val/test splits from the dataset.
"""
def __init__(
self,
root_dir: str,
scene: List[str],
batch_size: int = 32,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for LoveDA based DataLoaders.
Args:
root_dir: The ``root`` argument to pass to LoveDA Dataset classes
scene: specify whether to load only 'urban', only 'rural' or both
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.scene = scene
self.batch_size = batch_size
self.num_workers = num_workers
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and mask
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
_ = LoveDA(self.root_dir, scene=self.scene, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
train_transforms = self.preprocess
val_test_transforms = self.preprocess
self.train_dataset = LoveDA(
self.root_dir, split="train", scene=self.scene, transforms=train_transforms
)
self.val_dataset = LoveDA(
self.root_dir, split="val", scene=self.scene, transforms=val_test_transforms
)
self.test_dataset = LoveDA(
self.root_dir,
split="test",
scene=self.scene,
transforms=val_test_transforms,
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -3,20 +3,7 @@
"""National Agriculture Imagery Program (NAIP) dataset."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from ..samplers.batch import RandomBatchGeoSampler
from ..samplers.single import GridGeoSampler
from .chesapeake import Chesapeake13
from .geo import RasterDataset
from .utils import BoundingBox, stack_samples
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class NAIP(RasterDataset):
@ -55,147 +42,3 @@ class NAIP(RasterDataset):
# Plotting
all_bands = ["R", "G", "B", "NIR"]
rgb_bands = ["R", "G", "B"]
class NAIPChesapeakeDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NAIP and Chesapeake datasets.
Uses the train/val/test splits from the dataset.
"""
# TODO: tune these hyperparams
length = 1000
stride = 128
def __init__(
self,
naip_root_dir: str,
chesapeake_root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
patch_size: int = 256,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders.
Args:
naip_root_dir: directory containing NAIP data
chesapeake_root_dir: directory containing Chesapeake data
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
patch_size: size of patches to sample
"""
super().__init__() # type: ignore[no-untyped-call]
self.naip_root_dir = naip_root_dir
self.chesapeake_root_dir = chesapeake_root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.patch_size = patch_size
def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the NAIP Dataset.
Args:
sample: NAIP image dictionary
Returns:
preprocessed NAIP data
"""
sample["image"] = sample["image"] / 255.0
sample["image"] = sample["image"].float()
return sample
def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Chesapeake Dataset.
Args:
sample: Chesapeake mask dictionary
Returns:
preprocessed Chesapeake data
"""
sample["mask"] = sample["mask"].long()[0]
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: state to set up
"""
# TODO: these transforms will be applied independently, this won't work if we
# add things like random horizontal flip
chesapeake = Chesapeake13(
self.chesapeake_root_dir, transforms=self.chesapeake_transform
)
naip = NAIP(
self.naip_root_dir,
chesapeake.crs,
chesapeake.res,
transforms=self.naip_transform,
)
self.dataset = chesapeake & naip
# TODO: figure out better train/val/test split
roi = self.dataset.bounds
midx = roi.minx + (roi.maxx - roi.minx) / 2
midy = roi.miny + (roi.maxy - roi.miny) / 2
train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt)
val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt)
test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt)
self.train_sampler = RandomBatchGeoSampler(
naip, self.patch_size, self.batch_size, self.length, train_roi
)
self.val_sampler = GridGeoSampler(naip, self.patch_size, self.stride, val_roi)
self.test_sampler = GridGeoSampler(naip, self.patch_size, self.stride, test_roi)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.dataset,
batch_sampler=self.train_sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.dataset,
batch_size=self.batch_size,
sampler=self.val_sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.dataset,
batch_size=self.batch_size,
sampler=self.test_sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)

Просмотреть файл

@ -4,39 +4,17 @@
"""NASA Marine Debris dataset."""
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import rasterio
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from torchvision.utils import draw_bounding_boxes
from .geo import VisionDataset
from .utils import dataset_split, download_radiant_mlhub_dataset, extract_archive
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
"""Custom object detection collate fn to handle variable boxes.
Args:
batch: list of sample dicts return by dataset
Returns:
batch dict output
"""
output: Dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output["boxes"] = [sample["boxes"] for sample in batch]
return output
from .utils import download_radiant_mlhub_dataset, extract_archive
class NASAMarineDebris(VisionDataset):
@ -279,109 +257,3 @@ class NASAMarineDebris(VisionDataset):
plt.suptitle(suptitle)
return fig
class NASAMarineDebrisDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NASA Marine Debris dataset."""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NASA Marine Debris based DataLoaders.
Args:
root_dir: The ``root`` argument to pass to the Dataset class
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
test_split_pct: What percentage of the dataset to use as a test set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
NASAMarineDebris(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = NASAMarineDebris(self.root_dir, transforms=transforms)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
collate_fn=collate_fn,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)

Просмотреть файл

@ -5,25 +5,23 @@
import glob
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Union
import kornia.augmentation as K
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from einops import repeat
from matplotlib.figure import Figure
from numpy import ndarray as Array
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate
from torchvision.transforms import Compose, Normalize
from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks
from .geo import VisionDataset
from .utils import download_url, extract_archive, sort_sentinel2_bands
from .utils import (
download_url,
draw_semantic_segmentation_masks,
extract_archive,
sort_sentinel2_bands,
)
class OSCD(VisionDataset):
@ -317,202 +315,3 @@ class OSCD(VisionDataset):
plt.suptitle(suptitle)
return fig
class OSCDDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the OSCD dataset.
Uses the train/test splits from the dataset and further splits
the train split into train/val splits.
.. versionadded: 0.2
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[
1583.0741,
1374.3202,
1294.1616,
1325.6158,
1478.7408,
1933.0822,
2166.0608,
2076.4868,
2306.0652,
690.9814,
16.2360,
2080.3347,
1524.6930,
]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[
52.1937,
83.4168,
105.6966,
151.1401,
147.4615,
115.9289,
123.1974,
114.6483,
141.4530,
73.2758,
4.8368,
213.4821,
179.4793,
]
)
def __init__(
self,
root_dir: str,
bands: str = "all",
train_batch_size: int = 32,
num_workers: int = 0,
val_split_pct: float = 0.2,
patch_size: Tuple[int, int] = (64, 64),
num_patches_per_tile: int = 32,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for OSCD based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the OSCD Dataset classes
bands: "rgb" or "all"
train_batch_size: The batch size used in the train DataLoader
(val_batch_size == test_batch_size == 1)
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
patch_size: Size of random patch from image and mask (height, width)
num_patches_per_tile: number of random patches per sample
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.bands = bands
self.train_batch_size = train_batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.patch_size = patch_size
self.num_patches_per_tile = num_patches_per_tile
if bands == "rgb":
self.band_means = self.band_means[[3, 2, 1], None, None]
self.band_stds = self.band_stds[[3, 2, 1], None, None]
else:
self.band_means = self.band_means[:, None, None]
self.band_stds = self.band_stds[:, None, None]
self.norm = Normalize(self.band_means, self.band_stds)
self.rcrop = K.AugmentationSequential(
K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True
)
self.padto = K.PadTo((1280, 1280))
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
sample["mask"] = sample["mask"]
sample["image"] = self.norm(sample["image"])
sample["image"] = torch.flatten( # type: ignore[attr-defined]
sample["image"], 0, 1
)
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
OSCD(self.root_dir, split="train", bands=self.bands, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
"""
def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]:
images, masks = [], []
for i in range(self.num_patches_per_tile):
mask = repeat(sample["mask"], "h w -> t h w", t=2).float()
image, mask = self.rcrop(sample["image"], mask)
mask = mask.squeeze()[0]
images.append(image.squeeze())
masks.append(mask.long())
sample["image"] = torch.stack(images)
sample["mask"] = torch.stack(masks)
return sample
def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["image"] = self.padto(sample["image"])[0]
sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0]
return sample
train_transforms = Compose([self.preprocess, n_random_crop])
# for testing and validation we pad all inputs to a fixed size to avoid issues
# with the upsampling paths in encoder-decoder architectures
test_transforms = Compose([self.preprocess, pad_to])
train_dataset = OSCD(
self.root_dir, split="train", bands=self.bands, transforms=train_transforms
)
if self.val_split_pct > 0.0:
val_dataset = OSCD(
self.root_dir,
split="train",
bands=self.bands,
transforms=test_transforms,
)
self.train_dataset, self.val_dataset, _ = dataset_split(
train_dataset, val_pct=self.val_split_pct, test_pct=0.0
)
self.val_dataset.dataset = val_dataset
else:
self.train_dataset = train_dataset # type: ignore[assignment]
self.val_dataset = None # type: ignore[assignment]
self.test_dataset = OSCD(
self.root_dir, split="test", bands=self.bands, transforms=test_transforms
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call]
batch
)
r_batch["image"] = torch.flatten( # type: ignore[attr-defined]
r_batch["image"], 0, 1
)
r_batch["mask"] = torch.flatten( # type: ignore[attr-defined]
r_batch["mask"], 0, 1
)
return r_batch
return DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
collate_fn=collate_wrapper,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
if self.val_split_pct == 0.0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=1,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
return DataLoader(
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

Просмотреть файл

@ -4,22 +4,23 @@
"""Potsdam dataset."""
import os
from typing import Any, Callable, Dict, Optional
from typing import Callable, Dict, Optional
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import rasterio
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks
from .geo import VisionDataset
from .utils import check_integrity, extract_archive, rgb_to_mask
from .utils import (
check_integrity,
draw_semantic_segmentation_masks,
extract_archive,
rgb_to_mask,
)
class Potsdam2D(VisionDataset):
@ -293,111 +294,3 @@ class Potsdam2D(VisionDataset):
plt.suptitle(suptitle)
return fig
class Potsdam2DDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the Potsdam2D dataset.
Uses the train/test splits from the dataset.
.. versionadded: 0.2
"""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Potsdam2D based DataLoaders.
Args:
root_dir: The ``root`` argument to pass to the Potsdam2D Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = Potsdam2D(self.root_dir, "train", transforms=transforms)
if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset # type: ignore[assignment]
self.val_dataset = None # type: ignore[assignment]
self.test_dataset = Potsdam2D(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
if self.val_split_pct == 0.0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -4,23 +4,15 @@
"""RESISC45 dataset."""
import os
from typing import Any, Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize
from .geo import VisionClassificationDataset
from .utils import download_url, extract_archive
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class RESISC45(VisionClassificationDataset):
"""RESISC45 dataset.
@ -288,109 +280,3 @@ class RESISC45(VisionClassificationDataset):
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class RESISC45DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the RESISC45 dataset.
Uses the train/val/test splits from the dataset.
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[0.36801773, 0.38097873, 0.343583]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[0.14540215, 0.13558227, 0.13203649]
)
def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.norm = Normalize(self.band_means, self.band_stds)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = self.norm(sample["image"])
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
RESISC45(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
self.train_dataset = RESISC45(self.root_dir, "train", transforms=transforms)
self.val_dataset = RESISC45(self.root_dir, "val", transforms=transforms)
self.test_dataset = RESISC45(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -4,23 +4,16 @@
"""SEN12MS dataset."""
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional
import numpy as np
import pytorch_lightning as pl
import rasterio
import torch
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import DataLoader, Subset
from .geo import VisionDataset
from .utils import check_integrity
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class SEN12MS(VisionDataset):
"""SEN12MS dataset.
@ -246,188 +239,3 @@ class SEN12MS(VisionDataset):
if not check_integrity(filepath, md5 if self.checksum else None):
return False
return True
class SEN12MSDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the SEN12MS dataset.
Implements 80/20 geographic train/val splits and uses the test split from the
classification dataset definitions. See :func:`setup` for more details.
Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See
https://arxiv.org/abs/2002.08254.
"""
#: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader
#: here https://github.com/lukasliebel/dfc2020_baseline.
DFC2020_CLASS_MAPPING = torch.tensor( # type: ignore[attr-defined]
[
0, # maps 0s to 0
1, # maps 1s to 1
1, # maps 2s to 1
1, # ...
1,
1,
2,
2,
3,
3,
4,
5,
6,
7,
6,
8,
9,
10,
]
)
def __init__(
self,
root_dir: str,
seed: int,
band_set: str = "all",
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes
seed: The seed value to use when doing the sklearn based ShuffleSplit
band_set: The subset of S1/S2 bands to use. Options are: "all",
"s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes:
B2, B3, B4, B8, B11, and B12.
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
assert band_set in SEN12MS.BAND_SETS.keys()
self.root_dir = root_dir
self.seed = seed
self.band_set = band_set
self.band_indices = SEN12MS.BAND_SETS[band_set]
self.batch_size = batch_size
self.num_workers = num_workers
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and mask
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
if self.band_set == "all":
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000
elif self.band_set == "s1":
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
else:
sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000
sample["mask"] = sample["mask"][0, :, :].long()
sample["mask"] = torch.take( # type: ignore[attr-defined]
self.DFC2020_CLASS_MAPPING, sample["mask"]
)
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
We split samples between train and val geographically with proportions of 80/20.
This mimics the geographic test set split.
Args:
stage: stage to set up
"""
season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000}
self.all_train_dataset = SEN12MS(
self.root_dir,
split="train",
bands=self.band_indices,
transforms=self.custom_transform,
checksum=False,
)
self.all_test_dataset = SEN12MS(
self.root_dir,
split="test",
bands=self.band_indices,
transforms=self.custom_transform,
checksum=False,
)
# A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif"
# This patch will belong to the scene that is uniquelly identified by its
# (season, scene_id) tuple. Because the largest scene_id is 149, we can simply
# give each season a large number and representing a `unique_scene_id` as
# `season_id + scene_id`.
scenes = []
for scene_fn in self.all_train_dataset.ids:
parts = scene_fn.split("_")
season_id = season_to_int[parts[1]]
scene_id = int(parts[3])
scenes.append(season_id + scene_id)
train_indices, val_indices = next(
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
scenes, groups=scenes
)
)
self.train_dataset = Subset(self.all_train_dataset, train_indices)
self.val_dataset = Subset(self.all_train_dataset, val_indices)
self.test_dataset = Subset(
self.all_test_dataset, range(len(self.all_test_dataset))
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -4,23 +4,16 @@
"""So2Sat dataset."""
import os
from typing import Any, Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from .geo import VisionDataset
from .utils import check_integrity, percentile_normalization
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class So2Sat(VisionDataset):
"""So2Sat dataset.
@ -250,211 +243,3 @@ class So2Sat(VisionDataset):
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class So2SatDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the So2Sat dataset.
Uses the train/val/test splits from the dataset.
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[
-3.591224256609313e-05,
-7.658561276843396e-06,
5.9373857475971184e-05,
2.5166231537121083e-05,
0.04420110659759328,
0.25761027084996196,
0.0007556743372573258,
0.0013503466830024448,
0.12375696117681859,
0.1092774636368323,
0.1010855203267882,
0.1142398616114001,
0.1592656692023089,
0.18147236008771792,
0.1745740312291377,
0.19501607349635292,
0.15428468872076637,
0.10905050699570007,
]
).reshape(18, 1, 1)
band_stds = torch.tensor( # type: ignore[attr-defined]
[
0.17555201137417686,
0.17556463274968204,
0.45998793417834255,
0.455988755730148,
2.8559909213125763,
8.324800606439833,
2.4498757382563103,
1.4647352984509094,
0.03958795985905458,
0.047778262752410296,
0.06636616706371974,
0.06358874912497474,
0.07744387147984592,
0.09101635085921553,
0.09218466562387101,
0.10164581233948201,
0.09991773043519253,
0.08780632509122865,
]
).reshape(18, 1, 1)
# this reorders the bands to put S2 RGB first, then remainder of S2, then S1
reindex_to_rgb_first = [
10,
9,
8,
11,
12,
13,
14,
15,
16,
17,
# 0,
# 1,
# 2,
# 3,
# 4,
# 5,
# 6,
# 7,
]
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
bands: str = "rgb",
unsupervised_mode: bool = False,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for So2Sat based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
bands: Either "rgb" or "s2"
unsupervised_mode: Makes the train dataloader return imagery from the train,
val, and test sets
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.bands = bands
self.unsupervised_mode = unsupervised_mode
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image
Returns:
preprocessed sample
"""
# sample["image"] = (sample["image"] - self.band_means) / self.band_stds
sample["image"] = sample["image"].float()
sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :]
if self.bands == "rgb":
sample["image"] = sample["image"][:3, :, :]
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
So2Sat(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
train_transforms = Compose([self.preprocess])
val_test_transforms = self.preprocess
if not self.unsupervised_mode:
self.train_dataset = So2Sat(
self.root_dir, split="train", transforms=train_transforms
)
self.val_dataset = So2Sat(
self.root_dir, split="validation", transforms=val_test_transforms
)
self.test_dataset = So2Sat(
self.root_dir, split="test", transforms=val_test_transforms
)
else:
temp_train = So2Sat(
self.root_dir, split="train", transforms=train_transforms
)
self.val_dataset = So2Sat(
self.root_dir, split="validation", transforms=train_transforms
)
self.test_dataset = So2Sat(
self.root_dir, split="test", transforms=train_transforms
)
self.train_dataset = cast(
So2Sat, temp_train + self.val_dataset + self.test_dataset
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -24,8 +24,8 @@ from rasterio.features import rasterize
from rasterio.transform import Affine
from torch import Tensor
from torchgeo.datasets.geo import VisionDataset
from torchgeo.datasets.utils import (
from .geo import VisionDataset
from .utils import (
check_integrity,
download_radiant_mlhub_collection,
extract_archive,

Просмотреть файл

@ -3,24 +3,15 @@
"""UC Merced dataset."""
import os
from typing import Any, Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize
from .geo import VisionClassificationDataset
from .utils import check_integrity, download_url, extract_archive
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class UCMerced(VisionClassificationDataset):
"""UC Merced dataset.
@ -251,110 +242,3 @@ class UCMerced(VisionClassificationDataset):
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class UCMercedDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the UC Merced dataset.
Uses random train/val/test splits.
"""
band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined]
band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined]
def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for UCMerced based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.norm = Normalize(self.band_means, self.band_stds)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
c, h, w = sample["image"].shape
if h != 256 or w != 256:
sample["image"] = torchvision.transforms.functional.resize(
sample["image"], size=(256, 256)
)
sample["image"] = self.norm(sample["image"])
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
UCMerced(self.root_dir, download=False, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms)
self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms)
self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -32,7 +32,6 @@ import numpy as np
import rasterio
import torch
from torch import Tensor
from torch.utils.data import Dataset, Subset, random_split
from torchvision.datasets.utils import check_integrity, download_url
from torchvision.utils import draw_segmentation_masks
@ -48,7 +47,6 @@ __all__ = (
"concat_samples",
"merge_samples",
"rasterio_loader",
"dataset_split",
"sort_sentinel2_bands",
"draw_semantic_segmentation_masks",
"rgb_to_mask",
@ -519,31 +517,6 @@ def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg]
return array
def dataset_split(
dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None
) -> List[Subset[Any]]:
"""Split a torch Dataset into train/val/test sets.
If ``test_pct`` is not set then only train and validation splits are returned.
Args:
dataset: dataset to be split into train/val or train/val/test subsets
val_pct: percentage of samples to be in validation set
test_pct: (Optional) percentage of samples to be in test set
Returns:
a list of the subset datasets. Either [train, val] or [train, val, test]
"""
if test_pct is None:
val_length = int(len(dataset) * val_pct)
train_length = len(dataset) - val_length
return random_split(dataset, [train_length, val_length])
else:
val_length = int(len(dataset) * val_pct)
test_length = int(len(dataset) * test_pct)
train_length = len(dataset) - (val_length + test_length)
return random_split(dataset, [train_length, val_length, test_length])
def sort_sentinel2_bands(x: str) -> str:
"""Sort Sentinel-2 band files in the correct order."""
x = os.path.basename(x).split("_")[-1]

Просмотреть файл

@ -4,21 +4,22 @@
"""Vaihingen dataset."""
import os
from typing import Any, Callable, Dict, Optional
from typing import Callable, Dict, Optional
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks
from .geo import VisionDataset
from .utils import check_integrity, extract_archive, rgb_to_mask
from .utils import (
check_integrity,
draw_semantic_segmentation_masks,
extract_archive,
rgb_to_mask,
)
class Vaihingen2D(VisionDataset):
@ -293,111 +294,3 @@ class Vaihingen2D(VisionDataset):
plt.suptitle(suptitle)
return fig
class Vaihingen2DDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the Vaihingen2D dataset.
Uses the train/test splits from the dataset.
.. versionadded: 0.2
"""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Vaihingen2D based DataLoaders.
Args:
root_dir: The ``root`` argument to pass to the Vaihingen Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = Vaihingen2D(self.root_dir, "train", transforms=transforms)
if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset # type: ignore[assignment]
self.val_dataset = None # type: ignore[assignment]
self.test_dataset = Vaihingen2D(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
if self.val_split_pct == 0.0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -5,20 +5,16 @@
import glob
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks
from .geo import VisionDataset
from .utils import check_integrity, extract_archive
from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive
class XView2(VisionDataset):
@ -282,111 +278,3 @@ class XView2(VisionDataset):
plt.suptitle(suptitle)
return fig
class XView2DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the xView2 dataset.
Uses the train/val/test splits from the dataset.
.. versionadded: 0.2
"""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for xView2 based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the xView2 Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = XView2(self.root_dir, "train", transforms=transforms)
if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset # type: ignore[assignment]
self.val_dataset = None # type: ignore[assignment]
self.test_dataset = XView2(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
if self.val_split_pct == 0.0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -10,9 +10,7 @@ from typing import Iterator, List, Optional, Tuple, Union
from rtree.index import Index, Property
from torch.utils.data import Sampler
from torchgeo.datasets.geo import GeoDataset
from torchgeo.datasets.utils import BoundingBox
from ..datasets import BoundingBox, GeoDataset
from .utils import _to_tuple, get_random_bounding_box
# https://github.com/pytorch/pytorch/issues/60979

Просмотреть файл

@ -10,9 +10,7 @@ from typing import Iterator, Optional, Tuple, Union
from rtree.index import Index, Property
from torch.utils.data import Sampler
from torchgeo.datasets.geo import GeoDataset
from torchgeo.datasets.utils import BoundingBox
from ..datasets import BoundingBox, GeoDataset
from .utils import _to_tuple, get_random_bounding_box
# https://github.com/pytorch/pytorch/issues/60979

Просмотреть файл

@ -6,7 +6,7 @@
import random
from typing import Tuple, Union
from torchgeo.datasets.utils import BoundingBox
from ..datasets import BoundingBox
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: