This commit is contained in:
Adam J. Stewart 2024-04-04 09:15:56 +02:00 коммит произвёл GitHub
Родитель 26cab1060b
Коммит ea57469b0b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
128 изменённых файлов: 624 добавлений и 630 удалений

2
.github/workflows/style.yaml поставляемый
Просмотреть файл

@ -158,7 +158,7 @@ jobs:
- name: List pip dependencies
run: pip list
- name: Run pyupgrade checks
run: pyupgrade --py39-plus $(find . -path ./docs/src -prune -o -name "*.py" -print)
run: pyupgrade --py310-plus $(find . -path ./docs/src -prune -o -name "*.py" -print)
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}
cancel-in-progress: true

4
.github/workflows/tests.yaml поставляемый
Просмотреть файл

@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ['3.10', '3.11', '3.12']
steps:
- name: Clone repo
uses: actions/checkout@v4.1.2
@ -71,7 +71,7 @@ jobs:
- name: Set up python
uses: actions/setup-python@v5.1.0
with:
python-version: '3.9'
python-version: '3.10'
- name: Cache dependencies
uses: actions/cache@v4.0.2
id: cache

Просмотреть файл

@ -3,7 +3,7 @@ repos:
rev: v3.15.2
hooks:
- id: pyupgrade
args: [--py39-plus]
args: [--py310-plus]
- repo: https://github.com/pycqa/isort
rev: 5.13.2

Просмотреть файл

@ -103,7 +103,7 @@ All of these tools should be used from the root of the project to ensure that ou
$ black .
$ isort .
$ pyupgrade --py39-plus $(find . -name "*.py")
$ pyupgrade --py310-plus $(find . -name "*.py")
Flake8, pydocstyle, and mypy won't format your code for you, but they will warn you about potential issues with your code or docstrings:

Просмотреть файл

