Remove non-geospatial transforms (#198)

This commit is contained in:
Adam J. Stewart 2021-10-15 23:59:17 -05:00 коммит произвёл GitHub
Родитель 24c3f70f5f
Коммит 173de92ac8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
24 изменённых файлов: 47 добавлений и 180 удалений

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Any, Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ADVANCE
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -40,7 +40,7 @@ class TestADVANCE:
monkeypatch.setattr(ADVANCE, "urls", urls) # type: ignore[attr-defined]
monkeypatch.setattr(ADVANCE, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return ADVANCE(root, transforms, download=True, checksum=True)
@pytest.fixture

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import BeninSmallHolderCashews
from torchgeo.transforms import Identity
class Dataset:
@ -50,7 +50,7 @@ class TestBeninSmallHolderCashews:
BeninSmallHolderCashews, "dates", ("2019_11_05",)
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return BeninSmallHolderCashews(
root,
transforms=transforms,

Просмотреть файл

@ -9,12 +9,12 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, CanadianBuildingFootprints, ZipDataset
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -57,7 +57,7 @@ class TestCanadianBuildingFootprints:
plt, "show", lambda *args: None
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return CanadianBuildingFootprints(
root, res=0.1, transforms=transforms, download=True, checksum=True
)

Просмотреть файл

@ -11,12 +11,12 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import CDL, BoundingBox, ZipDataset
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -44,7 +44,7 @@ class TestCDL:
plt, "show", lambda *args: None
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return CDL(root, transforms=transforms, download=True, checksum=True)
def test_getitem(self, dataset: CDL) -> None:

Просмотреть файл

@ -9,13 +9,13 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, Chesapeake13, ChesapeakeCVPR, ZipDataset
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -42,7 +42,7 @@ class TestChesapeake13:
plt, "show", lambda *args: None
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return Chesapeake13(root, transforms=transforms, download=True, checksum=True)
def test_getitem(self, dataset: Chesapeake13) -> None:
@ -108,7 +108,7 @@ class TestChesapeakeCVPR:
["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"],
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return ChesapeakeCVPR(
root,
splits=["de-test"],

Просмотреть файл

@ -8,6 +8,7 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
@ -15,7 +16,6 @@ from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import COWCCounting, COWCDetection
from torchgeo.datasets.cowc import COWC
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -56,7 +56,7 @@ class TestCOWCCounting:
monkeypatch.setattr(COWCCounting, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return COWCCounting(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: COWC) -> None:
@ -114,7 +114,7 @@ class TestCOWCDetection:
monkeypatch.setattr(COWCDetection, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
split = "train"
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return COWCDetection(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: COWC) -> None:

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import CV4AKenyaCropType
from torchgeo.transforms import Identity
class Dataset:
@ -55,7 +55,7 @@ class TestCV4AKenyaCropType:
CV4AKenyaCropType, "dates", ["20190606"]
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return CV4AKenyaCropType(
root,
transforms=transforms,

Просмотреть файл

@ -9,12 +9,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import TropicalCycloneWindEstimation
from torchgeo.transforms import Identity
class Dataset:
@ -57,7 +57,7 @@ class TestTropicalCycloneWindEstimation:
)
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return TropicalCycloneWindEstimation(
root, split, transforms, download=True, api_key="", checksum=True
)

Просмотреть файл

@ -8,12 +8,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ETCI2021
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -55,7 +55,7 @@ class TestETCI2021:
monkeypatch.setattr(ETCI2021, "metadata", metadata) # type: ignore[attr-defined] # noqa: E501
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return ETCI2021(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: ETCI2021) -> None:

Просмотреть файл

@ -7,6 +7,7 @@ from typing import Dict
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import ConcatDataset
@ -21,7 +22,6 @@ from torchgeo.datasets import (
VisionDataset,
ZipDataset,
)
from torchgeo.transforms import Identity
class CustomGeoDataset(GeoDataset):
@ -106,7 +106,7 @@ class TestRasterDataset:
root = os.path.join("tests", "data", "landsat8")
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
crs = CRS.from_epsg(3005)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
cache = request.param
return Landsat8(root, bands=bands, crs=crs, transforms=transforms, cache=cache)

Просмотреть файл

@ -8,12 +8,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import GID15
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -37,7 +37,7 @@ class TestGID15:
monkeypatch.setattr(GID15, "url", url) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return GID15(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: GID15) -> None:

Просмотреть файл

@ -8,13 +8,13 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import LandCoverAI
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -40,7 +40,7 @@ class TestLandCoverAI:
monkeypatch.setattr(LandCoverAI, "sha256", sha256) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return LandCoverAI(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: LandCoverAI) -> None:

Просмотреть файл

@ -8,11 +8,11 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, Landsat8, ZipDataset
from torchgeo.transforms import Identity
class TestLandsat8:
@ -23,7 +23,7 @@ class TestLandsat8:
)
root = os.path.join("tests", "data", "landsat8")
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return Landsat8(root, bands=bands, transforms=transforms)
def test_separate_files(self, dataset: Landsat8) -> None:

Просмотреть файл

@ -8,12 +8,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import LEVIRCDPlus
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
@ -37,7 +37,7 @@ class TestLEVIRCDPlus:
monkeypatch.setattr(LEVIRCDPlus, "url", url) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return LEVIRCDPlus(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: LEVIRCDPlus) -> None:

Просмотреть файл

@ -8,11 +8,11 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import NAIP, BoundingBox, ZipDataset
from torchgeo.transforms import Identity
class TestNAIP:
@ -22,7 +22,7 @@ class TestNAIP:
plt, "show", lambda *args: None
)
root = os.path.join("tests", "data", "naip")
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return NAIP(root, transforms=transforms)
def test_getitem(self, dataset: NAIP) -> None:

Просмотреть файл

@ -9,13 +9,13 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import VHR10
from torchgeo.transforms import Identity
pytest.importorskip("rarfile")
pytest.importorskip("pycocotools")
@ -51,7 +51,7 @@ class TestVHR10:
monkeypatch.setitem(VHR10.target_meta, "md5", md5) # type: ignore[attr-defined]
root = str(tmp_path)
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return VHR10(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: VHR10) -> None:

Просмотреть файл

@ -7,12 +7,12 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import SEN12MS
from torchgeo.transforms import Identity
class TestSEN12MS:
@ -38,7 +38,7 @@ class TestSEN12MS:
monkeypatch.setattr(SEN12MS, "md5s", md5s) # type: ignore[attr-defined]
root = os.path.join("tests", "data", "sen12ms")
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return SEN12MS(root, split, transforms=transforms, checksum=True)
def test_getitem(self, dataset: SEN12MS) -> None:

Просмотреть файл

@ -6,10 +6,10 @@ from pathlib import Path
import pytest
import torch
import torch.nn as nn
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, Sentinel2, ZipDataset
from torchgeo.transforms import Identity
class TestSentinel2:
@ -17,7 +17,7 @@ class TestSentinel2:
def dataset(self) -> Sentinel2:
root = os.path.join("tests", "data", "sentinel2")
bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11"]
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return Sentinel2(root, bands=bands, transforms=transforms)
def test_separate_files(self, dataset: Sentinel2) -> None:

Просмотреть файл

@ -7,11 +7,11 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import So2Sat
from torchgeo.transforms import Identity
class TestSo2Sat:
@ -28,7 +28,7 @@ class TestSo2Sat:
monkeypatch.setattr(So2Sat, "md5s", md5s) # type: ignore[attr-defined]
root = os.path.join("tests", "data", "so2sat")
split = request.param
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return So2Sat(root, split, transforms, checksum=True)
def test_getitem(self, dataset: So2Sat) -> None:

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4
from torchgeo.transforms import Identity
TEST_DATA_DIR = "tests/data/spacenet"
@ -51,7 +51,7 @@ class TestSpaceNet1:
SpaceNet1, "collection_md5_dict", test_md5
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return SpaceNet1(
root,
image=request.param,
@ -104,7 +104,7 @@ class TestSpaceNet2:
SpaceNet2, "collection_md5_dict", test_md5
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return SpaceNet2(
root,
image=request.param,
@ -165,7 +165,7 @@ class TestSpaceNet4:
SpaceNet4, "collection_md5_dict", test_md5
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return SpaceNet4(
root,
image=request.param,

Просмотреть файл

@ -9,11 +9,11 @@ from typing import Any, Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ZueriCrop
from torchgeo.transforms import Identity
pytest.importorskip("h5py")
@ -41,7 +41,7 @@ class TestZueriCrop:
monkeypatch.setattr(ZueriCrop, "urls", urls) # type: ignore[attr-defined]
monkeypatch.setattr(ZueriCrop, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return ZueriCrop(root, transforms, download=True, checksum=True)
@pytest.fixture

Просмотреть файл

@ -105,46 +105,6 @@ def assert_matching(output: Dict[str, Tensor], expected: Dict[str, Tensor]) -> N
assert equal, err
def test_random_horizontal_flip(sample: Dict[str, Tensor]) -> None:
tr = transforms.RandomHorizontalFlip(p=1)
output = tr(sample)
expected = {
"image": torch.tensor( # type: ignore[attr-defined]
[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]
),
"mask": torch.tensor( # type: ignore[attr-defined]
[[1, 0, 0], [1, 1, 0], [1, 1, 1]]
),
"boxes": torch.tensor( # type: ignore[attr-defined]
[[1, 0, 3, 2], [0, 1, 2, 3]]
),
}
assert_matching(output, expected)
def test_random_vertical_flip(sample: Dict[str, Tensor]) -> None:
tr = transforms.RandomVerticalFlip(p=1)
output = tr(sample)
expected = {
"image": torch.tensor( # type: ignore[attr-defined]
[[[7, 8, 9], [4, 5, 6], [1, 2, 3]]]
),
"mask": torch.tensor( # type: ignore[attr-defined]
[[1, 1, 1], [0, 1, 1], [0, 0, 1]]
),
"boxes": torch.tensor( # type: ignore[attr-defined]
[[0, 1, 2, 3], [1, 0, 3, 2]]
),
}
assert_matching(output, expected)
def test_identity(sample: Dict[str, Tensor]) -> None:
tr = transforms.Identity()
output = tr(sample)
assert_matching(output, sample)
def test_augmentation_sequential_gray(batch_gray: Dict[str, Tensor]) -> None:
expected = {
"image": torch.tensor( # type: ignore[attr-defined]

Просмотреть файл

@ -4,12 +4,7 @@
"""TorchGeo transforms."""
from .indices import AppendNDBI, AppendNDSI, AppendNDVI, AppendNDWI
from .transforms import (
AugmentationSequential,
Identity,
RandomHorizontalFlip,
RandomVerticalFlip,
)
from .transforms import AugmentationSequential
__all__ = (
"AppendNDBI",
@ -17,9 +12,6 @@ __all__ = (
"AppendNDVI",
"AppendNDWI",
"AugmentationSequential",
"Identity",
"RandomHorizontalFlip",
"RandomVerticalFlip",
)
# https://stackoverflow.com/questions/40018681

Просмотреть файл

@ -15,91 +15,6 @@ from torch.nn import Module # type: ignore[attr-defined]
Module.__module__ = "torch.nn"
class RandomHorizontalFlip(Module): # type: ignore[misc,name-defined]
"""Horizontally flip the given sample randomly with a given probability."""
def __init__(self, p: float = 0.5) -> None:
"""Initialize a new transform instance.
Args:
p: probability of the sample being flipped
"""
super().__init__()
self.p = p
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Randomly flip the image and target tensors.
Args:
sample: a single data sample
Returns:
a possibly flipped sample
"""
if torch.rand(1) < self.p:
if "image" in sample:
sample["image"] = sample["image"].flip(-1)
if "boxes" in sample:
height, width = sample["image"].shape[-2:]
sample["boxes"][:, [0, 2]] = width - sample["boxes"][:, [2, 0]]
if "mask" in sample:
sample["mask"] = sample["mask"].flip(-1)
return sample
class RandomVerticalFlip(Module): # type: ignore[misc,name-defined]
"""Vertically flip the given sample randomly with a given probability."""
def __init__(self, p: float = 0.5) -> None:
"""Initialize a new transform instance.
Args:
p: probability of the sample being flipped
"""
super().__init__()
self.p = p
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Randomly flip the image and target tensors.
Args:
sample: a single data sample
Returns:
a possibly flipped sample
"""
if torch.rand(1) < self.p:
if "image" in sample:
sample["image"] = sample["image"].flip(-2)
if "boxes" in sample:
height, width = sample["image"].shape[-2:]
sample["boxes"][:, [1, 3]] = height - sample["boxes"][:, [3, 1]]
if "mask" in sample:
sample["mask"] = sample["mask"].flip(-2)
return sample
class Identity(Module): # type: ignore[misc,name-defined]
"""Identity function used for testing purposes."""
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Do nothing.
Args:
sample: the input
Returns:
the unchanged input
"""
return sample
class AugmentationSequential(Module): # type: ignore[misc]
"""Wrapper around kornia AugmentationSequential to handle input dicts."""