diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index ddfd29c01..2f3d80ff5 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -158,7 +158,7 @@ jobs: - name: List pip dependencies run: pip list - name: Run pyupgrade checks - run: pyupgrade --py39-plus $(find . -path ./docs/src -prune -o -name "*.py" -print) + run: pyupgrade --py310-plus $(find . -path ./docs/src -prune -o -name "*.py" -print) concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }} cancel-in-progress: true diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 5c64023eb..98be90158 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -17,7 +17,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ['3.10', '3.11', '3.12'] steps: - name: Clone repo uses: actions/checkout@v4.1.2 @@ -71,7 +71,7 @@ jobs: - name: Set up python uses: actions/setup-python@v5.1.0 with: - python-version: '3.9' + python-version: '3.10' - name: Cache dependencies uses: actions/cache@v4.0.2 id: cache diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 664e66bc8..44fca1327 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: v3.15.2 hooks: - id: pyupgrade - args: [--py39-plus] + args: [--py310-plus] - repo: https://github.com/pycqa/isort rev: 5.13.2 diff --git a/docs/user/contributing.rst b/docs/user/contributing.rst index 50e4d2515..7e1263478 100644 --- a/docs/user/contributing.rst +++ b/docs/user/contributing.rst @@ -103,7 +103,7 @@ All of these tools should be used from the root of the project to ensure that ou $ black . $ isort . - $ pyupgrade --py39-plus $(find . -name "*.py") + $ pyupgrade --py310-plus $(find . -name "*.py") Flake8, pydocstyle, and mypy won't format your code for you, but they will warn you about potential issues with your code or docstrings: diff --git a/experiments/ssl4eo/download_ssl4eo.py b/experiments/ssl4eo/download_ssl4eo.py index 7389dbf24..3ca4fa71c 100755 --- a/experiments/ssl4eo/download_ssl4eo.py +++ b/experiments/ssl4eo/download_ssl4eo.py @@ -56,7 +56,7 @@ import warnings from collections import defaultdict from datetime import date, timedelta from multiprocessing.dummy import Lock, Pool -from typing import Any, Optional +from typing import Any import ee import numpy as np @@ -168,7 +168,7 @@ def get_patch( new_resolutions: list[int], dtype: str = "float32", meta_cloud_name: str = "CLOUD_COVER", - default_value: Optional[float] = None, + default_value: float | None = None, ) -> dict[str, Any]: image = collection.sort(meta_cloud_name).first() region = ee.Geometry.Point(center_coord).buffer(radius).bounds() @@ -214,7 +214,7 @@ def get_random_patches_match( new_resolutions: list[int], dtype: str, meta_cloud_name: str, - default_value: Optional[float], + default_value: float | None, dates: list[date], radius: float, debug: bool = False, diff --git a/pyproject.toml b/pyproject.toml index 350d9273b..7c8f87260 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta" name = "torchgeo" description = "TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" license = {file = "LICENSE"} authors = [ {name = "Adam J. Stewart", email = "ajstewart426@gmail.com"}, @@ -29,7 +29,6 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -39,9 +38,8 @@ classifiers = [ dependencies = [ # einops 0.3+ required for einops.repeat "einops>=0.3", - # fiona 1.8.19+ required to fix erroneous warning - # https://github.com/Toblerity/Fiona/issues/986 - "fiona>=1.8.19", + # fiona 1.8.21+ required for Python 3.10 wheels + "fiona>=1.8.21", # kornia 0.6.9+ required for kornia.augmentation.RandomBrightness "kornia>=0.6.9", # lightly 1.4.4+ required for MoCo v3 support @@ -50,24 +48,24 @@ dependencies = [ "lightly>=1.4.4,!=1.4.26", # lightning 2+ required for LightningCLI args + sys.argv support "lightning[pytorch-extra]>=2", - # matplotlib 3.3.3+ required for Python 3.9 wheels - "matplotlib>=3.3.3", - # numpy 1.19.3+ required by Python 3.9 wheels - "numpy>=1.19.3", - # pandas 1.1.3+ required for Python 3.9 wheels - "pandas>=1.1.3", - # pillow 8+ required for Python 3.9 wheels - "pillow>=8", - # pyproj 3+ required for Python 3.9 wheels - "pyproj>=3", - # rasterio 1.2+ required for Python 3.9 wheels - "rasterio>=1.2", - # rtree 1+ required for len(index), index & index, index | index + # matplotlib 3.5+ required for Python 3.10 wheels + "matplotlib>=3.5", + # numpy 1.21.2+ required by Python 3.10 wheels + "numpy>=1.21.2", + # pandas 1.3.3+ required for Python 3.10 wheels + "pandas>=1.3.3", + # pillow 8.4+ required for Python 3.10 wheels + "pillow>=8.4", + # pyproj 3.3+ required for Python 3.10 wheels + "pyproj>=3.3", + # rasterio 1.3+ required for Python 3.10 wheels + "rasterio>=1.3", + # rtree 1+ required for Python 3.10 wheels "rtree>=1", # segmentation-models-pytorch 0.2+ required for smp.losses module "segmentation-models-pytorch>=0.2", - # shapely 1.7.1+ required for Python 3.9 wheels - "shapely>=1.7.1", + # shapely 1.8+ required for Python 3.10 wheels + "shapely>=1.8", # timm 0.4.12 required by segmentation-models-pytorch "timm>=0.4.12", # torch 1.13+ required by torchvision @@ -81,27 +79,25 @@ dynamic = ["version"] [project.optional-dependencies] datasets = [ - # h5py 3+ required for Python 3.9 wheels - "h5py>=3", + # h5py 3.6+ required for Python 3.10 wheels + "h5py>=3.6", # laspy 2+ required for laspy.read "laspy>=2", - # opencv-python 4.4.0.46+ required for Python 3.9 wheels - "opencv-python>=4.4.0.46", - # pycocotools 2.0.5+ required for cython 3+ support - "pycocotools>=2.0.5", + # opencv-python 4.5.4+ required for Python 3.10 wheels + "opencv-python>=4.5.4", + # pycocotools 2.0.7+ required for wheels + "pycocotools>=2.0.7", # pyvista 0.34.2+ required to avoid ImportError in CI "pyvista>=0.34.2", # radiant-mlhub 0.3+ required for newer tqdm support required by lightning "radiant-mlhub>=0.3", # rarfile 4+ required for wheels "rarfile>=4", - # scikit-image 0.18+ required for numpy 1.17+ compatibility - # https://github.com/scikit-image/scikit-image/issues/3655 - "scikit-image>=0.18", - # scipy 1.6.2+ required for scikit-image 0.18+ compatibility - "scipy>=1.6.2", - # zipfile-deflate64 0.2+ required for extraction bugfix: - # https://github.com/brianhelba/zipfile-deflate64/issues/19 + # scikit-image 0.19+ required for Python 3.10 wheels + "scikit-image>=0.19", + # scipy 1.7.2+ required for Python 3.10 wheels + "scipy>=1.7.2", + # zipfile-deflate64 0.2+ required for Python 3.10 wheels "zipfile-deflate64>=0.2", ] docs = [ @@ -126,7 +122,7 @@ style = [ "isort[colors]>=5.8", # pydocstyle 6.1+ required for pyproject.toml support "pydocstyle[toml]>=6.1", - # pyupgrade 2.8+ required for --py39-plus flag + # pyupgrade 2.8+ required for --py310-plus flag "pyupgrade>=2.8", ] tests = [ @@ -151,7 +147,7 @@ Homepage = "https://github.com/microsoft/torchgeo" Documentation = "https://torchgeo.readthedocs.io" [tool.black] -target-version = ["py39", "py310"] +target-version = ["py310"] color = true skip_magic_trailing_comma = true @@ -170,7 +166,7 @@ skip_gitignore = true color_output = true [tool.mypy] -python_version = "3.9" +python_version = "3.10" ignore_missing_imports = true show_error_codes = true exclude = "(build|data|dist|docs/src|images|logo|logs|output)/" diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 34b22649c..8889520d1 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -3,34 +3,34 @@ setuptools==61.0.0 # install einops==0.3.0 -fiona==1.8.19 +fiona==1.8.21 kornia==0.6.9 lightly==1.4.4 lightning[pytorch-extra]==2.0.0 -matplotlib==3.3.3 -numpy==1.19.3 -pandas==1.1.3 -pillow==8.0.0 -pyproj==3.0.0 -rasterio==1.2.0 +matplotlib==3.5.0 +numpy==1.21.2 +pandas==1.3.3 +pillow==8.4.0 +pyproj==3.3.0 +rasterio==1.3.0.post1 rtree==1.0.0 segmentation-models-pytorch==0.2.0 -shapely==1.7.1 +shapely==1.8.0 timm==0.4.12 torch==1.13.0 torchmetrics==0.10.0 torchvision==0.14.0 # datasets -h5py==3.0.0 +h5py==3.6.0 laspy==2.0.0 -opencv-python==4.4.0.46 -pycocotools==2.0.5 +opencv-python==4.5.4.58 +pycocotools==2.0.7 pyvista==0.34.2 radiant-mlhub==0.3.0 rarfile==4.0 -scikit-image==0.18.0 -scipy==1.6.2 +scikit-image==0.19.0 +scipy==1.7.2 zipfile-deflate64==0.2.0 # docs diff --git a/tests/data/raster/data.py b/tests/data/raster/data.py index 517649607..304c9f96f 100755 --- a/tests/data/raster/data.py +++ b/tests/data/raster/data.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -from typing import Optional import numpy as np import rasterio as rio @@ -18,7 +17,7 @@ def write_raster( res: int = RES[0], epsg: int = EPSG[0], dtype: str = "uint8", - path: Optional[str] = None, + path: str | None = None, ) -> None: """Write a raster file. diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 8af3333a8..c98f010c5 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -5,7 +5,6 @@ import pickle import sys from collections.abc import Iterable from pathlib import Path -from typing import Optional, Union import pytest import torch @@ -35,7 +34,7 @@ class CustomGeoDataset(GeoDataset): bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5), crs: CRS = CRS.from_epsg(4087), res: float = 1, - paths: Optional[Union[str, Iterable[str]]] = None, + paths: str | Iterable[str] | None = None, ) -> None: super().__init__() self.index.insert(0, tuple(bounds)) @@ -249,7 +248,7 @@ class TestRasterDataset: }, ], ) - def test_files(self, paths: Union[str, Iterable[str]]) -> None: + def test_files(self, paths: str | Iterable[str]) -> None: assert 1 <= len(NAIP(paths).files) <= 2 def test_getitem_single_file(self, naip: NAIP) -> None: diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index f2f75b566..11f5838e1 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from math import floor, isclose -from typing import Any, Union +from typing import Any import pytest from rasterio.crs import CRS @@ -65,7 +65,7 @@ class CustomGeoDataset(GeoDataset): ], ) def test_random_bbox_assignment( - lengths: Sequence[Union[int, float]], expected_lengths: Sequence[int] + lengths: Sequence[int | float], expected_lengths: Sequence[int] ) -> None: ds = CustomGeoDataset( [ @@ -255,8 +255,7 @@ def test_roi_split() -> None: ], ) def test_time_series_split( - lengths: Sequence[Union[tuple[int, int], int, float]], - expected_lengths: Sequence[int], + lengths: Sequence[tuple[int, int] | int | float], expected_lengths: Sequence[int] ) -> None: ds = CustomGeoDataset( [ diff --git a/tests/models/test_api.py b/tests/models/test_api.py index fca5261cc..18d1f1a80 100644 --- a/tests/models/test_api.py +++ b/tests/models/test_api.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import enum -from typing import Callable +from collections.abc import Callable import pytest import torch.nn as nn diff --git a/tests/samplers/test_utils.py b/tests/samplers/test_utils.py index 20d828827..8480fc5cf 100644 --- a/tests/samplers/test_utils.py +++ b/tests/samplers/test_utils.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import math -from typing import Optional, Union +from typing import Union import pytest @@ -33,7 +33,7 @@ MAYBE_TUPLE = Union[float, tuple[float, float]] ], ) def test_tile_to_chips( - size: MAYBE_TUPLE, stride: Optional[MAYBE_TUPLE], expected: MAYBE_TUPLE + size: MAYBE_TUPLE, stride: MAYBE_TUPLE | None, expected: MAYBE_TUPLE ) -> None: bounds = BoundingBox(0, 10, 20, 30, 40, 50) size = _to_tuple(size) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index f7b27d252..fb485eec9 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -3,7 +3,7 @@ """AgriFieldNet datamodule.""" -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K import torch @@ -25,8 +25,8 @@ class AgriFieldNetDataModule(GeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[int, tuple[int, int]] = 256, - length: Optional[int] = None, + patch_size: int | tuple[int, int] = 256, + length: int | None = None, num_workers: int = 0, **kwargs: Any, ) -> None: diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 604c0d7e3..7c1d3846b 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -3,7 +3,7 @@ """Chesapeake Bay High-Resolution Land Cover Project datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K import torch.nn as nn @@ -63,7 +63,7 @@ class ChesapeakeCVPRDataModule(GeoDataModule): test_splits: list[str], batch_size: int = 64, patch_size: int = 256, - length: Optional[int] = None, + length: int | None = None, num_workers: int = 0, class_set: int = 7, use_prior_labels: bool = False, diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 3bfa594a2..2195aca3f 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -3,7 +3,7 @@ """DeepGlobe Land Cover Classification Challenge datamodule.""" -from typing import Any, Union +from typing import Any import kornia.augmentation as K @@ -24,7 +24,7 @@ class DeepGlobeLandCoverDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[tuple[int, int], int] = 64, + patch_size: tuple[int, int] | int = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 96da7683b..06d09b37f 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -3,7 +3,8 @@ """Base classes for all :mod:`torchgeo` data modules.""" -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable +from typing import Any, cast import kornia.augmentation as K import torch @@ -55,27 +56,27 @@ class BaseDataModule(LightningDataModule): self.kwargs = kwargs # Datasets - self.dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.dataset: Dataset[dict[str, Tensor]] | None = None + self.train_dataset: Dataset[dict[str, Tensor]] | None = None + self.val_dataset: Dataset[dict[str, Tensor]] | None = None + self.test_dataset: Dataset[dict[str, Tensor]] | None = None + self.predict_dataset: Dataset[dict[str, Tensor]] | None = None # Data loaders - self.train_batch_size: Optional[int] = None - self.val_batch_size: Optional[int] = None - self.test_batch_size: Optional[int] = None - self.predict_batch_size: Optional[int] = None + self.train_batch_size: int | None = None + self.val_batch_size: int | None = None + self.test_batch_size: int | None = None + self.predict_batch_size: int | None = None # Data augmentation Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] self.aug: Transform = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] ) - self.train_aug: Optional[Transform] = None - self.val_aug: Optional[Transform] = None - self.test_aug: Optional[Transform] = None - self.predict_aug: Optional[Transform] = None + self.train_aug: Transform | None = None + self.val_aug: Transform | None = None + self.test_aug: Transform | None = None + self.predict_aug: Transform | None = None def prepare_data(self) -> None: """Download and prepare data. @@ -141,7 +142,7 @@ class BaseDataModule(LightningDataModule): return batch - def plot(self, *args: Any, **kwargs: Any) -> Optional[Figure]: + def plot(self, *args: Any, **kwargs: Any) -> Figure | None: """Run the plot method of the validation dataset if one exists. Should only be called during 'fit' or 'validate' stages as ``val_dataset`` @@ -154,7 +155,7 @@ class BaseDataModule(LightningDataModule): Returns: A matplotlib Figure with the image, ground truth, and predictions. """ - fig: Optional[Figure] = None + fig: Figure | None = None dataset = self.dataset or self.val_dataset if dataset is not None: if hasattr(dataset, "plot"): @@ -172,8 +173,8 @@ class GeoDataModule(BaseDataModule): self, dataset_class: type[GeoDataset], batch_size: int = 1, - patch_size: Union[int, tuple[int, int]] = 64, - length: Optional[int] = None, + patch_size: int | tuple[int, int] = 64, + length: int | None = None, num_workers: int = 0, **kwargs: Any, ) -> None: @@ -196,18 +197,18 @@ class GeoDataModule(BaseDataModule): self.collate_fn = stack_samples # Samplers - self.sampler: Optional[GeoSampler] = None - self.train_sampler: Optional[GeoSampler] = None - self.val_sampler: Optional[GeoSampler] = None - self.test_sampler: Optional[GeoSampler] = None - self.predict_sampler: Optional[GeoSampler] = None + self.sampler: GeoSampler | None = None + self.train_sampler: GeoSampler | None = None + self.val_sampler: GeoSampler | None = None + self.test_sampler: GeoSampler | None = None + self.predict_sampler: GeoSampler | None = None # Batch samplers - self.batch_sampler: Optional[BatchGeoSampler] = None - self.train_batch_sampler: Optional[BatchGeoSampler] = None - self.val_batch_sampler: Optional[BatchGeoSampler] = None - self.test_batch_sampler: Optional[BatchGeoSampler] = None - self.predict_batch_sampler: Optional[BatchGeoSampler] = None + self.batch_sampler: BatchGeoSampler | None = None + self.train_batch_sampler: BatchGeoSampler | None = None + self.val_batch_sampler: BatchGeoSampler | None = None + self.test_batch_sampler: BatchGeoSampler | None = None + self.predict_batch_sampler: BatchGeoSampler | None = None def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 8a40bddb0..9f6b2f5da 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -3,7 +3,7 @@ """GID-15 datamodule.""" -from typing import Any, Union +from typing import Any import kornia.augmentation as K @@ -26,7 +26,7 @@ class GID15DataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[tuple[int, int], int] = 64, + patch_size: tuple[int, int] | int = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index ca46c39dc..c8489441f 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -3,7 +3,7 @@ """InriaAerialImageLabeling datamodule.""" -from typing import Any, Union +from typing import Any import kornia.augmentation as K @@ -26,7 +26,7 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[tuple[int, int], int] = 64, + patch_size: tuple[int, int] | int = 64, num_workers: int = 0, **kwargs: Any, ) -> None: diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index dbc6dddeb..82a6f10d9 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -3,7 +3,7 @@ """L7 Irish datamodule.""" -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K import torch @@ -25,8 +25,8 @@ class L7IrishDataModule(GeoDataModule): def __init__( self, batch_size: int = 1, - patch_size: Union[int, tuple[int, int]] = 224, - length: Optional[int] = None, + patch_size: int | tuple[int, int] = 224, + length: int | None = None, num_workers: int = 0, **kwargs: Any, ) -> None: diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index f372aba05..7201e69d1 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -3,7 +3,7 @@ """L8 Biome datamodule.""" -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K import torch @@ -25,8 +25,8 @@ class L8BiomeDataModule(GeoDataModule): def __init__( self, batch_size: int = 1, - patch_size: Union[int, tuple[int, int]] = 224, - length: Optional[int] = None, + patch_size: int | tuple[int, int] = 224, + length: int | None = None, num_workers: int = 0, **kwargs: Any, ) -> None: diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 411717fed..2107928d4 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -3,7 +3,7 @@ """LEVIR-CD+ datamodule.""" -from typing import Any, Union +from typing import Any import kornia.augmentation as K @@ -25,7 +25,7 @@ class LEVIRCDDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 8, - patch_size: Union[tuple[int, int], int] = 256, + patch_size: tuple[int, int] | int = 256, num_workers: int = 0, **kwargs: Any, ) -> None: @@ -70,7 +70,7 @@ class LEVIRCDPlusDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 8, - patch_size: Union[tuple[int, int], int] = 256, + patch_size: tuple[int, int] | int = 256, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 1631734d0..5fb24a43a 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -3,7 +3,7 @@ """National Agriculture Imagery Program (NAIP) datamodule.""" -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K from matplotlib.figure import Figure @@ -23,8 +23,8 @@ class NAIPChesapeakeDataModule(GeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[int, tuple[int, int]] = 256, - length: Optional[int] = None, + patch_size: int | tuple[int, int] = 256, + length: int | None = None, num_workers: int = 0, **kwargs: Any, ) -> None: diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 19f346770..61c5ca9aa 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -3,7 +3,7 @@ """OSCD datamodule.""" -from typing import Any, Union +from typing import Any import kornia.augmentation as K import torch @@ -60,7 +60,7 @@ class OSCDDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[tuple[int, int], int] = 64, + patch_size: tuple[int, int] | int = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 0bd712d46..397ef25d7 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -3,7 +3,7 @@ """Potsdam datamodule.""" -from typing import Any, Union +from typing import Any import kornia.augmentation as K @@ -26,7 +26,7 @@ class Potsdam2DDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[tuple[int, int], int] = 64, + patch_size: tuple[int, int] | int = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/sentinel2_cdl.py b/torchgeo/datamodules/sentinel2_cdl.py index b7315cb98..b219d7315 100644 --- a/torchgeo/datamodules/sentinel2_cdl.py +++ b/torchgeo/datamodules/sentinel2_cdl.py @@ -3,7 +3,7 @@ """Sentinel-2 and CDL datamodule.""" -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K import torch @@ -26,8 +26,8 @@ class Sentinel2CDLDataModule(GeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[int, tuple[int, int]] = 64, - length: Optional[int] = None, + patch_size: int | tuple[int, int] = 64, + length: int | None = None, num_workers: int = 0, **kwargs: Any, ) -> None: diff --git a/torchgeo/datamodules/sentinel2_nccm.py b/torchgeo/datamodules/sentinel2_nccm.py index feabdb4f6..952bc2812 100644 --- a/torchgeo/datamodules/sentinel2_nccm.py +++ b/torchgeo/datamodules/sentinel2_nccm.py @@ -3,7 +3,7 @@ """Sentinel-2 and NCCM datamodule.""" -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K import torch @@ -26,8 +26,8 @@ class Sentinel2NCCMDataModule(GeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[int, tuple[int, int]] = 64, - length: Optional[int] = None, + patch_size: int | tuple[int, int] = 64, + length: int | None = None, num_workers: int = 0, **kwargs: Any, ) -> None: diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py index 6ce54000c..fee2b5c64 100644 --- a/torchgeo/datamodules/sentinel2_south_america_soybean.py +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -5,7 +5,7 @@ """South America Soybean datamodule.""" -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K import torch @@ -28,8 +28,8 @@ class Sentinel2SouthAmericaSoybeanDataModule(GeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[int, tuple[int, int]] = 64, - length: Optional[int] = None, + patch_size: int | tuple[int, int] = 64, + length: int | None = None, num_workers: int = 0, **kwargs: Any, ) -> None: diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index f02bbda6d..d1b9d5184 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -3,7 +3,7 @@ """SSL4EO datamodule.""" -from typing import Any, Union +from typing import Any import kornia.augmentation as K from kornia.constants import DataKey, Resample @@ -23,7 +23,7 @@ class SSL4EOLBenchmarkDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[int, tuple[int, int]] = 224, + patch_size: int | tuple[int, int] = 224, num_workers: int = 0, **kwargs: Any, ) -> None: diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 7083012f8..a6069250d 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -4,8 +4,8 @@ """Common datamodule utilities.""" import math -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import numpy as np import torch @@ -103,9 +103,9 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: def dataset_split( - dataset: Union[TensorDataset, NonGeoDataset], + dataset: TensorDataset | NonGeoDataset, val_pct: float, - test_pct: Optional[float] = None, + test_pct: float | None = None, ) -> list[Subset[Any]]: """Split a torch Dataset into train/val/test sets. @@ -142,9 +142,9 @@ def dataset_split( def group_shuffle_split( groups: Iterable[Any], - train_size: Optional[float] = None, - test_size: Optional[float] = None, - random_state: Optional[int] = None, + train_size: float | None = None, + test_size: float | None = None, + random_state: int | None = None, ) -> tuple[list[int], list[int]]: """Method for performing a single group-wise shuffle split of data. diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 883ff7781..441fafdd9 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -3,7 +3,7 @@ """Vaihingen datamodule.""" -from typing import Any, Union +from typing import Any import kornia.augmentation as K @@ -26,7 +26,7 @@ class Vaihingen2DDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[tuple[int, int], int] = 64, + patch_size: tuple[int, int] | int = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 0059d6c71..0bafef71b 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -3,7 +3,7 @@ """NWPU VHR-10 datamodule.""" -from typing import Any, Union +from typing import Any import kornia.augmentation as K import torch @@ -26,7 +26,7 @@ class VHR10DataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[tuple[int, int], int] = 512, + patch_size: tuple[int, int] | int = 512, num_workers: int = 0, val_split_pct: float = 0.2, test_split_pct: float = 0.2, diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index 3618db0fa..139bba0f4 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -5,7 +5,8 @@ import glob import os -from typing import Callable, Optional, cast +from collections.abc import Callable +from typing import cast import matplotlib.pyplot as plt import numpy as np @@ -87,7 +88,7 @@ class ADVANCE(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -228,7 +229,7 @@ class ADVANCE(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index 6c5959a06..b12da8f50 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -5,8 +5,8 @@ import json import os -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -56,10 +56,10 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, cache: bool = True, ) -> None: @@ -121,7 +121,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index 67a0c994f..3f4ceffd3 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -5,8 +5,8 @@ import os import re -from collections.abc import Iterable, Sequence -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable, Iterable, Sequence +from typing import Any, cast import matplotlib.pyplot as plt import torch @@ -113,11 +113,11 @@ class AgriFieldNet(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, classes: list[int] = list(cmap.keys()), bands: Sequence[str] = all_bands, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, cache: bool = True, ) -> None: """Initialize a new AgriFieldNet dataset instance. @@ -218,7 +218,7 @@ class AgriFieldNet(RasterDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/airphen.py b/torchgeo/datasets/airphen.py index 8a2a73b8b..a38db917a 100644 --- a/torchgeo/datasets/airphen.py +++ b/torchgeo/datasets/airphen.py @@ -3,7 +3,7 @@ """Airphen dataset.""" -from typing import Any, Optional +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -46,7 +46,7 @@ class Airphen(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/astergdem.py b/torchgeo/datasets/astergdem.py index 99215f774..801b24a52 100644 --- a/torchgeo/datasets/astergdem.py +++ b/torchgeo/datasets/astergdem.py @@ -3,7 +3,8 @@ """Aster Global Digital Elevation Model dataset.""" -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -46,10 +47,10 @@ class AsterGDEM(RasterDataset): def __init__( self, - paths: Union[str, list[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + paths: str | list[str] = "data", + crs: CRS | None = None, + res: float | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -89,7 +90,7 @@ class AsterGDEM(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 6b9f95d34..0aefc55d7 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -5,8 +5,8 @@ import json import os +from collections.abc import Callable from functools import lru_cache -from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -182,9 +182,9 @@ class BeninSmallHolderCashews(NonGeoDataset): chip_size: int = 256, stride: int = 128, bands: tuple[str, ...] = all_bands, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, verbose: bool = False, ) -> None: @@ -408,7 +408,7 @@ class BeninSmallHolderCashews(NonGeoDataset): return images and targets - def _download(self, api_key: Optional[str] = None) -> None: + def _download(self, api_key: str | None = None) -> None: """Download the dataset and extract it. Args: @@ -434,7 +434,7 @@ class BeninSmallHolderCashews(NonGeoDataset): sample: dict[str, Tensor], show_titles: bool = True, time_step: int = 0, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 229c546ff..3b92343d9 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -6,7 +6,7 @@ import glob import json import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -275,7 +275,7 @@ class BigEarthNet(NonGeoDataset): split: str = "train", bands: str = "all", num_classes: int = 19, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -533,7 +533,7 @@ class BigEarthNet(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index 970c55949..4b73f58e8 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -5,7 +5,6 @@ import os from collections.abc import Sequence -from typing import Optional import matplotlib.pyplot as plt import numpy as np @@ -219,7 +218,7 @@ class BioMassters(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index d3010e44d..63b76e592 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -4,8 +4,8 @@ """Canadian Building Footprints dataset.""" import os -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -61,10 +61,10 @@ class CanadianBuildingFootprints(VectorDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, res: float = 0.00001, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -127,7 +127,7 @@ class CanadianBuildingFootprints(VectorDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index d0b9eb43a..87ebc9029 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -4,8 +4,8 @@ """CDL dataset.""" import os -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import matplotlib.pyplot as plt import torch @@ -206,12 +206,12 @@ class CDL(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, years: list[int] = [2023], classes: list[int] = list(cmap.keys()), - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -336,7 +336,7 @@ class CDL(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/chabud.py b/torchgeo/datasets/chabud.py index febed7031..21362bb9f 100644 --- a/torchgeo/datasets/chabud.py +++ b/torchgeo/datasets/chabud.py @@ -4,7 +4,7 @@ """ChaBuD dataset.""" import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -77,7 +77,7 @@ class ChaBuD(NonGeoDataset): root: str = "data", split: str = "train", bands: list[str] = all_bands, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -235,7 +235,7 @@ class ChaBuD(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index f92d1e21d..a0509302e 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -6,8 +6,8 @@ import abc import os import sys -from collections.abc import Iterable, Sequence -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable, Iterable, Sequence +from typing import Any, cast import fiona import matplotlib.pyplot as plt @@ -90,10 +90,10 @@ class Chesapeake(RasterDataset, abc.ABC): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -170,7 +170,7 @@ class Chesapeake(RasterDataset, abc.ABC): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. @@ -512,7 +512,7 @@ class ChesapeakeCVPR(GeoDataset): root: str = "data", splits: Sequence[str] = ["de-train"], layers: Sequence[str] = ["naip-new", "lc"], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -711,7 +711,7 @@ class ChesapeakeCVPR(GeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index 60bbe78e2..fbe9e9e32 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -5,8 +5,8 @@ import json import os -from collections.abc import Sequence -from typing import Any, Callable, Optional +from collections.abc import Callable, Sequence +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -111,9 +111,9 @@ class CloudCoverDetection(NonGeoDataset): root: str = "data", split: str = "train", bands: Sequence[str] = band_names, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initiatlize a new Cloud Cover Detection Dataset instance. @@ -329,7 +329,7 @@ class CloudCoverDetection(NonGeoDataset): return images and targets - def _download(self, api_key: Optional[str] = None) -> None: + def _download(self, api_key: str | None = None) -> None: """Download the dataset and extract it. Args: @@ -355,7 +355,7 @@ class CloudCoverDetection(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index ac42c8d1e..8f2a53dfb 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -4,7 +4,8 @@ """CMS Global Mangrove Canopy dataset.""" import os -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -167,12 +168,12 @@ class CMSGlobalMangroveCanopy(RasterDataset): def __init__( self, - paths: Union[str, list[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, + paths: str | list[str] = "data", + crs: CRS | None = None, + res: float | None = None, measurement: str = "agb", country: str = all_countries[0], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, checksum: bool = False, ) -> None: @@ -250,7 +251,7 @@ class CMSGlobalMangroveCanopy(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 0e0518502..902cbdcb4 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -6,7 +6,8 @@ import abc import csv import os -from typing import Callable, Optional, cast +from collections.abc import Callable +from typing import cast import matplotlib.pyplot as plt import numpy as np @@ -65,7 +66,7 @@ class COWC(NonGeoDataset, abc.ABC): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -192,7 +193,7 @@ class COWC(NonGeoDataset, abc.ABC): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py index 08d69f2fc..e386ac92a 100644 --- a/torchgeo/datasets/cropharvest.py +++ b/torchgeo/datasets/cropharvest.py @@ -6,7 +6,7 @@ import glob import json import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -96,7 +96,7 @@ class CropHarvest(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -293,7 +293,7 @@ class CropHarvest(NonGeoDataset): features_path = os.path.join(self.root, self.file_dict["features"]["filename"]) extract_archive(features_path) - def plot(self, sample: dict[str, Tensor], subtitle: Optional[str] = None) -> Figure: + def plot(self, sample: dict[str, Tensor], subtitle: str | None = None) -> Figure: """Plot a sample from the dataset using bands for Agriculture RGB composite. Args: diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index 0ea981d8a..dc4d97e2a 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -5,8 +5,8 @@ import csv import os +from collections.abc import Callable from functools import lru_cache -from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -125,9 +125,9 @@ class CV4AKenyaCropType(NonGeoDataset): chip_size: int = 256, stride: int = 128, bands: tuple[str, ...] = band_names, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, verbose: bool = False, ) -> None: @@ -388,7 +388,7 @@ class CV4AKenyaCropType(NonGeoDataset): return train_field_ids, test_field_ids - def _download(self, api_key: Optional[str] = None) -> None: + def _download(self, api_key: str | None = None) -> None: """Download the dataset and extract it. Args: @@ -411,7 +411,7 @@ class CV4AKenyaCropType(NonGeoDataset): sample: dict[str, Tensor], show_titles: bool = True, time_step: int = 0, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 6ab5a14c9..46f2e9d6f 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -5,8 +5,9 @@ import json import os +from collections.abc import Callable from functools import lru_cache -from typing import Any, Callable, Optional +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -73,9 +74,9 @@ class TropicalCyclone(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new Tropical Cyclone Wind Estimation Competition Dataset. @@ -204,7 +205,7 @@ class TropicalCyclone(NonGeoDataset): return False return True - def _download(self, api_key: Optional[str] = None) -> None: + def _download(self, api_key: str | None = None) -> None: """Download the dataset and extract it. Args: @@ -227,7 +228,7 @@ class TropicalCyclone(NonGeoDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py index c265321a3..03b9a91ff 100644 --- a/torchgeo/datasets/deepglobelandcover.py +++ b/torchgeo/datasets/deepglobelandcover.py @@ -4,7 +4,7 @@ """DeepGlobe Land Cover Classification Challenge dataset.""" import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -102,7 +102,7 @@ class DeepGlobeLandCover(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize a new DeepGlobeLandCover dataset instance. @@ -229,7 +229,7 @@ class DeepGlobeLandCover(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, alpha: float = 0.5, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py index 5268ea4f0..40e0ad9b6 100644 --- a/torchgeo/datasets/dfc2022.py +++ b/torchgeo/datasets/dfc2022.py @@ -5,8 +5,7 @@ import glob import os -from collections.abc import Sequence -from typing import Callable, Optional +from collections.abc import Callable, Sequence import matplotlib.pyplot as plt import numpy as np @@ -144,7 +143,7 @@ class DFC2022(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize a new DFC2022 dataset instance. @@ -229,7 +228,7 @@ class DFC2022(NonGeoDataset): return files - def _load_image(self, path: str, shape: Optional[Sequence[int]] = None) -> Tensor: + def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor: """Load a single image. Args: @@ -296,7 +295,7 @@ class DFC2022(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index f6842738a..0518ffeb8 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -5,8 +5,8 @@ import os import sys -from collections.abc import Sequence -from typing import Any, Callable, Optional, cast +from collections.abc import Callable, Sequence +from typing import Any, cast import fiona import matplotlib.pyplot as plt @@ -255,7 +255,7 @@ class EnviroAtlas(GeoDataset): root: str = "data", splits: Sequence[str] = ["pittsburgh_pa-2010_1m-train"], layers: Sequence[str] = ["naip", "prior"], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, prior_as_input: bool = False, cache: bool = True, download: bool = False, @@ -445,7 +445,7 @@ class EnviroAtlas(GeoDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index 2d26f2456..2238f6943 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -5,8 +5,8 @@ import glob import os -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -68,10 +68,10 @@ class Esri2020(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -138,7 +138,7 @@ class Esri2020(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 7dfa50fb2..7dd9bf022 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -82,7 +82,7 @@ class ETCI2021(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -255,7 +255,7 @@ class ETCI2021(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index ce82500ee..15913aeed 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -5,8 +5,8 @@ import glob import os -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -83,10 +83,10 @@ class EUDEM(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, checksum: bool = False, ) -> None: @@ -140,7 +140,7 @@ class EUDEM(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index ebfe19a76..b1cfa57b2 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -5,8 +5,8 @@ import csv import os -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import fiona import matplotlib.pyplot as plt @@ -88,11 +88,11 @@ class EuroCrops(VectorDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", + paths: str | Iterable[str] = "data", crs: CRS = CRS.from_epsg(4326), res: float = 0.00001, - classes: Optional[list[str]] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + classes: list[str] | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -166,7 +166,7 @@ class EuroCrops(VectorDataset): self.base_url + fname, self.paths, md5=md5 if self.checksum else None ) - def _load_class_map(self, classes: Optional[list[str]]) -> None: + def _load_class_map(self, classes: list[str] | None) -> None: """Load map from HCAT class codes to class indices. If classes is provided, then we simply use those codes. Otherwise, we load @@ -221,7 +221,7 @@ class EuroCrops(VectorDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 2802e6924..e7376310e 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -4,8 +4,8 @@ """EuroSAT dataset.""" import os -from collections.abc import Sequence -from typing import Callable, Optional, cast +from collections.abc import Callable, Sequence +from typing import cast import matplotlib.pyplot as plt import numpy as np @@ -106,7 +106,7 @@ class EuroSAT(NonGeoClassificationDataset): root: str = "data", split: str = "train", bands: Sequence[str] = BAND_SETS["all"], - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -247,7 +247,7 @@ class EuroSAT(NonGeoClassificationDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index 24cc7b97d..49b9349b3 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -5,7 +5,8 @@ import glob import os -from typing import Any, Callable, Optional, cast +from collections.abc import Callable +from typing import Any, cast from xml.etree.ElementTree import Element, parse import matplotlib.patches as patches @@ -230,7 +231,7 @@ class FAIR1M(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -382,7 +383,7 @@ class FAIR1M(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/fire_risk.py b/torchgeo/datasets/fire_risk.py index d51a4384f..e10bbae8a 100644 --- a/torchgeo/datasets/fire_risk.py +++ b/torchgeo/datasets/fire_risk.py @@ -4,7 +4,8 @@ """FireRisk dataset.""" import os -from typing import Callable, Optional, cast +from collections.abc import Callable +from typing import cast import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -68,7 +69,7 @@ class FireRisk(NonGeoClassificationDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -136,7 +137,7 @@ class FireRisk(NonGeoClassificationDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index b74f3bcd3..b20ab5296 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -5,7 +5,8 @@ import glob import os -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from xml.etree import ElementTree import matplotlib.patches as patches @@ -110,7 +111,7 @@ class ForestDamage(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -259,7 +260,7 @@ class ForestDamage(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 6aeb1699b..f175df463 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -10,8 +10,8 @@ import os import re import sys import warnings -from collections.abc import Iterable, Sequence -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable, Iterable, Sequence +from typing import Any, cast import fiona import fiona.transform @@ -83,7 +83,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): dataset = landsat7 | landsat8 """ - paths: Union[str, Iterable[str]] + paths: str | Iterable[str] _crs = CRS.from_epsg(4326) _res = 0.0 @@ -108,7 +108,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): __add__ = None # type: ignore[assignment] def __init__( - self, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None + self, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None ) -> None: """Initialize a new GeoDataset instance. @@ -190,9 +190,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): # NOTE: This hack should be removed once the following issue is fixed: # https://github.com/Toblerity/rtree/issues/87 - def __getstate__( - self, - ) -> tuple[dict[str, Any], list[tuple[Any, Any, Optional[Any]]]]: + def __getstate__(self) -> tuple[dict[str, Any], list[tuple[Any, Any, Any | None]]]: """Define how instances are pickled. Returns: @@ -388,11 +386,11 @@ class RasterDataset(GeoDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, - bands: Optional[Sequence[str]] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, + bands: Sequence[str] | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, ) -> None: """Initialize a new RasterDataset instance. @@ -539,7 +537,7 @@ class RasterDataset(GeoDataset): self, filepaths: Sequence[str], query: BoundingBox, - band_indexes: Optional[Sequence[int]] = None, + band_indexes: Sequence[int] | None = None, ) -> Tensor: """Load and merge one or more files. @@ -612,11 +610,11 @@ class VectorDataset(GeoDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, res: float = 0.0001, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, - label_name: Optional[str] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + label_name: str | None = None, ) -> None: """Initialize a new VectorDataset instance. @@ -809,9 +807,9 @@ class NonGeoClassificationDataset(NonGeoDataset, ImageFolder): # type: ignore[m def __init__( self, root: str = "data", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, - loader: Optional[Callable[[str], Any]] = pil_loader, - is_valid_file: Optional[Callable[[str], bool]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + loader: Callable[[str], Any] | None = pil_loader, + is_valid_file: Callable[[str], bool] | None = None, ) -> None: """Initialize a new NonGeoClassificationDataset instance. @@ -909,7 +907,7 @@ class IntersectionDataset(GeoDataset): collate_fn: Callable[ [Sequence[dict[str, Any]]], dict[str, Any] ] = concat_samples, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, ) -> None: """Initialize a new IntersectionDataset instance. @@ -1067,7 +1065,7 @@ class UnionDataset(GeoDataset): collate_fn: Callable[ [Sequence[dict[str, Any]]], dict[str, Any] ] = merge_samples, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, ) -> None: """Initialize a new UnionDataset instance. diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 6fc2b5201..fb7dd0459 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -89,7 +89,7 @@ class GID15(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -234,7 +234,7 @@ class GID15(NonGeoDataset): md5=self.md5 if self.checksum else None, ) - def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure: + def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: """Plot a sample from the dataset. Args: diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index c9da83b9b..6b5d1a07a 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -5,8 +5,8 @@ import glob import os -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable, Iterable +from typing import Any, cast import matplotlib.pyplot as plt import torch @@ -119,11 +119,11 @@ class GlobBiomass(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, measurement: str = "agb", - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, checksum: bool = False, ) -> None: @@ -225,7 +225,7 @@ class GlobBiomass(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 715c3bfed..4460bfc08 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -5,7 +5,8 @@ import glob import os -from typing import Any, Callable, Optional, cast, overload +from collections.abc import Callable +from typing import Any, cast, overload import fiona import matplotlib.pyplot as plt @@ -148,7 +149,7 @@ class IDTReeS(NonGeoDataset): root: str = "data", split: str = "train", task: str = "task1", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -333,7 +334,7 @@ class IDTReeS(NonGeoDataset): def _load( self, root: str - ) -> tuple[list[str], Optional[dict[int, dict[str, Any]]], Any]: + ) -> tuple[list[str], dict[int, dict[str, Any]] | None, Any]: """Load files, geometries, and labels. Args: @@ -419,8 +420,8 @@ class IDTReeS(NonGeoDataset): image_size: tuple[int, int], min_size: int, boxes: Tensor, - labels: Optional[Tensor], - ) -> tuple[Tensor, Optional[Tensor]]: + labels: Tensor | None, + ) -> tuple[Tensor, Tensor | None]: """Clip boxes to image size and filter boxes with sides less than ``min_size``. Args: @@ -477,7 +478,7 @@ class IDTReeS(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, hsi_indices: tuple[int, int, int] = (0, 1, 2), ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index b3ab0a6fd..b42a005e5 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -6,7 +6,8 @@ import glob import os import re -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -64,7 +65,7 @@ class InriaAerialImageLabeling(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, checksum: bool = False, ) -> None: """Initialize a new InriaAerialImageLabeling Dataset instance. @@ -200,7 +201,7 @@ class InriaAerialImageLabeling(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index b2d87eade..7e12fe67d 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -5,8 +5,8 @@ import glob import os -from collections.abc import Iterable, Sequence -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable, Iterable, Sequence +from typing import Any, cast import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -97,11 +97,11 @@ class L7Irish(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = CRS.from_epsg(3857), - res: Optional[float] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = CRS.from_epsg(3857), + res: float | None = None, bands: Sequence[str] = all_bands, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -221,7 +221,7 @@ class L7Irish(RasterDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index 88ef15592..f4fa80479 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -5,8 +5,8 @@ import glob import os -from collections.abc import Iterable, Sequence -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable, Iterable, Sequence +from typing import Any, cast import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -96,11 +96,11 @@ class L8Biome(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]], - crs: Optional[CRS] = CRS.from_epsg(3857), - res: Optional[float] = None, + paths: str | Iterable[str], + crs: CRS | None = CRS.from_epsg(3857), + res: float | None = None, bands: Sequence[str] = all_bands, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -217,7 +217,7 @@ class L8Biome(RasterDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 8dec50b56..d240fa892 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -6,8 +6,9 @@ import abc import glob import hashlib import os +from collections.abc import Callable from functools import lru_cache -from typing import Any, Callable, Optional, cast +from typing import Any, cast import matplotlib.pyplot as plt import numpy as np @@ -152,7 +153,7 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. @@ -209,9 +210,9 @@ class LandCoverAIGeo(LandCoverAIBase, RasterDataset): def __init__( self, root: str = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + crs: CRS | None = None, + res: float | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -299,7 +300,7 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index b8e524905..503580cfd 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -4,8 +4,8 @@ """Landsat datasets.""" import abc -from collections.abc import Iterable, Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable, Sequence +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -59,11 +59,11 @@ class Landsat(RasterDataset, abc.ABC): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, - bands: Optional[Sequence[str]] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, + bands: Sequence[str] | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -94,7 +94,7 @@ class Landsat(RasterDataset, abc.ABC): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 76481fed1..b2f1891c0 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -6,7 +6,7 @@ import abc import glob import os -from typing import Callable, Optional, Union +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -29,14 +29,14 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC): .. versionadded:: 0.6 """ - splits: Union[list[str], dict[str, dict[str, str]]] + splits: list[str] | dict[str, dict[str, str]] directories = ["A", "B", "label"] def __init__( self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -135,7 +135,7 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 9c7e2aaff..cff2d1af8 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -93,7 +93,7 @@ class LoveDA(NonGeoDataset): root: str = "data", split: str = "train", scene: list[str] = ["urban", "rural"], - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -256,7 +256,7 @@ class LoveDA(NonGeoDataset): md5=self.md5 if self.checksum else None, ) - def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure: + def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: """Plot a sample from the dataset. Args: diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index 68b601bb5..849d66485 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -6,7 +6,7 @@ import os import shutil from collections import defaultdict -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -111,7 +111,7 @@ class MapInWild(NonGeoDataset): root: str = "data", modality: list[str] = ["mask", "esa_wc", "viirs", "s2_summer"], split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -223,7 +223,7 @@ class MapInWild(NonGeoDataset): tensor = torch.from_numpy(array).float() return tensor - def _verify(self, url: str, md5: Optional[str] = None) -> None: + def _verify(self, url: str, md5: str | None = None) -> None: """Verify the integrity of the dataset. Args: @@ -258,7 +258,7 @@ class MapInWild(NonGeoDataset): if not url.endswith(".csv"): self._extract(url) - def _download(self, url: str, md5: Optional[str]) -> None: + def _download(self, url: str, md5: str | None) -> None: """Downloads a modality. Args: @@ -330,7 +330,7 @@ class MapInWild(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index ed9a9c156..171ff6f3d 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -4,7 +4,8 @@ """Million-AID dataset.""" import glob import os -from typing import Any, Callable, Optional, cast +from collections.abc import Callable +from typing import Any, cast import matplotlib.pyplot as plt import numpy as np @@ -191,7 +192,7 @@ class MillionAID(NonGeoDataset): root: str = "data", task: str = "multi-class", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize a new MillionAID dataset instance. @@ -332,7 +333,7 @@ class MillionAID(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index f70ff47a2..19f12a29f 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -3,7 +3,7 @@ """National Agriculture Imagery Program (NAIP) dataset.""" -from typing import Any, Optional +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -52,7 +52,7 @@ class NAIP(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index a1637f46e..4ebd0f6e3 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -4,7 +4,7 @@ """NASA Marine Debris dataset.""" import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -66,9 +66,9 @@ class NASAMarineDebris(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, verbose: bool = False, ) -> None: @@ -224,7 +224,7 @@ class NASAMarineDebris(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/nccm.py b/torchgeo/datasets/nccm.py index f4728775b..48d33137d 100644 --- a/torchgeo/datasets/nccm.py +++ b/torchgeo/datasets/nccm.py @@ -3,8 +3,8 @@ """Northeastern China Crop Map Dataset.""" -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import matplotlib.pyplot as plt import torch @@ -82,11 +82,11 @@ class NCCM(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, years: list[int] = [2019], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -170,7 +170,7 @@ class NCCM(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index 6c1243538..7cb03b55d 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -5,8 +5,8 @@ import glob import os -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import matplotlib.pyplot as plt import torch @@ -107,12 +107,12 @@ class NLCD(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, years: list[int] = [2019], classes: list[int] = list(cmap.keys()), - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -230,7 +230,7 @@ class NLCD(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index e05ea9322..988d194e8 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -7,8 +7,8 @@ import glob import json import os import sys -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable, Iterable +from typing import Any, cast import fiona import fiona.transform @@ -206,10 +206,10 @@ class OpenBuildings(VectorDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, res: float = 0.0001, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, checksum: bool = False, ) -> None: """Initialize a new Dataset instance. @@ -413,7 +413,7 @@ class OpenBuildings(VectorDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index eebe7348d..20f0a5c53 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -5,8 +5,7 @@ import glob import os -from collections.abc import Sequence -from typing import Callable, Optional, Union +from collections.abc import Callable, Sequence import matplotlib.pyplot as plt import numpy as np @@ -103,7 +102,7 @@ class OSCD(NonGeoDataset): root: str = "data", split: str = "train", bands: Sequence[str] = all_bands, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -164,7 +163,7 @@ class OSCD(NonGeoDataset): """ return len(self.files) - def _load_files(self) -> list[dict[str, Union[str, Sequence[str]]]]: + def _load_files(self) -> list[dict[str, str | Sequence[str]]]: regions = [] labels_root = os.path.join( self.root, @@ -284,7 +283,7 @@ class OSCD(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, alpha: float = 0.5, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index 84925a850..88d95884c 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -4,8 +4,7 @@ """PASTIS dataset.""" import os -from collections.abc import Sequence -from typing import Callable, Optional +from collections.abc import Callable, Sequence import fiona import matplotlib.pyplot as plt @@ -132,7 +131,7 @@ class PASTIS(NonGeoDataset): folds: Sequence[int] = (1, 2, 3, 4, 5), bands: str = "s2", mode: str = "semantic", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -347,7 +346,7 @@ class PASTIS(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/patternnet.py b/torchgeo/datasets/patternnet.py index 876f14dd5..ba445d671 100644 --- a/torchgeo/datasets/patternnet.py +++ b/torchgeo/datasets/patternnet.py @@ -4,7 +4,8 @@ """PatternNet dataset.""" import os -from typing import Callable, Optional, cast +from collections.abc import Callable +from typing import cast import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -84,7 +85,7 @@ class PatternNet(NonGeoClassificationDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -145,7 +146,7 @@ class PatternNet(NonGeoClassificationDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index 782fd7ce8..b1b5f2154 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -4,7 +4,7 @@ """Potsdam dataset.""" import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -123,7 +123,7 @@ class Potsdam2D(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize a new Potsdam dataset instance. @@ -240,7 +240,7 @@ class Potsdam2D(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, alpha: float = 0.5, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/prisma.py b/torchgeo/datasets/prisma.py index f6e70d4ff..442706578 100644 --- a/torchgeo/datasets/prisma.py +++ b/torchgeo/datasets/prisma.py @@ -3,7 +3,7 @@ """PRISMA datasets.""" -from typing import Any, Optional +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -81,7 +81,7 @@ class PRISMA(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index 9ee8a82d6..c345c79e7 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.patches as patches import matplotlib.pyplot as plt @@ -69,7 +69,7 @@ class ReforesTree(NonGeoDataset): def __init__( self, root: str = "data", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -209,7 +209,7 @@ class ReforesTree(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index 07fdfd272..babbf880d 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -4,7 +4,8 @@ """RESISC45 dataset.""" import os -from typing import Callable, Optional, cast +from collections.abc import Callable +from typing import cast import matplotlib.pyplot as plt import numpy as np @@ -112,7 +113,7 @@ class RESISC45(NonGeoClassificationDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -193,7 +194,7 @@ class RESISC45(NonGeoClassificationDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py index 0d437134f..102067e5f 100644 --- a/torchgeo/datasets/rwanda_field_boundary.py +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -4,8 +4,7 @@ """Rwanda Field Boundary Competition dataset.""" import os -from collections.abc import Sequence -from typing import Callable, Optional +from collections.abc import Callable, Sequence import matplotlib.pyplot as plt import numpy as np @@ -91,9 +90,9 @@ class RwandaFieldBoundary(NonGeoDataset): root: str = "data", split: str = "train", bands: Sequence[str] = all_bands, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new RwandaFieldBoundary instance. @@ -258,7 +257,7 @@ class RwandaFieldBoundary(NonGeoDataset): sample: dict[str, Tensor], show_titles: bool = True, time_step: int = 0, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index b01ad3cae..5f2918f74 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -6,7 +6,6 @@ import os import random from collections.abc import Callable, Collection, Iterable -from typing import Optional import matplotlib.patches as mpatches import matplotlib.pyplot as plt @@ -219,7 +218,7 @@ class SeasoNet(NonGeoDataset): bands: Iterable[str] = all_bands, grids: Iterable[int] = [1, 2], concat_seasons: int = 1, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -405,7 +404,7 @@ class SeasoNet(NonGeoDataset): sample: dict[str, Tensor], show_titles: bool = True, show_legend: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index 88a55fca5..52ffd1726 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -5,7 +5,7 @@ import os import random -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -79,7 +79,7 @@ class SeasonalContrastS2(NonGeoDataset): version: str = "100k", seasons: int = 1, bands: list[str] = rgb_bands, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -231,7 +231,7 @@ class SeasonalContrastS2(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index dc14fe3cc..7ff44d417 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -4,8 +4,7 @@ """SEN12MS dataset.""" import os -from collections.abc import Sequence -from typing import Callable, Optional +from collections.abc import Callable, Sequence import matplotlib.pyplot as plt import numpy as np @@ -173,7 +172,7 @@ class SEN12MS(NonGeoDataset): root: str = "data", split: str = "train", bands: Sequence[str] = BAND_SETS["all"], - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize a new SEN12MS dataset instance. @@ -320,7 +319,7 @@ class SEN12MS(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 8c4869e48..8cd7989f8 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -3,8 +3,8 @@ """Sentinel datasets.""" -from collections.abc import Iterable, Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable, Sequence +from typing import Any import matplotlib.pyplot as plt import torch @@ -141,11 +141,11 @@ class Sentinel1(Sentinel): def __init__( self, - paths: Union[str, list[str]] = "data", - crs: Optional[CRS] = None, + paths: str | list[str] = "data", + crs: CRS | None = None, res: float = 10, bands: Sequence[str] = ["VV", "VH"], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -194,7 +194,7 @@ To create a dataset containing both, use: self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. @@ -297,11 +297,11 @@ class Sentinel2(Sentinel): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, res: float = 10, - bands: Optional[Sequence[str]] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + bands: Sequence[str] | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, ) -> None: """Initialize a new Dataset instance. @@ -333,7 +333,7 @@ class Sentinel2(Sentinel): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index d35897182..71ec3c4d6 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -4,7 +4,8 @@ """SKy Images and Photovoltaic Power Dataset (SKIPP'D).""" import os -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -74,7 +75,7 @@ class SKIPPD(NonGeoDataset): root: str = "data", split: str = "trainval", task: str = "nowcast", - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -133,7 +134,7 @@ class SKIPPD(NonGeoDataset): return num_datapoints - def __getitem__(self, index: int) -> dict[str, Union[str, Tensor]]: + def __getitem__(self, index: int) -> dict[str, str | Tensor]: """Return an index within the dataset. Args: @@ -142,7 +143,7 @@ class SKIPPD(NonGeoDataset): Returns: data and label at that index """ - sample: dict[str, Union[str, Tensor]] = {"image": self._load_image(index)} + sample: dict[str, str | Tensor] = {"image": self._load_image(index)} sample.update(self._load_features(index)) if self.transforms is not None: @@ -176,7 +177,7 @@ class SKIPPD(NonGeoDataset): tensor = torch.from_numpy(arr).to(torch.float32) return tensor - def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]: + def _load_features(self, index: int) -> dict[str, str | Tensor]: """Load label. Args: @@ -195,7 +196,7 @@ class SKIPPD(NonGeoDataset): path = os.path.join(self.root, f"times_{self.split}_{self.task}.npy") datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat) - features: dict[str, Union[str, Tensor]] = { + features: dict[str, str | Tensor] = { "label": torch.tensor(label, dtype=torch.float32), "date": datestring, } @@ -241,7 +242,7 @@ class SKIPPD(NonGeoDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 321df4d4a..dafbb8863 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -4,8 +4,8 @@ """So2Sat dataset.""" import os -from collections.abc import Sequence -from typing import Callable, Optional, cast +from collections.abc import Callable, Sequence +from typing import cast import matplotlib.pyplot as plt import numpy as np @@ -196,7 +196,7 @@ class So2Sat(NonGeoDataset): version: str = "2", split: str = "train", bands: Sequence[str] = BAND_SETS["all"], - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize a new So2Sat dataset instance. @@ -340,7 +340,7 @@ class So2Sat(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index 8165c2396..7d81a4aa8 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -5,8 +5,8 @@ import os import re -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable, Iterable +from typing import Any, cast import matplotlib.pyplot as plt import torch @@ -98,11 +98,11 @@ class SouthAfricaCropType(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, classes: list[int] = list(cmap.keys()), bands: list[str] = all_bands, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, ) -> None: """Initialize a new South Africa Crop Type dataset instance. @@ -224,7 +224,7 @@ class SouthAfricaCropType(RasterDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index e1d1e952c..131841d77 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -3,8 +3,8 @@ """South America Soybean Dataset.""" -from collections.abc import Iterable -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -71,11 +71,11 @@ class SouthAmericaSoybean(RasterDataset): def __init__( self, - paths: Union[str, Iterable[str]] = "data", - crs: Optional[CRS] = None, - res: Optional[float] = None, + paths: str | Iterable[str] = "data", + crs: CRS | None = None, + res: float | None = None, years: list[int] = [2021], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, download: bool = False, checksum: bool = False, @@ -133,7 +133,7 @@ class SouthAmericaSoybean(RasterDataset): self, sample: dict[str, Any], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index c6780e197..111367dbd 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -9,7 +9,8 @@ import glob import math import os import re -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import fiona import matplotlib.pyplot as plt @@ -81,9 +82,9 @@ class SpaceNet(NonGeoDataset, abc.ABC): root: str, image: str, collections: list[str] = [], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new SpaceNet Dataset instance. @@ -274,7 +275,7 @@ class SpaceNet(NonGeoDataset, abc.ABC): return to_be_downloaded - def _download(self, collections: list[str], api_key: Optional[str] = None) -> None: + def _download(self, collections: list[str], api_key: str | None = None) -> None: """Download the dataset and extract it. Args: @@ -299,7 +300,7 @@ class SpaceNet(NonGeoDataset, abc.ABC): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. @@ -398,9 +399,9 @@ class SpaceNet1(SpaceNet): self, root: str = "data", image: str = "rgb", - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new SpaceNet 1 Dataset instance. @@ -514,9 +515,9 @@ class SpaceNet2(SpaceNet): root: str = "data", image: str = "PS-RGB", collections: list[str] = [], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new SpaceNet 2 Dataset instance. @@ -633,11 +634,11 @@ class SpaceNet3(SpaceNet): self, root: str = "data", image: str = "PS-RGB", - speed_mask: Optional[bool] = False, + speed_mask: bool | None = False, collections: list[str] = [], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new SpaceNet 3 Dataset instance. @@ -733,7 +734,7 @@ class SpaceNet3(SpaceNet): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. @@ -884,9 +885,9 @@ class SpaceNet4(SpaceNet): root: str = "data", image: str = "PS-RGBNIR", angles: list[str] = [], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new SpaceNet 4 Dataset instance. @@ -1051,11 +1052,11 @@ class SpaceNet5(SpaceNet3): self, root: str = "data", image: str = "PS-RGB", - speed_mask: Optional[bool] = False, + speed_mask: bool | None = False, collections: list[str] = [], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new SpaceNet 5 Dataset instance. @@ -1183,9 +1184,9 @@ class SpaceNet6(SpaceNet): self, root: str = "data", image: str = "PS-RGB", - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, ) -> None: """Initialize a new SpaceNet 6 Dataset instance. @@ -1212,7 +1213,7 @@ class SpaceNet6(SpaceNet): self.files = self._load_files(os.path.join(root, self.dataset_id)) - def __download(self, api_key: Optional[str] = None) -> None: + def __download(self, api_key: str | None = None) -> None: """Download the dataset and extract it. Args: @@ -1281,9 +1282,9 @@ class SpaceNet7(SpaceNet): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new SpaceNet 7 Dataset instance. diff --git a/torchgeo/datasets/splits.py b/torchgeo/datasets/splits.py index 0164f5438..af02d0cc1 100644 --- a/torchgeo/datasets/splits.py +++ b/torchgeo/datasets/splits.py @@ -7,7 +7,7 @@ from collections.abc import Sequence from copy import deepcopy from itertools import accumulate from math import floor, isclose -from typing import Optional, Union, cast +from typing import cast from rtree.index import Index, Property from torch import Generator, default_generator, randint, randperm @@ -50,7 +50,7 @@ def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[in def random_bbox_assignment( dataset: GeoDataset, lengths: Sequence[float], - generator: Optional[Generator] = default_generator, + generator: Generator | None = default_generator, ) -> list[GeoDataset]: """Split a GeoDataset randomly assigning its index's BoundingBoxes. @@ -104,7 +104,7 @@ def random_bbox_assignment( def random_bbox_splitting( dataset: GeoDataset, fractions: Sequence[float], - generator: Optional[Generator] = default_generator, + generator: Generator | None = default_generator, ) -> list[GeoDataset]: """Split a GeoDataset randomly splitting its index's BoundingBoxes. @@ -172,7 +172,7 @@ def random_grid_cell_assignment( dataset: GeoDataset, fractions: Sequence[float], grid_size: int = 6, - generator: Optional[Generator] = default_generator, + generator: Generator | None = default_generator, ) -> list[GeoDataset]: """Overlays a grid over a GeoDataset and randomly assigns cells to new GeoDatasets. @@ -289,7 +289,7 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> list[GeoDatas def time_series_split( - dataset: GeoDataset, lengths: Sequence[Union[float, tuple[float, float]]] + dataset: GeoDataset, lengths: Sequence[float | tuple[float, float]] ) -> list[GeoDataset]: """Split a GeoDataset on its time dimension to create non-overlapping GeoDatasets. diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py index 7b862030e..39ddbc187 100644 --- a/torchgeo/datasets/ssl4eo.py +++ b/torchgeo/datasets/ssl4eo.py @@ -6,7 +6,8 @@ import glob import os import random -from typing import Callable, Optional, TypedDict +from collections.abc import Callable +from typing import TypedDict import matplotlib.pyplot as plt import numpy as np @@ -163,7 +164,7 @@ class SSL4EOL(NonGeoDataset): root: str = "data", split: str = "oli_sr", seasons: int = 1, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -285,7 +286,7 @@ class SSL4EOL(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. @@ -405,7 +406,7 @@ class SSL4EOS12(NonGeoDataset): root: str = "data", split: str = "s2c", seasons: int = 1, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize a new SSL4EOS12 instance. @@ -500,7 +501,7 @@ class SSL4EOS12(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index d7d2444eb..d37097928 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -110,8 +110,8 @@ class SSL4EOLBenchmark(NonGeoDataset): sensor: str = "oli_sr", product: str = "cdl", split: str = "train", - classes: Optional[list[int]] = None, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + classes: list[int] | None = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -327,7 +327,7 @@ class SSL4EOLBenchmark(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index eed76054e..03cdf06b3 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -4,7 +4,8 @@ """SustainBench Crop Yield dataset.""" import os -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -60,7 +61,7 @@ class SustainBenchCropYield(NonGeoDataset): root: str = "data", split: str = "train", countries: list[str] = ["usa"], - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -195,7 +196,7 @@ class SustainBenchCropYield(NonGeoDataset): sample: dict[str, Any], band_idx: int = 0, show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 05ac17599..b8be66e45 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -3,7 +3,8 @@ """UC Merced dataset.""" import os -from typing import Callable, Optional, cast +from collections.abc import Callable +from typing import cast import matplotlib.pyplot as plt import numpy as np @@ -85,7 +86,7 @@ class UCMerced(NonGeoClassificationDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -190,7 +191,7 @@ class UCMerced(NonGeoClassificationDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index a3156ab73..5b27c2e86 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -5,8 +5,7 @@ import glob import os -from collections.abc import Sequence -from typing import Callable, Optional +from collections.abc import Callable, Sequence import matplotlib.pyplot as plt import numpy as np @@ -89,7 +88,7 @@ class USAVars(NonGeoDataset): root: str = "data", split: str = "train", labels: Sequence[str] = ALL_LABELS, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -230,7 +229,7 @@ class USAVars(NonGeoDataset): self, sample: dict[str, Tensor], show_labels: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 59ecfcda6..913dad794 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -4,7 +4,7 @@ """Vaihingen dataset.""" import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -122,7 +122,7 @@ class Vaihingen2D(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize a new Vaihingen2D dataset instance. @@ -241,7 +241,7 @@ class Vaihingen2D(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, alpha: float = 0.5, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 43756df71..93eabdf72 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -4,7 +4,8 @@ """NWPU VHR-10 dataset.""" import os -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -183,7 +184,7 @@ class VHR10(NonGeoDataset): self, root: str = "data", split: str = "positive", - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -364,8 +365,8 @@ class VHR10(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, - show_feats: Optional[str] = "both", + suptitle: str | None = None, + show_feats: str | None = "both", box_alpha: float = 0.7, mask_alpha: float = 0.7, ) -> Figure: diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py index d782cd2b5..ec41088d3 100644 --- a/torchgeo/datasets/western_usa_live_fuel_moisture.py +++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py @@ -6,7 +6,8 @@ import glob import json import os -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import pandas as pd import torch @@ -204,9 +205,9 @@ class WesternUSALiveFuelMoisture(NonGeoDataset): self, root: str = "data", input_features: list[str] = all_variable_names, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: Optional[str] = None, + api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new Western USA Live Fuel Moisture Dataset. @@ -329,7 +330,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset): pathname = os.path.join(self.root, self.collection_id) + ".tar.gz" extract_archive(pathname, self.root) - def _download(self, api_key: Optional[str] = None) -> None: + def _download(self, api_key: str | None = None) -> None: """Download the dataset and extract it. Args: diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 82bacc2be..7df948af3 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -5,7 +5,7 @@ import glob import os -from typing import Callable, Optional +from collections.abc import Callable import matplotlib.pyplot as plt import numpy as np @@ -72,7 +72,7 @@ class XView2(NonGeoDataset): self, root: str = "data", split: str = "train", - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize a new xView2 dataset instance. @@ -225,7 +225,7 @@ class XView2(NonGeoDataset): self, sample: dict[str, Tensor], show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, alpha: float = 0.5, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index f0fb1410f..d1f502973 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -4,8 +4,7 @@ """ZueriCrop dataset.""" import os -from collections.abc import Sequence -from typing import Callable, Optional +from collections.abc import Callable, Sequence import matplotlib.pyplot as plt import torch @@ -71,7 +70,7 @@ class ZueriCrop(NonGeoDataset): self, root: str = "data", bands: Sequence[str] = band_names, - transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: @@ -264,7 +263,7 @@ class ZueriCrop(NonGeoDataset): sample: dict[str, Tensor], time_step: int = 0, show_titles: bool = True, - suptitle: Optional[str] = None, + suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index f2ab4a5d9..50c138a81 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -10,7 +10,8 @@ See the following references for design details: * https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py """ # noqa: E501 -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any import torch.nn as nn from torchvision.models._api import WeightsEnum @@ -55,7 +56,7 @@ def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module: return model -def get_model_weights(name: Union[Callable[..., nn.Module], str]) -> WeightsEnum: +def get_model_weights(name: Callable[..., nn.Module] | str) -> WeightsEnum: """Get the weights enum class associated with a given model. .. versionadded:: 0.4 diff --git a/torchgeo/models/dofa.py b/torchgeo/models/dofa.py index 0e12f27ac..bb6c04377 100644 --- a/torchgeo/models/dofa.py +++ b/torchgeo/models/dofa.py @@ -4,7 +4,7 @@ """Dynamic One-For-All (DOFA) models.""" from functools import partial -from typing import Any, Optional +from typing import Any import kornia.augmentation as K import torch @@ -441,7 +441,7 @@ def dofa_small_patch16_224(**kwargs: Any) -> DOFA: def dofa_base_patch16_224( - weights: Optional[DOFABase16_Weights] = None, **kwargs: Any + weights: DOFABase16_Weights | None = None, **kwargs: Any ) -> DOFA: """Dynamic One-For-All (DOFA) base patch size 16 model. @@ -477,7 +477,7 @@ def dofa_base_patch16_224( def dofa_large_patch16_224( - weights: Optional[DOFALarge16_Weights] = None, **kwargs: Any + weights: DOFALarge16_Weights | None = None, **kwargs: Any ) -> DOFA: """Dynamic One-For-All (DOFA) large patch size 16 model. diff --git a/torchgeo/models/fcsiam.py b/torchgeo/models/fcsiam.py index 4d95a0edd..be23e565b 100644 --- a/torchgeo/models/fcsiam.py +++ b/torchgeo/models/fcsiam.py @@ -3,8 +3,8 @@ """Fully convolutional change detection (FCCD) implementations.""" -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any import segmentation_models_pytorch as smp import torch @@ -25,13 +25,13 @@ class FCSiamConc(SegmentationModel): # type: ignore[misc] self, encoder_name: str = "resnet34", encoder_depth: int = 5, - encoder_weights: Optional[str] = "imagenet", + encoder_weights: str | None = "imagenet", decoder_use_batchnorm: bool = True, decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), - decoder_attention_type: Optional[str] = None, + decoder_attention_type: str | None = None, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, Callable[[Tensor], Tensor]]] = None, + activation: str | Callable[[Tensor], Tensor] | None = None, ): """Initialize a new FCSiamConc model. diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py index 59f42223c..477624df0 100644 --- a/torchgeo/models/rcf.py +++ b/torchgeo/models/rcf.py @@ -3,8 +3,6 @@ """Implementation of a random convolutional feature projection model.""" -from typing import Optional - import numpy as np import torch import torch.nn.functional as F @@ -43,9 +41,9 @@ class RCF(Module): features: int = 16, kernel_size: int = 3, bias: float = -1.0, - seed: Optional[int] = None, + seed: int | None = None, mode: str = "gaussian", - dataset: Optional[NonGeoDataset] = None, + dataset: NonGeoDataset | None = None, ) -> None: """Initializes the RCF model. diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 70c058711..e78c9394c 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -3,7 +3,7 @@ """Pre-trained ResNet models.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K import timm @@ -463,7 +463,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] def resnet18( - weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any + weights: ResNet18_Weights | None = None, *args: Any, **kwargs: Any ) -> ResNet: """ResNet-18 model. @@ -497,7 +497,7 @@ def resnet18( def resnet50( - weights: Optional[ResNet50_Weights] = None, *args: Any, **kwargs: Any + weights: ResNet50_Weights | None = None, *args: Any, **kwargs: Any ) -> ResNet: """ResNet-50 model. diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py index 66dd0ebc0..b5c38acce 100644 --- a/torchgeo/models/swin.py +++ b/torchgeo/models/swin.py @@ -3,7 +3,7 @@ """Pre-trained Swin v2 Transformer models.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K import torch @@ -136,7 +136,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] def swin_v2_b( - weights: Optional[Swin_V2_B_Weights] = None, *args: Any, **kwargs: Any + weights: Swin_V2_B_Weights | None = None, *args: Any, **kwargs: Any ) -> SwinTransformer: """Swin Transformer v2 base model. diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py index 85c2e4288..2c7e37fc2 100644 --- a/torchgeo/models/vit.py +++ b/torchgeo/models/vit.py @@ -3,7 +3,7 @@ """Pre-trained Vision Transformer models.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K import timm @@ -205,7 +205,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] def vit_small_patch16_224( - weights: Optional[ViTSmall16_Weights] = None, *args: Any, **kwargs: Any + weights: ViTSmall16_Weights | None = None, *args: Any, **kwargs: Any ) -> VisionTransformer: """Vision Transform (ViT) small patch size 16 model. diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index ef9bb05f1..22726f74b 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -5,7 +5,6 @@ import abc from collections.abc import Iterator -from typing import Optional, Union import torch from rtree.index import Index, Property @@ -25,7 +24,7 @@ class BatchGeoSampler(Sampler[list[BoundingBox]], abc.ABC): longitude, height, width, projection, coordinate system, and time. """ - def __init__(self, dataset: GeoDataset, roi: Optional[BoundingBox] = None) -> None: + def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: """Initialize a new Sampler instance. Args: @@ -66,10 +65,10 @@ class RandomBatchGeoSampler(BatchGeoSampler): def __init__( self, dataset: GeoDataset, - size: Union[tuple[float, float], float], + size: tuple[float, float] | float, batch_size: int, - length: Optional[int] = None, - roi: Optional[BoundingBox] = None, + length: int | None = None, + roi: BoundingBox | None = None, units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 360798b58..094142cb6 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,8 +4,7 @@ """TorchGeo samplers.""" import abc -from collections.abc import Iterable, Iterator -from typing import Callable, Optional, Union +from collections.abc import Callable, Iterable, Iterator import torch from rtree.index import Index, Property @@ -25,7 +24,7 @@ class GeoSampler(Sampler[BoundingBox], abc.ABC): longitude, height, width, projection, coordinate system, and time. """ - def __init__(self, dataset: GeoDataset, roi: Optional[BoundingBox] = None) -> None: + def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: """Initialize a new Sampler instance. Args: @@ -69,9 +68,9 @@ class RandomGeoSampler(GeoSampler): def __init__( self, dataset: GeoDataset, - size: Union[tuple[float, float], float], - length: Optional[int] = None, - roi: Optional[BoundingBox] = None, + size: tuple[float, float] | float, + length: int | None = None, + roi: BoundingBox | None = None, units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -174,9 +173,9 @@ class GridGeoSampler(GeoSampler): def __init__( self, dataset: GeoDataset, - size: Union[tuple[float, float], float], - stride: Union[tuple[float, float], float], - roi: Optional[BoundingBox] = None, + size: tuple[float, float] | float, + stride: tuple[float, float] | float, + roi: BoundingBox | None = None, units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -271,10 +270,7 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, - dataset: GeoDataset, - roi: Optional[BoundingBox] = None, - shuffle: bool = False, + self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False ) -> None: """Initialize a new Sampler instance. diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 329f3677e..53061f579 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -4,7 +4,7 @@ """Common sampler utilities.""" import math -from typing import Optional, Union, overload +from typing import overload import torch @@ -12,14 +12,14 @@ from ..datasets import BoundingBox @overload -def _to_tuple(value: Union[tuple[int, int], int]) -> tuple[int, int]: ... +def _to_tuple(value: tuple[int, int] | int) -> tuple[int, int]: ... @overload -def _to_tuple(value: Union[tuple[float, float], float]) -> tuple[float, float]: ... +def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]: ... -def _to_tuple(value: Union[tuple[float, float], float]) -> tuple[float, float]: +def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]: """Convert value to a tuple if it is not already a tuple. Args: @@ -35,7 +35,7 @@ def _to_tuple(value: Union[tuple[float, float], float]) -> tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: Union[tuple[float, float], float], res: float + bounds: BoundingBox, size: tuple[float, float] | float, res: float ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -80,7 +80,7 @@ def get_random_bounding_box( def tile_to_chips( bounds: BoundingBox, size: tuple[float, float], - stride: Optional[tuple[float, float]] = None, + stride: tuple[float, float] | None = None, ) -> tuple[int, int]: r"""Compute number of :term:`chips ` that can be sampled from a :term:`tile`. diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py index 878034a77..338961e3c 100644 --- a/torchgeo/trainers/base.py +++ b/torchgeo/trainers/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any import lightning from lightning.pytorch import LightningModule @@ -28,7 +28,7 @@ class BaseTask(LightningModule, ABC): #: Whether the goal is to minimize or maximize the performance metric to monitor. mode = "min" - def __init__(self, ignore: Optional[Union[Sequence[str], str]] = None) -> None: + def __init__(self, ignore: Sequence[str] | str | None = None) -> None: """Initialize a new BaseTask instance. Args: diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 37258fb63..a46ad2dce 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -4,7 +4,7 @@ """BYOL trainer for self-supervised learning (SSL).""" import os -from typing import Any, Optional, Union +from typing import Any import timm import torch @@ -148,8 +148,8 @@ class BackboneWrapper(nn.Module): self.hidden_size = hidden_size self.layer = layer - self._projector: Optional[nn.Module] = None - self._projector_dim: Optional[int] = None + self._projector: nn.Module | None = None + self._projector_dim: int | None = None self._encoded = torch.empty(0) self._register_hook() @@ -223,7 +223,7 @@ class BYOL(nn.Module): in_channels: int = 4, projection_size: int = 256, hidden_size: int = 4096, - augment_fn: Optional[nn.Module] = None, + augment_fn: nn.Module | None = None, beta: float = 0.99, **kwargs: Any, ) -> None: @@ -297,7 +297,7 @@ class BYOLTask(BaseTask): def __init__( self, model: str = "resnet50", - weights: Optional[Union[WeightsEnum, str, bool]] = None, + weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, lr: float = 1e-3, patience: int = 10, diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 3aa06be84..9f6c828b6 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -4,7 +4,7 @@ """Trainers for image classification.""" import os -from typing import Any, Optional, Union +from typing import Any import matplotlib.pyplot as plt import timm @@ -35,11 +35,11 @@ class ClassificationTask(BaseTask): def __init__( self, model: str = "resnet50", - weights: Optional[Union[WeightsEnum, str, bool]] = None, + weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_classes: int = 1000, loss: str = "ce", - class_weights: Optional[Tensor] = None, + class_weights: Tensor | None = None, lr: float = 1e-3, patience: int = 10, freeze_backbone: bool = False, @@ -220,7 +220,7 @@ class ClassificationTask(BaseTask): batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] - fig: Optional[Figure] = None + fig: Figure | None = None try: fig = datamodule.plot(sample) except RGBBandsMissingError: @@ -364,7 +364,7 @@ class MultiLabelClassificationTask(ClassificationTask): batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] - fig: Optional[Figure] = None + fig: Figure | None = None try: fig = datamodule.plot(sample) except RGBBandsMissingError: diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index ee44405dd..c74fd7156 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -4,7 +4,7 @@ """Trainers for object detection.""" from functools import partial -from typing import Any, Optional +from typing import Any import matplotlib.pyplot as plt import torch @@ -60,7 +60,7 @@ class ObjectDetectionTask(BaseTask): self, model: str = "faster-rcnn", backbone: str = "resnet50", - weights: Optional[bool] = None, + weights: bool | None = None, in_channels: int = 3, num_classes: int = 1000, trainable_layers: int = 3, @@ -109,7 +109,7 @@ class ObjectDetectionTask(BaseTask): """ backbone: str = self.hparams["backbone"] model: str = self.hparams["model"] - weights: Optional[bool] = self.hparams["weights"] + weights: bool | None = self.hparams["weights"] num_classes: int = self.hparams["num_classes"] freeze_backbone: bool = self.hparams["freeze_backbone"] @@ -289,7 +289,7 @@ class ObjectDetectionTask(BaseTask): sample["image"] *= 255 sample["image"] = sample["image"].to(torch.uint8) - fig: Optional[Figure] = None + fig: Figure | None = None try: fig = datamodule.plot(sample) except RGBBandsMissingError: diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index 8b1e59cf3..e06aecdf7 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -6,7 +6,7 @@ import os import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K import lightning @@ -141,7 +141,7 @@ class MoCoTask(BaseTask): def __init__( self, model: str = "resnet50", - weights: Optional[Union[WeightsEnum, str, bool]] = None, + weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, version: int = 3, layers: int = 3, @@ -156,9 +156,9 @@ class MoCoTask(BaseTask): moco_momentum: float = 0.99, gather_distributed: bool = False, size: int = 224, - grayscale_weights: Optional[Tensor] = None, - augmentation1: Optional[nn.Module] = None, - augmentation2: Optional[nn.Module] = None, + grayscale_weights: Tensor | None = None, + augmentation1: nn.Module | None = None, + augmentation2: nn.Module | None = None, ) -> None: """Initialize a new MoCoTask instance. diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 7ab19d215..4fb3a9209 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -4,7 +4,7 @@ """Trainers for regression.""" import os -from typing import Any, Optional, Union +from typing import Any import matplotlib.pyplot as plt import segmentation_models_pytorch as smp @@ -31,7 +31,7 @@ class RegressionTask(BaseTask): self, model: str = "resnet50", backbone: str = "resnet50", - weights: Optional[Union[WeightsEnum, str, bool]] = None, + weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_outputs: int = 1, num_filters: int = 3, @@ -211,7 +211,7 @@ class RegressionTask(BaseTask): batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] - fig: Optional[Figure] = None + fig: Figure | None = None try: fig = datamodule.plot(sample) except RGBBandsMissingError: diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 1249f3f28..502826db3 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -4,7 +4,7 @@ """Trainers for semantic segmentation.""" import os -from typing import Any, Optional, Union +from typing import Any import matplotlib.pyplot as plt import segmentation_models_pytorch as smp @@ -28,13 +28,13 @@ class SemanticSegmentationTask(BaseTask): self, model: str = "unet", backbone: str = "resnet50", - weights: Optional[Union[WeightsEnum, str, bool]] = None, + weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_classes: int = 1000, num_filters: int = 3, loss: str = "ce", - class_weights: Optional[Tensor] = None, - ignore_index: Optional[int] = None, + class_weights: Tensor | None = None, + ignore_index: int | None = None, lr: float = 1e-3, patience: int = 10, freeze_backbone: bool = False, @@ -194,7 +194,7 @@ class SemanticSegmentationTask(BaseTask): for balanced performance assessment across imbalanced classes. """ num_classes: int = self.hparams["num_classes"] - ignore_index: Optional[int] = self.hparams["ignore_index"] + ignore_index: int | None = self.hparams["ignore_index"] metrics = MetricCollection( [ MulticlassAccuracy( @@ -268,7 +268,7 @@ class SemanticSegmentationTask(BaseTask): batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] - fig: Optional[Figure] = None + fig: Figure | None = None try: fig = datamodule.plot(sample) except RGBBandsMissingError: diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index b753dabcb..44e517731 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -5,7 +5,7 @@ import os import warnings -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K import lightning @@ -73,20 +73,20 @@ class SimCLRTask(BaseTask): def __init__( self, model: str = "resnet50", - weights: Optional[Union[WeightsEnum, str, bool]] = None, + weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, version: int = 2, layers: int = 3, - hidden_dim: Optional[int] = None, - output_dim: Optional[int] = None, + hidden_dim: int | None = None, + output_dim: int | None = None, lr: float = 4.8, weight_decay: float = 1e-4, temperature: float = 0.07, memory_bank_size: int = 64000, gather_distributed: bool = False, size: int = 224, - grayscale_weights: Optional[Tensor] = None, - augmentations: Optional[nn.Module] = None, + grayscale_weights: Tensor | None = None, + augmentations: nn.Module | None = None, ) -> None: """Initialize a new SimCLRTask instance. diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index b5cd8f1e9..7332cd1d1 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -5,7 +5,7 @@ import warnings from collections import OrderedDict -from typing import Optional, Union, cast +from typing import cast import torch import torch.nn as nn @@ -127,8 +127,8 @@ def reinit_initial_conv_layer( layer: Conv2d, new_in_channels: int, keep_rgb_weights: bool, - new_stride: Optional[Union[int, tuple[int, int]]] = None, - new_padding: Optional[Union[str, Union[int, tuple[int, int]]]] = None, + new_stride: int | tuple[int, int] | None = None, + new_padding: str | int | tuple[int, int] | None = None, ) -> Conv2d: """Clones a Conv2d layer while optionally retaining some of the original weights. diff --git a/torchgeo/transforms/color.py b/torchgeo/transforms/color.py index 5459fc2f8..efc7055fe 100644 --- a/torchgeo/transforms/color.py +++ b/torchgeo/transforms/color.py @@ -3,8 +3,6 @@ """TorchGeo color transforms.""" -from typing import Optional - from kornia.augmentation import IntensityAugmentationBase2D from torch import Tensor @@ -57,7 +55,7 @@ class RandomGrayscale(IntensityAugmentationBase2D): input: Tensor, params: dict[str, Tensor], flags: dict[str, Tensor], - transform: Optional[Tensor] = None, + transform: Tensor | None = None, ) -> Tensor: """Apply the transform. diff --git a/torchgeo/transforms/indices.py b/torchgeo/transforms/indices.py index eb72afff4..3f52fd796 100644 --- a/torchgeo/transforms/indices.py +++ b/torchgeo/transforms/indices.py @@ -8,8 +8,6 @@ For more information about indices see the following references: - https://github.com/awesome-spectral-indices/awesome-spectral-indices """ -from typing import Optional - import torch from kornia.augmentation import IntensityAugmentationBase2D from torch import Tensor @@ -44,7 +42,7 @@ class AppendNormalizedDifferenceIndex(IntensityAugmentationBase2D): input: Tensor, params: dict[str, Tensor], flags: dict[str, int], - transform: Optional[Tensor] = None, + transform: Tensor | None = None, ) -> Tensor: """Apply the transform. @@ -319,7 +317,7 @@ class AppendTriBandNormalizedDifferenceIndex(IntensityAugmentationBase2D): input: Tensor, params: dict[str, Tensor], flags: dict[str, int], - transform: Optional[Tensor] = None, + transform: Tensor | None = None, ) -> Tensor: """Apply the transform. diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 19239c26b..ca95461f7 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -3,7 +3,7 @@ """TorchGeo transforms.""" -from typing import Any, Optional, Union +from typing import Any import kornia.augmentation as K import torch @@ -25,7 +25,7 @@ class AugmentationSequential(Module): def __init__( self, - *args: Union[K.base._AugmentationBase, K.ImageSequential, Lambda], + *args: K.base._AugmentationBase | K.ImageSequential | Lambda, data_keys: list[str], **kwargs: Any, ) -> None: @@ -84,7 +84,7 @@ class AugmentationSequential(Module): batch["masks"] = rearrange(batch["masks"], "c h w -> () c h w") inputs = [batch[k] for k in self.data_keys] - outputs_list: Union[Tensor, list[Tensor]] = self.augs(*inputs) + outputs_list: Tensor | list[Tensor] = self.augs(*inputs) outputs_list = ( outputs_list if isinstance(outputs_list, list) else [outputs_list] ) @@ -147,7 +147,7 @@ class _RandomNCrop(K.GeometricAugmentationBase2D): input: Tensor, params: dict[str, Tensor], flags: dict[str, Any], - transform: Optional[Tensor] = None, + transform: Tensor | None = None, ) -> Tensor: """Apply the transform. @@ -169,7 +169,7 @@ class _RandomNCrop(K.GeometricAugmentationBase2D): class _NCropGenerator(K.random_generator.CropGenerator): """Generate N random crops.""" - def __init__(self, size: Union[tuple[int, int], Tensor], num: int) -> None: + def __init__(self, size: tuple[int, int] | Tensor, num: int) -> None: """Initialize a new _NCropGenerator instance. Args: @@ -207,9 +207,9 @@ class _ExtractPatches(K.GeometricAugmentationBase2D): def __init__( self, - window_size: Union[int, tuple[int, int]], - stride: Optional[Union[int, tuple[int, int]]] = None, - padding: Optional[Union[int, tuple[int, int]]] = 0, + window_size: int | tuple[int, int], + stride: int | tuple[int, int] | None = None, + padding: int | tuple[int, int] | None = 0, keepdim: bool = True, ) -> None: """Initialize a new _ExtractPatches instance. @@ -250,7 +250,7 @@ class _ExtractPatches(K.GeometricAugmentationBase2D): input: Tensor, params: dict[str, Tensor], flags: dict[str, Any], - transform: Optional[Tensor] = None, + transform: Tensor | None = None, ) -> Tensor: """Apply the transform.