@ -56,7 +56,7 @@ import warnings
from collections import defaultdict
from datetime import date, timedelta
from multiprocessing.dummy import Lock, Pool
from typing import Any, Optional
from typing import Any
import ee
import numpy as np
@ -168,7 +168,7 @@ def get_patch(
new_resolutions: list[int],
dtype: str = "float32",
meta_cloud_name: str = "CLOUD_COVER",
default_value: Optional[float] = None,
default_value: float | None = None,
) -> dict[str, Any]:
image = collection.sort(meta_cloud_name).first()
region = ee.Geometry.Point(center_coord).buffer(radius).bounds()
@ -214,7 +214,7 @@ def get_random_patches_match(
new_resolutions: list[int],
dtype: str,
meta_cloud_name: str,
default_value: Optional[float],
default_value: float | None,
dates: list[date],
radius: float,
debug: bool = False,

Просмотреть файл

@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"
name = "torchgeo"
description = "TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data"
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "Adam J. Stewart", email = "ajstewart426@gmail.com"},
@ -29,7 +29,6 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
@ -39,9 +38,8 @@ classifiers = [
dependencies = [
# einops 0.3+ required for einops.repeat
"einops>=0.3",
# fiona 1.8.19+ required to fix erroneous warning
# https://github.com/Toblerity/Fiona/issues/986
"fiona>=1.8.19",
# fiona 1.8.21+ required for Python 3.10 wheels
"fiona>=1.8.21",
# kornia 0.6.9+ required for kornia.augmentation.RandomBrightness
"kornia>=0.6.9",
# lightly 1.4.4+ required for MoCo v3 support
@ -50,24 +48,24 @@ dependencies = [
"lightly>=1.4.4,!=1.4.26",
# lightning 2+ required for LightningCLI args + sys.argv support
"lightning[pytorch-extra]>=2",
# matplotlib 3.3.3+ required for Python 3.9 wheels
"matplotlib>=3.3.3",
# numpy 1.19.3+ required by Python 3.9 wheels
"numpy>=1.19.3",
# pandas 1.1.3+ required for Python 3.9 wheels
"pandas>=1.1.3",
# pillow 8+ required for Python 3.9 wheels
"pillow>=8",
# pyproj 3+ required for Python 3.9 wheels
"pyproj>=3",
# rasterio 1.2+ required for Python 3.9 wheels
"rasterio>=1.2",
# rtree 1+ required for len(index), index & index, index | index
# matplotlib 3.5+ required for Python 3.10 wheels
"matplotlib>=3.5",
# numpy 1.21.2+ required by Python 3.10 wheels
"numpy>=1.21.2",
# pandas 1.3.3+ required for Python 3.10 wheels
"pandas>=1.3.3",
# pillow 8.4+ required for Python 3.10 wheels
"pillow>=8.4",
# pyproj 3.3+ required for Python 3.10 wheels
"pyproj>=3.3",
# rasterio 1.3+ required for Python 3.10 wheels
"rasterio>=1.3",
# rtree 1+ required for Python 3.10 wheels
"rtree>=1",
# segmentation-models-pytorch 0.2+ required for smp.losses module
"segmentation-models-pytorch>=0.2",
# shapely 1.7.1+ required for Python 3.9 wheels
"shapely>=1.7.1",
# shapely 1.8+ required for Python 3.10 wheels
"shapely>=1.8",
# timm 0.4.12 required by segmentation-models-pytorch
"timm>=0.4.12",
# torch 1.13+ required by torchvision
@ -81,27 +79,25 @@ dynamic = ["version"]
[project.optional-dependencies]
datasets = [
# h5py 3+ required for Python 3.9 wheels
"h5py>=3",
# h5py 3.6+ required for Python 3.10 wheels
"h5py>=3.6",
# laspy 2+ required for laspy.read
"laspy>=2",
# opencv-python 4.4.0.46+ required for Python 3.9 wheels
"opencv-python>=4.4.0.46",
# pycocotools 2.0.5+ required for cython 3+ support
"pycocotools>=2.0.5",
# opencv-python 4.5.4+ required for Python 3.10 wheels
"opencv-python>=4.5.4",
# pycocotools 2.0.7+ required for wheels
"pycocotools>=2.0.7",
# pyvista 0.34.2+ required to avoid ImportError in CI
"pyvista>=0.34.2",
# radiant-mlhub 0.3+ required for newer tqdm support required by lightning
"radiant-mlhub>=0.3",
# rarfile 4+ required for wheels
"rarfile>=4",
# scikit-image 0.18+ required for numpy 1.17+ compatibility
# https://github.com/scikit-image/scikit-image/issues/3655
"scikit-image>=0.18",
# scipy 1.6.2+ required for scikit-image 0.18+ compatibility
"scipy>=1.6.2",
# zipfile-deflate64 0.2+ required for extraction bugfix:
# https://github.com/brianhelba/zipfile-deflate64/issues/19
# scikit-image 0.19+ required for Python 3.10 wheels
"scikit-image>=0.19",
# scipy 1.7.2+ required for Python 3.10 wheels
"scipy>=1.7.2",
# zipfile-deflate64 0.2+ required for Python 3.10 wheels
"zipfile-deflate64>=0.2",
]
docs = [
@ -126,7 +122,7 @@ style = [
"isort[colors]>=5.8",
# pydocstyle 6.1+ required for pyproject.toml support
"pydocstyle[toml]>=6.1",
# pyupgrade 2.8+ required for --py39-plus flag
# pyupgrade 2.8+ required for --py310-plus flag
"pyupgrade>=2.8",
]
tests = [
@ -151,7 +147,7 @@ Homepage = "https://github.com/microsoft/torchgeo"
Documentation = "https://torchgeo.readthedocs.io"
[tool.black]
target-version = ["py39", "py310"]
target-version = ["py310"]
color = true
skip_magic_trailing_comma = true
@ -170,7 +166,7 @@ skip_gitignore = true
color_output = true
[tool.mypy]
python_version = "3.9"
python_version = "3.10"
ignore_missing_imports = true
show_error_codes = true
exclude = "(build|data|dist|docs/src|images|logo|logs|output)/"

Просмотреть файл

@ -3,34 +3,34 @@ setuptools==61.0.0
# install
einops==0.3.0
fiona==1.8.19
fiona==1.8.21
kornia==0.6.9
lightly==1.4.4
lightning[pytorch-extra]==2.0.0
matplotlib==3.3.3
numpy==1.19.3
pandas==1.1.3
pillow==8.0.0
pyproj==3.0.0
rasterio==1.2.0
matplotlib==3.5.0
numpy==1.21.2
pandas==1.3.3
pillow==8.4.0
pyproj==3.3.0
rasterio==1.3.0.post1
rtree==1.0.0
segmentation-models-pytorch==0.2.0
shapely==1.7.1
shapely==1.8.0
timm==0.4.12
torch==1.13.0
torchmetrics==0.10.0
torchvision==0.14.0
# datasets
h5py==3.0.0
h5py==3.6.0
laspy==2.0.0
opencv-python==4.4.0.46
pycocotools==2.0.5
opencv-python==4.5.4.58
pycocotools==2.0.7
pyvista==0.34.2
radiant-mlhub==0.3.0
rarfile==4.0
scikit-image==0.18.0
scipy==1.6.2
scikit-image==0.19.0
scipy==1.7.2
zipfile-deflate64==0.2.0
# docs

Просмотреть файл

@ -2,7 +2,6 @@
# Licensed under the MIT License.
import os
from typing import Optional
import numpy as np
import rasterio as rio
@ -18,7 +17,7 @@ def write_raster(
res: int = RES[0],
epsg: int = EPSG[0],
dtype: str = "uint8",
path: Optional[str] = None,
path: str | None = None,
) -> None:
"""Write a raster file.

Просмотреть файл

@ -5,7 +5,6 @@ import pickle
import sys
from collections.abc import Iterable
from pathlib import Path
from typing import Optional, Union
import pytest
import torch
@ -35,7 +34,7 @@ class CustomGeoDataset(GeoDataset):
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(4087),
res: float = 1,
paths: Optional[Union[str, Iterable[str]]] = None,
paths: str | Iterable[str] | None = None,
) -> None:
super().__init__()
self.index.insert(0, tuple(bounds))
@ -249,7 +248,7 @@ class TestRasterDataset:
},
],
)
def test_files(self, paths: Union[str, Iterable[str]]) -> None:
def test_files(self, paths: str | Iterable[str]) -> None:
assert 1 <= len(NAIP(paths).files) <= 2
def test_getitem_single_file(self, naip: NAIP) -> None:

Просмотреть файл

@ -3,7 +3,7 @@
from collections.abc import Sequence
from math import floor, isclose
from typing import Any, Union
from typing import Any
import pytest
from rasterio.crs import CRS
@ -65,7 +65,7 @@ class CustomGeoDataset(GeoDataset):
],
)
def test_random_bbox_assignment(
lengths: Sequence[Union[int, float]], expected_lengths: Sequence[int]
lengths: Sequence[int | float], expected_lengths: Sequence[int]
) -> None:
ds = CustomGeoDataset(
[
@ -255,8 +255,7 @@ def test_roi_split() -> None:
],
)
def test_time_series_split(
lengths: Sequence[Union[tuple[int, int], int, float]],
expected_lengths: Sequence[int],
lengths: Sequence[tuple[int, int] | int | float], expected_lengths: Sequence[int]
) -> None:
ds = CustomGeoDataset(
[

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT License.
import enum
from typing import Callable
from collections.abc import Callable
import pytest
import torch.nn as nn

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT License.
import math
from typing import Optional, Union
from typing import Union
import pytest
@ -33,7 +33,7 @@ MAYBE_TUPLE = Union[float, tuple[float, float]]
],
)
def test_tile_to_chips(
size: MAYBE_TUPLE, stride: Optional[MAYBE_TUPLE], expected: MAYBE_TUPLE
size: MAYBE_TUPLE, stride: MAYBE_TUPLE | None, expected: MAYBE_TUPLE
) -> None:
bounds = BoundingBox(0, 10, 20, 30, 40, 50)
size = _to_tuple(size)

Просмотреть файл

@ -3,7 +3,7 @@
"""AgriFieldNet datamodule."""
from typing import Any, Optional, Union
from typing import Any
import kornia.augmentation as K
import torch
@ -25,8 +25,8 @@ class AgriFieldNetDataModule(GeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 256,
length: Optional[int] = None,
patch_size: int | tuple[int, int] = 256,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -3,7 +3,7 @@
"""Chesapeake Bay High-Resolution Land Cover Project datamodule."""
from typing import Any, Optional
from typing import Any
import kornia.augmentation as K
import torch.nn as nn
@ -63,7 +63,7 @@ class ChesapeakeCVPRDataModule(GeoDataModule):
test_splits: list[str],
batch_size: int = 64,
patch_size: int = 256,
length: Optional[int] = None,
length: int | None = None,
num_workers: int = 0,
class_set: int = 7,
use_prior_labels: bool = False,

Просмотреть файл

@ -3,7 +3,7 @@
"""DeepGlobe Land Cover Classification Challenge datamodule."""
from typing import Any, Union
from typing import Any
import kornia.augmentation as K
@ -24,7 +24,7 @@ class DeepGlobeLandCoverDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[tuple[int, int], int] = 64,
patch_size: tuple[int, int] | int = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,

Просмотреть файл

@ -3,7 +3,8 @@
"""Base classes for all :mod:`torchgeo` data modules."""
from typing import Any, Callable, Optional, Union, cast
from collections.abc import Callable
from typing import Any, cast
import kornia.augmentation as K
import torch
@ -55,27 +56,27 @@ class BaseDataModule(LightningDataModule):
self.kwargs = kwargs
# Datasets
self.dataset: Optional[Dataset[dict[str, Tensor]]] = None
self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None
self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None
self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None
self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None
self.dataset: Dataset[dict[str, Tensor]] | None = None
self.train_dataset: Dataset[dict[str, Tensor]] | None = None
self.val_dataset: Dataset[dict[str, Tensor]] | None = None
self.test_dataset: Dataset[dict[str, Tensor]] | None = None
self.predict_dataset: Dataset[dict[str, Tensor]] | None = None
# Data loaders
self.train_batch_size: Optional[int] = None
self.val_batch_size: Optional[int] = None
self.test_batch_size: Optional[int] = None
self.predict_batch_size: Optional[int] = None
self.train_batch_size: int | None = None
self.val_batch_size: int | None = None
self.test_batch_size: int | None = None
self.predict_batch_size: int | None = None
# Data augmentation
Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]]
self.aug: Transform = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image"]
)
self.train_aug: Optional[Transform] = None
self.val_aug: Optional[Transform] = None
self.test_aug: Optional[Transform] = None
self.predict_aug: Optional[Transform] = None
self.train_aug: Transform | None = None
self.val_aug: Transform | None = None
self.test_aug: Transform | None = None
self.predict_aug: Transform | None = None
def prepare_data(self) -> None:
"""Download and prepare data.
@ -141,7 +142,7 @@ class BaseDataModule(LightningDataModule):
return batch
def plot(self, *args: Any, **kwargs: Any) -> Optional[Figure]:
def plot(self, *args: Any, **kwargs: Any) -> Figure | None:
"""Run the plot method of the validation dataset if one exists.
Should only be called during 'fit' or 'validate' stages as ``val_dataset``
@ -154,7 +155,7 @@ class BaseDataModule(LightningDataModule):
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
fig: Optional[Figure] = None
fig: Figure | None = None
dataset = self.dataset or self.val_dataset
if dataset is not None:
if hasattr(dataset, "plot"):
@ -172,8 +173,8 @@ class GeoDataModule(BaseDataModule):
self,
dataset_class: type[GeoDataset],
batch_size: int = 1,
patch_size: Union[int, tuple[int, int]] = 64,
length: Optional[int] = None,
patch_size: int | tuple[int, int] = 64,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
@ -196,18 +197,18 @@ class GeoDataModule(BaseDataModule):
self.collate_fn = stack_samples
# Samplers
self.sampler: Optional[GeoSampler] = None
self.train_sampler: Optional[GeoSampler] = None
self.val_sampler: Optional[GeoSampler] = None
self.test_sampler: Optional[GeoSampler] = None
self.predict_sampler: Optional[GeoSampler] = None
self.sampler: GeoSampler | None = None
self.train_sampler: GeoSampler | None = None
self.val_sampler: GeoSampler | None = None
self.test_sampler: GeoSampler | None = None
self.predict_sampler: GeoSampler | None = None
# Batch samplers
self.batch_sampler: Optional[BatchGeoSampler] = None
self.train_batch_sampler: Optional[BatchGeoSampler] = None
self.val_batch_sampler: Optional[BatchGeoSampler] = None
self.test_batch_sampler: Optional[BatchGeoSampler] = None
self.predict_batch_sampler: Optional[BatchGeoSampler] = None
self.batch_sampler: BatchGeoSampler | None = None
self.train_batch_sampler: BatchGeoSampler | None = None
self.val_batch_sampler: BatchGeoSampler | None = None
self.test_batch_sampler: BatchGeoSampler | None = None
self.predict_batch_sampler: BatchGeoSampler | None = None
def setup(self, stage: str) -> None:
"""Set up datasets and samplers.

Просмотреть файл

@ -3,7 +3,7 @@
"""GID-15 datamodule."""
from typing import Any, Union
from typing import Any
import kornia.augmentation as K
@ -26,7 +26,7 @@ class GID15DataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[tuple[int, int], int] = 64,
patch_size: tuple[int, int] | int = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,

Просмотреть файл

@ -3,7 +3,7 @@
"""InriaAerialImageLabeling datamodule."""
from typing import Any, Union
from typing import Any
import kornia.augmentation as K
@ -26,7 +26,7 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[tuple[int, int], int] = 64,
patch_size: tuple[int, int] | int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -3,7 +3,7 @@
"""L7 Irish datamodule."""
from typing import Any, Optional, Union
from typing import Any
import kornia.augmentation as K
import torch
@ -25,8 +25,8 @@ class L7IrishDataModule(GeoDataModule):
def __init__(
self,
batch_size: int = 1,
patch_size: Union[int, tuple[int, int]] = 224,
length: Optional[int] = None,
patch_size: int | tuple[int, int] = 224,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -3,7 +3,7 @@
"""L8 Biome datamodule."""
from typing import Any, Optional, Union
from typing import Any
import kornia.augmentation as K
import torch
@ -25,8 +25,8 @@ class L8BiomeDataModule(GeoDataModule):
def __init__(
self,
batch_size: int = 1,
patch_size: Union[int, tuple[int, int]] = 224,
length: Optional[int] = None,
patch_size: int | tuple[int, int] = 224,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -3,7 +3,7 @@
"""LEVIR-CD+ datamodule."""
from typing import Any, Union
from typing import Any
import kornia.augmentation as K
@ -25,7 +25,7 @@ class LEVIRCDDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
patch_size: Union[tuple[int, int], int] = 256,
patch_size: tuple[int, int] | int = 256,
num_workers: int = 0,
**kwargs: Any,
) -> None:
@ -70,7 +70,7 @@ class LEVIRCDPlusDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 8,
patch_size: Union[tuple[int, int], int] = 256,
patch_size: tuple[int, int] | int = 256,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,

Просмотреть файл

@ -3,7 +3,7 @@
"""National Agriculture Imagery Program (NAIP) datamodule."""
from typing import Any, Optional, Union
from typing import Any
import kornia.augmentation as K
from matplotlib.figure import Figure
@ -23,8 +23,8 @@ class NAIPChesapeakeDataModule(GeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 256,
length: Optional[int] = None,
patch_size: int | tuple[int, int] = 256,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -3,7 +3,7 @@
"""OSCD datamodule."""
from typing import Any, Union
from typing import Any
import kornia.augmentation as K
import torch
@ -60,7 +60,7 @@ class OSCDDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[tuple[int, int], int] = 64,
patch_size: tuple[int, int] | int = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,

Просмотреть файл

@ -3,7 +3,7 @@
"""Potsdam datamodule."""
from typing import Any, Union
from typing import Any
import kornia.augmentation as K
@ -26,7 +26,7 @@ class Potsdam2DDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[tuple[int, int], int] = 64,
patch_size: tuple[int, int] | int = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,

Просмотреть файл

@ -3,7 +3,7 @@
"""Sentinel-2 and CDL datamodule."""
from typing import Any, Optional, Union
from typing import Any
import kornia.augmentation as K
import torch
@ -26,8 +26,8 @@ class Sentinel2CDLDataModule(GeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 64,
length: Optional[int] = None,
patch_size: int | tuple[int, int] = 64,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -3,7 +3,7 @@
"""Sentinel-2 and NCCM datamodule."""
from typing import Any, Optional, Union
from typing import Any
import kornia.augmentation as K
import torch
@ -26,8 +26,8 @@ class Sentinel2NCCMDataModule(GeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 64,
length: Optional[int] = None,
patch_size: int | tuple[int, int] = 64,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -5,7 +5,7 @@
"""South America Soybean datamodule."""
from typing import Any, Optional, Union
from typing import Any
import kornia.augmentation as K
import torch
@ -28,8 +28,8 @@ class Sentinel2SouthAmericaSoybeanDataModule(GeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 64,
length: Optional[int] = None,
patch_size: int | tuple[int, int] = 64,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -3,7 +3,7 @@
"""SSL4EO datamodule."""
from typing import Any, Union
from typing import Any
import kornia.augmentation as K
from kornia.constants import DataKey, Resample
@ -23,7 +23,7 @@ class SSL4EOLBenchmarkDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 224,
patch_size: int | tuple[int, int] = 224,
num_workers: int = 0,
**kwargs: Any,
) -> None:

Просмотреть файл

@ -4,8 +4,8 @@
"""Common datamodule utilities."""
import math
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import numpy as np
import torch
@ -103,9 +103,9 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
def dataset_split(
dataset: Union[TensorDataset, NonGeoDataset],
dataset: TensorDataset | NonGeoDataset,
val_pct: float,
test_pct: Optional[float] = None,
test_pct: float | None = None,
) -> list[Subset[Any]]:
"""Split a torch Dataset into train/val/test sets.
@ -142,9 +142,9 @@ def dataset_split(
def group_shuffle_split(
groups: Iterable[Any],
train_size: Optional[float] = None,
test_size: Optional[float] = None,
random_state: Optional[int] = None,
train_size: float | None = None,
test_size: float | None = None,
random_state: int | None = None,
) -> tuple[list[int], list[int]]:
"""Method for performing a single group-wise shuffle split of data.

Просмотреть файл

@ -3,7 +3,7 @@
"""Vaihingen datamodule."""
from typing import Any, Union
from typing import Any
import kornia.augmentation as K
@ -26,7 +26,7 @@ class Vaihingen2DDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[tuple[int, int], int] = 64,
patch_size: tuple[int, int] | int = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,

Просмотреть файл

@ -3,7 +3,7 @@
"""NWPU VHR-10 datamodule."""
from typing import Any, Union
from typing import Any
import kornia.augmentation as K
import torch
@ -26,7 +26,7 @@ class VHR10DataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[tuple[int, int], int] = 512,
patch_size: tuple[int, int] | int = 512,
num_workers: int = 0,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,

Просмотреть файл

@ -5,7 +5,8 @@
import glob
import os
from typing import Callable, Optional, cast
from collections.abc import Callable
from typing import cast
import matplotlib.pyplot as plt
import numpy as np
@ -87,7 +88,7 @@ class ADVANCE(NonGeoDataset):
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -228,7 +229,7 @@ class ADVANCE(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import json
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -56,10 +56,10 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
cache: bool = True,
) -> None:
@ -121,7 +121,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import os
import re
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
import matplotlib.pyplot as plt
import torch
@ -113,11 +113,11 @@ class AgriFieldNet(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
classes: list[int] = list(cmap.keys()),
bands: Sequence[str] = all_bands,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
cache: bool = True,
) -> None:
"""Initialize a new AgriFieldNet dataset instance.
@ -218,7 +218,7 @@ class AgriFieldNet(RasterDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -3,7 +3,7 @@
"""Airphen dataset."""
from typing import Any, Optional
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -46,7 +46,7 @@ class Airphen(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -3,7 +3,8 @@
"""Aster Global Digital Elevation Model dataset."""
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -46,10 +47,10 @@ class AsterGDEM(RasterDataset):
def __init__(
self,
paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
paths: str | list[str] = "data",
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
) -> None:
"""Initialize a new Dataset instance.
@ -89,7 +90,7 @@ class AsterGDEM(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import json
import os
from collections.abc import Callable
from functools import lru_cache
from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
@ -182,9 +182,9 @@ class BeninSmallHolderCashews(NonGeoDataset):
chip_size: int = 256,
stride: int = 128,
bands: tuple[str, ...] = all_bands,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
verbose: bool = False,
) -> None:
@ -408,7 +408,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
return images and targets
def _download(self, api_key: Optional[str] = None) -> None:
def _download(self, api_key: str | None = None) -> None:
"""Download the dataset and extract it.
Args:
@ -434,7 +434,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
time_step: int = 0,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -6,7 +6,7 @@
import glob
import json
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -275,7 +275,7 @@ class BigEarthNet(NonGeoDataset):
split: str = "train",
bands: str = "all",
num_classes: int = 19,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -533,7 +533,7 @@ class BigEarthNet(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,7 +5,6 @@
import os
from collections.abc import Sequence
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
@ -219,7 +218,7 @@ class BioMassters(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,8 +4,8 @@
"""Canadian Building Footprints dataset."""
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -61,10 +61,10 @@ class CanadianBuildingFootprints(VectorDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float = 0.00001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -127,7 +127,7 @@ class CanadianBuildingFootprints(VectorDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,8 +4,8 @@
"""CDL dataset."""
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import matplotlib.pyplot as plt
import torch
@ -206,12 +206,12 @@ class CDL(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
years: list[int] = [2023],
classes: list[int] = list(cmap.keys()),
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -336,7 +336,7 @@ class CDL(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,7 @@
"""ChaBuD dataset."""
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -77,7 +77,7 @@ class ChaBuD(NonGeoDataset):
root: str = "data",
split: str = "train",
bands: list[str] = all_bands,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -235,7 +235,7 @@ class ChaBuD(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -6,8 +6,8 @@
import abc
import os
import sys
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
import fiona
import matplotlib.pyplot as plt
@ -90,10 +90,10 @@ class Chesapeake(RasterDataset, abc.ABC):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -170,7 +170,7 @@ class Chesapeake(RasterDataset, abc.ABC):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.
@ -512,7 +512,7 @@ class ChesapeakeCVPR(GeoDataset):
root: str = "data",
splits: Sequence[str] = ["de-train"],
layers: Sequence[str] = ["naip-new", "lc"],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -711,7 +711,7 @@ class ChesapeakeCVPR(GeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import json
import os
from collections.abc import Sequence
from typing import Any, Callable, Optional
from collections.abc import Callable, Sequence
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
@ -111,9 +111,9 @@ class CloudCoverDetection(NonGeoDataset):
root: str = "data",
split: str = "train",
bands: Sequence[str] = band_names,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initiatlize a new Cloud Cover Detection Dataset instance.
@ -329,7 +329,7 @@ class CloudCoverDetection(NonGeoDataset):
return images and targets
def _download(self, api_key: Optional[str] = None) -> None:
def _download(self, api_key: str | None = None) -> None:
"""Download the dataset and extract it.
Args:
@ -355,7 +355,7 @@ class CloudCoverDetection(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,8 @@
"""CMS Global Mangrove Canopy dataset."""
import os
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -167,12 +168,12 @@ class CMSGlobalMangroveCanopy(RasterDataset):
def __init__(
self,
paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
paths: str | list[str] = "data",
crs: CRS | None = None,
res: float | None = None,
measurement: str = "agb",
country: str = all_countries[0],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
checksum: bool = False,
) -> None:
@ -250,7 +251,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -6,7 +6,8 @@
import abc
import csv
import os
from typing import Callable, Optional, cast
from collections.abc import Callable
from typing import cast
import matplotlib.pyplot as plt
import numpy as np
@ -65,7 +66,7 @@ class COWC(NonGeoDataset, abc.ABC):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -192,7 +193,7 @@ class COWC(NonGeoDataset, abc.ABC):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -6,7 +6,7 @@
import glob
import json
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -96,7 +96,7 @@ class CropHarvest(NonGeoDataset):
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -293,7 +293,7 @@ class CropHarvest(NonGeoDataset):
features_path = os.path.join(self.root, self.file_dict["features"]["filename"])
extract_archive(features_path)
def plot(self, sample: dict[str, Tensor], subtitle: Optional[str] = None) -> Figure:
def plot(self, sample: dict[str, Tensor], subtitle: str | None = None) -> Figure:
"""Plot a sample from the dataset using bands for Agriculture RGB composite.
Args:

Просмотреть файл

@ -5,8 +5,8 @@
import csv
import os
from collections.abc import Callable
from functools import lru_cache
from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
@ -125,9 +125,9 @@ class CV4AKenyaCropType(NonGeoDataset):
chip_size: int = 256,
stride: int = 128,
bands: tuple[str, ...] = band_names,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
verbose: bool = False,
) -> None:
@ -388,7 +388,7 @@ class CV4AKenyaCropType(NonGeoDataset):
return train_field_ids, test_field_ids
def _download(self, api_key: Optional[str] = None) -> None:
def _download(self, api_key: str | None = None) -> None:
"""Download the dataset and extract it.
Args:
@ -411,7 +411,7 @@ class CV4AKenyaCropType(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
time_step: int = 0,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,9 @@
import json
import os
from collections.abc import Callable
from functools import lru_cache
from typing import Any, Callable, Optional
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
@ -73,9 +74,9 @@ class TropicalCyclone(NonGeoDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new Tropical Cyclone Wind Estimation Competition Dataset.
@ -204,7 +205,7 @@ class TropicalCyclone(NonGeoDataset):
return False
return True
def _download(self, api_key: Optional[str] = None) -> None:
def _download(self, api_key: str | None = None) -> None:
"""Download the dataset and extract it.
Args:
@ -227,7 +228,7 @@ class TropicalCyclone(NonGeoDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,7 @@
"""DeepGlobe Land Cover Classification Challenge dataset."""
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -102,7 +102,7 @@ class DeepGlobeLandCover(NonGeoDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new DeepGlobeLandCover dataset instance.
@ -229,7 +229,7 @@ class DeepGlobeLandCover(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
alpha: float = 0.5,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,7 @@
import glob
import os
from collections.abc import Sequence
from typing import Callable, Optional
from collections.abc import Callable, Sequence
import matplotlib.pyplot as plt
import numpy as np
@ -144,7 +143,7 @@ class DFC2022(NonGeoDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new DFC2022 dataset instance.
@ -229,7 +228,7 @@ class DFC2022(NonGeoDataset):
return files
def _load_image(self, path: str, shape: Optional[Sequence[int]] = None) -> Tensor:
def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor:
"""Load a single image.
Args:
@ -296,7 +295,7 @@ class DFC2022(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import os
import sys
from collections.abc import Sequence
from typing import Any, Callable, Optional, cast
from collections.abc import Callable, Sequence
from typing import Any, cast
import fiona
import matplotlib.pyplot as plt
@ -255,7 +255,7 @@ class EnviroAtlas(GeoDataset):
root: str = "data",
splits: Sequence[str] = ["pittsburgh_pa-2010_1m-train"],
layers: Sequence[str] = ["naip", "prior"],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
prior_as_input: bool = False,
cache: bool = True,
download: bool = False,
@ -445,7 +445,7 @@ class EnviroAtlas(GeoDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import glob
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -68,10 +68,10 @@ class Esri2020(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -138,7 +138,7 @@ class Esri2020(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,7 +5,7 @@
import glob
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -82,7 +82,7 @@ class ETCI2021(NonGeoDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -255,7 +255,7 @@ class ETCI2021(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import glob
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -83,10 +83,10 @@ class EUDEM(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
checksum: bool = False,
) -> None:
@ -140,7 +140,7 @@ class EUDEM(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import csv
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import fiona
import matplotlib.pyplot as plt
@ -88,11 +88,11 @@ class EuroCrops(VectorDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
paths: str | Iterable[str] = "data",
crs: CRS = CRS.from_epsg(4326),
res: float = 0.00001,
classes: Optional[list[str]] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
classes: list[str] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -166,7 +166,7 @@ class EuroCrops(VectorDataset):
self.base_url + fname, self.paths, md5=md5 if self.checksum else None
)
def _load_class_map(self, classes: Optional[list[str]]) -> None:
def _load_class_map(self, classes: list[str] | None) -> None:
"""Load map from HCAT class codes to class indices.
If classes is provided, then we simply use those codes. Otherwise, we load
@ -221,7 +221,7 @@ class EuroCrops(VectorDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,8 +4,8 @@
"""EuroSAT dataset."""
import os
from collections.abc import Sequence
from typing import Callable, Optional, cast
from collections.abc import Callable, Sequence
from typing import cast
import matplotlib.pyplot as plt
import numpy as np
@ -106,7 +106,7 @@ class EuroSAT(NonGeoClassificationDataset):
root: str = "data",
split: str = "train",
bands: Sequence[str] = BAND_SETS["all"],
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -247,7 +247,7 @@ class EuroSAT(NonGeoClassificationDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,7 +5,8 @@
import glob
import os
from typing import Any, Callable, Optional, cast
from collections.abc import Callable
from typing import Any, cast
from xml.etree.ElementTree import Element, parse
import matplotlib.patches as patches
@ -230,7 +231,7 @@ class FAIR1M(NonGeoDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -382,7 +383,7 @@ class FAIR1M(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,8 @@
"""FireRisk dataset."""
import os
from typing import Callable, Optional, cast
from collections.abc import Callable
from typing import cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -68,7 +69,7 @@ class FireRisk(NonGeoClassificationDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -136,7 +137,7 @@ class FireRisk(NonGeoClassificationDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,7 +5,8 @@
import glob
import os
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any
from xml.etree import ElementTree
import matplotlib.patches as patches
@ -110,7 +111,7 @@ class ForestDamage(NonGeoDataset):
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -259,7 +260,7 @@ class ForestDamage(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -10,8 +10,8 @@ import os
import re
import sys
import warnings
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
import fiona
import fiona.transform
@ -83,7 +83,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
dataset = landsat7 | landsat8
"""
paths: Union[str, Iterable[str]]
paths: str | Iterable[str]
_crs = CRS.from_epsg(4326)
_res = 0.0
@ -108,7 +108,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
__add__ = None # type: ignore[assignment]
def __init__(
self, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
self, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None
) -> None:
"""Initialize a new GeoDataset instance.
@ -190,9 +190,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
# NOTE: This hack should be removed once the following issue is fixed:
# https://github.com/Toblerity/rtree/issues/87
def __getstate__(
self,
) -> tuple[dict[str, Any], list[tuple[Any, Any, Optional[Any]]]]:
def __getstate__(self) -> tuple[dict[str, Any], list[tuple[Any, Any, Any | None]]]:
"""Define how instances are pickled.
Returns:
@ -388,11 +386,11 @@ class RasterDataset(GeoDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
bands: Sequence[str] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
) -> None:
"""Initialize a new RasterDataset instance.
@ -539,7 +537,7 @@ class RasterDataset(GeoDataset):
self,
filepaths: Sequence[str],
query: BoundingBox,
band_indexes: Optional[Sequence[int]] = None,
band_indexes: Sequence[int] | None = None,
) -> Tensor:
"""Load and merge one or more files.
@ -612,11 +610,11 @@ class VectorDataset(GeoDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float = 0.0001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
label_name: Optional[str] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
label_name: str | None = None,
) -> None:
"""Initialize a new VectorDataset instance.
@ -809,9 +807,9 @@ class NonGeoClassificationDataset(NonGeoDataset, ImageFolder): # type: ignore[m
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
loader: Optional[Callable[[str], Any]] = pil_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
loader: Callable[[str], Any] | None = pil_loader,
is_valid_file: Callable[[str], bool] | None = None,
) -> None:
"""Initialize a new NonGeoClassificationDataset instance.
@ -909,7 +907,7 @@ class IntersectionDataset(GeoDataset):
collate_fn: Callable[
[Sequence[dict[str, Any]]], dict[str, Any]
] = concat_samples,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
) -> None:
"""Initialize a new IntersectionDataset instance.
@ -1067,7 +1065,7 @@ class UnionDataset(GeoDataset):
collate_fn: Callable[
[Sequence[dict[str, Any]]], dict[str, Any]
] = merge_samples,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
) -> None:
"""Initialize a new UnionDataset instance.

Просмотреть файл

@ -5,7 +5,7 @@
import glob
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -89,7 +89,7 @@ class GID15(NonGeoDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -234,7 +234,7 @@ class GID15(NonGeoDataset):
md5=self.md5 if self.checksum else None,
)
def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -5,8 +5,8 @@
import glob
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union, cast
from collections.abc import Callable, Iterable
from typing import Any, cast
import matplotlib.pyplot as plt
import torch
@ -119,11 +119,11 @@ class GlobBiomass(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
measurement: str = "agb",
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
checksum: bool = False,
) -> None:
@ -225,7 +225,7 @@ class GlobBiomass(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,7 +5,8 @@
import glob
import os
from typing import Any, Callable, Optional, cast, overload
from collections.abc import Callable
from typing import Any, cast, overload
import fiona
import matplotlib.pyplot as plt
@ -148,7 +149,7 @@ class IDTReeS(NonGeoDataset):
root: str = "data",
split: str = "train",
task: str = "task1",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -333,7 +334,7 @@ class IDTReeS(NonGeoDataset):
def _load(
self, root: str
) -> tuple[list[str], Optional[dict[int, dict[str, Any]]], Any]:
) -> tuple[list[str], dict[int, dict[str, Any]] | None, Any]:
"""Load files, geometries, and labels.
Args:
@ -419,8 +420,8 @@ class IDTReeS(NonGeoDataset):
image_size: tuple[int, int],
min_size: int,
boxes: Tensor,
labels: Optional[Tensor],
) -> tuple[Tensor, Optional[Tensor]]:
labels: Tensor | None,
) -> tuple[Tensor, Tensor | None]:
"""Clip boxes to image size and filter boxes with sides less than ``min_size``.
Args:
@ -477,7 +478,7 @@ class IDTReeS(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
hsi_indices: tuple[int, int, int] = (0, 1, 2),
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -6,7 +6,8 @@
import glob
import os
import re
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
@ -64,7 +65,7 @@ class InriaAerialImageLabeling(NonGeoDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new InriaAerialImageLabeling Dataset instance.
@ -200,7 +201,7 @@ class InriaAerialImageLabeling(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import glob
import os
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -97,11 +97,11 @@ class L7Irish(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = CRS.from_epsg(3857),
res: Optional[float] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = CRS.from_epsg(3857),
res: float | None = None,
bands: Sequence[str] = all_bands,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -221,7 +221,7 @@ class L7Irish(RasterDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import glob
import os
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -96,11 +96,11 @@ class L8Biome(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]],
crs: Optional[CRS] = CRS.from_epsg(3857),
res: Optional[float] = None,
paths: str | Iterable[str],
crs: CRS | None = CRS.from_epsg(3857),
res: float | None = None,
bands: Sequence[str] = all_bands,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -217,7 +217,7 @@ class L8Biome(RasterDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -6,8 +6,9 @@ import abc
import glob
import hashlib
import os
from collections.abc import Callable
from functools import lru_cache
from typing import Any, Callable, Optional, cast
from typing import Any, cast
import matplotlib.pyplot as plt
import numpy as np
@ -152,7 +153,7 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.
@ -209,9 +210,9 @@ class LandCoverAIGeo(LandCoverAIBase, RasterDataset):
def __init__(
self,
root: str = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -299,7 +300,7 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:

Просмотреть файл

@ -4,8 +4,8 @@
"""Landsat datasets."""
import abc
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable, Sequence
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -59,11 +59,11 @@ class Landsat(RasterDataset, abc.ABC):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
bands: Sequence[str] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
) -> None:
"""Initialize a new Dataset instance.
@ -94,7 +94,7 @@ class Landsat(RasterDataset, abc.ABC):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -6,7 +6,7 @@
import abc
import glob
import os
from typing import Callable, Optional, Union
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -29,14 +29,14 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC):
.. versionadded:: 0.6
"""
splits: Union[list[str], dict[str, dict[str, str]]]
splits: list[str] | dict[str, dict[str, str]]
directories = ["A", "B", "label"]
def __init__(
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -135,7 +135,7 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,7 +5,7 @@
import glob
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -93,7 +93,7 @@ class LoveDA(NonGeoDataset):
root: str = "data",
split: str = "train",
scene: list[str] = ["urban", "rural"],
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -256,7 +256,7 @@ class LoveDA(NonGeoDataset):
md5=self.md5 if self.checksum else None,
)
def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -6,7 +6,7 @@
import os
import shutil
from collections import defaultdict
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -111,7 +111,7 @@ class MapInWild(NonGeoDataset):
root: str = "data",
modality: list[str] = ["mask", "esa_wc", "viirs", "s2_summer"],
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -223,7 +223,7 @@ class MapInWild(NonGeoDataset):
tensor = torch.from_numpy(array).float()
return tensor
def _verify(self, url: str, md5: Optional[str] = None) -> None:
def _verify(self, url: str, md5: str | None = None) -> None:
"""Verify the integrity of the dataset.
Args:
@ -258,7 +258,7 @@ class MapInWild(NonGeoDataset):
if not url.endswith(".csv"):
self._extract(url)
def _download(self, url: str, md5: Optional[str]) -> None:
def _download(self, url: str, md5: str | None) -> None:
"""Downloads a modality.
Args:
@ -330,7 +330,7 @@ class MapInWild(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,8 @@
"""Million-AID dataset."""
import glob
import os
from typing import Any, Callable, Optional, cast
from collections.abc import Callable
from typing import Any, cast
import matplotlib.pyplot as plt
import numpy as np
@ -191,7 +192,7 @@ class MillionAID(NonGeoDataset):
root: str = "data",
task: str = "multi-class",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new MillionAID dataset instance.
@ -332,7 +333,7 @@ class MillionAID(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -3,7 +3,7 @@
"""National Agriculture Imagery Program (NAIP) dataset."""
from typing import Any, Optional
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -52,7 +52,7 @@ class NAIP(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,7 @@
"""NASA Marine Debris dataset."""
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -66,9 +66,9 @@ class NASAMarineDebris(NonGeoDataset):
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
verbose: bool = False,
) -> None:
@ -224,7 +224,7 @@ class NASAMarineDebris(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -3,8 +3,8 @@
"""Northeastern China Crop Map Dataset."""
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import matplotlib.pyplot as plt
import torch
@ -82,11 +82,11 @@ class NCCM(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
years: list[int] = [2019],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -170,7 +170,7 @@ class NCCM(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import glob
import os
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import matplotlib.pyplot as plt
import torch
@ -107,12 +107,12 @@ class NLCD(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
years: list[int] = [2019],
classes: list[int] = list(cmap.keys()),
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -230,7 +230,7 @@ class NLCD(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -7,8 +7,8 @@ import glob
import json
import os
import sys
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union, cast
from collections.abc import Callable, Iterable
from typing import Any, cast
import fiona
import fiona.transform
@ -206,10 +206,10 @@ class OpenBuildings(VectorDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float = 0.0001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new Dataset instance.
@ -413,7 +413,7 @@ class OpenBuildings(VectorDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,7 @@
import glob
import os
from collections.abc import Sequence
from typing import Callable, Optional, Union
from collections.abc import Callable, Sequence
import matplotlib.pyplot as plt
import numpy as np
@ -103,7 +102,7 @@ class OSCD(NonGeoDataset):
root: str = "data",
split: str = "train",
bands: Sequence[str] = all_bands,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -164,7 +163,7 @@ class OSCD(NonGeoDataset):
"""
return len(self.files)
def _load_files(self) -> list[dict[str, Union[str, Sequence[str]]]]:
def _load_files(self) -> list[dict[str, str | Sequence[str]]]:
regions = []
labels_root = os.path.join(
self.root,
@ -284,7 +283,7 @@ class OSCD(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
alpha: float = 0.5,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,8 +4,7 @@
"""PASTIS dataset."""
import os
from collections.abc import Sequence
from typing import Callable, Optional
from collections.abc import Callable, Sequence
import fiona
import matplotlib.pyplot as plt
@ -132,7 +131,7 @@ class PASTIS(NonGeoDataset):
folds: Sequence[int] = (1, 2, 3, 4, 5),
bands: str = "s2",
mode: str = "semantic",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -347,7 +346,7 @@ class PASTIS(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,8 @@
"""PatternNet dataset."""
import os
from typing import Callable, Optional, cast
from collections.abc import Callable
from typing import cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -84,7 +85,7 @@ class PatternNet(NonGeoClassificationDataset):
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -145,7 +146,7 @@ class PatternNet(NonGeoClassificationDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,7 @@
"""Potsdam dataset."""
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -123,7 +123,7 @@ class Potsdam2D(NonGeoDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new Potsdam dataset instance.
@ -240,7 +240,7 @@ class Potsdam2D(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
alpha: float = 0.5,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -3,7 +3,7 @@
"""PRISMA datasets."""
from typing import Any, Optional
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -81,7 +81,7 @@ class PRISMA(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,7 +5,7 @@
import glob
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.patches as patches
import matplotlib.pyplot as plt
@ -69,7 +69,7 @@ class ReforesTree(NonGeoDataset):
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -209,7 +209,7 @@ class ReforesTree(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,8 @@
"""RESISC45 dataset."""
import os
from typing import Callable, Optional, cast
from collections.abc import Callable
from typing import cast
import matplotlib.pyplot as plt
import numpy as np
@ -112,7 +113,7 @@ class RESISC45(NonGeoClassificationDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -193,7 +194,7 @@ class RESISC45(NonGeoClassificationDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,8 +4,7 @@
"""Rwanda Field Boundary Competition dataset."""
import os
from collections.abc import Sequence
from typing import Callable, Optional
from collections.abc import Callable, Sequence
import matplotlib.pyplot as plt
import numpy as np
@ -91,9 +90,9 @@ class RwandaFieldBoundary(NonGeoDataset):
root: str = "data",
split: str = "train",
bands: Sequence[str] = all_bands,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new RwandaFieldBoundary instance.
@ -258,7 +257,7 @@ class RwandaFieldBoundary(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
time_step: int = 0,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -6,7 +6,6 @@
import os
import random
from collections.abc import Callable, Collection, Iterable
from typing import Optional
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
@ -219,7 +218,7 @@ class SeasoNet(NonGeoDataset):
bands: Iterable[str] = all_bands,
grids: Iterable[int] = [1, 2],
concat_seasons: int = 1,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -405,7 +404,7 @@ class SeasoNet(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
show_legend: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,7 +5,7 @@
import os
import random
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -79,7 +79,7 @@ class SeasonalContrastS2(NonGeoDataset):
version: str = "100k",
seasons: int = 1,
bands: list[str] = rgb_bands,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -231,7 +231,7 @@ class SeasonalContrastS2(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,8 +4,7 @@
"""SEN12MS dataset."""
import os
from collections.abc import Sequence
from typing import Callable, Optional
from collections.abc import Callable, Sequence
import matplotlib.pyplot as plt
import numpy as np
@ -173,7 +172,7 @@ class SEN12MS(NonGeoDataset):
root: str = "data",
split: str = "train",
bands: Sequence[str] = BAND_SETS["all"],
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new SEN12MS dataset instance.
@ -320,7 +319,7 @@ class SEN12MS(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -3,8 +3,8 @@
"""Sentinel datasets."""
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable, Sequence
from typing import Any
import matplotlib.pyplot as plt
import torch
@ -141,11 +141,11 @@ class Sentinel1(Sentinel):
def __init__(
self,
paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
paths: str | list[str] = "data",
crs: CRS | None = None,
res: float = 10,
bands: Sequence[str] = ["VV", "VH"],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
) -> None:
"""Initialize a new Dataset instance.
@ -194,7 +194,7 @@ To create a dataset containing both, use:
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.
@ -297,11 +297,11 @@ class Sentinel2(Sentinel):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float = 10,
bands: Optional[Sequence[str]] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
bands: Sequence[str] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
) -> None:
"""Initialize a new Dataset instance.
@ -333,7 +333,7 @@ class Sentinel2(Sentinel):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,8 @@
"""SKy Images and Photovoltaic Power Dataset (SKIPP'D)."""
import os
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
@ -74,7 +75,7 @@ class SKIPPD(NonGeoDataset):
root: str = "data",
split: str = "trainval",
task: str = "nowcast",
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -133,7 +134,7 @@ class SKIPPD(NonGeoDataset):
return num_datapoints
def __getitem__(self, index: int) -> dict[str, Union[str, Tensor]]:
def __getitem__(self, index: int) -> dict[str, str | Tensor]:
"""Return an index within the dataset.
Args:
@ -142,7 +143,7 @@ class SKIPPD(NonGeoDataset):
Returns:
data and label at that index
"""
sample: dict[str, Union[str, Tensor]] = {"image": self._load_image(index)}
sample: dict[str, str | Tensor] = {"image": self._load_image(index)}
sample.update(self._load_features(index))
if self.transforms is not None:
@ -176,7 +177,7 @@ class SKIPPD(NonGeoDataset):
tensor = torch.from_numpy(arr).to(torch.float32)
return tensor
def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]:
def _load_features(self, index: int) -> dict[str, str | Tensor]:
"""Load label.
Args:
@ -195,7 +196,7 @@ class SKIPPD(NonGeoDataset):
path = os.path.join(self.root, f"times_{self.split}_{self.task}.npy")
datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat)
features: dict[str, Union[str, Tensor]] = {
features: dict[str, str | Tensor] = {
"label": torch.tensor(label, dtype=torch.float32),
"date": datestring,
}
@ -241,7 +242,7 @@ class SKIPPD(NonGeoDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,8 +4,8 @@
"""So2Sat dataset."""
import os
from collections.abc import Sequence
from typing import Callable, Optional, cast
from collections.abc import Callable, Sequence
from typing import cast
import matplotlib.pyplot as plt
import numpy as np
@ -196,7 +196,7 @@ class So2Sat(NonGeoDataset):
version: str = "2",
split: str = "train",
bands: Sequence[str] = BAND_SETS["all"],
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new So2Sat dataset instance.
@ -340,7 +340,7 @@ class So2Sat(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,8 +5,8 @@
import os
import re
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union, cast
from collections.abc import Callable, Iterable
from typing import Any, cast
import matplotlib.pyplot as plt
import torch
@ -98,11 +98,11 @@ class SouthAfricaCropType(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
classes: list[int] = list(cmap.keys()),
bands: list[str] = all_bands,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
) -> None:
"""Initialize a new South Africa Crop Type dataset instance.
@ -224,7 +224,7 @@ class SouthAfricaCropType(RasterDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -3,8 +3,8 @@
"""South America Soybean Dataset."""
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -71,11 +71,11 @@ class SouthAmericaSoybean(RasterDataset):
def __init__(
self,
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
paths: str | Iterable[str] = "data",
crs: CRS | None = None,
res: float | None = None,
years: list[int] = [2021],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
checksum: bool = False,
@ -133,7 +133,7 @@ class SouthAmericaSoybean(RasterDataset):
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -9,7 +9,8 @@ import glob
import math
import os
import re
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any
import fiona
import matplotlib.pyplot as plt
@ -81,9 +82,9 @@ class SpaceNet(NonGeoDataset, abc.ABC):
root: str,
image: str,
collections: list[str] = [],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet Dataset instance.
@ -274,7 +275,7 @@ class SpaceNet(NonGeoDataset, abc.ABC):
return to_be_downloaded
def _download(self, collections: list[str], api_key: Optional[str] = None) -> None:
def _download(self, collections: list[str], api_key: str | None = None) -> None:
"""Download the dataset and extract it.
Args:
@ -299,7 +300,7 @@ class SpaceNet(NonGeoDataset, abc.ABC):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.
@ -398,9 +399,9 @@ class SpaceNet1(SpaceNet):
self,
root: str = "data",
image: str = "rgb",
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet 1 Dataset instance.
@ -514,9 +515,9 @@ class SpaceNet2(SpaceNet):
root: str = "data",
image: str = "PS-RGB",
collections: list[str] = [],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet 2 Dataset instance.
@ -633,11 +634,11 @@ class SpaceNet3(SpaceNet):
self,
root: str = "data",
image: str = "PS-RGB",
speed_mask: Optional[bool] = False,
speed_mask: bool | None = False,
collections: list[str] = [],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet 3 Dataset instance.
@ -733,7 +734,7 @@ class SpaceNet3(SpaceNet):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.
@ -884,9 +885,9 @@ class SpaceNet4(SpaceNet):
root: str = "data",
image: str = "PS-RGBNIR",
angles: list[str] = [],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet 4 Dataset instance.
@ -1051,11 +1052,11 @@ class SpaceNet5(SpaceNet3):
self,
root: str = "data",
image: str = "PS-RGB",
speed_mask: Optional[bool] = False,
speed_mask: bool | None = False,
collections: list[str] = [],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet 5 Dataset instance.
@ -1183,9 +1184,9 @@ class SpaceNet6(SpaceNet):
self,
root: str = "data",
image: str = "PS-RGB",
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
) -> None:
"""Initialize a new SpaceNet 6 Dataset instance.
@ -1212,7 +1213,7 @@ class SpaceNet6(SpaceNet):
self.files = self._load_files(os.path.join(root, self.dataset_id))
def __download(self, api_key: Optional[str] = None) -> None:
def __download(self, api_key: str | None = None) -> None:
"""Download the dataset and extract it.
Args:
@ -1281,9 +1282,9 @@ class SpaceNet7(SpaceNet):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: Optional[str] = None,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet 7 Dataset instance.

Просмотреть файл

@ -7,7 +7,7 @@ from collections.abc import Sequence
from copy import deepcopy
from itertools import accumulate
from math import floor, isclose
from typing import Optional, Union, cast
from typing import cast
from rtree.index import Index, Property
from torch import Generator, default_generator, randint, randperm
@ -50,7 +50,7 @@ def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[in
def random_bbox_assignment(
dataset: GeoDataset,
lengths: Sequence[float],
generator: Optional[Generator] = default_generator,
generator: Generator | None = default_generator,
) -> list[GeoDataset]:
"""Split a GeoDataset randomly assigning its index's BoundingBoxes.
@ -104,7 +104,7 @@ def random_bbox_assignment(
def random_bbox_splitting(
dataset: GeoDataset,
fractions: Sequence[float],
generator: Optional[Generator] = default_generator,
generator: Generator | None = default_generator,
) -> list[GeoDataset]:
"""Split a GeoDataset randomly splitting its index's BoundingBoxes.
@ -172,7 +172,7 @@ def random_grid_cell_assignment(
dataset: GeoDataset,
fractions: Sequence[float],
grid_size: int = 6,
generator: Optional[Generator] = default_generator,
generator: Generator | None = default_generator,
) -> list[GeoDataset]:
"""Overlays a grid over a GeoDataset and randomly assigns cells to new GeoDatasets.
@ -289,7 +289,7 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> list[GeoDatas
def time_series_split(
dataset: GeoDataset, lengths: Sequence[Union[float, tuple[float, float]]]
dataset: GeoDataset, lengths: Sequence[float | tuple[float, float]]
) -> list[GeoDataset]:
"""Split a GeoDataset on its time dimension to create non-overlapping GeoDatasets.

Просмотреть файл

@ -6,7 +6,8 @@
import glob
import os
import random
from typing import Callable, Optional, TypedDict
from collections.abc import Callable
from typing import TypedDict
import matplotlib.pyplot as plt
import numpy as np
@ -163,7 +164,7 @@ class SSL4EOL(NonGeoDataset):
root: str = "data",
split: str = "oli_sr",
seasons: int = 1,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -285,7 +286,7 @@ class SSL4EOL(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.
@ -405,7 +406,7 @@ class SSL4EOS12(NonGeoDataset):
root: str = "data",
split: str = "s2c",
seasons: int = 1,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new SSL4EOS12 instance.
@ -500,7 +501,7 @@ class SSL4EOS12(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -5,7 +5,7 @@
import glob
import os
from typing import Callable, Optional
from collections.abc import Callable
import matplotlib.pyplot as plt
import numpy as np
@ -110,8 +110,8 @@ class SSL4EOLBenchmark(NonGeoDataset):
sensor: str = "oli_sr",
product: str = "cdl",
split: str = "train",
classes: Optional[list[int]] = None,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
classes: list[int] | None = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -327,7 +327,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -4,7 +4,8 @@
"""SustainBench Crop Yield dataset."""
import os
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
@ -60,7 +61,7 @@ class SustainBenchCropYield(NonGeoDataset):
root: str = "data",
split: str = "train",
countries: list[str] = ["usa"],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -195,7 +196,7 @@ class SustainBenchCropYield(NonGeoDataset):
sample: dict[str, Any],
band_idx: int = 0,
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Просмотреть файл

@ -3,7 +3,8 @@
"""UC Merced dataset."""
import os
from typing import Callable, Optional, cast
from collections.abc import Callable
from typing import cast
import matplotlib.pyplot as plt
import numpy as np
@ -85,7 +86,7 @@ class UCMerced(NonGeoClassificationDataset):
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
@ -190,7 +191,7 @@ class UCMerced(NonGeoClassificationDataset):
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше