diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index 799b00281..762164b8e 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -9,11 +9,11 @@ from typing import Any, Generator import pytest import torch +import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils from torchgeo.datasets import ADVANCE -from torchgeo.transforms import Identity def download_url(url: str, root: str, *args: str) -> None: @@ -40,7 +40,7 @@ class TestADVANCE: monkeypatch.setattr(ADVANCE, "urls", urls) # type: ignore[attr-defined] monkeypatch.setattr(ADVANCE, "md5s", md5s) # type: ignore[attr-defined] root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return ADVANCE(root, transforms, download=True, checksum=True) @pytest.fixture diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py index 53261ee00..c22835fea 100644 --- a/tests/datasets/test_benin_cashews.py +++ b/tests/datasets/test_benin_cashews.py @@ -9,11 +9,11 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset from torchgeo.datasets import BeninSmallHolderCashews -from torchgeo.transforms import Identity class Dataset: @@ -50,7 +50,7 @@ class TestBeninSmallHolderCashews: BeninSmallHolderCashews, "dates", ("2019_11_05",) ) root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return BeninSmallHolderCashews( root, transforms=transforms, diff --git a/tests/datasets/test_cbf.py b/tests/datasets/test_cbf.py index 6d0ca26f9..6b21794ad 100644 --- a/tests/datasets/test_cbf.py +++ b/tests/datasets/test_cbf.py @@ -9,12 +9,12 @@ from typing import Generator import matplotlib.pyplot as plt import pytest import torch +import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS import torchgeo.datasets.utils from torchgeo.datasets import BoundingBox, CanadianBuildingFootprints, ZipDataset -from torchgeo.transforms import Identity def download_url(url: str, root: str, *args: str) -> None: @@ -57,7 +57,7 @@ class TestCanadianBuildingFootprints: plt, "show", lambda *args: None ) root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return CanadianBuildingFootprints( root, res=0.1, transforms=transforms, download=True, checksum=True ) diff --git a/tests/datasets/test_cdl.py b/tests/datasets/test_cdl.py index 42d45ba07..30c08a3a3 100644 --- a/tests/datasets/test_cdl.py +++ b/tests/datasets/test_cdl.py @@ -11,12 +11,12 @@ from typing import Generator import matplotlib.pyplot as plt import pytest import torch +import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS import torchgeo.datasets.utils from torchgeo.datasets import CDL, BoundingBox, ZipDataset -from torchgeo.transforms import Identity def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -44,7 +44,7 @@ class TestCDL: plt, "show", lambda *args: None ) root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return CDL(root, transforms=transforms, download=True, checksum=True) def test_getitem(self, dataset: CDL) -> None: diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 68884e425..beb7339ec 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -9,13 +9,13 @@ from typing import Generator import matplotlib.pyplot as plt import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS import torchgeo.datasets.utils from torchgeo.datasets import BoundingBox, Chesapeake13, ChesapeakeCVPR, ZipDataset -from torchgeo.transforms import Identity def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -42,7 +42,7 @@ class TestChesapeake13: plt, "show", lambda *args: None ) root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return Chesapeake13(root, transforms=transforms, download=True, checksum=True) def test_getitem(self, dataset: Chesapeake13) -> None: @@ -108,7 +108,7 @@ class TestChesapeakeCVPR: ["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"], ) root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return ChesapeakeCVPR( root, splits=["de-test"], diff --git a/tests/datasets/test_cowc.py b/tests/datasets/test_cowc.py index 6f2c56fff..24c0a4454 100644 --- a/tests/datasets/test_cowc.py +++ b/tests/datasets/test_cowc.py @@ -8,6 +8,7 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset @@ -15,7 +16,6 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils from torchgeo.datasets import COWCCounting, COWCDetection from torchgeo.datasets.cowc import COWC -from torchgeo.transforms import Identity def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -56,7 +56,7 @@ class TestCOWCCounting: monkeypatch.setattr(COWCCounting, "md5s", md5s) # type: ignore[attr-defined] root = str(tmp_path) split = request.param - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return COWCCounting(root, split, transforms, download=True, checksum=True) def test_getitem(self, dataset: COWC) -> None: @@ -114,7 +114,7 @@ class TestCOWCDetection: monkeypatch.setattr(COWCDetection, "md5s", md5s) # type: ignore[attr-defined] root = str(tmp_path) split = "train" - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return COWCDetection(root, split, transforms, download=True, checksum=True) def test_getitem(self, dataset: COWC) -> None: diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py index 49ef2fb4a..7a3052a80 100644 --- a/tests/datasets/test_cv4a_kenya_crop_type.py +++ b/tests/datasets/test_cv4a_kenya_crop_type.py @@ -9,11 +9,11 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset from torchgeo.datasets import CV4AKenyaCropType -from torchgeo.transforms import Identity class Dataset: @@ -55,7 +55,7 @@ class TestCV4AKenyaCropType: CV4AKenyaCropType, "dates", ["20190606"] ) root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return CV4AKenyaCropType( root, transforms=transforms, diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index db63e815b..c4d558bec 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -9,12 +9,12 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset from torchgeo.datasets import TropicalCycloneWindEstimation -from torchgeo.transforms import Identity class Dataset: @@ -57,7 +57,7 @@ class TestTropicalCycloneWindEstimation: ) root = str(tmp_path) split = request.param - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return TropicalCycloneWindEstimation( root, split, transforms, download=True, api_key="", checksum=True ) diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py index de440dd7e..88481df18 100644 --- a/tests/datasets/test_etci2021.py +++ b/tests/datasets/test_etci2021.py @@ -8,12 +8,12 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils from torchgeo.datasets import ETCI2021 -from torchgeo.transforms import Identity def download_url(url: str, root: str, *args: str) -> None: @@ -55,7 +55,7 @@ class TestETCI2021: monkeypatch.setattr(ETCI2021, "metadata", metadata) # type: ignore[attr-defined] # noqa: E501 root = str(tmp_path) split = request.param - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return ETCI2021(root, split, transforms, download=True, checksum=True) def test_getitem(self, dataset: ETCI2021) -> None: diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 2c8854d66..8deb5d030 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -7,6 +7,7 @@ from typing import Dict import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import ConcatDataset @@ -21,7 +22,6 @@ from torchgeo.datasets import ( VisionDataset, ZipDataset, ) -from torchgeo.transforms import Identity class CustomGeoDataset(GeoDataset): @@ -106,7 +106,7 @@ class TestRasterDataset: root = os.path.join("tests", "data", "landsat8") bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"] crs = CRS.from_epsg(3005) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] cache = request.param return Landsat8(root, bands=bands, crs=crs, transforms=transforms, cache=cache) diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py index 7cc90edd9..9674a9847 100644 --- a/tests/datasets/test_gid15.py +++ b/tests/datasets/test_gid15.py @@ -8,12 +8,12 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils from torchgeo.datasets import GID15 -from torchgeo.transforms import Identity def download_url(url: str, root: str, *args: str) -> None: @@ -37,7 +37,7 @@ class TestGID15: monkeypatch.setattr(GID15, "url", url) # type: ignore[attr-defined] root = str(tmp_path) split = request.param - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return GID15(root, split, transforms, download=True, checksum=True) def test_getitem(self, dataset: GID15) -> None: diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 559a01e7f..4b632b58a 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -8,13 +8,13 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset import torchgeo.datasets.utils from torchgeo.datasets import LandCoverAI -from torchgeo.transforms import Identity def download_url(url: str, root: str, *args: str) -> None: @@ -40,7 +40,7 @@ class TestLandCoverAI: monkeypatch.setattr(LandCoverAI, "sha256", sha256) # type: ignore[attr-defined] root = str(tmp_path) split = request.param - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return LandCoverAI(root, split, transforms, download=True, checksum=True) def test_getitem(self, dataset: LandCoverAI) -> None: diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index 4286791e6..5ca46bdf8 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -8,11 +8,11 @@ from typing import Generator import matplotlib.pyplot as plt import pytest import torch +import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS from torchgeo.datasets import BoundingBox, Landsat8, ZipDataset -from torchgeo.transforms import Identity class TestLandsat8: @@ -23,7 +23,7 @@ class TestLandsat8: ) root = os.path.join("tests", "data", "landsat8") bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"] - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return Landsat8(root, bands=bands, transforms=transforms) def test_separate_files(self, dataset: Landsat8) -> None: diff --git a/tests/datasets/test_levircd.py b/tests/datasets/test_levircd.py index 9bd77f7e4..2aca6c8b0 100644 --- a/tests/datasets/test_levircd.py +++ b/tests/datasets/test_levircd.py @@ -8,12 +8,12 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils from torchgeo.datasets import LEVIRCDPlus -from torchgeo.transforms import Identity def download_url(url: str, root: str, *args: str) -> None: @@ -37,7 +37,7 @@ class TestLEVIRCDPlus: monkeypatch.setattr(LEVIRCDPlus, "url", url) # type: ignore[attr-defined] root = str(tmp_path) split = request.param - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return LEVIRCDPlus(root, split, transforms, download=True, checksum=True) def test_getitem(self, dataset: LEVIRCDPlus) -> None: diff --git a/tests/datasets/test_naip.py b/tests/datasets/test_naip.py index 838fbfca4..bde5b289b 100644 --- a/tests/datasets/test_naip.py +++ b/tests/datasets/test_naip.py @@ -8,11 +8,11 @@ from typing import Generator import matplotlib.pyplot as plt import pytest import torch +import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS from torchgeo.datasets import NAIP, BoundingBox, ZipDataset -from torchgeo.transforms import Identity class TestNAIP: @@ -22,7 +22,7 @@ class TestNAIP: plt, "show", lambda *args: None ) root = os.path.join("tests", "data", "naip") - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return NAIP(root, transforms=transforms) def test_getitem(self, dataset: NAIP) -> None: diff --git a/tests/datasets/test_nwpu.py b/tests/datasets/test_nwpu.py index c8e6fb1d7..9007132c3 100644 --- a/tests/datasets/test_nwpu.py +++ b/tests/datasets/test_nwpu.py @@ -9,13 +9,13 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset import torchgeo.datasets.utils from torchgeo.datasets import VHR10 -from torchgeo.transforms import Identity pytest.importorskip("rarfile") pytest.importorskip("pycocotools") @@ -51,7 +51,7 @@ class TestVHR10: monkeypatch.setitem(VHR10.target_meta, "md5", md5) # type: ignore[attr-defined] root = str(tmp_path) split = request.param - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return VHR10(root, split, transforms, download=True, checksum=True) def test_getitem(self, dataset: VHR10) -> None: diff --git a/tests/datasets/test_sen12ms.py b/tests/datasets/test_sen12ms.py index 8a882d089..94c796434 100644 --- a/tests/datasets/test_sen12ms.py +++ b/tests/datasets/test_sen12ms.py @@ -7,12 +7,12 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from torch.utils.data import ConcatDataset from torchgeo.datasets import SEN12MS -from torchgeo.transforms import Identity class TestSEN12MS: @@ -38,7 +38,7 @@ class TestSEN12MS: monkeypatch.setattr(SEN12MS, "md5s", md5s) # type: ignore[attr-defined] root = os.path.join("tests", "data", "sen12ms") split = request.param - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return SEN12MS(root, split, transforms=transforms, checksum=True) def test_getitem(self, dataset: SEN12MS) -> None: diff --git a/tests/datasets/test_sentinel.py b/tests/datasets/test_sentinel.py index 743e3e915..601e9d917 100644 --- a/tests/datasets/test_sentinel.py +++ b/tests/datasets/test_sentinel.py @@ -6,10 +6,10 @@ from pathlib import Path import pytest import torch +import torch.nn as nn from rasterio.crs import CRS from torchgeo.datasets import BoundingBox, Sentinel2, ZipDataset -from torchgeo.transforms import Identity class TestSentinel2: @@ -17,7 +17,7 @@ class TestSentinel2: def dataset(self) -> Sentinel2: root = os.path.join("tests", "data", "sentinel2") bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11"] - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return Sentinel2(root, bands=bands, transforms=transforms) def test_separate_files(self, dataset: Sentinel2) -> None: diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index d76c8498a..e54e0ed20 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -7,11 +7,11 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from torchgeo.datasets import So2Sat -from torchgeo.transforms import Identity class TestSo2Sat: @@ -28,7 +28,7 @@ class TestSo2Sat: monkeypatch.setattr(So2Sat, "md5s", md5s) # type: ignore[attr-defined] root = os.path.join("tests", "data", "so2sat") split = request.param - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return So2Sat(root, split, transforms, checksum=True) def test_getitem(self, dataset: So2Sat) -> None: diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index b769f4385..06747f510 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -9,11 +9,11 @@ from typing import Generator import pytest import torch +import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4 -from torchgeo.transforms import Identity TEST_DATA_DIR = "tests/data/spacenet" @@ -51,7 +51,7 @@ class TestSpaceNet1: SpaceNet1, "collection_md5_dict", test_md5 ) root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return SpaceNet1( root, image=request.param, @@ -104,7 +104,7 @@ class TestSpaceNet2: SpaceNet2, "collection_md5_dict", test_md5 ) root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return SpaceNet2( root, image=request.param, @@ -165,7 +165,7 @@ class TestSpaceNet4: SpaceNet4, "collection_md5_dict", test_md5 ) root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return SpaceNet4( root, image=request.param, diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py index 926e44c7c..613efd11e 100644 --- a/tests/datasets/test_zuericrop.py +++ b/tests/datasets/test_zuericrop.py @@ -9,11 +9,11 @@ from typing import Any, Generator import pytest import torch +import torch.nn as nn from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils from torchgeo.datasets import ZueriCrop -from torchgeo.transforms import Identity pytest.importorskip("h5py") @@ -41,7 +41,7 @@ class TestZueriCrop: monkeypatch.setattr(ZueriCrop, "urls", urls) # type: ignore[attr-defined] monkeypatch.setattr(ZueriCrop, "md5s", md5s) # type: ignore[attr-defined] root = str(tmp_path) - transforms = Identity() + transforms = nn.Identity() # type: ignore[attr-defined] return ZueriCrop(root, transforms, download=True, checksum=True) @pytest.fixture diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index e15bada1a..b45aad80a 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -105,46 +105,6 @@ def assert_matching(output: Dict[str, Tensor], expected: Dict[str, Tensor]) -> N assert equal, err -def test_random_horizontal_flip(sample: Dict[str, Tensor]) -> None: - tr = transforms.RandomHorizontalFlip(p=1) - output = tr(sample) - expected = { - "image": torch.tensor( # type: ignore[attr-defined] - [[[3, 2, 1], [6, 5, 4], [9, 8, 7]]] - ), - "mask": torch.tensor( # type: ignore[attr-defined] - [[1, 0, 0], [1, 1, 0], [1, 1, 1]] - ), - "boxes": torch.tensor( # type: ignore[attr-defined] - [[1, 0, 3, 2], [0, 1, 2, 3]] - ), - } - assert_matching(output, expected) - - -def test_random_vertical_flip(sample: Dict[str, Tensor]) -> None: - tr = transforms.RandomVerticalFlip(p=1) - output = tr(sample) - expected = { - "image": torch.tensor( # type: ignore[attr-defined] - [[[7, 8, 9], [4, 5, 6], [1, 2, 3]]] - ), - "mask": torch.tensor( # type: ignore[attr-defined] - [[1, 1, 1], [0, 1, 1], [0, 0, 1]] - ), - "boxes": torch.tensor( # type: ignore[attr-defined] - [[0, 1, 2, 3], [1, 0, 3, 2]] - ), - } - assert_matching(output, expected) - - -def test_identity(sample: Dict[str, Tensor]) -> None: - tr = transforms.Identity() - output = tr(sample) - assert_matching(output, sample) - - def test_augmentation_sequential_gray(batch_gray: Dict[str, Tensor]) -> None: expected = { "image": torch.tensor( # type: ignore[attr-defined] diff --git a/torchgeo/transforms/__init__.py b/torchgeo/transforms/__init__.py index 48974b061..f3c074a25 100644 --- a/torchgeo/transforms/__init__.py +++ b/torchgeo/transforms/__init__.py @@ -4,12 +4,7 @@ """TorchGeo transforms.""" from .indices import AppendNDBI, AppendNDSI, AppendNDVI, AppendNDWI -from .transforms import ( - AugmentationSequential, - Identity, - RandomHorizontalFlip, - RandomVerticalFlip, -) +from .transforms import AugmentationSequential __all__ = ( "AppendNDBI", @@ -17,9 +12,6 @@ __all__ = ( "AppendNDVI", "AppendNDWI", "AugmentationSequential", - "Identity", - "RandomHorizontalFlip", - "RandomVerticalFlip", ) # https://stackoverflow.com/questions/40018681 diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index f53c540ce..b1b8cda3a 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -15,91 +15,6 @@ from torch.nn import Module # type: ignore[attr-defined] Module.__module__ = "torch.nn" -class RandomHorizontalFlip(Module): # type: ignore[misc,name-defined] - """Horizontally flip the given sample randomly with a given probability.""" - - def __init__(self, p: float = 0.5) -> None: - """Initialize a new transform instance. - - Args: - p: probability of the sample being flipped - """ - super().__init__() - self.p = p - - def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - """Randomly flip the image and target tensors. - - Args: - sample: a single data sample - - Returns: - a possibly flipped sample - """ - if torch.rand(1) < self.p: - if "image" in sample: - sample["image"] = sample["image"].flip(-1) - - if "boxes" in sample: - height, width = sample["image"].shape[-2:] - sample["boxes"][:, [0, 2]] = width - sample["boxes"][:, [2, 0]] - - if "mask" in sample: - sample["mask"] = sample["mask"].flip(-1) - - return sample - - -class RandomVerticalFlip(Module): # type: ignore[misc,name-defined] - """Vertically flip the given sample randomly with a given probability.""" - - def __init__(self, p: float = 0.5) -> None: - """Initialize a new transform instance. - - Args: - p: probability of the sample being flipped - """ - super().__init__() - self.p = p - - def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - """Randomly flip the image and target tensors. - - Args: - sample: a single data sample - - Returns: - a possibly flipped sample - """ - if torch.rand(1) < self.p: - if "image" in sample: - sample["image"] = sample["image"].flip(-2) - - if "boxes" in sample: - height, width = sample["image"].shape[-2:] - sample["boxes"][:, [1, 3]] = height - sample["boxes"][:, [3, 1]] - - if "mask" in sample: - sample["mask"] = sample["mask"].flip(-2) - - return sample - - -class Identity(Module): # type: ignore[misc,name-defined] - """Identity function used for testing purposes.""" - - def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - """Do nothing. - - Args: - sample: the input - - Returns: - the unchanged input - """ - return sample - - class AugmentationSequential(Module): # type: ignore[misc] """Wrapper around kornia AugmentationSequential to handle input dicts."""