зеркало из https://github.com/microsoft/torchgeo.git
Move DataModules to torchgeo.datamodules (#321)
* Move DataModules to torchgeo.datamodules * Clean up local imports
This commit is contained in:
Родитель
5a57d6c9a3
Коммит
cbebc1e0db
|
@ -0,0 +1,105 @@
|
|||
torchgeo.datamodules
|
||||
====================
|
||||
|
||||
.. module:: torchgeo.datamodules
|
||||
|
||||
Geospatial DataModules
|
||||
----------------------
|
||||
|
||||
Chesapeake Bay High-Resolution Land Cover Project
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: ChesapeakeCVPRDataModule
|
||||
|
||||
National Agriculture Imagery Program (NAIP)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: NAIPChesapeakeDataModule
|
||||
|
||||
Non-geospatial DataModules
|
||||
--------------------------
|
||||
|
||||
BigEarthNet
|
||||
^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: BigEarthNetDataModule
|
||||
|
||||
Cars Overhead With Context (COWC)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: COWCCountingDataModule
|
||||
|
||||
ETCI2021 Flood Detection
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: ETCI2021DataModule
|
||||
|
||||
EuroSAT
|
||||
^^^^^^^
|
||||
|
||||
.. autoclass:: EuroSATDataModule
|
||||
|
||||
FAIR1M (Fine-grAined object recognItion in high-Resolution imagery)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: FAIR1MDataModule
|
||||
|
||||
LandCover.ai (Land Cover from Aerial Imagery)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: LandCoverAIDataModule
|
||||
|
||||
LoveDA (Land-cOVEr Domain Adaptive semantic segmentation)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: LoveDADataModule
|
||||
|
||||
NASA Marine Debris
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: NASAMarineDebrisDataModule
|
||||
|
||||
OSCD (Onera Satellite Change Detection)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: OSCDDataModule
|
||||
|
||||
Potsdam
|
||||
^^^^^^^
|
||||
|
||||
.. autoclass:: Potsdam2DDataModule
|
||||
|
||||
RESISC45 (Remote Sensing Image Scene Classification)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: RESISC45DataModule
|
||||
|
||||
SEN12MS
|
||||
^^^^^^^
|
||||
|
||||
.. autoclass:: SEN12MSDataModule
|
||||
|
||||
So2Sat
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: So2SatDataModule
|
||||
|
||||
Tropical Cyclone Wind Estimation Competition
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: CycloneDataModule
|
||||
|
||||
UC Merced
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: UCMercedDataModule
|
||||
|
||||
Vaihingen
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: Vaihingen2DDataModule
|
||||
|
||||
xView2
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: XView2DataModule
|
|
@ -31,7 +31,6 @@ Chesapeake Bay High-Resolution Land Cover Project
|
|||
.. autoclass:: ChesapeakeVA
|
||||
.. autoclass:: ChesapeakeWV
|
||||
.. autoclass:: ChesapeakeCVPR
|
||||
.. autoclass:: ChesapeakeCVPRDataModule
|
||||
|
||||
Cropland Data Layer (CDL)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -57,7 +56,6 @@ National Agriculture Imagery Program (NAIP)
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: NAIP
|
||||
.. autoclass:: NAIPChesapeakeDataModule
|
||||
|
||||
Sentinel
|
||||
^^^^^^^^
|
||||
|
@ -86,7 +84,6 @@ BigEarthNet
|
|||
^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: BigEarthNet
|
||||
.. autoclass:: BigEarthNetDataModule
|
||||
|
||||
Cars Overhead With Context (COWC)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -94,7 +91,6 @@ Cars Overhead With Context (COWC)
|
|||
.. autoclass:: COWC
|
||||
.. autoclass:: COWCCounting
|
||||
.. autoclass:: COWCDetection
|
||||
.. autoclass:: COWCCountingDataModule
|
||||
|
||||
CV4A Kenya Crop Type Competition
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -105,19 +101,16 @@ ETCI2021 Flood Detection
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: ETCI2021
|
||||
.. autoclass:: ETCI2021DataModule
|
||||
|
||||
EuroSAT
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
^^^^^^^
|
||||
|
||||
.. autoclass:: EuroSAT
|
||||
.. autoclass:: EuroSATDataModule
|
||||
|
||||
FAIR1M (Fine-grAined object recognItion in high-Resolution imagery)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: FAIR1M
|
||||
.. autoclass:: FAIR1MDataModule
|
||||
|
||||
GID-15 (Gaofen Image Dataset)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -133,7 +126,6 @@ LandCover.ai (Land Cover from Aerial Imagery)
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: LandCoverAI
|
||||
.. autoclass:: LandCoverAIDataModule
|
||||
|
||||
LEVIR-CD+ (LEVIR Change Detection +)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -144,19 +136,16 @@ LoveDA (Land-cOVEr Domain Adaptive semantic segmentation)
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: LoveDA
|
||||
.. autoclass:: LoveDADataModule
|
||||
|
||||
NASA Marine Debris
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: NASAMarineDebris
|
||||
.. autoclass:: NASAMarineDebrisDataModule
|
||||
|
||||
OSCD (Onera Satellite Change Detection)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: OSCD
|
||||
.. autoclass:: OSCDDataModule
|
||||
|
||||
PatternNet
|
||||
^^^^^^^^^^
|
||||
|
@ -167,13 +156,11 @@ Potsdam
|
|||
^^^^^^^
|
||||
|
||||
.. autoclass:: Potsdam2D
|
||||
.. autoclass:: Potsdam2DDataModule
|
||||
|
||||
RESISC45 (Remote Sensing Image Scene Classification)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: RESISC45
|
||||
.. autoclass:: RESISC45DataModule
|
||||
|
||||
Seasonal Contrast
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
@ -184,13 +171,11 @@ SEN12MS
|
|||
^^^^^^^
|
||||
|
||||
.. autoclass:: SEN12MS
|
||||
.. autoclass:: SEN12MSDataModule
|
||||
|
||||
So2Sat
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: So2Sat
|
||||
.. autoclass:: So2SatDataModule
|
||||
|
||||
SpaceNet
|
||||
^^^^^^^^
|
||||
|
@ -206,30 +191,26 @@ Tropical Cyclone Wind Estimation Competition
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: TropicalCycloneWindEstimation
|
||||
.. autoclass:: CycloneDataModule
|
||||
|
||||
UC Merced
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: UCMerced
|
||||
|
||||
Vaihingen
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: Vaihingen2D
|
||||
.. autoclass:: Vaihingen2DDataModule
|
||||
|
||||
NWPU VHR-10
|
||||
^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: VHR10
|
||||
|
||||
UC Merced
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: UCMerced
|
||||
.. autoclass:: UCMercedDataModule
|
||||
|
||||
xView2
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: XView2
|
||||
.. autoclass:: XView2DataModule
|
||||
|
||||
ZueriCrop
|
||||
^^^^^^^^^
|
||||
|
|
|
@ -15,6 +15,7 @@ torchgeo
|
|||
:maxdepth: 2
|
||||
:caption: Package Reference
|
||||
|
||||
api/datamodules
|
||||
api/datasets
|
||||
api/models
|
||||
api/samplers
|
||||
|
|
|
@ -11,7 +11,7 @@ import os
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from torchgeo.datasets import ChesapeakeCVPRDataModule
|
||||
from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
||||
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask
|
||||
|
||||
ALL_TEST_SPLITS = [["de-val"], ["pa-test"], ["ny-test"], ["pa-test", "ny-test"]]
|
||||
|
|
|
@ -73,7 +73,7 @@ strict_equality = true
|
|||
|
||||
[tool.pydocstyle]
|
||||
convention = "google"
|
||||
match_dir = "(datasets|models|samplers|torchgeo|trainers|transforms)"
|
||||
match_dir = "(datamodules|datasets|models|samplers|torchgeo|trainers|transforms)"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# Skip slow tests by default
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datamodules import BigEarthNetDataModule
|
||||
|
||||
|
||||
class TestBigEarthNetDataModule:
|
||||
@pytest.fixture(scope="class", params=["s1", "s2", "all"])
|
||||
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
|
||||
bands = request.param
|
||||
root = os.path.join("tests", "data", "bigearthnet")
|
||||
num_classes = 19
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = BigEarthNetDataModule(root, bands, num_classes, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
||||
|
||||
|
||||
class TestChesapeakeCVPRDataModule:
|
||||
@pytest.fixture(scope="class", params=[5, 7])
|
||||
def datamodule(self, request: SubRequest) -> ChesapeakeCVPRDataModule:
|
||||
dm = ChesapeakeCVPRDataModule(
|
||||
os.path.join("tests", "data", "chesapeake", "cvpr"),
|
||||
["de-test"],
|
||||
["de-test"],
|
||||
["de-test"],
|
||||
patch_size=32,
|
||||
patches_per_tile=2,
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
class_set=request.param,
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
||||
def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None:
|
||||
nodata_check = datamodule.nodata_check(4)
|
||||
sample = {
|
||||
"image": torch.ones(1, 2, 2), # type: ignore[attr-defined]
|
||||
"mask": torch.ones(2, 2), # type: ignore[attr-defined]
|
||||
}
|
||||
out = nodata_check(sample)
|
||||
assert torch.equal( # type: ignore[attr-defined]
|
||||
out["image"], torch.zeros(1, 4, 4) # type: ignore[attr-defined]
|
||||
)
|
||||
assert torch.equal( # type: ignore[attr-defined]
|
||||
out["mask"], torch.zeros(4, 4) # type: ignore[attr-defined]
|
||||
)
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import COWCCountingDataModule
|
||||
|
||||
|
||||
class TestCOWCCountingDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> COWCCountingDataModule:
|
||||
root = os.path.join("tests", "data", "cowc_counting")
|
||||
seed = 0
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = COWCCountingDataModule(root, seed, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: COWCCountingDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: COWCCountingDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: COWCCountingDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import CycloneDataModule
|
||||
|
||||
|
||||
class TestCycloneDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> CycloneDataModule:
|
||||
root = os.path.join("tests", "data", "cyclone")
|
||||
seed = 0
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = CycloneDataModule(root, seed, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: CycloneDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: CycloneDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: CycloneDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import ETCI2021DataModule
|
||||
|
||||
|
||||
class TestETCI2021DataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> ETCI2021DataModule:
|
||||
root = os.path.join("tests", "data", "etci2021")
|
||||
seed = 0
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = ETCI2021DataModule(root, seed, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: ETCI2021DataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: ETCI2021DataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: ETCI2021DataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import EuroSATDataModule
|
||||
|
||||
|
||||
class TestEuroSATDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> EuroSATDataModule:
|
||||
root = os.path.join("tests", "data", "eurosat")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = EuroSATDataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: EuroSATDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: EuroSATDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: EuroSATDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import FAIR1MDataModule
|
||||
|
||||
|
||||
class TestFAIR1MDataModule:
|
||||
@pytest.fixture(scope="class", params=[True, False])
|
||||
def datamodule(self) -> FAIR1MDataModule:
|
||||
root = os.path.join("tests", "data", "fair1m")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = FAIR1MDataModule(
|
||||
root, batch_size, num_workers, val_split_pct=0.33, test_split_pct=0.33
|
||||
)
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import LandCoverAIDataModule
|
||||
|
||||
|
||||
class TestLandCoverAIDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> LandCoverAIDataModule:
|
||||
root = os.path.join("tests", "data", "landcoverai")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = LandCoverAIDataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import LoveDADataModule
|
||||
|
||||
|
||||
class TestLoveDADataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> LoveDADataModule:
|
||||
root = os.path.join("tests", "data", "loveda")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
scene = ["rural", "urban"]
|
||||
|
||||
dm = LoveDADataModule(
|
||||
root_dir=root, scene=scene, batch_size=batch_size, num_workers=num_workers
|
||||
)
|
||||
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: LoveDADataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: LoveDADataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: LoveDADataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import NAIPChesapeakeDataModule
|
||||
|
||||
|
||||
class TestNAIPChesapeakeDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> NAIPChesapeakeDataModule:
|
||||
dm = NAIPChesapeakeDataModule(
|
||||
os.path.join("tests", "data", "naip"),
|
||||
os.path.join("tests", "data", "chesapeake", "BAYWIDE"),
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
)
|
||||
dm.patch_size = 32
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import NASAMarineDebrisDataModule
|
||||
|
||||
|
||||
class TestNASAMarineDebrisDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> NASAMarineDebrisDataModule:
|
||||
root = os.path.join("tests", "data", "nasa_marine_debris")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
val_split_pct = 0.3
|
||||
test_split_pct = 0.3
|
||||
dm = NASAMarineDebrisDataModule(
|
||||
root, batch_size, num_workers, val_split_pct, test_split_pct
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datamodules import OSCDDataModule
|
||||
|
||||
|
||||
class TestOSCDDataModule:
|
||||
@pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5]))
|
||||
def datamodule(self, request: SubRequest) -> OSCDDataModule:
|
||||
bands, val_split_pct = request.param
|
||||
patch_size = (2, 2)
|
||||
num_patches_per_tile = 2
|
||||
root = os.path.join("tests", "data", "oscd")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = OSCDDataModule(
|
||||
root,
|
||||
bands,
|
||||
batch_size,
|
||||
num_workers,
|
||||
val_split_pct,
|
||||
patch_size,
|
||||
num_patches_per_tile,
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
sample = next(iter(datamodule.train_dataloader()))
|
||||
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
|
||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2
|
||||
if datamodule.bands == "all":
|
||||
assert sample["image"].shape[1] == 26
|
||||
else:
|
||||
assert sample["image"].shape[1] == 6
|
||||
|
||||
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
sample = next(iter(datamodule.val_dataloader()))
|
||||
if datamodule.val_split_pct > 0.0:
|
||||
assert (
|
||||
sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
|
||||
)
|
||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
||||
if datamodule.bands == "all":
|
||||
assert sample["image"].shape[1] == 26
|
||||
else:
|
||||
assert sample["image"].shape[1] == 6
|
||||
|
||||
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
sample = next(iter(datamodule.test_dataloader()))
|
||||
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
|
||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
||||
if datamodule.bands == "all":
|
||||
assert sample["image"].shape[1] == 26
|
||||
else:
|
||||
assert sample["image"].shape[1] == 6
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datamodules import Potsdam2DDataModule
|
||||
|
||||
|
||||
class TestPotsdam2DDataModule:
|
||||
@pytest.fixture(scope="class", params=[0.0, 0.5])
|
||||
def datamodule(self, request: SubRequest) -> Potsdam2DDataModule:
|
||||
root = os.path.join("tests", "data", "potsdam")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
val_split_size = request.param
|
||||
dm = Potsdam2DDataModule(
|
||||
root, batch_size, num_workers, val_split_pct=val_split_size
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import RESISC45DataModule
|
||||
|
||||
|
||||
class TestRESISC45DataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> RESISC45DataModule:
|
||||
root = os.path.join("tests", "data", "resisc45")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = RESISC45DataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: RESISC45DataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: RESISC45DataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: RESISC45DataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datamodules import SEN12MSDataModule
|
||||
|
||||
|
||||
class TestSEN12MSDataModule:
|
||||
@pytest.fixture(scope="class", params=["all", "s1", "s2-all", "s2-reduced"])
|
||||
def datamodule(self, request: SubRequest) -> SEN12MSDataModule:
|
||||
root = os.path.join("tests", "data", "sen12ms")
|
||||
seed = 0
|
||||
bands = request.param
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = SEN12MSDataModule(root, seed, bands, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: SEN12MSDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: SEN12MSDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: SEN12MSDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datamodules import So2SatDataModule
|
||||
|
||||
pytest.importorskip("h5py")
|
||||
|
||||
|
||||
class TestSo2SatDataModule:
|
||||
@pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"]))
|
||||
def datamodule(self, request: SubRequest) -> So2SatDataModule:
|
||||
unsupervised_mode, bands = request.param
|
||||
root = os.path.join("tests", "data", "so2sat")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = So2SatDataModule(root, batch_size, num_workers, bands, unsupervised_mode)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: So2SatDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: So2SatDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: So2SatDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datamodules import UCMercedDataModule
|
||||
|
||||
|
||||
class TestUCMercedDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> UCMercedDataModule:
|
||||
root = os.path.join("tests", "data", "ucmerced")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = UCMercedDataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: UCMercedDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: UCMercedDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: UCMercedDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
from torchgeo.datamodules.utils import dataset_split
|
||||
|
||||
|
||||
def test_dataset_split() -> None:
|
||||
num_samples = 24
|
||||
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
|
||||
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
|
||||
ds = TensorDataset(x, y)
|
||||
|
||||
# Test only train/val set split
|
||||
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
|
||||
assert len(train_ds) == num_samples // 2
|
||||
assert len(val_ds) == num_samples // 2
|
||||
|
||||
# Test train/val/test set split
|
||||
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
|
||||
assert len(train_ds) == num_samples // 3
|
||||
assert len(val_ds) == num_samples // 3
|
||||
assert len(test_ds) == num_samples // 3
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datamodules import Vaihingen2DDataModule
|
||||
|
||||
|
||||
class TestVaihingen2DDataModule:
|
||||
@pytest.fixture(scope="class", params=[0.0, 0.5])
|
||||
def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule:
|
||||
root = os.path.join("tests", "data", "vaihingen")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
val_split_size = request.param
|
||||
dm = Vaihingen2DDataModule(
|
||||
root, batch_size, num_workers, val_split_pct=val_split_size
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datamodules import XView2DataModule
|
||||
|
||||
|
||||
class TestXView2DataModule:
|
||||
@pytest.fixture(scope="class", params=[0.0, 0.5])
|
||||
def datamodule(self, request: SubRequest) -> XView2DataModule:
|
||||
root = os.path.join("tests", "data", "xview2")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
val_split_size = request.param
|
||||
dm = XView2DataModule(
|
||||
root, batch_size, num_workers, val_split_pct=val_split_size
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: XView2DataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: XView2DataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: XView2DataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BigEarthNet, BigEarthNetDataModule
|
||||
from torchgeo.datasets import BigEarthNet
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -148,26 +148,3 @@ class TestBigEarthNet:
|
|||
"to automaticaly download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
BigEarthNet(str(tmp_path))
|
||||
|
||||
|
||||
class TestBigEarthNetDataModule:
|
||||
@pytest.fixture(scope="class", params=["s1", "s2", "all"])
|
||||
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
|
||||
bands = request.param
|
||||
root = os.path.join("tests", "data", "bigearthnet")
|
||||
num_classes = 19
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = BigEarthNetDataModule(root, bands, num_classes, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -19,7 +19,6 @@ from torchgeo.datasets import (
|
|||
BoundingBox,
|
||||
Chesapeake13,
|
||||
ChesapeakeCVPR,
|
||||
ChesapeakeCVPRDataModule,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
@ -179,45 +178,3 @@ class TestChesapeakeCVPR:
|
|||
IndexError, match="query: .* spans multiple tiles which is not valid"
|
||||
):
|
||||
ds[dataset.bounds]
|
||||
|
||||
|
||||
class TestChesapeakeCVPRDataModule:
|
||||
@pytest.fixture(scope="class", params=[5, 7])
|
||||
def datamodule(self, request: SubRequest) -> ChesapeakeCVPRDataModule:
|
||||
dm = ChesapeakeCVPRDataModule(
|
||||
os.path.join("tests", "data", "chesapeake", "cvpr"),
|
||||
["de-test"],
|
||||
["de-test"],
|
||||
["de-test"],
|
||||
patch_size=32,
|
||||
patches_per_tile=2,
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
class_set=request.param,
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
||||
def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None:
|
||||
nodata_check = datamodule.nodata_check(4)
|
||||
sample = {
|
||||
"image": torch.ones(1, 2, 2), # type: ignore[attr-defined]
|
||||
"mask": torch.ones(2, 2), # type: ignore[attr-defined]
|
||||
}
|
||||
out = nodata_check(sample)
|
||||
assert torch.equal( # type: ignore[attr-defined]
|
||||
out["image"], torch.zeros(1, 4, 4) # type: ignore[attr-defined]
|
||||
)
|
||||
assert torch.equal( # type: ignore[attr-defined]
|
||||
out["mask"], torch.zeros(4, 4) # type: ignore[attr-defined]
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@ from _pytest.monkeypatch import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import COWCCounting, COWCCountingDataModule, COWCDetection
|
||||
from torchgeo.datasets import COWCCounting, COWCDetection
|
||||
from torchgeo.datasets.cowc import COWC
|
||||
|
||||
|
||||
|
@ -148,25 +148,3 @@ class TestCOWCDetection:
|
|||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
COWCDetection(str(tmp_path))
|
||||
|
||||
|
||||
class TestCOWCCountingDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> COWCCountingDataModule:
|
||||
root = os.path.join("tests", "data", "cowc_counting")
|
||||
seed = 0
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = COWCCountingDataModule(root, seed, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: COWCCountingDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: COWCCountingDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: COWCCountingDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -15,7 +15,7 @@ from _pytest.fixtures import SubRequest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import CycloneDataModule, TropicalCycloneWindEstimation
|
||||
from torchgeo.datasets import TropicalCycloneWindEstimation
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
@ -103,25 +103,3 @@ class TestTropicalCycloneWindEstimation:
|
|||
)
|
||||
dataset.plot(sample)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestCycloneDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> CycloneDataModule:
|
||||
root = os.path.join("tests", "data", "cyclone")
|
||||
seed = 0
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = CycloneDataModule(root, seed, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: CycloneDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: CycloneDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: CycloneDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -14,7 +14,7 @@ from _pytest.fixtures import SubRequest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ETCI2021, ETCI2021DataModule
|
||||
from torchgeo.datasets import ETCI2021
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -95,25 +95,3 @@ class TestETCI2021:
|
|||
x["prediction"] = x["mask"][0].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestETCI2021DataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> ETCI2021DataModule:
|
||||
root = os.path.join("tests", "data", "etci2021")
|
||||
seed = 0
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = ETCI2021DataModule(root, seed, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: ETCI2021DataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: ETCI2021DataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: ETCI2021DataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import EuroSAT, EuroSATDataModule
|
||||
from torchgeo.datasets import EuroSAT
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -100,24 +100,3 @@ class TestEuroSAT:
|
|||
x["prediction"] = x["label"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestEuroSATDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> EuroSATDataModule:
|
||||
root = os.path.join("tests", "data", "eurosat")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = EuroSATDataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: EuroSATDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: EuroSATDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: EuroSATDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import FAIR1M, FAIR1MDataModule
|
||||
from torchgeo.datasets import FAIR1M
|
||||
|
||||
|
||||
class TestFAIR1M:
|
||||
|
@ -73,25 +73,3 @@ class TestFAIR1M:
|
|||
x["prediction_boxes"] = x["boxes"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestFAIR1MDataModule:
|
||||
@pytest.fixture(scope="class", params=[True, False])
|
||||
def datamodule(self) -> FAIR1MDataModule:
|
||||
root = os.path.join("tests", "data", "fair1m")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = FAIR1MDataModule(
|
||||
root, batch_size, num_workers, val_split_pct=0.33, test_split_pct=0.33
|
||||
)
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import LandCoverAI, LandCoverAIDataModule
|
||||
from torchgeo.datasets import LandCoverAI
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -78,24 +78,3 @@ class TestLandCoverAI:
|
|||
x["prediction"] = x["mask"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestLandCoverAIDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> LandCoverAIDataModule:
|
||||
root = os.path.join("tests", "data", "landcoverai")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = LandCoverAIDataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: LandCoverAIDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -14,7 +14,7 @@ from _pytest.fixtures import SubRequest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import LoveDA, LoveDADataModule
|
||||
from torchgeo.datasets import LoveDA
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -99,29 +99,3 @@ class TestLoveDA:
|
|||
def test_plot(self, dataset: LoveDA) -> None:
|
||||
dataset.plot(dataset[0], suptitle="Test")
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestLoveDADataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> LoveDADataModule:
|
||||
root = os.path.join("tests", "data", "loveda")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
scene = ["rural", "urban"]
|
||||
|
||||
dm = LoveDADataModule(
|
||||
root_dir=root, scene=scene, batch_size=batch_size, num_workers=num_workers
|
||||
)
|
||||
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: LoveDADataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: LoveDADataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: LoveDADataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -12,13 +12,7 @@ import torch.nn as nn
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import (
|
||||
NAIP,
|
||||
BoundingBox,
|
||||
IntersectionDataset,
|
||||
NAIPChesapeakeDataModule,
|
||||
UnionDataset,
|
||||
)
|
||||
from torchgeo.datasets import NAIP, BoundingBox, IntersectionDataset, UnionDataset
|
||||
|
||||
|
||||
class TestNAIP:
|
||||
|
@ -60,27 +54,3 @@ class TestNAIP:
|
|||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
dataset[query]
|
||||
|
||||
|
||||
class TestNAIPChesapeakeDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> NAIPChesapeakeDataModule:
|
||||
dm = NAIPChesapeakeDataModule(
|
||||
os.path.join("tests", "data", "naip"),
|
||||
os.path.join("tests", "data", "chesapeake", "BAYWIDE"),
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
)
|
||||
dm.patch_size = 32
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import NASAMarineDebris, NASAMarineDebrisDataModule
|
||||
from torchgeo.datasets import NASAMarineDebris
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
@ -85,28 +85,3 @@ class TestNASAMarineDebris:
|
|||
x["prediction_boxes"] = x["boxes"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestNASAMarineDebrisDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> NASAMarineDebrisDataModule:
|
||||
root = os.path.join("tests", "data", "nasa_marine_debris")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
val_split_pct = 0.3
|
||||
test_split_pct = 0.3
|
||||
dm = NASAMarineDebrisDataModule(
|
||||
root, batch_size, num_workers, val_split_pct, test_split_pct
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: NASAMarineDebrisDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -16,7 +16,7 @@ from matplotlib import pyplot as plt
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import OSCD, OSCDDataModule
|
||||
from torchgeo.datasets import OSCD
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -105,56 +105,3 @@ class TestOSCD:
|
|||
def test_plot(self, dataset: OSCD) -> None:
|
||||
dataset.plot(dataset[0], suptitle="Test")
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestOSCDDataModule:
|
||||
@pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5]))
|
||||
def datamodule(self, request: SubRequest) -> OSCDDataModule:
|
||||
bands, val_split_pct = request.param
|
||||
patch_size = (2, 2)
|
||||
num_patches_per_tile = 2
|
||||
root = os.path.join("tests", "data", "oscd")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = OSCDDataModule(
|
||||
root,
|
||||
bands,
|
||||
batch_size,
|
||||
num_workers,
|
||||
val_split_pct,
|
||||
patch_size,
|
||||
num_patches_per_tile,
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
sample = next(iter(datamodule.train_dataloader()))
|
||||
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2)
|
||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 2
|
||||
if datamodule.bands == "all":
|
||||
assert sample["image"].shape[1] == 26
|
||||
else:
|
||||
assert sample["image"].shape[1] == 6
|
||||
|
||||
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
sample = next(iter(datamodule.val_dataloader()))
|
||||
if datamodule.val_split_pct > 0.0:
|
||||
assert (
|
||||
sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
|
||||
)
|
||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
||||
if datamodule.bands == "all":
|
||||
assert sample["image"].shape[1] == 26
|
||||
else:
|
||||
assert sample["image"].shape[1] == 6
|
||||
|
||||
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
sample = next(iter(datamodule.test_dataloader()))
|
||||
assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280)
|
||||
assert sample["image"].shape[0] == sample["mask"].shape[0] == 1
|
||||
if datamodule.bands == "all":
|
||||
assert sample["image"].shape[1] == 26
|
||||
else:
|
||||
assert sample["image"].shape[1] == 6
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import Potsdam2D, Potsdam2DDataModule
|
||||
from torchgeo.datasets import Potsdam2D
|
||||
|
||||
|
||||
class TestPotsdam2D:
|
||||
|
@ -75,27 +75,3 @@ class TestPotsdam2D:
|
|||
x["prediction"] = x["mask"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestPotsdam2DDataModule:
|
||||
@pytest.fixture(scope="class", params=[0.0, 0.5])
|
||||
def datamodule(self, request: SubRequest) -> Potsdam2DDataModule:
|
||||
root = os.path.join("tests", "data", "potsdam")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
val_split_size = request.param
|
||||
dm = Potsdam2DDataModule(
|
||||
root, batch_size, num_workers, val_split_pct=val_split_size
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -15,7 +15,7 @@ from _pytest.fixtures import SubRequest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import RESISC45, RESISC45DataModule
|
||||
from torchgeo.datasets import RESISC45
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -101,24 +101,3 @@ class TestRESISC45:
|
|||
x["prediction"] = x["label"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestRESISC45DataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> RESISC45DataModule:
|
||||
root = os.path.join("tests", "data", "resisc45")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = RESISC45DataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: RESISC45DataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: RESISC45DataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: RESISC45DataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import SEN12MS, SEN12MSDataModule
|
||||
from torchgeo.datasets import SEN12MS
|
||||
|
||||
|
||||
class TestSEN12MS:
|
||||
|
@ -82,26 +82,3 @@ class TestSEN12MS:
|
|||
ds = SEN12MS(root, bands=bands, checksum=False)
|
||||
x = ds[0]["image"]
|
||||
assert x.shape[0] == len(bands)
|
||||
|
||||
|
||||
class TestSEN12MSDataModule:
|
||||
@pytest.fixture(scope="class", params=["all", "s1", "s2-all", "s2-reduced"])
|
||||
def datamodule(self, request: SubRequest) -> SEN12MSDataModule:
|
||||
root = os.path.join("tests", "data", "sen12ms")
|
||||
seed = 0
|
||||
bands = request.param
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = SEN12MSDataModule(root, seed, bands, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: SEN12MSDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: SEN12MSDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: SEN12MSDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import So2Sat, So2SatDataModule
|
||||
from torchgeo.datasets import So2Sat
|
||||
|
||||
pytest.importorskip("h5py")
|
||||
|
||||
|
@ -91,25 +91,3 @@ class TestSo2Sat:
|
|||
match="h5py is not installed and is required to use this dataset",
|
||||
):
|
||||
So2Sat(dataset.root)
|
||||
|
||||
|
||||
class TestSo2SatDataModule:
|
||||
@pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"]))
|
||||
def datamodule(self, request: SubRequest) -> So2SatDataModule:
|
||||
unsupervised_mode, bands = request.param
|
||||
root = os.path.join("tests", "data", "so2sat")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = So2SatDataModule(root, batch_size, num_workers, bands, unsupervised_mode)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: So2SatDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: So2SatDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: So2SatDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import UCMerced, UCMercedDataModule
|
||||
from torchgeo.datasets import UCMerced
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -102,24 +102,3 @@ class TestUCMerced:
|
|||
x["prediction"] = x["label"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestUCMercedDataModule:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> UCMercedDataModule:
|
||||
root = os.path.join("tests", "data", "ucmerced")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = UCMercedDataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: UCMercedDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: UCMercedDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: UCMercedDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -18,13 +18,11 @@ import pytest
|
|||
import torch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets.utils import (
|
||||
BoundingBox,
|
||||
concat_samples,
|
||||
dataset_split,
|
||||
disambiguate_timestamp,
|
||||
download_and_extract_archive,
|
||||
download_radiant_mlhub_collection,
|
||||
|
@ -563,24 +561,6 @@ def test_nonexisting_directory(tmp_path: Path) -> None:
|
|||
assert subdir.cwd() == subdir
|
||||
|
||||
|
||||
def test_dataset_split() -> None:
|
||||
num_samples = 24
|
||||
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
|
||||
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
|
||||
ds = TensorDataset(x, y)
|
||||
|
||||
# Test only train/val set split
|
||||
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
|
||||
assert len(train_ds) == num_samples // 2
|
||||
assert len(val_ds) == num_samples // 2
|
||||
|
||||
# Test train/val/test set split
|
||||
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
|
||||
assert len(train_ds) == num_samples // 3
|
||||
assert len(val_ds) == num_samples // 3
|
||||
assert len(test_ds) == num_samples // 3
|
||||
|
||||
|
||||
def test_percentile_normalization() -> None:
|
||||
img = np.array([[1, 2], [98, 100]])
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import Vaihingen2D, Vaihingen2DDataModule
|
||||
from torchgeo.datasets import Vaihingen2D
|
||||
|
||||
|
||||
class TestVaihingen2D:
|
||||
|
@ -84,27 +84,3 @@ class TestVaihingen2D:
|
|||
x["prediction"] = x["mask"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestVaihingen2DDataModule:
|
||||
@pytest.fixture(scope="class", params=[0.0, 0.5])
|
||||
def datamodule(self, request: SubRequest) -> Vaihingen2DDataModule:
|
||||
root = os.path.join("tests", "data", "vaihingen")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
val_split_size = request.param
|
||||
dm = Vaihingen2DDataModule(
|
||||
root, batch_size, num_workers, val_split_pct=val_split_size
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import XView2, XView2DataModule
|
||||
from torchgeo.datasets import XView2
|
||||
|
||||
|
||||
class TestXView2:
|
||||
|
@ -95,27 +95,3 @@ class TestXView2:
|
|||
x["prediction"] = x["mask"][0].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestXView2DataModule:
|
||||
@pytest.fixture(scope="class", params=[0.0, 0.5])
|
||||
def datamodule(self, request: SubRequest) -> XView2DataModule:
|
||||
root = os.path.join("tests", "data", "xview2")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
val_split_size = request.param
|
||||
dm = XView2DataModule(
|
||||
root, batch_size, num_workers, val_split_pct=val_split_size
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: XView2DataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: XView2DataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: XView2DataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
|
|
|
@ -12,7 +12,7 @@ from omegaconf import OmegaConf
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from torchvision.models import resnet18
|
||||
|
||||
from torchgeo.datasets import ChesapeakeCVPRDataModule
|
||||
from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
||||
from torchgeo.trainers import BYOLTask
|
||||
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from _pytest.fixtures import SubRequest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.datasets import ChesapeakeCVPRDataModule
|
||||
from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
||||
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask
|
||||
|
||||
from .test_utils import FakeTrainer, mocked_log
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.datasets import LandCoverAIDataModule
|
||||
from torchgeo.datamodules import LandCoverAIDataModule
|
||||
from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask
|
||||
|
||||
from .test_utils import FakeTrainer, mocked_log
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.datasets import NAIPChesapeakeDataModule
|
||||
from torchgeo.datamodules import NAIPChesapeakeDataModule
|
||||
from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask
|
||||
|
||||
from .test_utils import FakeTrainer, mocked_log
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.datasets import CycloneDataModule
|
||||
from torchgeo.datamodules import CycloneDataModule
|
||||
from torchgeo.trainers import RegressionTask
|
||||
|
||||
from .test_utils import mocked_log
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Generator
|
|||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import RESISC45DataModule
|
||||
from torchgeo.datamodules import RESISC45DataModule
|
||||
from torchgeo.trainers.resisc45 import RESISC45ClassificationTask
|
||||
|
||||
from .test_utils import FakeTrainer, mocked_log
|
||||
|
|
|
@ -9,7 +9,7 @@ from _pytest.fixtures import SubRequest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.datasets import ChesapeakeCVPRDataModule
|
||||
from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
||||
from torchgeo.trainers import SemanticSegmentationTask
|
||||
|
||||
from .test_utils import FakeTrainer, mocked_log
|
||||
|
|
|
@ -13,7 +13,7 @@ from typing import Dict, Tuple, Type
|
|||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from .datasets import (
|
||||
from .datamodules import (
|
||||
BigEarthNetDataModule,
|
||||
ChesapeakeCVPRDataModule,
|
||||
COWCCountingDataModule,
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""TorchGeo datamodules."""
|
||||
|
||||
from .bigearthnet import BigEarthNetDataModule
|
||||
from .chesapeake import ChesapeakeCVPRDataModule
|
||||
from .cowc import COWCCountingDataModule
|
||||
from .cyclone import CycloneDataModule
|
||||
from .etci2021 import ETCI2021DataModule
|
||||
from .eurosat import EuroSATDataModule
|
||||
from .fair1m import FAIR1MDataModule
|
||||
from .landcoverai import LandCoverAIDataModule
|
||||
from .loveda import LoveDADataModule
|
||||
from .naip import NAIPChesapeakeDataModule
|
||||
from .nasa_marine_debris import NASAMarineDebrisDataModule
|
||||
from .oscd import OSCDDataModule
|
||||
from .potsdam import Potsdam2DDataModule
|
||||
from .resisc45 import RESISC45DataModule
|
||||
from .sen12ms import SEN12MSDataModule
|
||||
from .so2sat import So2SatDataModule
|
||||
from .ucmerced import UCMercedDataModule
|
||||
from .vaihingen import Vaihingen2DDataModule
|
||||
from .xview import XView2DataModule
|
||||
|
||||
__all__ = (
|
||||
# GeoDataset
|
||||
"ChesapeakeCVPRDataModule",
|
||||
"NAIPChesapeakeDataModule",
|
||||
# VisionDataset
|
||||
"BigEarthNetDataModule",
|
||||
"COWCCountingDataModule",
|
||||
"ETCI2021DataModule",
|
||||
"EuroSATDataModule",
|
||||
"FAIR1MDataModule",
|
||||
"LandCoverAIDataModule",
|
||||
"LoveDADataModule",
|
||||
"NASAMarineDebrisDataModule",
|
||||
"OSCDDataModule",
|
||||
"Potsdam2DDataModule",
|
||||
"RESISC45DataModule",
|
||||
"SEN12MSDataModule",
|
||||
"So2SatDataModule",
|
||||
"CycloneDataModule",
|
||||
"UCMercedDataModule",
|
||||
"Vaihingen2DDataModule",
|
||||
"XView2DataModule",
|
||||
)
|
||||
|
||||
# https://stackoverflow.com/questions/40018681
|
||||
for module in __all__:
|
||||
globals()[module].__module__ = "torchgeo.datamodules"
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""BigEarthNet datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import BigEarthNet
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class BigEarthNetDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the BigEarthNet dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
# (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
|
||||
# min/max band statistics computed on 100k random samples
|
||||
band_mins_raw = torch.tensor( # type: ignore[attr-defined]
|
||||
[-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
|
||||
)
|
||||
band_maxs_raw = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
31.0,
|
||||
35.0,
|
||||
18556.0,
|
||||
20528.0,
|
||||
18976.0,
|
||||
17874.0,
|
||||
16611.0,
|
||||
16512.0,
|
||||
16394.0,
|
||||
16672.0,
|
||||
16141.0,
|
||||
16097.0,
|
||||
15336.0,
|
||||
15203.0,
|
||||
]
|
||||
)
|
||||
|
||||
# min/max band statistics computed by percentile clipping the
|
||||
# above to samples to [2, 98]
|
||||
band_mins = torch.tensor( # type: ignore[attr-defined]
|
||||
[-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
)
|
||||
band_maxs = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
6.0,
|
||||
16.0,
|
||||
9859.0,
|
||||
12872.0,
|
||||
13163.0,
|
||||
14445.0,
|
||||
12477.0,
|
||||
12563.0,
|
||||
12289.0,
|
||||
15596.0,
|
||||
12183.0,
|
||||
9458.0,
|
||||
5897.0,
|
||||
5544.0,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
bands: str = "all",
|
||||
num_classes: int = 19,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes
|
||||
bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all}
|
||||
num_classes: number of classes to load in target. one of {19, 43}
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.bands = bands
|
||||
self.num_classes = num_classes
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
if bands == "all":
|
||||
self.mins = self.band_mins[:, None, None]
|
||||
self.maxs = self.band_maxs[:, None, None]
|
||||
elif bands == "s1":
|
||||
self.mins = self.band_mins[:2, None, None]
|
||||
self.maxs = self.band_maxs[:2, None, None]
|
||||
else:
|
||||
self.mins = self.band_mins[2:, None, None]
|
||||
self.maxs = self.band_maxs[2:, None, None]
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins)
|
||||
sample["image"] = torch.clip( # type: ignore[attr-defined]
|
||||
sample["image"], min=0.0, max=1.0
|
||||
)
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
BigEarthNet(self.root_dir, split="train", bands=self.bands, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
self.train_dataset = BigEarthNet(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.bands,
|
||||
num_classes=self.num_classes,
|
||||
transforms=transforms,
|
||||
)
|
||||
self.val_dataset = BigEarthNet(
|
||||
self.root_dir,
|
||||
split="val",
|
||||
bands=self.bands,
|
||||
num_classes=self.num_classes,
|
||||
transforms=transforms,
|
||||
)
|
||||
self.test_dataset = BigEarthNet(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
bands=self.bands,
|
||||
num_classes=self.num_classes,
|
||||
transforms=transforms,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation."""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing."""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,312 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Chesapeake Bay High-Resolution Land Cover Project datamodule."""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import ChesapeakeCVPR, stack_samples
|
||||
from ..samplers.batch import RandomBatchGeoSampler
|
||||
from ..samplers.single import GridGeoSampler
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class ChesapeakeCVPRDataModule(LightningDataModule):
|
||||
"""LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.
|
||||
|
||||
Uses the random splits defined per state to partition tiles into train, val,
|
||||
and test sets.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
train_splits: List[str],
|
||||
val_splits: List[str],
|
||||
test_splits: List[str],
|
||||
patches_per_tile: int = 200,
|
||||
patch_size: int = 256,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
class_set: int = 7,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the ChesapeakeCVPR Dataset
|
||||
classes
|
||||
train_splits: The splits used to train the model, e.g. ["ny-train"]
|
||||
val_splits: The splits used to validate the model, e.g. ["ny-val"]
|
||||
test_splits: The splits used to test the model, e.g. ["ny-test"]
|
||||
patches_per_tile: The number of patches per tile to sample
|
||||
patch_size: The size of each patch in pixels (test patches will be 1.5 times
|
||||
this size)
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
class_set: The high-resolution land cover class set to use - 5 or 7
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
for state in train_splits + val_splits + test_splits:
|
||||
assert state in ChesapeakeCVPR.splits
|
||||
assert class_set in [5, 7]
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.train_splits = train_splits
|
||||
self.val_splits = val_splits
|
||||
self.test_splits = test_splits
|
||||
self.layers = ["naip-new", "lc"]
|
||||
self.patches_per_tile = patches_per_tile
|
||||
self.patch_size = patch_size
|
||||
# This is a rough estimate of how large of a patch we will need to sample in
|
||||
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
|
||||
self.original_patch_size = int(patch_size * 2.0)
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.class_set = class_set
|
||||
|
||||
def pad_to(
|
||||
self, size: int = 512, image_value: int = 0, mask_value: int = 0
|
||||
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
|
||||
"""Returns a function to perform a padding transform on a single sample.
|
||||
|
||||
Args:
|
||||
size: output image size
|
||||
image_value: value to pad image with
|
||||
mask_value: value to pad mask with
|
||||
|
||||
Returns:
|
||||
function to perform padding
|
||||
"""
|
||||
|
||||
def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
_, height, width = sample["image"].shape
|
||||
assert height <= size and width <= size
|
||||
|
||||
height_pad = size - height
|
||||
width_pad = size - width
|
||||
|
||||
# See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
||||
# for a description of the format of the padding tuple
|
||||
sample["image"] = F.pad(
|
||||
sample["image"],
|
||||
(0, width_pad, 0, height_pad),
|
||||
mode="constant",
|
||||
value=image_value,
|
||||
)
|
||||
sample["mask"] = F.pad(
|
||||
sample["mask"],
|
||||
(0, width_pad, 0, height_pad),
|
||||
mode="constant",
|
||||
value=mask_value,
|
||||
)
|
||||
return sample
|
||||
|
||||
return pad_inner
|
||||
|
||||
def center_crop(
|
||||
self, size: int = 512
|
||||
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
|
||||
"""Returns a function to perform a center crop transform on a single sample.
|
||||
|
||||
Args:
|
||||
size: output image size
|
||||
|
||||
Returns:
|
||||
function to perform center crop
|
||||
"""
|
||||
|
||||
def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
_, height, width = sample["image"].shape
|
||||
|
||||
y1 = (height - size) // 2
|
||||
x1 = (width - size) // 2
|
||||
sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size]
|
||||
sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size]
|
||||
|
||||
return sample
|
||||
|
||||
return center_crop_inner
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Preprocesses a single sample.
|
||||
|
||||
Args:
|
||||
sample: sample dictionary containing image and mask
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0
|
||||
sample["mask"] = sample["mask"]
|
||||
sample["mask"] = sample["mask"].squeeze()
|
||||
|
||||
if self.class_set == 5:
|
||||
sample["mask"][sample["mask"] == 5] = 4
|
||||
sample["mask"][sample["mask"] == 6] = 4
|
||||
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["mask"] = sample["mask"].long()
|
||||
|
||||
return sample
|
||||
|
||||
def nodata_check(
|
||||
self, size: int = 512
|
||||
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
|
||||
"""Returns a function to check for nodata or mis-sized input.
|
||||
|
||||
Args:
|
||||
size: output image size
|
||||
|
||||
Returns:
|
||||
function to check for nodata values
|
||||
"""
|
||||
|
||||
def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
num_channels, height, width = sample["image"].shape
|
||||
|
||||
if height < size or width < size:
|
||||
sample["image"] = torch.zeros( # type: ignore[attr-defined]
|
||||
(num_channels, size, size)
|
||||
)
|
||||
sample["mask"] = torch.zeros((size, size)) # type: ignore[attr-defined]
|
||||
|
||||
return sample
|
||||
|
||||
return nodata_check_inner
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Confirms that the dataset is downloaded on the local node.
|
||||
|
||||
This method is called once per node, while :func:`setup` is called once per GPU.
|
||||
"""
|
||||
ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
splits=self.train_splits,
|
||||
layers=self.layers,
|
||||
transforms=None,
|
||||
download=False,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_transforms = Compose(
|
||||
[
|
||||
self.center_crop(self.patch_size),
|
||||
self.nodata_check(self.patch_size),
|
||||
self.preprocess,
|
||||
]
|
||||
)
|
||||
val_transforms = Compose(
|
||||
[
|
||||
self.center_crop(self.patch_size),
|
||||
self.nodata_check(self.patch_size),
|
||||
self.preprocess,
|
||||
]
|
||||
)
|
||||
test_transforms = Compose(
|
||||
[
|
||||
self.pad_to(self.original_patch_size, image_value=0, mask_value=0),
|
||||
self.preprocess,
|
||||
]
|
||||
)
|
||||
|
||||
self.train_dataset = ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
splits=self.train_splits,
|
||||
layers=self.layers,
|
||||
transforms=train_transforms,
|
||||
download=False,
|
||||
checksum=False,
|
||||
)
|
||||
self.val_dataset = ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
splits=self.val_splits,
|
||||
layers=self.layers,
|
||||
transforms=val_transforms,
|
||||
download=False,
|
||||
checksum=False,
|
||||
)
|
||||
self.test_dataset = ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
splits=self.test_splits,
|
||||
layers=self.layers,
|
||||
transforms=test_transforms,
|
||||
download=False,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
sampler = RandomBatchGeoSampler(
|
||||
self.train_dataset,
|
||||
size=self.original_patch_size,
|
||||
batch_size=self.batch_size,
|
||||
length=self.patches_per_tile * len(self.train_dataset),
|
||||
)
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
sampler = GridGeoSampler(
|
||||
self.val_dataset,
|
||||
size=self.original_patch_size,
|
||||
stride=self.original_patch_size,
|
||||
)
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
sampler = GridGeoSampler(
|
||||
self.test_dataset,
|
||||
size=self.original_patch_size,
|
||||
stride=self.original_patch_size,
|
||||
)
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""COWC datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch import Generator # type: ignore[attr-defined]
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
from ..datasets import COWCCounting
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class COWCCountingDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the COWC Counting dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for COWC Counting based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the COWCCounting Dataset class
|
||||
seed: The seed value to use when doing the dataset random_split
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and target
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0 # scale to [0, 1]
|
||||
sample["label"] = sample["label"].float()
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
|
||||
|
||||
This includes optionally downloading the dataset. This is done once per node,
|
||||
while :func:`setup` is done once per GPU.
|
||||
"""
|
||||
COWCCounting(self.root_dir, download=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_val_dataset = COWCCounting(
|
||||
self.root_dir, split="train", transforms=self.custom_transform
|
||||
)
|
||||
self.test_dataset = COWCCounting(
|
||||
self.root_dir, split="test", transforms=self.custom_transform
|
||||
)
|
||||
self.train_dataset, self.val_dataset = random_split(
|
||||
train_val_dataset,
|
||||
[len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)],
|
||||
generator=Generator().manual_seed(self.seed),
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tropical Cyclone Wind Estimation Competition datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from sklearn.model_selection import GroupShuffleSplit
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from ..datasets import TropicalCycloneWindEstimation
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class CycloneDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the NASA Cyclone dataset.
|
||||
|
||||
Implements 80/20 train/val splits based on hurricane storm ids.
|
||||
See :func:`setup` for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the
|
||||
TropicalCycloneWindEstimation Datasets classes
|
||||
seed: The seed value to use when doing the sklearn based GroupShuffleSplit
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
api_key: The RadiantEarth MLHub API key to use if the dataset needs to be
|
||||
downloaded
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.api_key = api_key
|
||||
|
||||
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and target
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0 # scale to [0,1]
|
||||
sample["image"] = (
|
||||
sample["image"].unsqueeze(0).repeat(3, 1, 1)
|
||||
) # convert to 3 channel
|
||||
sample["label"] = torch.as_tensor( # type: ignore[attr-defined]
|
||||
sample["label"]
|
||||
).float()
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
|
||||
|
||||
This includes optionally downloading the dataset. This is done once per node,
|
||||
while :func:`setup` is done once per GPU.
|
||||
"""
|
||||
TropicalCycloneWindEstimation(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
transforms=self.custom_transform,
|
||||
download=self.api_key is not None,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
|
||||
We split samples between train/val by the ``storm_id`` property. I.e. all
|
||||
samples with the same ``storm_id`` value will be either in the train or the val
|
||||
split. This is important to test one type of generalizability -- given a new
|
||||
storm, can we predict its windspeed. The test set, however, contains *some*
|
||||
storms from the training set (specifically, the latter parts of the storms) as
|
||||
well as some novel storms.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
self.all_train_dataset = TropicalCycloneWindEstimation(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
transforms=self.custom_transform,
|
||||
download=False,
|
||||
)
|
||||
|
||||
self.all_test_dataset = TropicalCycloneWindEstimation(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
transforms=self.custom_transform,
|
||||
download=False,
|
||||
)
|
||||
|
||||
storm_ids = []
|
||||
for item in self.all_train_dataset.collection:
|
||||
storm_id = item["href"].split("/")[0].split("_")[-2]
|
||||
storm_ids.append(storm_id)
|
||||
|
||||
train_indices, val_indices = next(
|
||||
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
|
||||
storm_ids, groups=storm_ids
|
||||
)
|
||||
)
|
||||
|
||||
self.train_dataset = Subset(self.all_train_dataset, train_indices)
|
||||
self.val_dataset = Subset(self.all_train_dataset, val_indices)
|
||||
self.test_dataset = Subset(
|
||||
self.all_test_dataset, range(len(self.all_test_dataset))
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,151 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""ETCI 2021 datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import Generator # type: ignore[attr-defined]
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
from ..datasets import ETCI2021
|
||||
|
||||
|
||||
class ETCI2021DataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the ETCI2021 dataset.
|
||||
|
||||
Splits the existing train split from the dataset into train/val with 80/20
|
||||
proportions, then uses the existing val dataset as the test data.
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int = 0,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for ETCI2021 based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the ETCI2021 Dataset classes
|
||||
seed: The seed value to use when doing the dataset random_split
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Notably, moves the given water mask to act as an input layer.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
image = sample["image"]
|
||||
water_mask = sample["mask"][0].unsqueeze(0)
|
||||
flood_mask = sample["mask"][1]
|
||||
flood_mask = (flood_mask > 0).long()
|
||||
|
||||
sample["image"] = torch.cat( # type: ignore[attr-defined]
|
||||
[image, water_mask], dim=0
|
||||
).float()
|
||||
sample["image"] /= 255.0
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
sample["mask"] = flood_mask
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
ETCI2021(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_val_dataset = ETCI2021(
|
||||
self.root_dir, split="train", transforms=self.preprocess
|
||||
)
|
||||
self.test_dataset = ETCI2021(
|
||||
self.root_dir, split="val", transforms=self.preprocess
|
||||
)
|
||||
|
||||
size_train_val = len(train_val_dataset)
|
||||
size_train = int(0.8 * size_train_val)
|
||||
size_val = size_train_val - size_train
|
||||
|
||||
self.train_dataset, self.val_dataset = random_split(
|
||||
train_val_dataset,
|
||||
[size_train, size_val],
|
||||
generator=Generator().manual_seed(self.seed),
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""EuroSAT datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from ..datasets import EuroSAT
|
||||
|
||||
|
||||
class EuroSATDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the EuroSAT dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
1354.40546513,
|
||||
1118.24399958,
|
||||
1042.92983953,
|
||||
947.62620298,
|
||||
1199.47283961,
|
||||
1999.79090914,
|
||||
2369.22292565,
|
||||
2296.82608323,
|
||||
732.08340178,
|
||||
12.11327804,
|
||||
1819.01027855,
|
||||
1118.92391149,
|
||||
2594.14080798,
|
||||
]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
245.71762908,
|
||||
333.00778264,
|
||||
395.09249139,
|
||||
593.75055589,
|
||||
566.4170017,
|
||||
861.18399006,
|
||||
1086.63139075,
|
||||
1117.98170791,
|
||||
404.91978886,
|
||||
4.77584468,
|
||||
1002.58768311,
|
||||
761.30323499,
|
||||
1231.58581042,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for EuroSAT based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
EuroSAT(self.root_dir)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms)
|
||||
self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms)
|
||||
self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""FAIR1M datamodule."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import FAIR1M
|
||||
from .utils import dataset_split
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
|
||||
"""Custom object detection collate fn to handle variable number of boxes.
|
||||
|
||||
Args:
|
||||
batch: list of sample dicts return by dataset
|
||||
Returns:
|
||||
batch dict output
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
output["image"] = torch.stack([sample["image"] for sample in batch])
|
||||
output["boxes"] = [sample["boxes"] for sample in batch]
|
||||
return output
|
||||
|
||||
|
||||
class FAIR1MDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the FAIR1M dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
test_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for FAIR1M based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the FAIR1M Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
test_split_pct: What percentage of the dataset to use as a test set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = FAIR1M(self.root_dir, transforms=transforms)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""LandCover.ai datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ..datasets import LandCoverAI
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class LandCoverAIDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the LandCover.ai dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for LandCover.ai based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and mask
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0
|
||||
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["mask"] = sample["mask"].float().unsqueeze(0) + 1
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
_ = LandCoverAI(self.root_dir, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_transforms = self.preprocess
|
||||
val_test_transforms = self.preprocess
|
||||
|
||||
self.train_dataset = LandCoverAI(
|
||||
self.root_dir, split="train", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.val_dataset = LandCoverAI(
|
||||
self.root_dir, split="val", transforms=val_test_transforms
|
||||
)
|
||||
|
||||
self.test_dataset = LandCoverAI(
|
||||
self.root_dir, split="test", transforms=val_test_transforms
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""LoveDA datamodule."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ..datasets import LoveDA
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class LoveDADataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the LoveDA dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
scene: List[str],
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for LoveDA based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` argument to pass to LoveDA Dataset classes
|
||||
scene: specify whether to load only 'urban', only 'rural' or both
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.scene = scene
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and mask
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
_ = LoveDA(self.root_dir, scene=self.scene, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_transforms = self.preprocess
|
||||
val_test_transforms = self.preprocess
|
||||
|
||||
self.train_dataset = LoveDA(
|
||||
self.root_dir, split="train", scene=self.scene, transforms=train_transforms
|
||||
)
|
||||
|
||||
self.val_dataset = LoveDA(
|
||||
self.root_dir, split="val", scene=self.scene, transforms=val_test_transforms
|
||||
)
|
||||
|
||||
self.test_dataset = LoveDA(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
scene=self.scene,
|
||||
transforms=val_test_transforms,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""National Agriculture Imagery Program (NAIP) datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples
|
||||
from ..samplers.batch import RandomBatchGeoSampler
|
||||
from ..samplers.single import GridGeoSampler
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the NAIP and Chesapeake datasets.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
# TODO: tune these hyperparams
|
||||
length = 1000
|
||||
stride = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
naip_root_dir: str,
|
||||
chesapeake_root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
patch_size: int = 256,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders.
|
||||
|
||||
Args:
|
||||
naip_root_dir: directory containing NAIP data
|
||||
chesapeake_root_dir: directory containing Chesapeake data
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
patch_size: size of patches to sample
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.naip_root_dir = naip_root_dir
|
||||
self.chesapeake_root_dir = chesapeake_root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.patch_size = patch_size
|
||||
|
||||
def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the NAIP Dataset.
|
||||
|
||||
Args:
|
||||
sample: NAIP image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed NAIP data
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0
|
||||
sample["image"] = sample["image"].float()
|
||||
return sample
|
||||
|
||||
def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Chesapeake Dataset.
|
||||
|
||||
Args:
|
||||
sample: Chesapeake mask dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed Chesapeake data
|
||||
"""
|
||||
sample["mask"] = sample["mask"].long()[0]
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: state to set up
|
||||
"""
|
||||
# TODO: these transforms will be applied independently, this won't work if we
|
||||
# add things like random horizontal flip
|
||||
chesapeake = Chesapeake13(
|
||||
self.chesapeake_root_dir, transforms=self.chesapeake_transform
|
||||
)
|
||||
naip = NAIP(
|
||||
self.naip_root_dir,
|
||||
chesapeake.crs,
|
||||
chesapeake.res,
|
||||
transforms=self.naip_transform,
|
||||
)
|
||||
self.dataset = chesapeake & naip
|
||||
|
||||
# TODO: figure out better train/val/test split
|
||||
roi = self.dataset.bounds
|
||||
midx = roi.minx + (roi.maxx - roi.minx) / 2
|
||||
midy = roi.miny + (roi.maxy - roi.miny) / 2
|
||||
train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt)
|
||||
val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt)
|
||||
test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt)
|
||||
|
||||
self.train_sampler = RandomBatchGeoSampler(
|
||||
naip, self.patch_size, self.batch_size, self.length, train_roi
|
||||
)
|
||||
self.val_sampler = GridGeoSampler(naip, self.patch_size, self.stride, val_roi)
|
||||
self.test_sampler = GridGeoSampler(naip, self.patch_size, self.stride, test_roi)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=self.train_sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=self.val_sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=self.test_sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""NASA Marine Debris datamodule."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import NASAMarineDebris
|
||||
from .utils import dataset_split
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
|
||||
"""Custom object detection collate fn to handle variable boxes.
|
||||
|
||||
Args:
|
||||
batch: list of sample dicts return by dataset
|
||||
|
||||
Returns:
|
||||
batch dict output
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
output["image"] = torch.stack([sample["image"] for sample in batch])
|
||||
output["boxes"] = [sample["boxes"] for sample in batch]
|
||||
return output
|
||||
|
||||
|
||||
class NASAMarineDebrisDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the NASA Marine Debris dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
test_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NASA Marine Debris based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` argument to pass to the Dataset class
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
test_split_pct: What percentage of the dataset to use as a test set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
NASAMarineDebris(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = NASAMarineDebris(self.root_dir, transforms=transforms)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""OSCD datamodule."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import kornia.augmentation as K
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from einops import repeat
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from ..datasets import OSCD
|
||||
from .utils import dataset_split
|
||||
|
||||
|
||||
class OSCDDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the OSCD dataset.
|
||||
|
||||
Uses the train/test splits from the dataset and further splits
|
||||
the train split into train/val splits.
|
||||
|
||||
.. versionadded: 0.2
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
1583.0741,
|
||||
1374.3202,
|
||||
1294.1616,
|
||||
1325.6158,
|
||||
1478.7408,
|
||||
1933.0822,
|
||||
2166.0608,
|
||||
2076.4868,
|
||||
2306.0652,
|
||||
690.9814,
|
||||
16.2360,
|
||||
2080.3347,
|
||||
1524.6930,
|
||||
]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
52.1937,
|
||||
83.4168,
|
||||
105.6966,
|
||||
151.1401,
|
||||
147.4615,
|
||||
115.9289,
|
||||
123.1974,
|
||||
114.6483,
|
||||
141.4530,
|
||||
73.2758,
|
||||
4.8368,
|
||||
213.4821,
|
||||
179.4793,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
bands: str = "all",
|
||||
train_batch_size: int = 32,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
patch_size: Tuple[int, int] = (64, 64),
|
||||
num_patches_per_tile: int = 32,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for OSCD based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the OSCD Dataset classes
|
||||
bands: "rgb" or "all"
|
||||
train_batch_size: The batch size used in the train DataLoader
|
||||
(val_batch_size == test_batch_size == 1)
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
patch_size: Size of random patch from image and mask (height, width)
|
||||
num_patches_per_tile: number of random patches per sample
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.bands = bands
|
||||
self.train_batch_size = train_batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
self.patch_size = patch_size
|
||||
self.num_patches_per_tile = num_patches_per_tile
|
||||
|
||||
if bands == "rgb":
|
||||
self.band_means = self.band_means[[3, 2, 1], None, None]
|
||||
self.band_stds = self.band_stds[[3, 2, 1], None, None]
|
||||
else:
|
||||
self.band_means = self.band_means[:, None, None]
|
||||
self.band_stds = self.band_stds[:, None, None]
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
self.rcrop = K.AugmentationSequential(
|
||||
K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True
|
||||
)
|
||||
self.padto = K.PadTo((1280, 1280))
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["mask"] = sample["mask"]
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
sample["image"] = torch.flatten( # type: ignore[attr-defined]
|
||||
sample["image"], 0, 1
|
||||
)
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
OSCD(self.root_dir, split="train", bands=self.bands, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
"""
|
||||
|
||||
def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
images, masks = [], []
|
||||
for i in range(self.num_patches_per_tile):
|
||||
mask = repeat(sample["mask"], "h w -> t h w", t=2).float()
|
||||
image, mask = self.rcrop(sample["image"], mask)
|
||||
mask = mask.squeeze()[0]
|
||||
images.append(image.squeeze())
|
||||
masks.append(mask.long())
|
||||
sample["image"] = torch.stack(images)
|
||||
sample["mask"] = torch.stack(masks)
|
||||
return sample
|
||||
|
||||
def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample["image"] = self.padto(sample["image"])[0]
|
||||
sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0]
|
||||
return sample
|
||||
|
||||
train_transforms = Compose([self.preprocess, n_random_crop])
|
||||
# for testing and validation we pad all inputs to a fixed size to avoid issues
|
||||
# with the upsampling paths in encoder-decoder architectures
|
||||
test_transforms = Compose([self.preprocess, pad_to])
|
||||
|
||||
train_dataset = OSCD(
|
||||
self.root_dir, split="train", bands=self.bands, transforms=train_transforms
|
||||
)
|
||||
|
||||
self.train_dataset: Dataset[Any]
|
||||
self.val_dataset: Dataset[Any]
|
||||
|
||||
if self.val_split_pct > 0.0:
|
||||
val_dataset = OSCD(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.bands,
|
||||
transforms=test_transforms,
|
||||
)
|
||||
self.train_dataset, self.val_dataset, _ = dataset_split(
|
||||
train_dataset, val_pct=self.val_split_pct, test_pct=0.0
|
||||
)
|
||||
self.val_dataset.dataset = val_dataset
|
||||
else:
|
||||
self.train_dataset = train_dataset
|
||||
self.val_dataset = train_dataset
|
||||
|
||||
self.test_dataset = OSCD(
|
||||
self.root_dir, split="test", bands=self.bands, transforms=test_transforms
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
|
||||
def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call]
|
||||
batch
|
||||
)
|
||||
r_batch["image"] = torch.flatten( # type: ignore[attr-defined]
|
||||
r_batch["image"], 0, 1
|
||||
)
|
||||
r_batch["mask"] = torch.flatten( # type: ignore[attr-defined]
|
||||
r_batch["mask"], 0, 1
|
||||
)
|
||||
return r_batch
|
||||
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.train_batch_size,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=collate_wrapper,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation."""
|
||||
return DataLoader(
|
||||
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing."""
|
||||
return DataLoader(
|
||||
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
|
||||
)
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Potsdam datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import Potsdam2D
|
||||
from .utils import dataset_split
|
||||
|
||||
|
||||
class Potsdam2DDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the Potsdam2D dataset.
|
||||
|
||||
Uses the train/test splits from the dataset.
|
||||
|
||||
.. versionadded: 0.2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for Potsdam2D based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` argument to pass to the Potsdam2D Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = Potsdam2D(self.root_dir, "train", transforms=transforms)
|
||||
|
||||
self.train_dataset: Dataset[Any]
|
||||
self.val_dataset: Dataset[Any]
|
||||
|
||||
if self.val_split_pct > 0.0:
|
||||
self.train_dataset, self.val_dataset, _ = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=0.0
|
||||
)
|
||||
else:
|
||||
self.train_dataset = dataset
|
||||
self.val_dataset = dataset
|
||||
|
||||
self.test_dataset = Potsdam2D(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""RESISC45 datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from ..datasets import RESISC45
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class RESISC45DataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the RESISC45 dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.36801773, 0.38097873, 0.343583]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.14540215, 0.13558227, 0.13203649]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
RESISC45(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
self.train_dataset = RESISC45(self.root_dir, "train", transforms=transforms)
|
||||
self.val_dataset = RESISC45(self.root_dir, "val", transforms=transforms)
|
||||
self.test_dataset = RESISC45(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,202 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""SEN12MS datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from sklearn.model_selection import GroupShuffleSplit
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from ..datasets import SEN12MS
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class SEN12MSDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the SEN12MS dataset.
|
||||
|
||||
Implements 80/20 geographic train/val splits and uses the test split from the
|
||||
classification dataset definitions. See :func:`setup` for more details.
|
||||
|
||||
Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See
|
||||
https://arxiv.org/abs/2002.08254.
|
||||
"""
|
||||
|
||||
#: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader
|
||||
#: here https://github.com/lukasliebel/dfc2020_baseline.
|
||||
DFC2020_CLASS_MAPPING = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
0, # maps 0s to 0
|
||||
1, # maps 1s to 1
|
||||
1, # maps 2s to 1
|
||||
1, # ...
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
6,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int,
|
||||
band_set: str = "all",
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes
|
||||
seed: The seed value to use when doing the sklearn based ShuffleSplit
|
||||
band_set: The subset of S1/S2 bands to use. Options are: "all",
|
||||
"s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes:
|
||||
B2, B3, B4, B8, B11, and B12.
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
assert band_set in SEN12MS.BAND_SETS.keys()
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.band_set = band_set
|
||||
self.band_indices = SEN12MS.BAND_SETS[band_set]
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and mask
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
|
||||
if self.band_set == "all":
|
||||
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
|
||||
sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000
|
||||
elif self.band_set == "s1":
|
||||
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
|
||||
else:
|
||||
sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000
|
||||
|
||||
sample["mask"] = sample["mask"][0, :, :].long()
|
||||
sample["mask"] = torch.take( # type: ignore[attr-defined]
|
||||
self.DFC2020_CLASS_MAPPING, sample["mask"]
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
|
||||
We split samples between train and val geographically with proportions of 80/20.
|
||||
This mimics the geographic test set split.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000}
|
||||
|
||||
self.all_train_dataset = SEN12MS(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.band_indices,
|
||||
transforms=self.custom_transform,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
self.all_test_dataset = SEN12MS(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
bands=self.band_indices,
|
||||
transforms=self.custom_transform,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
# A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif"
|
||||
# This patch will belong to the scene that is uniquelly identified by its
|
||||
# (season, scene_id) tuple. Because the largest scene_id is 149, we can simply
|
||||
# give each season a large number and representing a `unique_scene_id` as
|
||||
# `season_id + scene_id`.
|
||||
scenes = []
|
||||
for scene_fn in self.all_train_dataset.ids:
|
||||
parts = scene_fn.split("_")
|
||||
season_id = season_to_int[parts[1]]
|
||||
scene_id = int(parts[3])
|
||||
scenes.append(season_id + scene_id)
|
||||
|
||||
train_indices, val_indices = next(
|
||||
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
|
||||
scenes, groups=scenes
|
||||
)
|
||||
)
|
||||
|
||||
self.train_dataset = Subset(self.all_train_dataset, train_indices)
|
||||
self.val_dataset = Subset(self.all_train_dataset, val_indices)
|
||||
self.test_dataset = Subset(
|
||||
self.all_test_dataset, range(len(self.all_test_dataset))
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,225 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""So2Sat datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import So2Sat
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class So2SatDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the So2Sat dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
-3.591224256609313e-05,
|
||||
-7.658561276843396e-06,
|
||||
5.9373857475971184e-05,
|
||||
2.5166231537121083e-05,
|
||||
0.04420110659759328,
|
||||
0.25761027084996196,
|
||||
0.0007556743372573258,
|
||||
0.0013503466830024448,
|
||||
0.12375696117681859,
|
||||
0.1092774636368323,
|
||||
0.1010855203267882,
|
||||
0.1142398616114001,
|
||||
0.1592656692023089,
|
||||
0.18147236008771792,
|
||||
0.1745740312291377,
|
||||
0.19501607349635292,
|
||||
0.15428468872076637,
|
||||
0.10905050699570007,
|
||||
]
|
||||
).reshape(18, 1, 1)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
0.17555201137417686,
|
||||
0.17556463274968204,
|
||||
0.45998793417834255,
|
||||
0.455988755730148,
|
||||
2.8559909213125763,
|
||||
8.324800606439833,
|
||||
2.4498757382563103,
|
||||
1.4647352984509094,
|
||||
0.03958795985905458,
|
||||
0.047778262752410296,
|
||||
0.06636616706371974,
|
||||
0.06358874912497474,
|
||||
0.07744387147984592,
|
||||
0.09101635085921553,
|
||||
0.09218466562387101,
|
||||
0.10164581233948201,
|
||||
0.09991773043519253,
|
||||
0.08780632509122865,
|
||||
]
|
||||
).reshape(18, 1, 1)
|
||||
|
||||
# this reorders the bands to put S2 RGB first, then remainder of S2, then S1
|
||||
reindex_to_rgb_first = [
|
||||
10,
|
||||
9,
|
||||
8,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
# 0,
|
||||
# 1,
|
||||
# 2,
|
||||
# 3,
|
||||
# 4,
|
||||
# 5,
|
||||
# 6,
|
||||
# 7,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
bands: str = "rgb",
|
||||
unsupervised_mode: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for So2Sat based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
bands: Either "rgb" or "s2"
|
||||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||
val, and test sets
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.bands = bands
|
||||
self.unsupervised_mode = unsupervised_mode
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
# sample["image"] = (sample["image"] - self.band_means) / self.band_stds
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :]
|
||||
|
||||
if self.bands == "rgb":
|
||||
sample["image"] = sample["image"][:3, :, :]
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
So2Sat(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_transforms = Compose([self.preprocess])
|
||||
val_test_transforms = self.preprocess
|
||||
|
||||
if not self.unsupervised_mode:
|
||||
|
||||
self.train_dataset = So2Sat(
|
||||
self.root_dir, split="train", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.val_dataset = So2Sat(
|
||||
self.root_dir, split="validation", transforms=val_test_transforms
|
||||
)
|
||||
|
||||
self.test_dataset = So2Sat(
|
||||
self.root_dir, split="test", transforms=val_test_transforms
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
temp_train = So2Sat(
|
||||
self.root_dir, split="train", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.val_dataset = So2Sat(
|
||||
self.root_dir, split="validation", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.test_dataset = So2Sat(
|
||||
self.root_dir, split="test", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.train_dataset = cast(
|
||||
So2Sat, temp_train + self.val_dataset + self.test_dataset
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""UC Merced datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from ..datasets import UCMerced
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class UCMercedDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the UC Merced dataset.
|
||||
|
||||
Uses random train/val/test splits.
|
||||
"""
|
||||
|
||||
band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined]
|
||||
|
||||
band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined]
|
||||
|
||||
def __init__(
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for UCMerced based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
c, h, w = sample["image"].shape
|
||||
if h != 256 or w != 256:
|
||||
sample["image"] = torchvision.transforms.functional.resize(
|
||||
sample["image"], size=(256, 256)
|
||||
)
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
UCMerced(self.root_dir, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms)
|
||||
self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms)
|
||||
self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Common datamodule utilities."""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from torch.utils.data import Dataset, Subset, random_split
|
||||
|
||||
|
||||
def dataset_split(
|
||||
dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None
|
||||
) -> List[Subset[Any]]:
|
||||
"""Split a torch Dataset into train/val/test sets.
|
||||
|
||||
If ``test_pct`` is not set then only train and validation splits are returned.
|
||||
|
||||
Args:
|
||||
dataset: dataset to be split into train/val or train/val/test subsets
|
||||
val_pct: percentage of samples to be in validation set
|
||||
test_pct: (Optional) percentage of samples to be in test set
|
||||
Returns:
|
||||
a list of the subset datasets. Either [train, val] or [train, val, test]
|
||||
"""
|
||||
if test_pct is None:
|
||||
val_length = int(len(dataset) * val_pct)
|
||||
train_length = len(dataset) - val_length
|
||||
return random_split(dataset, [train_length, val_length])
|
||||
else:
|
||||
val_length = int(len(dataset) * val_pct)
|
||||
test_length = int(len(dataset) * test_pct)
|
||||
train_length = len(dataset) - (val_length + test_length)
|
||||
return random_split(dataset, [train_length, val_length, test_length])
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Vaihingen datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import Vaihingen2D
|
||||
from .utils import dataset_split
|
||||
|
||||
|
||||
class Vaihingen2DDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the Vaihingen2D dataset.
|
||||
|
||||
Uses the train/test splits from the dataset.
|
||||
|
||||
.. versionadded: 0.2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for Vaihingen2D based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` argument to pass to the Vaihingen Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = Vaihingen2D(self.root_dir, "train", transforms=transforms)
|
||||
|
||||
self.train_dataset: Dataset[Any]
|
||||
self.val_dataset: Dataset[Any]
|
||||
|
||||
if self.val_split_pct > 0.0:
|
||||
self.train_dataset, self.val_dataset, _ = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=0.0
|
||||
)
|
||||
else:
|
||||
self.train_dataset = dataset
|
||||
self.val_dataset = dataset
|
||||
|
||||
self.test_dataset = Vaihingen2D(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""xView2 datamodule."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import XView2
|
||||
from .utils import dataset_split
|
||||
|
||||
|
||||
class XView2DataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the xView2 dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
|
||||
.. versionadded: 0.2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for xView2 based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the xView2 Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = XView2(self.root_dir, "train", transforms=transforms)
|
||||
|
||||
self.train_dataset: Dataset[Any]
|
||||
self.val_dataset: Dataset[Any]
|
||||
|
||||
if self.val_split_pct > 0.0:
|
||||
self.train_dataset, self.val_dataset, _ = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=0.0
|
||||
)
|
||||
else:
|
||||
self.train_dataset = dataset
|
||||
self.val_dataset = dataset
|
||||
|
||||
self.test_dataset = XView2(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
from .advance import ADVANCE
|
||||
from .benin_cashews import BeninSmallHolderCashews
|
||||
from .bigearthnet import BigEarthNet, BigEarthNetDataModule
|
||||
from .bigearthnet import BigEarthNet
|
||||
from .cbf import CanadianBuildingFootprints
|
||||
from .cdl import CDL
|
||||
from .chesapeake import (
|
||||
|
@ -13,7 +13,6 @@ from .chesapeake import (
|
|||
Chesapeake7,
|
||||
Chesapeake13,
|
||||
ChesapeakeCVPR,
|
||||
ChesapeakeCVPRDataModule,
|
||||
ChesapeakeDC,
|
||||
ChesapeakeDE,
|
||||
ChesapeakeMD,
|
||||
|
@ -22,12 +21,12 @@ from .chesapeake import (
|
|||
ChesapeakeVA,
|
||||
ChesapeakeWV,
|
||||
)
|
||||
from .cowc import COWC, COWCCounting, COWCCountingDataModule, COWCDetection
|
||||
from .cowc import COWC, COWCCounting, COWCDetection
|
||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .cyclone import CycloneDataModule, TropicalCycloneWindEstimation
|
||||
from .etci2021 import ETCI2021, ETCI2021DataModule
|
||||
from .eurosat import EuroSAT, EuroSATDataModule
|
||||
from .fair1m import FAIR1M, FAIR1MDataModule
|
||||
from .cyclone import TropicalCycloneWindEstimation
|
||||
from .etci2021 import ETCI2021
|
||||
from .eurosat import EuroSAT
|
||||
from .fair1m import FAIR1M
|
||||
from .geo import (
|
||||
GeoDataset,
|
||||
IntersectionDataset,
|
||||
|
@ -39,7 +38,7 @@ from .geo import (
|
|||
)
|
||||
from .gid15 import GID15
|
||||
from .idtrees import IDTReeS
|
||||
from .landcoverai import LandCoverAI, LandCoverAIDataModule
|
||||
from .landcoverai import LandCoverAI
|
||||
from .landsat import (
|
||||
Landsat,
|
||||
Landsat1,
|
||||
|
@ -54,23 +53,23 @@ from .landsat import (
|
|||
Landsat9,
|
||||
)
|
||||
from .levircd import LEVIRCDPlus
|
||||
from .loveda import LoveDA, LoveDADataModule
|
||||
from .naip import NAIP, NAIPChesapeakeDataModule
|
||||
from .nasa_marine_debris import NASAMarineDebris, NASAMarineDebrisDataModule
|
||||
from .loveda import LoveDA
|
||||
from .naip import NAIP
|
||||
from .nasa_marine_debris import NASAMarineDebris
|
||||
from .nwpu import VHR10
|
||||
from .oscd import OSCD, OSCDDataModule
|
||||
from .oscd import OSCD
|
||||
from .patternnet import PatternNet
|
||||
from .potsdam import Potsdam2D, Potsdam2DDataModule
|
||||
from .resisc45 import RESISC45, RESISC45DataModule
|
||||
from .potsdam import Potsdam2D
|
||||
from .resisc45 import RESISC45
|
||||
from .seco import SeasonalContrastS2
|
||||
from .sen12ms import SEN12MS, SEN12MSDataModule
|
||||
from .sen12ms import SEN12MS
|
||||
from .sentinel import Sentinel, Sentinel2
|
||||
from .so2sat import So2Sat, So2SatDataModule
|
||||
from .so2sat import So2Sat
|
||||
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
|
||||
from .ucmerced import UCMerced, UCMercedDataModule
|
||||
from .ucmerced import UCMerced
|
||||
from .utils import BoundingBox, concat_samples, merge_samples, stack_samples
|
||||
from .vaihingen import Vaihingen2D, Vaihingen2DDataModule
|
||||
from .xview import XView2, XView2DataModule
|
||||
from .vaihingen import Vaihingen2D
|
||||
from .xview import XView2
|
||||
from .zuericrop import ZueriCrop
|
||||
|
||||
__all__ = (
|
||||
|
@ -88,7 +87,6 @@ __all__ = (
|
|||
"ChesapeakeVA",
|
||||
"ChesapeakeWV",
|
||||
"ChesapeakeCVPR",
|
||||
"ChesapeakeCVPRDataModule",
|
||||
"Landsat",
|
||||
"Landsat1",
|
||||
"Landsat2",
|
||||
|
@ -101,46 +99,32 @@ __all__ = (
|
|||
"Landsat8",
|
||||
"Landsat9",
|
||||
"NAIP",
|
||||
"NAIPChesapeakeDataModule",
|
||||
"Sentinel",
|
||||
"Sentinel2",
|
||||
# VisionDataset
|
||||
"ADVANCE",
|
||||
"BeninSmallHolderCashews",
|
||||
"BigEarthNet",
|
||||
"BigEarthNetDataModule",
|
||||
"COWC",
|
||||
"COWCCounting",
|
||||
"COWCDetection",
|
||||
"COWCCountingDataModule",
|
||||
"CV4AKenyaCropType",
|
||||
"ETCI2021",
|
||||
"ETCI2021DataModule",
|
||||
"EuroSAT",
|
||||
"EuroSATDataModule",
|
||||
"FAIR1M",
|
||||
"FAIR1MDataModule",
|
||||
"GID15",
|
||||
"IDTReeS",
|
||||
"LandCoverAI",
|
||||
"LandCoverAIDataModule",
|
||||
"LEVIRCDPlus",
|
||||
"LoveDA",
|
||||
"LoveDADataModule",
|
||||
"NASAMarineDebris",
|
||||
"NASAMarineDebrisDataModule",
|
||||
"OSCD",
|
||||
"OSCDDataModule",
|
||||
"PatternNet",
|
||||
"Potsdam2D",
|
||||
"Potsdam2DDataModule",
|
||||
"RESISC45",
|
||||
"RESISC45DataModule",
|
||||
"SeasonalContrastS2",
|
||||
"SEN12MS",
|
||||
"SEN12MSDataModule",
|
||||
"So2Sat",
|
||||
"So2SatDataModule",
|
||||
"SpaceNet",
|
||||
"SpaceNet1",
|
||||
"SpaceNet2",
|
||||
|
@ -148,14 +132,10 @@ __all__ = (
|
|||
"SpaceNet5",
|
||||
"SpaceNet7",
|
||||
"TropicalCycloneWindEstimation",
|
||||
"CycloneDataModule",
|
||||
"UCMerced",
|
||||
"UCMercedDataModule",
|
||||
"Vaihingen2D",
|
||||
"Vaihingen2DDataModule",
|
||||
"VHR10",
|
||||
"XView2",
|
||||
"XView2DataModule",
|
||||
"ZueriCrop",
|
||||
# Base classes
|
||||
"GeoDataset",
|
||||
|
|
|
@ -6,24 +6,17 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import rasterio
|
||||
import torch
|
||||
from rasterio.enums import Resampling
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import download_url, extract_archive, sort_sentinel2_bands
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class BigEarthNet(VisionDataset):
|
||||
"""BigEarthNet dataset.
|
||||
|
@ -511,164 +504,3 @@ class BigEarthNet(VisionDataset):
|
|||
"""
|
||||
if not filepath.endswith(".csv"):
|
||||
extract_archive(filepath)
|
||||
|
||||
|
||||
class BigEarthNetDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the BigEarthNet dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
# (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
|
||||
# min/max band statistics computed on 100k random samples
|
||||
band_mins_raw = torch.tensor( # type: ignore[attr-defined]
|
||||
[-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
|
||||
)
|
||||
band_maxs_raw = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
31.0,
|
||||
35.0,
|
||||
18556.0,
|
||||
20528.0,
|
||||
18976.0,
|
||||
17874.0,
|
||||
16611.0,
|
||||
16512.0,
|
||||
16394.0,
|
||||
16672.0,
|
||||
16141.0,
|
||||
16097.0,
|
||||
15336.0,
|
||||
15203.0,
|
||||
]
|
||||
)
|
||||
|
||||
# min/max band statistics computed by percentile clipping the
|
||||
# above to samples to [2, 98]
|
||||
band_mins = torch.tensor( # type: ignore[attr-defined]
|
||||
[-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
)
|
||||
band_maxs = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
6.0,
|
||||
16.0,
|
||||
9859.0,
|
||||
12872.0,
|
||||
13163.0,
|
||||
14445.0,
|
||||
12477.0,
|
||||
12563.0,
|
||||
12289.0,
|
||||
15596.0,
|
||||
12183.0,
|
||||
9458.0,
|
||||
5897.0,
|
||||
5544.0,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
bands: str = "all",
|
||||
num_classes: int = 19,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes
|
||||
bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all}
|
||||
num_classes: number of classes to load in target. one of {19, 43}
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.bands = bands
|
||||
self.num_classes = num_classes
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
if bands == "all":
|
||||
self.mins = self.band_mins[:, None, None]
|
||||
self.maxs = self.band_maxs[:, None, None]
|
||||
elif bands == "s1":
|
||||
self.mins = self.band_mins[:2, None, None]
|
||||
self.maxs = self.band_maxs[:2, None, None]
|
||||
else:
|
||||
self.mins = self.band_mins[2:, None, None]
|
||||
self.maxs = self.band_maxs[2:, None, None]
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins)
|
||||
sample["image"] = torch.clip( # type: ignore[attr-defined]
|
||||
sample["image"], min=0.0, max=1.0
|
||||
)
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
BigEarthNet(self.root_dir, split="train", bands=self.bands, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
self.train_dataset = BigEarthNet(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.bands,
|
||||
num_classes=self.num_classes,
|
||||
transforms=transforms,
|
||||
)
|
||||
self.val_dataset = BigEarthNet(
|
||||
self.root_dir,
|
||||
split="val",
|
||||
bands=self.bands,
|
||||
num_classes=self.num_classes,
|
||||
transforms=transforms,
|
||||
)
|
||||
self.test_dataset = BigEarthNet(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
bands=self.bands,
|
||||
num_classes=self.num_classes,
|
||||
transforms=transforms,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation."""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing."""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -16,21 +16,10 @@ import rasterio.mask
|
|||
import shapely.geometry
|
||||
import shapely.ops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..samplers.batch import RandomBatchGeoSampler
|
||||
from ..samplers.single import GridGeoSampler
|
||||
from .geo import GeoDataset, RasterDataset
|
||||
from .utils import BoundingBox, download_url, extract_archive, stack_samples
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
from .utils import BoundingBox, download_url, extract_archive
|
||||
|
||||
|
||||
class Chesapeake(RasterDataset, abc.ABC):
|
||||
|
@ -537,294 +526,3 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
extract_archive(os.path.join(self.root, self.filename))
|
||||
|
||||
|
||||
class ChesapeakeCVPRDataModule(LightningDataModule):
|
||||
"""LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.
|
||||
|
||||
Uses the random splits defined per state to partition tiles into train, val,
|
||||
and test sets.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
train_splits: List[str],
|
||||
val_splits: List[str],
|
||||
test_splits: List[str],
|
||||
patches_per_tile: int = 200,
|
||||
patch_size: int = 256,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
class_set: int = 7,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the ChesapeakeCVPR Dataset
|
||||
classes
|
||||
train_splits: The splits used to train the model, e.g. ["ny-train"]
|
||||
val_splits: The splits used to validate the model, e.g. ["ny-val"]
|
||||
test_splits: The splits used to test the model, e.g. ["ny-test"]
|
||||
patches_per_tile: The number of patches per tile to sample
|
||||
patch_size: The size of each patch in pixels (test patches will be 1.5 times
|
||||
this size)
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
class_set: The high-resolution land cover class set to use - 5 or 7
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
for state in train_splits + val_splits + test_splits:
|
||||
assert state in ChesapeakeCVPR.splits
|
||||
assert class_set in [5, 7]
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.train_splits = train_splits
|
||||
self.val_splits = val_splits
|
||||
self.test_splits = test_splits
|
||||
self.layers = ["naip-new", "lc"]
|
||||
self.patches_per_tile = patches_per_tile
|
||||
self.patch_size = patch_size
|
||||
# This is a rough estimate of how large of a patch we will need to sample in
|
||||
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
|
||||
self.original_patch_size = int(patch_size * 2.0)
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.class_set = class_set
|
||||
|
||||
def pad_to(
|
||||
self, size: int = 512, image_value: int = 0, mask_value: int = 0
|
||||
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
|
||||
"""Returns a function to perform a padding transform on a single sample.
|
||||
|
||||
Args:
|
||||
size: output image size
|
||||
image_value: value to pad image with
|
||||
mask_value: value to pad mask with
|
||||
|
||||
Returns:
|
||||
function to perform padding
|
||||
"""
|
||||
|
||||
def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
_, height, width = sample["image"].shape
|
||||
assert height <= size and width <= size
|
||||
|
||||
height_pad = size - height
|
||||
width_pad = size - width
|
||||
|
||||
# See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
||||
# for a description of the format of the padding tuple
|
||||
sample["image"] = F.pad(
|
||||
sample["image"],
|
||||
(0, width_pad, 0, height_pad),
|
||||
mode="constant",
|
||||
value=image_value,
|
||||
)
|
||||
sample["mask"] = F.pad(
|
||||
sample["mask"],
|
||||
(0, width_pad, 0, height_pad),
|
||||
mode="constant",
|
||||
value=mask_value,
|
||||
)
|
||||
return sample
|
||||
|
||||
return pad_inner
|
||||
|
||||
def center_crop(
|
||||
self, size: int = 512
|
||||
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
|
||||
"""Returns a function to perform a center crop transform on a single sample.
|
||||
|
||||
Args:
|
||||
size: output image size
|
||||
|
||||
Returns:
|
||||
function to perform center crop
|
||||
"""
|
||||
|
||||
def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
_, height, width = sample["image"].shape
|
||||
|
||||
y1 = (height - size) // 2
|
||||
x1 = (width - size) // 2
|
||||
sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size]
|
||||
sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size]
|
||||
|
||||
return sample
|
||||
|
||||
return center_crop_inner
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Preprocesses a single sample.
|
||||
|
||||
Args:
|
||||
sample: sample dictionary containing image and mask
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0
|
||||
sample["mask"] = sample["mask"]
|
||||
sample["mask"] = sample["mask"].squeeze()
|
||||
|
||||
if self.class_set == 5:
|
||||
sample["mask"][sample["mask"] == 5] = 4
|
||||
sample["mask"][sample["mask"] == 6] = 4
|
||||
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["mask"] = sample["mask"].long()
|
||||
|
||||
return sample
|
||||
|
||||
def nodata_check(
|
||||
self, size: int = 512
|
||||
) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]:
|
||||
"""Returns a function to check for nodata or mis-sized input.
|
||||
|
||||
Args:
|
||||
size: output image size
|
||||
|
||||
Returns:
|
||||
function to check for nodata values
|
||||
"""
|
||||
|
||||
def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
num_channels, height, width = sample["image"].shape
|
||||
|
||||
if height < size or width < size:
|
||||
sample["image"] = torch.zeros( # type: ignore[attr-defined]
|
||||
(num_channels, size, size)
|
||||
)
|
||||
sample["mask"] = torch.zeros((size, size)) # type: ignore[attr-defined]
|
||||
|
||||
return sample
|
||||
|
||||
return nodata_check_inner
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Confirms that the dataset is downloaded on the local node.
|
||||
|
||||
This method is called once per node, while :func:`setup` is called once per GPU.
|
||||
"""
|
||||
ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
splits=self.train_splits,
|
||||
layers=self.layers,
|
||||
transforms=None,
|
||||
download=False,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_transforms = Compose(
|
||||
[
|
||||
self.center_crop(self.patch_size),
|
||||
self.nodata_check(self.patch_size),
|
||||
self.preprocess,
|
||||
]
|
||||
)
|
||||
val_transforms = Compose(
|
||||
[
|
||||
self.center_crop(self.patch_size),
|
||||
self.nodata_check(self.patch_size),
|
||||
self.preprocess,
|
||||
]
|
||||
)
|
||||
test_transforms = Compose(
|
||||
[
|
||||
self.pad_to(self.original_patch_size, image_value=0, mask_value=0),
|
||||
self.preprocess,
|
||||
]
|
||||
)
|
||||
|
||||
self.train_dataset = ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
splits=self.train_splits,
|
||||
layers=self.layers,
|
||||
transforms=train_transforms,
|
||||
download=False,
|
||||
checksum=False,
|
||||
)
|
||||
self.val_dataset = ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
splits=self.val_splits,
|
||||
layers=self.layers,
|
||||
transforms=val_transforms,
|
||||
download=False,
|
||||
checksum=False,
|
||||
)
|
||||
self.test_dataset = ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
splits=self.test_splits,
|
||||
layers=self.layers,
|
||||
transforms=test_transforms,
|
||||
download=False,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
sampler = RandomBatchGeoSampler(
|
||||
self.train_dataset,
|
||||
size=self.original_patch_size,
|
||||
batch_size=self.batch_size,
|
||||
length=self.patches_per_tile * len(self.train_dataset),
|
||||
)
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
sampler = GridGeoSampler(
|
||||
self.val_dataset,
|
||||
size=self.original_patch_size,
|
||||
stride=self.original_patch_size,
|
||||
)
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
sampler = GridGeoSampler(
|
||||
self.test_dataset,
|
||||
size=self.original_patch_size,
|
||||
stride=self.original_patch_size,
|
||||
)
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
|
|
@ -6,22 +6,16 @@
|
|||
import abc
|
||||
import csv
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Generator, Tensor # type: ignore[attr-defined]
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, download_and_extract_archive
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class COWC(VisionDataset, abc.ABC):
|
||||
"""Abstract base class for the COWC dataset.
|
||||
|
@ -268,110 +262,3 @@ class COWCDetection(COWC):
|
|||
# 4. Unknown
|
||||
#
|
||||
# May need new abstract base class. Will need subclasses for different patch sizes.
|
||||
|
||||
|
||||
class COWCCountingDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the COWC Counting dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for COWC Counting based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the COWCCounting Dataset class
|
||||
seed: The seed value to use when doing the dataset random_split
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and target
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0 # scale to [0, 1]
|
||||
sample["label"] = sample["label"].float()
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
|
||||
|
||||
This includes optionally downloading the dataset. This is done once per node,
|
||||
while :func:`setup` is done once per GPU.
|
||||
"""
|
||||
COWCCounting(self.root_dir, download=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_val_dataset = COWCCounting(
|
||||
self.root_dir, split="train", transforms=self.custom_transform
|
||||
)
|
||||
self.test_dataset = COWCCounting(
|
||||
self.root_dir, split="test", transforms=self.custom_transform
|
||||
)
|
||||
self.train_dataset, self.val_dataset = random_split(
|
||||
train_val_dataset,
|
||||
[len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)],
|
||||
generator=Generator().manual_seed(self.seed),
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -10,20 +10,13 @@ from typing import Any, Callable, Dict, Optional, cast
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from PIL import Image
|
||||
from sklearn.model_selection import GroupShuffleSplit
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class TropicalCycloneWindEstimation(VisionDataset):
|
||||
"""Tropical Cyclone Wind Estimation Competition dataset.
|
||||
|
@ -254,157 +247,3 @@ class TropicalCycloneWindEstimation(VisionDataset):
|
|||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class CycloneDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the NASA Cyclone dataset.
|
||||
|
||||
Implements 80/20 train/val splits based on hurricane storm ids.
|
||||
See :func:`setup` for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the
|
||||
TropicalCycloneWindEstimation Datasets classes
|
||||
seed: The seed value to use when doing the sklearn based GroupShuffleSplit
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
api_key: The RadiantEarth MLHub API key to use if the dataset needs to be
|
||||
downloaded
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.api_key = api_key
|
||||
|
||||
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and target
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0 # scale to [0,1]
|
||||
sample["image"] = (
|
||||
sample["image"].unsqueeze(0).repeat(3, 1, 1)
|
||||
) # convert to 3 channel
|
||||
sample["label"] = torch.as_tensor( # type: ignore[attr-defined]
|
||||
sample["label"]
|
||||
).float()
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
|
||||
|
||||
This includes optionally downloading the dataset. This is done once per node,
|
||||
while :func:`setup` is done once per GPU.
|
||||
"""
|
||||
TropicalCycloneWindEstimation(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
transforms=self.custom_transform,
|
||||
download=self.api_key is not None,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
|
||||
We split samples between train/val by the ``storm_id`` property. I.e. all
|
||||
samples with the same ``storm_id`` value will be either in the train or the val
|
||||
split. This is important to test one type of generalizability -- given a new
|
||||
storm, can we predict its windspeed. The test set, however, contains *some*
|
||||
storms from the training set (specifically, the latter parts of the storms) as
|
||||
well as some novel storms.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
self.all_train_dataset = TropicalCycloneWindEstimation(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
transforms=self.custom_transform,
|
||||
download=False,
|
||||
)
|
||||
|
||||
self.all_test_dataset = TropicalCycloneWindEstimation(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
transforms=self.custom_transform,
|
||||
download=False,
|
||||
)
|
||||
|
||||
storm_ids = []
|
||||
for item in self.all_train_dataset.collection:
|
||||
storm_id = item["href"].split("/")[0].split("_")[-2]
|
||||
storm_ids.append(storm_id)
|
||||
|
||||
train_indices, val_indices = next(
|
||||
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
|
||||
storm_ids, groups=storm_ids
|
||||
)
|
||||
)
|
||||
|
||||
self.train_dataset = Subset(self.all_train_dataset, train_indices)
|
||||
self.val_dataset = Subset(self.all_train_dataset, val_indices)
|
||||
self.test_dataset = Subset(
|
||||
self.all_test_dataset, range(len(self.all_test_dataset))
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -5,16 +5,13 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Generator, Tensor # type: ignore[attr-defined]
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision.transforms import Normalize
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import download_and_extract_archive
|
||||
|
@ -320,140 +317,3 @@ class ETCI2021(VisionDataset):
|
|||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
||||
|
||||
|
||||
class ETCI2021DataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the ETCI2021 dataset.
|
||||
|
||||
Splits the existing train split from the dataset into train/val with 80/20
|
||||
proportions, then uses the existing val dataset as the test data.
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int = 0,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for ETCI2021 based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the ETCI2021 Dataset classes
|
||||
seed: The seed value to use when doing the dataset random_split
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Notably, moves the given water mask to act as an input layer.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
image = sample["image"]
|
||||
water_mask = sample["mask"][0].unsqueeze(0)
|
||||
flood_mask = sample["mask"][1]
|
||||
flood_mask = (flood_mask > 0).long()
|
||||
|
||||
sample["image"] = torch.cat( # type: ignore[attr-defined]
|
||||
[image, water_mask], dim=0
|
||||
).float()
|
||||
sample["image"] /= 255.0
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
sample["mask"] = flood_mask
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
ETCI2021(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_val_dataset = ETCI2021(
|
||||
self.root_dir, split="train", transforms=self.preprocess
|
||||
)
|
||||
self.test_dataset = ETCI2021(
|
||||
self.root_dir, split="val", transforms=self.preprocess
|
||||
)
|
||||
|
||||
size_train_val = len(train_val_dataset)
|
||||
size_train = int(0.8 * size_train_val)
|
||||
size_val = size_train_val - size_train
|
||||
|
||||
self.train_dataset, self.val_dataset = random_split(
|
||||
train_val_dataset,
|
||||
[size_train, size_val],
|
||||
generator=Generator().manual_seed(self.seed),
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -4,15 +4,11 @@
|
|||
"""EuroSAT dataset."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, cast
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from .geo import VisionClassificationDataset
|
||||
from .utils import check_integrity, download_url, extract_archive, rasterio_loader
|
||||
|
@ -229,138 +225,3 @@ class EuroSAT(VisionClassificationDataset):
|
|||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
||||
|
||||
|
||||
class EuroSATDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the EuroSAT dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
1354.40546513,
|
||||
1118.24399958,
|
||||
1042.92983953,
|
||||
947.62620298,
|
||||
1199.47283961,
|
||||
1999.79090914,
|
||||
2369.22292565,
|
||||
2296.82608323,
|
||||
732.08340178,
|
||||
12.11327804,
|
||||
1819.01027855,
|
||||
1118.92391149,
|
||||
2594.14080798,
|
||||
]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
245.71762908,
|
||||
333.00778264,
|
||||
395.09249139,
|
||||
593.75055589,
|
||||
566.4170017,
|
||||
861.18399006,
|
||||
1086.63139075,
|
||||
1117.98170791,
|
||||
404.91978886,
|
||||
4.77584468,
|
||||
1002.58768311,
|
||||
761.30323499,
|
||||
1231.58581042,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for EuroSAT based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
EuroSAT(self.root_dir)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms)
|
||||
self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms)
|
||||
self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -11,33 +11,12 @@ from xml.etree import ElementTree
|
|||
import matplotlib.patches as patches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets.utils import check_integrity, dataset_split, extract_archive
|
||||
from .geo import VisionDataset
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
|
||||
"""Custom object detection collate fn to handle variable number of boxes.
|
||||
|
||||
Args:
|
||||
batch: list of sample dicts return by dataset
|
||||
Returns:
|
||||
batch dict output
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
output["image"] = torch.stack([sample["image"] for sample in batch])
|
||||
output["boxes"] = [sample["boxes"] for sample in batch]
|
||||
return output
|
||||
from .utils import check_integrity, extract_archive
|
||||
|
||||
|
||||
def parse_pascal_voc(path: str) -> Dict[str, Any]:
|
||||
|
@ -350,102 +329,3 @@ class FAIR1M(VisionDataset):
|
|||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class FAIR1MDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the FAIR1M dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
test_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for FAIR1M based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the FAIR1M Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
test_split_pct: What percentage of the dataset to use as a test set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = FAIR1M(self.root_dir, transforms=transforms)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
|
|
@ -6,24 +6,18 @@
|
|||
import hashlib
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from matplotlib.colors import ListedColormap
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, download_and_extract_archive, working_dir
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class LandCoverAI(VisionDataset):
|
||||
r"""LandCover.ai dataset.
|
||||
|
@ -266,110 +260,3 @@ class LandCoverAI(VisionDataset):
|
|||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
||||
|
||||
|
||||
class LandCoverAIDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the LandCover.ai dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for LandCover.ai based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the Landcover.AI Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and mask
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0
|
||||
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["mask"] = sample["mask"].float().unsqueeze(0) + 1
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
_ = LandCoverAI(self.root_dir, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_transforms = self.preprocess
|
||||
val_test_transforms = self.preprocess
|
||||
|
||||
self.train_dataset = LandCoverAI(
|
||||
self.root_dir, split="train", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.val_dataset = LandCoverAI(
|
||||
self.root_dir, split="val", transforms=val_test_transforms
|
||||
)
|
||||
|
||||
self.test_dataset = LandCoverAI(
|
||||
self.root_dir, split="test", transforms=val_test_transforms
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -5,23 +5,17 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import download_and_extract_archive
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class LoveDA(VisionDataset):
|
||||
"""LoveDA dataset.
|
||||
|
@ -305,117 +299,3 @@ class LoveDA(VisionDataset):
|
|||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class LoveDADataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the LoveDA dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
scene: List[str],
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for LoveDA based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` argument to pass to LoveDA Dataset classes
|
||||
scene: specify whether to load only 'urban', only 'rural' or both
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.scene = scene
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and mask
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
_ = LoveDA(self.root_dir, scene=self.scene, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_transforms = self.preprocess
|
||||
val_test_transforms = self.preprocess
|
||||
|
||||
self.train_dataset = LoveDA(
|
||||
self.root_dir, split="train", scene=self.scene, transforms=train_transforms
|
||||
)
|
||||
|
||||
self.val_dataset = LoveDA(
|
||||
self.root_dir, split="val", scene=self.scene, transforms=val_test_transforms
|
||||
)
|
||||
|
||||
self.test_dataset = LoveDA(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
scene=self.scene,
|
||||
transforms=val_test_transforms,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -3,20 +3,7 @@
|
|||
|
||||
"""National Agriculture Imagery Program (NAIP) dataset."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ..samplers.batch import RandomBatchGeoSampler
|
||||
from ..samplers.single import GridGeoSampler
|
||||
from .chesapeake import Chesapeake13
|
||||
from .geo import RasterDataset
|
||||
from .utils import BoundingBox, stack_samples
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class NAIP(RasterDataset):
|
||||
|
@ -55,147 +42,3 @@ class NAIP(RasterDataset):
|
|||
# Plotting
|
||||
all_bands = ["R", "G", "B", "NIR"]
|
||||
rgb_bands = ["R", "G", "B"]
|
||||
|
||||
|
||||
class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the NAIP and Chesapeake datasets.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
# TODO: tune these hyperparams
|
||||
length = 1000
|
||||
stride = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
naip_root_dir: str,
|
||||
chesapeake_root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
patch_size: int = 256,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders.
|
||||
|
||||
Args:
|
||||
naip_root_dir: directory containing NAIP data
|
||||
chesapeake_root_dir: directory containing Chesapeake data
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
patch_size: size of patches to sample
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.naip_root_dir = naip_root_dir
|
||||
self.chesapeake_root_dir = chesapeake_root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.patch_size = patch_size
|
||||
|
||||
def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the NAIP Dataset.
|
||||
|
||||
Args:
|
||||
sample: NAIP image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed NAIP data
|
||||
"""
|
||||
sample["image"] = sample["image"] / 255.0
|
||||
sample["image"] = sample["image"].float()
|
||||
return sample
|
||||
|
||||
def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Chesapeake Dataset.
|
||||
|
||||
Args:
|
||||
sample: Chesapeake mask dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed Chesapeake data
|
||||
"""
|
||||
sample["mask"] = sample["mask"].long()[0]
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
Chesapeake13(self.chesapeake_root_dir, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: state to set up
|
||||
"""
|
||||
# TODO: these transforms will be applied independently, this won't work if we
|
||||
# add things like random horizontal flip
|
||||
chesapeake = Chesapeake13(
|
||||
self.chesapeake_root_dir, transforms=self.chesapeake_transform
|
||||
)
|
||||
naip = NAIP(
|
||||
self.naip_root_dir,
|
||||
chesapeake.crs,
|
||||
chesapeake.res,
|
||||
transforms=self.naip_transform,
|
||||
)
|
||||
self.dataset = chesapeake & naip
|
||||
|
||||
# TODO: figure out better train/val/test split
|
||||
roi = self.dataset.bounds
|
||||
midx = roi.minx + (roi.maxx - roi.minx) / 2
|
||||
midy = roi.miny + (roi.maxy - roi.miny) / 2
|
||||
train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt)
|
||||
val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt)
|
||||
test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt)
|
||||
|
||||
self.train_sampler = RandomBatchGeoSampler(
|
||||
naip, self.patch_size, self.batch_size, self.length, train_roi
|
||||
)
|
||||
self.val_sampler = GridGeoSampler(naip, self.patch_size, self.stride, val_roi)
|
||||
self.test_sampler = GridGeoSampler(naip, self.patch_size, self.stride, test_roi)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=self.train_sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=self.val_sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=self.test_sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
|
|
@ -4,39 +4,17 @@
|
|||
"""NASA Marine Debris dataset."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import rasterio
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
from torchvision.utils import draw_bounding_boxes
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import dataset_split, download_radiant_mlhub_dataset, extract_archive
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
|
||||
"""Custom object detection collate fn to handle variable boxes.
|
||||
|
||||
Args:
|
||||
batch: list of sample dicts return by dataset
|
||||
|
||||
Returns:
|
||||
batch dict output
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
output["image"] = torch.stack([sample["image"] for sample in batch])
|
||||
output["boxes"] = [sample["boxes"] for sample in batch]
|
||||
return output
|
||||
from .utils import download_radiant_mlhub_dataset, extract_archive
|
||||
|
||||
|
||||
class NASAMarineDebris(VisionDataset):
|
||||
|
@ -279,109 +257,3 @@ class NASAMarineDebris(VisionDataset):
|
|||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class NASAMarineDebrisDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the NASA Marine Debris dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
test_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NASA Marine Debris based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` argument to pass to the Dataset class
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
test_split_pct: What percentage of the dataset to use as a test set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
NASAMarineDebris(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = NASAMarineDebris(self.root_dir, transforms=transforms)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
|
|
@ -5,25 +5,23 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import kornia.augmentation as K
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from einops import repeat
|
||||
from matplotlib.figure import Figure
|
||||
from numpy import ndarray as Array
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks
|
||||
from .geo import VisionDataset
|
||||
from .utils import download_url, extract_archive, sort_sentinel2_bands
|
||||
from .utils import (
|
||||
download_url,
|
||||
draw_semantic_segmentation_masks,
|
||||
extract_archive,
|
||||
sort_sentinel2_bands,
|
||||
)
|
||||
|
||||
|
||||
class OSCD(VisionDataset):
|
||||
|
@ -317,202 +315,3 @@ class OSCD(VisionDataset):
|
|||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class OSCDDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the OSCD dataset.
|
||||
|
||||
Uses the train/test splits from the dataset and further splits
|
||||
the train split into train/val splits.
|
||||
|
||||
.. versionadded: 0.2
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
1583.0741,
|
||||
1374.3202,
|
||||
1294.1616,
|
||||
1325.6158,
|
||||
1478.7408,
|
||||
1933.0822,
|
||||
2166.0608,
|
||||
2076.4868,
|
||||
2306.0652,
|
||||
690.9814,
|
||||
16.2360,
|
||||
2080.3347,
|
||||
1524.6930,
|
||||
]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
52.1937,
|
||||
83.4168,
|
||||
105.6966,
|
||||
151.1401,
|
||||
147.4615,
|
||||
115.9289,
|
||||
123.1974,
|
||||
114.6483,
|
||||
141.4530,
|
||||
73.2758,
|
||||
4.8368,
|
||||
213.4821,
|
||||
179.4793,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
bands: str = "all",
|
||||
train_batch_size: int = 32,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
patch_size: Tuple[int, int] = (64, 64),
|
||||
num_patches_per_tile: int = 32,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for OSCD based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the OSCD Dataset classes
|
||||
bands: "rgb" or "all"
|
||||
train_batch_size: The batch size used in the train DataLoader
|
||||
(val_batch_size == test_batch_size == 1)
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
patch_size: Size of random patch from image and mask (height, width)
|
||||
num_patches_per_tile: number of random patches per sample
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.bands = bands
|
||||
self.train_batch_size = train_batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
self.patch_size = patch_size
|
||||
self.num_patches_per_tile = num_patches_per_tile
|
||||
|
||||
if bands == "rgb":
|
||||
self.band_means = self.band_means[[3, 2, 1], None, None]
|
||||
self.band_stds = self.band_stds[[3, 2, 1], None, None]
|
||||
else:
|
||||
self.band_means = self.band_means[:, None, None]
|
||||
self.band_stds = self.band_stds[:, None, None]
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
self.rcrop = K.AugmentationSequential(
|
||||
K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True
|
||||
)
|
||||
self.padto = K.PadTo((1280, 1280))
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["mask"] = sample["mask"]
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
sample["image"] = torch.flatten( # type: ignore[attr-defined]
|
||||
sample["image"], 0, 1
|
||||
)
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
OSCD(self.root_dir, split="train", bands=self.bands, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
"""
|
||||
|
||||
def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
images, masks = [], []
|
||||
for i in range(self.num_patches_per_tile):
|
||||
mask = repeat(sample["mask"], "h w -> t h w", t=2).float()
|
||||
image, mask = self.rcrop(sample["image"], mask)
|
||||
mask = mask.squeeze()[0]
|
||||
images.append(image.squeeze())
|
||||
masks.append(mask.long())
|
||||
sample["image"] = torch.stack(images)
|
||||
sample["mask"] = torch.stack(masks)
|
||||
return sample
|
||||
|
||||
def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample["image"] = self.padto(sample["image"])[0]
|
||||
sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0]
|
||||
return sample
|
||||
|
||||
train_transforms = Compose([self.preprocess, n_random_crop])
|
||||
# for testing and validation we pad all inputs to a fixed size to avoid issues
|
||||
# with the upsampling paths in encoder-decoder architectures
|
||||
test_transforms = Compose([self.preprocess, pad_to])
|
||||
|
||||
train_dataset = OSCD(
|
||||
self.root_dir, split="train", bands=self.bands, transforms=train_transforms
|
||||
)
|
||||
if self.val_split_pct > 0.0:
|
||||
val_dataset = OSCD(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.bands,
|
||||
transforms=test_transforms,
|
||||
)
|
||||
self.train_dataset, self.val_dataset, _ = dataset_split(
|
||||
train_dataset, val_pct=self.val_split_pct, test_pct=0.0
|
||||
)
|
||||
self.val_dataset.dataset = val_dataset
|
||||
else:
|
||||
self.train_dataset = train_dataset # type: ignore[assignment]
|
||||
self.val_dataset = None # type: ignore[assignment]
|
||||
|
||||
self.test_dataset = OSCD(
|
||||
self.root_dir, split="test", bands=self.bands, transforms=test_transforms
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
|
||||
def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call]
|
||||
batch
|
||||
)
|
||||
r_batch["image"] = torch.flatten( # type: ignore[attr-defined]
|
||||
r_batch["image"], 0, 1
|
||||
)
|
||||
r_batch["mask"] = torch.flatten( # type: ignore[attr-defined]
|
||||
r_batch["mask"], 0, 1
|
||||
)
|
||||
return r_batch
|
||||
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.train_batch_size,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=collate_wrapper,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation."""
|
||||
if self.val_split_pct == 0.0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=1,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing."""
|
||||
return DataLoader(
|
||||
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
|
||||
)
|
||||
|
|
|
@ -4,22 +4,23 @@
|
|||
"""Potsdam dataset."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, extract_archive, rgb_to_mask
|
||||
from .utils import (
|
||||
check_integrity,
|
||||
draw_semantic_segmentation_masks,
|
||||
extract_archive,
|
||||
rgb_to_mask,
|
||||
)
|
||||
|
||||
|
||||
class Potsdam2D(VisionDataset):
|
||||
|
@ -293,111 +294,3 @@ class Potsdam2D(VisionDataset):
|
|||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class Potsdam2DDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the Potsdam2D dataset.
|
||||
|
||||
Uses the train/test splits from the dataset.
|
||||
|
||||
.. versionadded: 0.2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for Potsdam2D based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` argument to pass to the Potsdam2D Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = Potsdam2D(self.root_dir, "train", transforms=transforms)
|
||||
|
||||
if self.val_split_pct > 0.0:
|
||||
self.train_dataset, self.val_dataset, _ = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=0.0
|
||||
)
|
||||
else:
|
||||
self.train_dataset = dataset # type: ignore[assignment]
|
||||
self.val_dataset = None # type: ignore[assignment]
|
||||
|
||||
self.test_dataset = Potsdam2D(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
if self.val_split_pct == 0.0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -4,23 +4,15 @@
|
|||
"""RESISC45 dataset."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, cast
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from .geo import VisionClassificationDataset
|
||||
from .utils import download_url, extract_archive
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class RESISC45(VisionClassificationDataset):
|
||||
"""RESISC45 dataset.
|
||||
|
@ -288,109 +280,3 @@ class RESISC45(VisionClassificationDataset):
|
|||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
||||
|
||||
|
||||
class RESISC45DataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the RESISC45 dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.36801773, 0.38097873, 0.343583]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.14540215, 0.13558227, 0.13203649]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
RESISC45(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
self.train_dataset = RESISC45(self.root_dir, "train", transforms=transforms)
|
||||
self.val_dataset = RESISC45(self.root_dir, "val", transforms=transforms)
|
||||
self.test_dataset = RESISC45(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -4,23 +4,16 @@
|
|||
"""SEN12MS dataset."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import rasterio
|
||||
import torch
|
||||
from sklearn.model_selection import GroupShuffleSplit
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class SEN12MS(VisionDataset):
|
||||
"""SEN12MS dataset.
|
||||
|
@ -246,188 +239,3 @@ class SEN12MS(VisionDataset):
|
|||
if not check_integrity(filepath, md5 if self.checksum else None):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class SEN12MSDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the SEN12MS dataset.
|
||||
|
||||
Implements 80/20 geographic train/val splits and uses the test split from the
|
||||
classification dataset definitions. See :func:`setup` for more details.
|
||||
|
||||
Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See
|
||||
https://arxiv.org/abs/2002.08254.
|
||||
"""
|
||||
|
||||
#: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader
|
||||
#: here https://github.com/lukasliebel/dfc2020_baseline.
|
||||
DFC2020_CLASS_MAPPING = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
0, # maps 0s to 0
|
||||
1, # maps 1s to 1
|
||||
1, # maps 2s to 1
|
||||
1, # ...
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
6,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int,
|
||||
band_set: str = "all",
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes
|
||||
seed: The seed value to use when doing the sklearn based ShuffleSplit
|
||||
band_set: The subset of S1/S2 bands to use. Options are: "all",
|
||||
"s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes:
|
||||
B2, B3, B4, B8, B11, and B12.
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
assert band_set in SEN12MS.BAND_SETS.keys()
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.band_set = band_set
|
||||
self.band_indices = SEN12MS.BAND_SETS[band_set]
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image and mask
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
|
||||
if self.band_set == "all":
|
||||
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
|
||||
sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000
|
||||
elif self.band_set == "s1":
|
||||
sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25
|
||||
else:
|
||||
sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000
|
||||
|
||||
sample["mask"] = sample["mask"][0, :, :].long()
|
||||
sample["mask"] = torch.take( # type: ignore[attr-defined]
|
||||
self.DFC2020_CLASS_MAPPING, sample["mask"]
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
|
||||
We split samples between train and val geographically with proportions of 80/20.
|
||||
This mimics the geographic test set split.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000}
|
||||
|
||||
self.all_train_dataset = SEN12MS(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.band_indices,
|
||||
transforms=self.custom_transform,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
self.all_test_dataset = SEN12MS(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
bands=self.band_indices,
|
||||
transforms=self.custom_transform,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
# A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif"
|
||||
# This patch will belong to the scene that is uniquelly identified by its
|
||||
# (season, scene_id) tuple. Because the largest scene_id is 149, we can simply
|
||||
# give each season a large number and representing a `unique_scene_id` as
|
||||
# `season_id + scene_id`.
|
||||
scenes = []
|
||||
for scene_fn in self.all_train_dataset.ids:
|
||||
parts = scene_fn.split("_")
|
||||
season_id = season_to_int[parts[1]]
|
||||
scene_id = int(parts[3])
|
||||
scenes.append(season_id + scene_id)
|
||||
|
||||
train_indices, val_indices = next(
|
||||
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
|
||||
scenes, groups=scenes
|
||||
)
|
||||
)
|
||||
|
||||
self.train_dataset = Subset(self.all_train_dataset, train_indices)
|
||||
self.val_dataset = Subset(self.all_train_dataset, val_indices)
|
||||
self.test_dataset = Subset(
|
||||
self.all_test_dataset, range(len(self.all_test_dataset))
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -4,23 +4,16 @@
|
|||
"""So2Sat dataset."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, cast
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, percentile_normalization
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class So2Sat(VisionDataset):
|
||||
"""So2Sat dataset.
|
||||
|
@ -250,211 +243,3 @@ class So2Sat(VisionDataset):
|
|||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
||||
|
||||
|
||||
class So2SatDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the So2Sat dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
-3.591224256609313e-05,
|
||||
-7.658561276843396e-06,
|
||||
5.9373857475971184e-05,
|
||||
2.5166231537121083e-05,
|
||||
0.04420110659759328,
|
||||
0.25761027084996196,
|
||||
0.0007556743372573258,
|
||||
0.0013503466830024448,
|
||||
0.12375696117681859,
|
||||
0.1092774636368323,
|
||||
0.1010855203267882,
|
||||
0.1142398616114001,
|
||||
0.1592656692023089,
|
||||
0.18147236008771792,
|
||||
0.1745740312291377,
|
||||
0.19501607349635292,
|
||||
0.15428468872076637,
|
||||
0.10905050699570007,
|
||||
]
|
||||
).reshape(18, 1, 1)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
0.17555201137417686,
|
||||
0.17556463274968204,
|
||||
0.45998793417834255,
|
||||
0.455988755730148,
|
||||
2.8559909213125763,
|
||||
8.324800606439833,
|
||||
2.4498757382563103,
|
||||
1.4647352984509094,
|
||||
0.03958795985905458,
|
||||
0.047778262752410296,
|
||||
0.06636616706371974,
|
||||
0.06358874912497474,
|
||||
0.07744387147984592,
|
||||
0.09101635085921553,
|
||||
0.09218466562387101,
|
||||
0.10164581233948201,
|
||||
0.09991773043519253,
|
||||
0.08780632509122865,
|
||||
]
|
||||
).reshape(18, 1, 1)
|
||||
|
||||
# this reorders the bands to put S2 RGB first, then remainder of S2, then S1
|
||||
reindex_to_rgb_first = [
|
||||
10,
|
||||
9,
|
||||
8,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
# 0,
|
||||
# 1,
|
||||
# 2,
|
||||
# 3,
|
||||
# 4,
|
||||
# 5,
|
||||
# 6,
|
||||
# 7,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
bands: str = "rgb",
|
||||
unsupervised_mode: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for So2Sat based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
bands: Either "rgb" or "s2"
|
||||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||
val, and test sets
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.bands = bands
|
||||
self.unsupervised_mode = unsupervised_mode
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
# sample["image"] = (sample["image"] - self.band_means) / self.band_stds
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :]
|
||||
|
||||
if self.bands == "rgb":
|
||||
sample["image"] = sample["image"][:3, :, :]
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
So2Sat(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
train_transforms = Compose([self.preprocess])
|
||||
val_test_transforms = self.preprocess
|
||||
|
||||
if not self.unsupervised_mode:
|
||||
|
||||
self.train_dataset = So2Sat(
|
||||
self.root_dir, split="train", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.val_dataset = So2Sat(
|
||||
self.root_dir, split="validation", transforms=val_test_transforms
|
||||
)
|
||||
|
||||
self.test_dataset = So2Sat(
|
||||
self.root_dir, split="test", transforms=val_test_transforms
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
temp_train = So2Sat(
|
||||
self.root_dir, split="train", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.val_dataset = So2Sat(
|
||||
self.root_dir, split="validation", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.test_dataset = So2Sat(
|
||||
self.root_dir, split="test", transforms=train_transforms
|
||||
)
|
||||
|
||||
self.train_dataset = cast(
|
||||
So2Sat, temp_train + self.val_dataset + self.test_dataset
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -24,8 +24,8 @@ from rasterio.features import rasterize
|
|||
from rasterio.transform import Affine
|
||||
from torch import Tensor
|
||||
|
||||
from torchgeo.datasets.geo import VisionDataset
|
||||
from torchgeo.datasets.utils import (
|
||||
from .geo import VisionDataset
|
||||
from .utils import (
|
||||
check_integrity,
|
||||
download_radiant_mlhub_collection,
|
||||
extract_archive,
|
||||
|
|
|
@ -3,24 +3,15 @@
|
|||
|
||||
"""UC Merced dataset."""
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, cast
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from .geo import VisionClassificationDataset
|
||||
from .utils import check_integrity, download_url, extract_archive
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class UCMerced(VisionClassificationDataset):
|
||||
"""UC Merced dataset.
|
||||
|
@ -251,110 +242,3 @@ class UCMerced(VisionClassificationDataset):
|
|||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
||||
|
||||
|
||||
class UCMercedDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the UC Merced dataset.
|
||||
|
||||
Uses random train/val/test splits.
|
||||
"""
|
||||
|
||||
band_means = torch.tensor([0, 0, 0]) # type: ignore[attr-defined]
|
||||
|
||||
band_stds = torch.tensor([1, 1, 1]) # type: ignore[attr-defined]
|
||||
|
||||
def __init__(
|
||||
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for UCMerced based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: dictionary containing image
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
c, h, w = sample["image"].shape
|
||||
if h != 256 or w != 256:
|
||||
sample["image"] = torchvision.transforms.functional.resize(
|
||||
sample["image"], size=(256, 256)
|
||||
)
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
UCMerced(self.root_dir, download=False, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms)
|
||||
self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms)
|
||||
self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -32,7 +32,6 @@ import numpy as np
|
|||
import rasterio
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset, Subset, random_split
|
||||
from torchvision.datasets.utils import check_integrity, download_url
|
||||
from torchvision.utils import draw_segmentation_masks
|
||||
|
||||
|
@ -48,7 +47,6 @@ __all__ = (
|
|||
"concat_samples",
|
||||
"merge_samples",
|
||||
"rasterio_loader",
|
||||
"dataset_split",
|
||||
"sort_sentinel2_bands",
|
||||
"draw_semantic_segmentation_masks",
|
||||
"rgb_to_mask",
|
||||
|
@ -519,31 +517,6 @@ def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg]
|
|||
return array
|
||||
|
||||
|
||||
def dataset_split(
|
||||
dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None
|
||||
) -> List[Subset[Any]]:
|
||||
"""Split a torch Dataset into train/val/test sets.
|
||||
|
||||
If ``test_pct`` is not set then only train and validation splits are returned.
|
||||
|
||||
Args:
|
||||
dataset: dataset to be split into train/val or train/val/test subsets
|
||||
val_pct: percentage of samples to be in validation set
|
||||
test_pct: (Optional) percentage of samples to be in test set
|
||||
Returns:
|
||||
a list of the subset datasets. Either [train, val] or [train, val, test]
|
||||
"""
|
||||
if test_pct is None:
|
||||
val_length = int(len(dataset) * val_pct)
|
||||
train_length = len(dataset) - val_length
|
||||
return random_split(dataset, [train_length, val_length])
|
||||
else:
|
||||
val_length = int(len(dataset) * val_pct)
|
||||
test_length = int(len(dataset) * test_pct)
|
||||
train_length = len(dataset) - (val_length + test_length)
|
||||
return random_split(dataset, [train_length, val_length, test_length])
|
||||
|
||||
|
||||
def sort_sentinel2_bands(x: str) -> str:
|
||||
"""Sort Sentinel-2 band files in the correct order."""
|
||||
x = os.path.basename(x).split("_")[-1]
|
||||
|
|
|
@ -4,21 +4,22 @@
|
|||
"""Vaihingen dataset."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, extract_archive, rgb_to_mask
|
||||
from .utils import (
|
||||
check_integrity,
|
||||
draw_semantic_segmentation_masks,
|
||||
extract_archive,
|
||||
rgb_to_mask,
|
||||
)
|
||||
|
||||
|
||||
class Vaihingen2D(VisionDataset):
|
||||
|
@ -293,111 +294,3 @@ class Vaihingen2D(VisionDataset):
|
|||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class Vaihingen2DDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the Vaihingen2D dataset.
|
||||
|
||||
Uses the train/test splits from the dataset.
|
||||
|
||||
.. versionadded: 0.2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for Vaihingen2D based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` argument to pass to the Vaihingen Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = Vaihingen2D(self.root_dir, "train", transforms=transforms)
|
||||
|
||||
if self.val_split_pct > 0.0:
|
||||
self.train_dataset, self.val_dataset, _ = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=0.0
|
||||
)
|
||||
else:
|
||||
self.train_dataset = dataset # type: ignore[assignment]
|
||||
self.val_dataset = None # type: ignore[assignment]
|
||||
|
||||
self.test_dataset = Vaihingen2D(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
if self.val_split_pct == 0.0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -5,20 +5,16 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, extract_archive
|
||||
from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive
|
||||
|
||||
|
||||
class XView2(VisionDataset):
|
||||
|
@ -282,111 +278,3 @@ class XView2(VisionDataset):
|
|||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class XView2DataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the xView2 dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
|
||||
.. versionadded: 0.2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for xView2 based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the xView2 Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.val_split_pct = val_split_pct
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset.
|
||||
|
||||
Args:
|
||||
sample: input image dictionary
|
||||
|
||||
Returns:
|
||||
preprocessed sample
|
||||
"""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
return sample
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
|
||||
Args:
|
||||
stage: stage to set up
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
dataset = XView2(self.root_dir, "train", transforms=transforms)
|
||||
|
||||
if self.val_split_pct > 0.0:
|
||||
self.train_dataset, self.val_dataset, _ = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=0.0
|
||||
)
|
||||
else:
|
||||
self.train_dataset = dataset # type: ignore[assignment]
|
||||
self.val_dataset = None # type: ignore[assignment]
|
||||
|
||||
self.test_dataset = XView2(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
||||
Returns:
|
||||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation.
|
||||
|
||||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
if self.val_split_pct == 0.0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
||||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
|
@ -10,9 +10,7 @@ from typing import Iterator, List, Optional, Tuple, Union
|
|||
from rtree.index import Index, Property
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets.geo import GeoDataset
|
||||
from torchgeo.datasets.utils import BoundingBox
|
||||
|
||||
from ..datasets import BoundingBox, GeoDataset
|
||||
from .utils import _to_tuple, get_random_bounding_box
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
|
|
|
@ -10,9 +10,7 @@ from typing import Iterator, Optional, Tuple, Union
|
|||
from rtree.index import Index, Property
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets.geo import GeoDataset
|
||||
from torchgeo.datasets.utils import BoundingBox
|
||||
|
||||
from ..datasets import BoundingBox, GeoDataset
|
||||
from .utils import _to_tuple, get_random_bounding_box
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import random
|
||||
from typing import Tuple, Union
|
||||
|
||||
from torchgeo.datasets.utils import BoundingBox
|
||||
from ..datasets import BoundingBox
|
||||
|
||||
|
||||
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
|
||||
|
|
Загрузка…
Ссылка в новой задаче