зеркало из https://github.com/microsoft/torchgeo.git
Remove non-geospatial transforms (#198)
This commit is contained in:
Родитель
24c3f70f5f
Коммит
173de92ac8
|
@ -9,11 +9,11 @@ from typing import Any, Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ADVANCE
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -40,7 +40,7 @@ class TestADVANCE:
|
|||
monkeypatch.setattr(ADVANCE, "urls", urls) # type: ignore[attr-defined]
|
||||
monkeypatch.setattr(ADVANCE, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return ADVANCE(root, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -9,11 +9,11 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import BeninSmallHolderCashews
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
@ -50,7 +50,7 @@ class TestBeninSmallHolderCashews:
|
|||
BeninSmallHolderCashews, "dates", ("2019_11_05",)
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return BeninSmallHolderCashews(
|
||||
root,
|
||||
transforms=transforms,
|
||||
|
|
|
@ -9,12 +9,12 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, CanadianBuildingFootprints, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -57,7 +57,7 @@ class TestCanadianBuildingFootprints:
|
|||
plt, "show", lambda *args: None
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return CanadianBuildingFootprints(
|
||||
root, res=0.1, transforms=transforms, download=True, checksum=True
|
||||
)
|
||||
|
|
|
@ -11,12 +11,12 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import CDL, BoundingBox, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -44,7 +44,7 @@ class TestCDL:
|
|||
plt, "show", lambda *args: None
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return CDL(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: CDL) -> None:
|
||||
|
|
|
@ -9,13 +9,13 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, Chesapeake13, ChesapeakeCVPR, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -42,7 +42,7 @@ class TestChesapeake13:
|
|||
plt, "show", lambda *args: None
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return Chesapeake13(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: Chesapeake13) -> None:
|
||||
|
@ -108,7 +108,7 @@ class TestChesapeakeCVPR:
|
|||
["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"],
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return ChesapeakeCVPR(
|
||||
root,
|
||||
splits=["de-test"],
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
@ -15,7 +16,6 @@ from torch.utils.data import ConcatDataset
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import COWCCounting, COWCDetection
|
||||
from torchgeo.datasets.cowc import COWC
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -56,7 +56,7 @@ class TestCOWCCounting:
|
|||
monkeypatch.setattr(COWCCounting, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return COWCCounting(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: COWC) -> None:
|
||||
|
@ -114,7 +114,7 @@ class TestCOWCDetection:
|
|||
monkeypatch.setattr(COWCDetection, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = "train"
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return COWCDetection(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: COWC) -> None:
|
||||
|
|
|
@ -9,11 +9,11 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import CV4AKenyaCropType
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
@ -55,7 +55,7 @@ class TestCV4AKenyaCropType:
|
|||
CV4AKenyaCropType, "dates", ["20190606"]
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return CV4AKenyaCropType(
|
||||
root,
|
||||
transforms=transforms,
|
||||
|
|
|
@ -9,12 +9,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import TropicalCycloneWindEstimation
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
@ -57,7 +57,7 @@ class TestTropicalCycloneWindEstimation:
|
|||
)
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return TropicalCycloneWindEstimation(
|
||||
root, split, transforms, download=True, api_key="", checksum=True
|
||||
)
|
||||
|
|
|
@ -8,12 +8,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ETCI2021
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -55,7 +55,7 @@ class TestETCI2021:
|
|||
monkeypatch.setattr(ETCI2021, "metadata", metadata) # type: ignore[attr-defined] # noqa: E501
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return ETCI2021(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: ETCI2021) -> None:
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Dict
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from rasterio.crs import CRS
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
@ -21,7 +22,6 @@ from torchgeo.datasets import (
|
|||
VisionDataset,
|
||||
ZipDataset,
|
||||
)
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class CustomGeoDataset(GeoDataset):
|
||||
|
@ -106,7 +106,7 @@ class TestRasterDataset:
|
|||
root = os.path.join("tests", "data", "landsat8")
|
||||
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
|
||||
crs = CRS.from_epsg(3005)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
cache = request.param
|
||||
return Landsat8(root, bands=bands, crs=crs, transforms=transforms, cache=cache)
|
||||
|
||||
|
|
|
@ -8,12 +8,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import GID15
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -37,7 +37,7 @@ class TestGID15:
|
|||
monkeypatch.setattr(GID15, "url", url) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return GID15(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: GID15) -> None:
|
||||
|
|
|
@ -8,13 +8,13 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import LandCoverAI
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -40,7 +40,7 @@ class TestLandCoverAI:
|
|||
monkeypatch.setattr(LandCoverAI, "sha256", sha256) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return LandCoverAI(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: LandCoverAI) -> None:
|
||||
|
|
|
@ -8,11 +8,11 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import BoundingBox, Landsat8, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestLandsat8:
|
||||
|
@ -23,7 +23,7 @@ class TestLandsat8:
|
|||
)
|
||||
root = os.path.join("tests", "data", "landsat8")
|
||||
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return Landsat8(root, bands=bands, transforms=transforms)
|
||||
|
||||
def test_separate_files(self, dataset: Landsat8) -> None:
|
||||
|
|
|
@ -8,12 +8,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import LEVIRCDPlus
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -37,7 +37,7 @@ class TestLEVIRCDPlus:
|
|||
monkeypatch.setattr(LEVIRCDPlus, "url", url) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return LEVIRCDPlus(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: LEVIRCDPlus) -> None:
|
||||
|
|
|
@ -8,11 +8,11 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import NAIP, BoundingBox, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestNAIP:
|
||||
|
@ -22,7 +22,7 @@ class TestNAIP:
|
|||
plt, "show", lambda *args: None
|
||||
)
|
||||
root = os.path.join("tests", "data", "naip")
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return NAIP(root, transforms=transforms)
|
||||
|
||||
def test_getitem(self, dataset: NAIP) -> None:
|
||||
|
|
|
@ -9,13 +9,13 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import VHR10
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
pytest.importorskip("rarfile")
|
||||
pytest.importorskip("pycocotools")
|
||||
|
@ -51,7 +51,7 @@ class TestVHR10:
|
|||
monkeypatch.setitem(VHR10.target_meta, "md5", md5) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return VHR10(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: VHR10) -> None:
|
||||
|
|
|
@ -7,12 +7,12 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import SEN12MS
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestSEN12MS:
|
||||
|
@ -38,7 +38,7 @@ class TestSEN12MS:
|
|||
monkeypatch.setattr(SEN12MS, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = os.path.join("tests", "data", "sen12ms")
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return SEN12MS(root, split, transforms=transforms, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: SEN12MS) -> None:
|
||||
|
|
|
@ -6,10 +6,10 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import BoundingBox, Sentinel2, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestSentinel2:
|
||||
|
@ -17,7 +17,7 @@ class TestSentinel2:
|
|||
def dataset(self) -> Sentinel2:
|
||||
root = os.path.join("tests", "data", "sentinel2")
|
||||
bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11"]
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return Sentinel2(root, bands=bands, transforms=transforms)
|
||||
|
||||
def test_separate_files(self, dataset: Sentinel2) -> None:
|
||||
|
|
|
@ -7,11 +7,11 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import So2Sat
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class TestSo2Sat:
|
||||
|
@ -28,7 +28,7 @@ class TestSo2Sat:
|
|||
monkeypatch.setattr(So2Sat, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = os.path.join("tests", "data", "so2sat")
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return So2Sat(root, split, transforms, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: So2Sat) -> None:
|
||||
|
|
|
@ -9,11 +9,11 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
TEST_DATA_DIR = "tests/data/spacenet"
|
||||
|
||||
|
@ -51,7 +51,7 @@ class TestSpaceNet1:
|
|||
SpaceNet1, "collection_md5_dict", test_md5
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return SpaceNet1(
|
||||
root,
|
||||
image=request.param,
|
||||
|
@ -104,7 +104,7 @@ class TestSpaceNet2:
|
|||
SpaceNet2, "collection_md5_dict", test_md5
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return SpaceNet2(
|
||||
root,
|
||||
image=request.param,
|
||||
|
@ -165,7 +165,7 @@ class TestSpaceNet4:
|
|||
SpaceNet4, "collection_md5_dict", test_md5
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return SpaceNet4(
|
||||
root,
|
||||
image=request.param,
|
||||
|
|
|
@ -9,11 +9,11 @@ from typing import Any, Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ZueriCrop
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
pytest.importorskip("h5py")
|
||||
|
||||
|
@ -41,7 +41,7 @@ class TestZueriCrop:
|
|||
monkeypatch.setattr(ZueriCrop, "urls", urls) # type: ignore[attr-defined]
|
||||
monkeypatch.setattr(ZueriCrop, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return ZueriCrop(root, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -105,46 +105,6 @@ def assert_matching(output: Dict[str, Tensor], expected: Dict[str, Tensor]) -> N
|
|||
assert equal, err
|
||||
|
||||
|
||||
def test_random_horizontal_flip(sample: Dict[str, Tensor]) -> None:
|
||||
tr = transforms.RandomHorizontalFlip(p=1)
|
||||
output = tr(sample)
|
||||
expected = {
|
||||
"image": torch.tensor( # type: ignore[attr-defined]
|
||||
[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]
|
||||
),
|
||||
"mask": torch.tensor( # type: ignore[attr-defined]
|
||||
[[1, 0, 0], [1, 1, 0], [1, 1, 1]]
|
||||
),
|
||||
"boxes": torch.tensor( # type: ignore[attr-defined]
|
||||
[[1, 0, 3, 2], [0, 1, 2, 3]]
|
||||
),
|
||||
}
|
||||
assert_matching(output, expected)
|
||||
|
||||
|
||||
def test_random_vertical_flip(sample: Dict[str, Tensor]) -> None:
|
||||
tr = transforms.RandomVerticalFlip(p=1)
|
||||
output = tr(sample)
|
||||
expected = {
|
||||
"image": torch.tensor( # type: ignore[attr-defined]
|
||||
[[[7, 8, 9], [4, 5, 6], [1, 2, 3]]]
|
||||
),
|
||||
"mask": torch.tensor( # type: ignore[attr-defined]
|
||||
[[1, 1, 1], [0, 1, 1], [0, 0, 1]]
|
||||
),
|
||||
"boxes": torch.tensor( # type: ignore[attr-defined]
|
||||
[[0, 1, 2, 3], [1, 0, 3, 2]]
|
||||
),
|
||||
}
|
||||
assert_matching(output, expected)
|
||||
|
||||
|
||||
def test_identity(sample: Dict[str, Tensor]) -> None:
|
||||
tr = transforms.Identity()
|
||||
output = tr(sample)
|
||||
assert_matching(output, sample)
|
||||
|
||||
|
||||
def test_augmentation_sequential_gray(batch_gray: Dict[str, Tensor]) -> None:
|
||||
expected = {
|
||||
"image": torch.tensor( # type: ignore[attr-defined]
|
||||
|
|
|
@ -4,12 +4,7 @@
|
|||
"""TorchGeo transforms."""
|
||||
|
||||
from .indices import AppendNDBI, AppendNDSI, AppendNDVI, AppendNDWI
|
||||
from .transforms import (
|
||||
AugmentationSequential,
|
||||
Identity,
|
||||
RandomHorizontalFlip,
|
||||
RandomVerticalFlip,
|
||||
)
|
||||
from .transforms import AugmentationSequential
|
||||
|
||||
__all__ = (
|
||||
"AppendNDBI",
|
||||
|
@ -17,9 +12,6 @@ __all__ = (
|
|||
"AppendNDVI",
|
||||
"AppendNDWI",
|
||||
"AugmentationSequential",
|
||||
"Identity",
|
||||
"RandomHorizontalFlip",
|
||||
"RandomVerticalFlip",
|
||||
)
|
||||
|
||||
# https://stackoverflow.com/questions/40018681
|
||||
|
|
|
@ -15,91 +15,6 @@ from torch.nn import Module # type: ignore[attr-defined]
|
|||
Module.__module__ = "torch.nn"
|
||||
|
||||
|
||||
class RandomHorizontalFlip(Module): # type: ignore[misc,name-defined]
|
||||
"""Horizontally flip the given sample randomly with a given probability."""
|
||||
|
||||
def __init__(self, p: float = 0.5) -> None:
|
||||
"""Initialize a new transform instance.
|
||||
|
||||
Args:
|
||||
p: probability of the sample being flipped
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""Randomly flip the image and target tensors.
|
||||
|
||||
Args:
|
||||
sample: a single data sample
|
||||
|
||||
Returns:
|
||||
a possibly flipped sample
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
if "image" in sample:
|
||||
sample["image"] = sample["image"].flip(-1)
|
||||
|
||||
if "boxes" in sample:
|
||||
height, width = sample["image"].shape[-2:]
|
||||
sample["boxes"][:, [0, 2]] = width - sample["boxes"][:, [2, 0]]
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].flip(-1)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class RandomVerticalFlip(Module): # type: ignore[misc,name-defined]
|
||||
"""Vertically flip the given sample randomly with a given probability."""
|
||||
|
||||
def __init__(self, p: float = 0.5) -> None:
|
||||
"""Initialize a new transform instance.
|
||||
|
||||
Args:
|
||||
p: probability of the sample being flipped
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""Randomly flip the image and target tensors.
|
||||
|
||||
Args:
|
||||
sample: a single data sample
|
||||
|
||||
Returns:
|
||||
a possibly flipped sample
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
if "image" in sample:
|
||||
sample["image"] = sample["image"].flip(-2)
|
||||
|
||||
if "boxes" in sample:
|
||||
height, width = sample["image"].shape[-2:]
|
||||
sample["boxes"][:, [1, 3]] = height - sample["boxes"][:, [3, 1]]
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].flip(-2)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class Identity(Module): # type: ignore[misc,name-defined]
|
||||
"""Identity function used for testing purposes."""
|
||||
|
||||
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""Do nothing.
|
||||
|
||||
Args:
|
||||
sample: the input
|
||||
|
||||
Returns:
|
||||
the unchanged input
|
||||
"""
|
||||
return sample
|
||||
|
||||
|
||||
class AugmentationSequential(Module): # type: ignore[misc]
|
||||
"""Wrapper around kornia AugmentationSequential to handle input dicts."""
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче