зеркало из https://github.com/microsoft/torchgeo.git
Drop Python 3.9 support (#1966)
This commit is contained in:
Родитель
26cab1060b
Коммит
ea57469b0b
|
@ -158,7 +158,7 @@ jobs:
|
|||
- name: List pip dependencies
|
||||
run: pip list
|
||||
- name: Run pyupgrade checks
|
||||
run: pyupgrade --py39-plus $(find . -path ./docs/src -prune -o -name "*.py" -print)
|
||||
run: pyupgrade --py310-plus $(find . -path ./docs/src -prune -o -name "*.py" -print)
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
|
|
@ -17,7 +17,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
python-version: ['3.10', '3.11', '3.12']
|
||||
steps:
|
||||
- name: Clone repo
|
||||
uses: actions/checkout@v4.1.2
|
||||
|
@ -71,7 +71,7 @@ jobs:
|
|||
- name: Set up python
|
||||
uses: actions/setup-python@v5.1.0
|
||||
with:
|
||||
python-version: '3.9'
|
||||
python-version: '3.10'
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4.0.2
|
||||
id: cache
|
||||
|
|
|
@ -3,7 +3,7 @@ repos:
|
|||
rev: v3.15.2
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py39-plus]
|
||||
args: [--py310-plus]
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
|
|
|
@ -103,7 +103,7 @@ All of these tools should be used from the root of the project to ensure that ou
|
|||
|
||||
$ black .
|
||||
$ isort .
|
||||
$ pyupgrade --py39-plus $(find . -name "*.py")
|
||||
$ pyupgrade --py310-plus $(find . -name "*.py")
|
||||
|
||||
|
||||
Flake8, pydocstyle, and mypy won't format your code for you, but they will warn you about potential issues with your code or docstrings:
|
||||
|
|
|
@ -56,7 +56,7 @@ import warnings
|
|||
from collections import defaultdict
|
||||
from datetime import date, timedelta
|
||||
from multiprocessing.dummy import Lock, Pool
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import ee
|
||||
import numpy as np
|
||||
|
@ -168,7 +168,7 @@ def get_patch(
|
|||
new_resolutions: list[int],
|
||||
dtype: str = "float32",
|
||||
meta_cloud_name: str = "CLOUD_COVER",
|
||||
default_value: Optional[float] = None,
|
||||
default_value: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
image = collection.sort(meta_cloud_name).first()
|
||||
region = ee.Geometry.Point(center_coord).buffer(radius).bounds()
|
||||
|
@ -214,7 +214,7 @@ def get_random_patches_match(
|
|||
new_resolutions: list[int],
|
||||
dtype: str,
|
||||
meta_cloud_name: str,
|
||||
default_value: Optional[float],
|
||||
default_value: float | None,
|
||||
dates: list[date],
|
||||
radius: float,
|
||||
debug: bool = False,
|
||||
|
|
|
@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"
|
|||
name = "torchgeo"
|
||||
description = "TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.10"
|
||||
license = {file = "LICENSE"}
|
||||
authors = [
|
||||
{name = "Adam J. Stewart", email = "ajstewart426@gmail.com"},
|
||||
|
@ -29,7 +29,6 @@ classifiers = [
|
|||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
|
@ -39,9 +38,8 @@ classifiers = [
|
|||
dependencies = [
|
||||
# einops 0.3+ required for einops.repeat
|
||||
"einops>=0.3",
|
||||
# fiona 1.8.19+ required to fix erroneous warning
|
||||
# https://github.com/Toblerity/Fiona/issues/986
|
||||
"fiona>=1.8.19",
|
||||
# fiona 1.8.21+ required for Python 3.10 wheels
|
||||
"fiona>=1.8.21",
|
||||
# kornia 0.6.9+ required for kornia.augmentation.RandomBrightness
|
||||
"kornia>=0.6.9",
|
||||
# lightly 1.4.4+ required for MoCo v3 support
|
||||
|
@ -50,24 +48,24 @@ dependencies = [
|
|||
"lightly>=1.4.4,!=1.4.26",
|
||||
# lightning 2+ required for LightningCLI args + sys.argv support
|
||||
"lightning[pytorch-extra]>=2",
|
||||
# matplotlib 3.3.3+ required for Python 3.9 wheels
|
||||
"matplotlib>=3.3.3",
|
||||
# numpy 1.19.3+ required by Python 3.9 wheels
|
||||
"numpy>=1.19.3",
|
||||
# pandas 1.1.3+ required for Python 3.9 wheels
|
||||
"pandas>=1.1.3",
|
||||
# pillow 8+ required for Python 3.9 wheels
|
||||
"pillow>=8",
|
||||
# pyproj 3+ required for Python 3.9 wheels
|
||||
"pyproj>=3",
|
||||
# rasterio 1.2+ required for Python 3.9 wheels
|
||||
"rasterio>=1.2",
|
||||
# rtree 1+ required for len(index), index & index, index | index
|
||||
# matplotlib 3.5+ required for Python 3.10 wheels
|
||||
"matplotlib>=3.5",
|
||||
# numpy 1.21.2+ required by Python 3.10 wheels
|
||||
"numpy>=1.21.2",
|
||||
# pandas 1.3.3+ required for Python 3.10 wheels
|
||||
"pandas>=1.3.3",
|
||||
# pillow 8.4+ required for Python 3.10 wheels
|
||||
"pillow>=8.4",
|
||||
# pyproj 3.3+ required for Python 3.10 wheels
|
||||
"pyproj>=3.3",
|
||||
# rasterio 1.3+ required for Python 3.10 wheels
|
||||
"rasterio>=1.3",
|
||||
# rtree 1+ required for Python 3.10 wheels
|
||||
"rtree>=1",
|
||||
# segmentation-models-pytorch 0.2+ required for smp.losses module
|
||||
"segmentation-models-pytorch>=0.2",
|
||||
# shapely 1.7.1+ required for Python 3.9 wheels
|
||||
"shapely>=1.7.1",
|
||||
# shapely 1.8+ required for Python 3.10 wheels
|
||||
"shapely>=1.8",
|
||||
# timm 0.4.12 required by segmentation-models-pytorch
|
||||
"timm>=0.4.12",
|
||||
# torch 1.13+ required by torchvision
|
||||
|
@ -81,27 +79,25 @@ dynamic = ["version"]
|
|||
|
||||
[project.optional-dependencies]
|
||||
datasets = [
|
||||
# h5py 3+ required for Python 3.9 wheels
|
||||
"h5py>=3",
|
||||
# h5py 3.6+ required for Python 3.10 wheels
|
||||
"h5py>=3.6",
|
||||
# laspy 2+ required for laspy.read
|
||||
"laspy>=2",
|
||||
# opencv-python 4.4.0.46+ required for Python 3.9 wheels
|
||||
"opencv-python>=4.4.0.46",
|
||||
# pycocotools 2.0.5+ required for cython 3+ support
|
||||
"pycocotools>=2.0.5",
|
||||
# opencv-python 4.5.4+ required for Python 3.10 wheels
|
||||
"opencv-python>=4.5.4",
|
||||
# pycocotools 2.0.7+ required for wheels
|
||||
"pycocotools>=2.0.7",
|
||||
# pyvista 0.34.2+ required to avoid ImportError in CI
|
||||
"pyvista>=0.34.2",
|
||||
# radiant-mlhub 0.3+ required for newer tqdm support required by lightning
|
||||
"radiant-mlhub>=0.3",
|
||||
# rarfile 4+ required for wheels
|
||||
"rarfile>=4",
|
||||
# scikit-image 0.18+ required for numpy 1.17+ compatibility
|
||||
# https://github.com/scikit-image/scikit-image/issues/3655
|
||||
"scikit-image>=0.18",
|
||||
# scipy 1.6.2+ required for scikit-image 0.18+ compatibility
|
||||
"scipy>=1.6.2",
|
||||
# zipfile-deflate64 0.2+ required for extraction bugfix:
|
||||
# https://github.com/brianhelba/zipfile-deflate64/issues/19
|
||||
# scikit-image 0.19+ required for Python 3.10 wheels
|
||||
"scikit-image>=0.19",
|
||||
# scipy 1.7.2+ required for Python 3.10 wheels
|
||||
"scipy>=1.7.2",
|
||||
# zipfile-deflate64 0.2+ required for Python 3.10 wheels
|
||||
"zipfile-deflate64>=0.2",
|
||||
]
|
||||
docs = [
|
||||
|
@ -126,7 +122,7 @@ style = [
|
|||
"isort[colors]>=5.8",
|
||||
# pydocstyle 6.1+ required for pyproject.toml support
|
||||
"pydocstyle[toml]>=6.1",
|
||||
# pyupgrade 2.8+ required for --py39-plus flag
|
||||
# pyupgrade 2.8+ required for --py310-plus flag
|
||||
"pyupgrade>=2.8",
|
||||
]
|
||||
tests = [
|
||||
|
@ -151,7 +147,7 @@ Homepage = "https://github.com/microsoft/torchgeo"
|
|||
Documentation = "https://torchgeo.readthedocs.io"
|
||||
|
||||
[tool.black]
|
||||
target-version = ["py39", "py310"]
|
||||
target-version = ["py310"]
|
||||
color = true
|
||||
skip_magic_trailing_comma = true
|
||||
|
||||
|
@ -170,7 +166,7 @@ skip_gitignore = true
|
|||
color_output = true
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.9"
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
show_error_codes = true
|
||||
exclude = "(build|data|dist|docs/src|images|logo|logs|output)/"
|
||||
|
|
|
@ -3,34 +3,34 @@ setuptools==61.0.0
|
|||
|
||||
# install
|
||||
einops==0.3.0
|
||||
fiona==1.8.19
|
||||
fiona==1.8.21
|
||||
kornia==0.6.9
|
||||
lightly==1.4.4
|
||||
lightning[pytorch-extra]==2.0.0
|
||||
matplotlib==3.3.3
|
||||
numpy==1.19.3
|
||||
pandas==1.1.3
|
||||
pillow==8.0.0
|
||||
pyproj==3.0.0
|
||||
rasterio==1.2.0
|
||||
matplotlib==3.5.0
|
||||
numpy==1.21.2
|
||||
pandas==1.3.3
|
||||
pillow==8.4.0
|
||||
pyproj==3.3.0
|
||||
rasterio==1.3.0.post1
|
||||
rtree==1.0.0
|
||||
segmentation-models-pytorch==0.2.0
|
||||
shapely==1.7.1
|
||||
shapely==1.8.0
|
||||
timm==0.4.12
|
||||
torch==1.13.0
|
||||
torchmetrics==0.10.0
|
||||
torchvision==0.14.0
|
||||
|
||||
# datasets
|
||||
h5py==3.0.0
|
||||
h5py==3.6.0
|
||||
laspy==2.0.0
|
||||
opencv-python==4.4.0.46
|
||||
pycocotools==2.0.5
|
||||
opencv-python==4.5.4.58
|
||||
pycocotools==2.0.7
|
||||
pyvista==0.34.2
|
||||
radiant-mlhub==0.3.0
|
||||
rarfile==4.0
|
||||
scikit-image==0.18.0
|
||||
scipy==1.6.2
|
||||
scikit-image==0.19.0
|
||||
scipy==1.7.2
|
||||
zipfile-deflate64==0.2.0
|
||||
|
||||
# docs
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import rasterio as rio
|
||||
|
@ -18,7 +17,7 @@ def write_raster(
|
|||
res: int = RES[0],
|
||||
epsg: int = EPSG[0],
|
||||
dtype: str = "uint8",
|
||||
path: Optional[str] = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
"""Write a raster file.
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import pickle
|
|||
import sys
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -35,7 +34,7 @@ class CustomGeoDataset(GeoDataset):
|
|||
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
|
||||
crs: CRS = CRS.from_epsg(4087),
|
||||
res: float = 1,
|
||||
paths: Optional[Union[str, Iterable[str]]] = None,
|
||||
paths: str | Iterable[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.index.insert(0, tuple(bounds))
|
||||
|
@ -249,7 +248,7 @@ class TestRasterDataset:
|
|||
},
|
||||
],
|
||||
)
|
||||
def test_files(self, paths: Union[str, Iterable[str]]) -> None:
|
||||
def test_files(self, paths: str | Iterable[str]) -> None:
|
||||
assert 1 <= len(NAIP(paths).files) <= 2
|
||||
|
||||
def test_getitem_single_file(self, naip: NAIP) -> None:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
from collections.abc import Sequence
|
||||
from math import floor, isclose
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from rasterio.crs import CRS
|
||||
|
@ -65,7 +65,7 @@ class CustomGeoDataset(GeoDataset):
|
|||
],
|
||||
)
|
||||
def test_random_bbox_assignment(
|
||||
lengths: Sequence[Union[int, float]], expected_lengths: Sequence[int]
|
||||
lengths: Sequence[int | float], expected_lengths: Sequence[int]
|
||||
) -> None:
|
||||
ds = CustomGeoDataset(
|
||||
[
|
||||
|
@ -255,8 +255,7 @@ def test_roi_split() -> None:
|
|||
],
|
||||
)
|
||||
def test_time_series_split(
|
||||
lengths: Sequence[Union[tuple[int, int], int, float]],
|
||||
expected_lengths: Sequence[int],
|
||||
lengths: Sequence[tuple[int, int] | int | float], expected_lengths: Sequence[int]
|
||||
) -> None:
|
||||
ds = CustomGeoDataset(
|
||||
[
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import enum
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -33,7 +33,7 @@ MAYBE_TUPLE = Union[float, tuple[float, float]]
|
|||
],
|
||||
)
|
||||
def test_tile_to_chips(
|
||||
size: MAYBE_TUPLE, stride: Optional[MAYBE_TUPLE], expected: MAYBE_TUPLE
|
||||
size: MAYBE_TUPLE, stride: MAYBE_TUPLE | None, expected: MAYBE_TUPLE
|
||||
) -> None:
|
||||
bounds = BoundingBox(0, 10, 20, 30, 40, 50)
|
||||
size = _to_tuple(size)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""AgriFieldNet datamodule."""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
@ -25,8 +25,8 @@ class AgriFieldNetDataModule(GeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[int, tuple[int, int]] = 256,
|
||||
length: Optional[int] = None,
|
||||
patch_size: int | tuple[int, int] = 256,
|
||||
length: int | None = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""Chesapeake Bay High-Resolution Land Cover Project datamodule."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch.nn as nn
|
||||
|
@ -63,7 +63,7 @@ class ChesapeakeCVPRDataModule(GeoDataModule):
|
|||
test_splits: list[str],
|
||||
batch_size: int = 64,
|
||||
patch_size: int = 256,
|
||||
length: Optional[int] = None,
|
||||
length: int | None = None,
|
||||
num_workers: int = 0,
|
||||
class_set: int = 7,
|
||||
use_prior_labels: bool = False,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""DeepGlobe Land Cover Classification Challenge datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
||||
|
@ -24,7 +24,7 @@ class DeepGlobeLandCoverDataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[tuple[int, int], int] = 64,
|
||||
patch_size: tuple[int, int] | int = 64,
|
||||
val_split_pct: float = 0.2,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
|
||||
"""Base classes for all :mod:`torchgeo` data modules."""
|
||||
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
@ -55,27 +56,27 @@ class BaseDataModule(LightningDataModule):
|
|||
self.kwargs = kwargs
|
||||
|
||||
# Datasets
|
||||
self.dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.dataset: Dataset[dict[str, Tensor]] | None = None
|
||||
self.train_dataset: Dataset[dict[str, Tensor]] | None = None
|
||||
self.val_dataset: Dataset[dict[str, Tensor]] | None = None
|
||||
self.test_dataset: Dataset[dict[str, Tensor]] | None = None
|
||||
self.predict_dataset: Dataset[dict[str, Tensor]] | None = None
|
||||
|
||||
# Data loaders
|
||||
self.train_batch_size: Optional[int] = None
|
||||
self.val_batch_size: Optional[int] = None
|
||||
self.test_batch_size: Optional[int] = None
|
||||
self.predict_batch_size: Optional[int] = None
|
||||
self.train_batch_size: int | None = None
|
||||
self.val_batch_size: int | None = None
|
||||
self.test_batch_size: int | None = None
|
||||
self.predict_batch_size: int | None = None
|
||||
|
||||
# Data augmentation
|
||||
Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]]
|
||||
self.aug: Transform = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std), data_keys=["image"]
|
||||
)
|
||||
self.train_aug: Optional[Transform] = None
|
||||
self.val_aug: Optional[Transform] = None
|
||||
self.test_aug: Optional[Transform] = None
|
||||
self.predict_aug: Optional[Transform] = None
|
||||
self.train_aug: Transform | None = None
|
||||
self.val_aug: Transform | None = None
|
||||
self.test_aug: Transform | None = None
|
||||
self.predict_aug: Transform | None = None
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Download and prepare data.
|
||||
|
@ -141,7 +142,7 @@ class BaseDataModule(LightningDataModule):
|
|||
|
||||
return batch
|
||||
|
||||
def plot(self, *args: Any, **kwargs: Any) -> Optional[Figure]:
|
||||
def plot(self, *args: Any, **kwargs: Any) -> Figure | None:
|
||||
"""Run the plot method of the validation dataset if one exists.
|
||||
|
||||
Should only be called during 'fit' or 'validate' stages as ``val_dataset``
|
||||
|
@ -154,7 +155,7 @@ class BaseDataModule(LightningDataModule):
|
|||
Returns:
|
||||
A matplotlib Figure with the image, ground truth, and predictions.
|
||||
"""
|
||||
fig: Optional[Figure] = None
|
||||
fig: Figure | None = None
|
||||
dataset = self.dataset or self.val_dataset
|
||||
if dataset is not None:
|
||||
if hasattr(dataset, "plot"):
|
||||
|
@ -172,8 +173,8 @@ class GeoDataModule(BaseDataModule):
|
|||
self,
|
||||
dataset_class: type[GeoDataset],
|
||||
batch_size: int = 1,
|
||||
patch_size: Union[int, tuple[int, int]] = 64,
|
||||
length: Optional[int] = None,
|
||||
patch_size: int | tuple[int, int] = 64,
|
||||
length: int | None = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
@ -196,18 +197,18 @@ class GeoDataModule(BaseDataModule):
|
|||
self.collate_fn = stack_samples
|
||||
|
||||
# Samplers
|
||||
self.sampler: Optional[GeoSampler] = None
|
||||
self.train_sampler: Optional[GeoSampler] = None
|
||||
self.val_sampler: Optional[GeoSampler] = None
|
||||
self.test_sampler: Optional[GeoSampler] = None
|
||||
self.predict_sampler: Optional[GeoSampler] = None
|
||||
self.sampler: GeoSampler | None = None
|
||||
self.train_sampler: GeoSampler | None = None
|
||||
self.val_sampler: GeoSampler | None = None
|
||||
self.test_sampler: GeoSampler | None = None
|
||||
self.predict_sampler: GeoSampler | None = None
|
||||
|
||||
# Batch samplers
|
||||
self.batch_sampler: Optional[BatchGeoSampler] = None
|
||||
self.train_batch_sampler: Optional[BatchGeoSampler] = None
|
||||
self.val_batch_sampler: Optional[BatchGeoSampler] = None
|
||||
self.test_batch_sampler: Optional[BatchGeoSampler] = None
|
||||
self.predict_batch_sampler: Optional[BatchGeoSampler] = None
|
||||
self.batch_sampler: BatchGeoSampler | None = None
|
||||
self.train_batch_sampler: BatchGeoSampler | None = None
|
||||
self.val_batch_sampler: BatchGeoSampler | None = None
|
||||
self.test_batch_sampler: BatchGeoSampler | None = None
|
||||
self.predict_batch_sampler: BatchGeoSampler | None = None
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Set up datasets and samplers.
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""GID-15 datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
||||
|
@ -26,7 +26,7 @@ class GID15DataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[tuple[int, int], int] = 64,
|
||||
patch_size: tuple[int, int] | int = 64,
|
||||
val_split_pct: float = 0.2,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""InriaAerialImageLabeling datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
||||
|
@ -26,7 +26,7 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[tuple[int, int], int] = 64,
|
||||
patch_size: tuple[int, int] | int = 64,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""L7 Irish datamodule."""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
@ -25,8 +25,8 @@ class L7IrishDataModule(GeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
patch_size: Union[int, tuple[int, int]] = 224,
|
||||
length: Optional[int] = None,
|
||||
patch_size: int | tuple[int, int] = 224,
|
||||
length: int | None = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""L8 Biome datamodule."""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
@ -25,8 +25,8 @@ class L8BiomeDataModule(GeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
patch_size: Union[int, tuple[int, int]] = 224,
|
||||
length: Optional[int] = None,
|
||||
patch_size: int | tuple[int, int] = 224,
|
||||
length: int | None = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""LEVIR-CD+ datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
||||
|
@ -25,7 +25,7 @@ class LEVIRCDDataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 8,
|
||||
patch_size: Union[tuple[int, int], int] = 256,
|
||||
patch_size: tuple[int, int] | int = 256,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
@ -70,7 +70,7 @@ class LEVIRCDPlusDataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 8,
|
||||
patch_size: Union[tuple[int, int], int] = 256,
|
||||
patch_size: tuple[int, int] | int = 256,
|
||||
val_split_pct: float = 0.2,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""National Agriculture Imagery Program (NAIP) datamodule."""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -23,8 +23,8 @@ class NAIPChesapeakeDataModule(GeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[int, tuple[int, int]] = 256,
|
||||
length: Optional[int] = None,
|
||||
patch_size: int | tuple[int, int] = 256,
|
||||
length: int | None = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""OSCD datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
@ -60,7 +60,7 @@ class OSCDDataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[tuple[int, int], int] = 64,
|
||||
patch_size: tuple[int, int] | int = 64,
|
||||
val_split_pct: float = 0.2,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""Potsdam datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
||||
|
@ -26,7 +26,7 @@ class Potsdam2DDataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[tuple[int, int], int] = 64,
|
||||
patch_size: tuple[int, int] | int = 64,
|
||||
val_split_pct: float = 0.2,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""Sentinel-2 and CDL datamodule."""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
@ -26,8 +26,8 @@ class Sentinel2CDLDataModule(GeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[int, tuple[int, int]] = 64,
|
||||
length: Optional[int] = None,
|
||||
patch_size: int | tuple[int, int] = 64,
|
||||
length: int | None = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""Sentinel-2 and NCCM datamodule."""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
@ -26,8 +26,8 @@ class Sentinel2NCCMDataModule(GeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[int, tuple[int, int]] = 64,
|
||||
length: Optional[int] = None,
|
||||
patch_size: int | tuple[int, int] = 64,
|
||||
length: int | None = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
"""South America Soybean datamodule."""
|
||||
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
@ -28,8 +28,8 @@ class Sentinel2SouthAmericaSoybeanDataModule(GeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[int, tuple[int, int]] = 64,
|
||||
length: Optional[int] = None,
|
||||
patch_size: int | tuple[int, int] = 64,
|
||||
length: int | None = None,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""SSL4EO datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
from kornia.constants import DataKey, Resample
|
||||
|
@ -23,7 +23,7 @@ class SSL4EOLBenchmarkDataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[int, tuple[int, int]] = 224,
|
||||
patch_size: int | tuple[int, int] = 224,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
"""Common datamodule utilities."""
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -103,9 +103,9 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
|
|||
|
||||
|
||||
def dataset_split(
|
||||
dataset: Union[TensorDataset, NonGeoDataset],
|
||||
dataset: TensorDataset | NonGeoDataset,
|
||||
val_pct: float,
|
||||
test_pct: Optional[float] = None,
|
||||
test_pct: float | None = None,
|
||||
) -> list[Subset[Any]]:
|
||||
"""Split a torch Dataset into train/val/test sets.
|
||||
|
||||
|
@ -142,9 +142,9 @@ def dataset_split(
|
|||
|
||||
def group_shuffle_split(
|
||||
groups: Iterable[Any],
|
||||
train_size: Optional[float] = None,
|
||||
test_size: Optional[float] = None,
|
||||
random_state: Optional[int] = None,
|
||||
train_size: float | None = None,
|
||||
test_size: float | None = None,
|
||||
random_state: int | None = None,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
"""Method for performing a single group-wise shuffle split of data.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""Vaihingen datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
||||
|
@ -26,7 +26,7 @@ class Vaihingen2DDataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[tuple[int, int], int] = 64,
|
||||
patch_size: tuple[int, int] | int = 64,
|
||||
val_split_pct: float = 0.2,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""NWPU VHR-10 datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
@ -26,7 +26,7 @@ class VHR10DataModule(NonGeoDataModule):
|
|||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[tuple[int, int], int] = 512,
|
||||
patch_size: tuple[int, int] | int = 512,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
test_split_pct: float = 0.2,
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Callable, Optional, cast
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -87,7 +88,7 @@ class ADVANCE(NonGeoDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -228,7 +229,7 @@ class ADVANCE(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -56,10 +56,10 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
cache: bool = True,
|
||||
) -> None:
|
||||
|
@ -121,7 +121,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
@ -113,11 +113,11 @@ class AgriFieldNet(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
classes: list[int] = list(cmap.keys()),
|
||||
bands: Sequence[str] = all_bands,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
cache: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a new AgriFieldNet dataset instance.
|
||||
|
@ -218,7 +218,7 @@ class AgriFieldNet(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""Airphen dataset."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -46,7 +46,7 @@ class Airphen(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
|
||||
"""Aster Global Digital Elevation Model dataset."""
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -46,10 +47,10 @@ class AsterGDEM(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, list[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
paths: str | list[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
@ -89,7 +90,7 @@ class AsterGDEM(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -182,9 +182,9 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
|||
chip_size: int = 256,
|
||||
stride: int = 128,
|
||||
bands: tuple[str, ...] = all_bands,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
|
@ -408,7 +408,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
|||
|
||||
return images and targets
|
||||
|
||||
def _download(self, api_key: Optional[str] = None) -> None:
|
||||
def _download(self, api_key: str | None = None) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Args:
|
||||
|
@ -434,7 +434,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
time_step: int = 0,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -275,7 +275,7 @@ class BigEarthNet(NonGeoDataset):
|
|||
split: str = "train",
|
||||
bands: str = "all",
|
||||
num_classes: int = 19,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -533,7 +533,7 @@ class BigEarthNet(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -219,7 +218,7 @@ class BioMassters(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
"""Canadian Building Footprints dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -61,10 +61,10 @@ class CanadianBuildingFootprints(VectorDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float = 0.00001,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -127,7 +127,7 @@ class CanadianBuildingFootprints(VectorDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
"""CDL dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
@ -206,12 +206,12 @@ class CDL(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
years: list[int] = [2023],
|
||||
classes: list[int] = list(cmap.keys()),
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -336,7 +336,7 @@ class CDL(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"""ChaBuD dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -77,7 +77,7 @@ class ChaBuD(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "train",
|
||||
bands: list[str] = all_bands,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -235,7 +235,7 @@ class ChaBuD(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
import abc
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import fiona
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -90,10 +90,10 @@ class Chesapeake(RasterDataset, abc.ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -170,7 +170,7 @@ class Chesapeake(RasterDataset, abc.ABC):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
@ -512,7 +512,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
root: str = "data",
|
||||
splits: Sequence[str] = ["de-train"],
|
||||
layers: Sequence[str] = ["naip-new", "lc"],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -711,7 +711,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -111,9 +111,9 @@ class CloudCoverDetection(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "train",
|
||||
bands: Sequence[str] = band_names,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initiatlize a new Cloud Cover Detection Dataset instance.
|
||||
|
@ -329,7 +329,7 @@ class CloudCoverDetection(NonGeoDataset):
|
|||
|
||||
return images and targets
|
||||
|
||||
def _download(self, api_key: Optional[str] = None) -> None:
|
||||
def _download(self, api_key: str | None = None) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Args:
|
||||
|
@ -355,7 +355,7 @@ class CloudCoverDetection(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
"""CMS Global Mangrove Canopy dataset."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -167,12 +168,12 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, list[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
paths: str | list[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
measurement: str = "agb",
|
||||
country: str = all_countries[0],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -250,7 +251,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
import abc
|
||||
import csv
|
||||
import os
|
||||
from typing import Callable, Optional, cast
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -65,7 +66,7 @@ class COWC(NonGeoDataset, abc.ABC):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -192,7 +193,7 @@ class COWC(NonGeoDataset, abc.ABC):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -96,7 +96,7 @@ class CropHarvest(NonGeoDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -293,7 +293,7 @@ class CropHarvest(NonGeoDataset):
|
|||
features_path = os.path.join(self.root, self.file_dict["features"]["filename"])
|
||||
extract_archive(features_path)
|
||||
|
||||
def plot(self, sample: dict[str, Tensor], subtitle: Optional[str] = None) -> Figure:
|
||||
def plot(self, sample: dict[str, Tensor], subtitle: str | None = None) -> Figure:
|
||||
"""Plot a sample from the dataset using bands for Agriculture RGB composite.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import csv
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -125,9 +125,9 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
chip_size: int = 256,
|
||||
stride: int = 128,
|
||||
bands: tuple[str, ...] = band_names,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
|
@ -388,7 +388,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
|
||||
return train_field_ids, test_field_ids
|
||||
|
||||
def _download(self, api_key: Optional[str] = None) -> None:
|
||||
def _download(self, api_key: str | None = None) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Args:
|
||||
|
@ -411,7 +411,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
time_step: int = 0,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,9 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -73,9 +74,9 @@ class TropicalCyclone(NonGeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new Tropical Cyclone Wind Estimation Competition Dataset.
|
||||
|
@ -204,7 +205,7 @@ class TropicalCyclone(NonGeoDataset):
|
|||
return False
|
||||
return True
|
||||
|
||||
def _download(self, api_key: Optional[str] = None) -> None:
|
||||
def _download(self, api_key: str | None = None) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Args:
|
||||
|
@ -227,7 +228,7 @@ class TropicalCyclone(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"""DeepGlobe Land Cover Classification Challenge dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -102,7 +102,7 @@ class DeepGlobeLandCover(NonGeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new DeepGlobeLandCover dataset instance.
|
||||
|
@ -229,7 +229,7 @@ class DeepGlobeLandCover(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
alpha: float = 0.5,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
|
|
@ -5,8 +5,7 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -144,7 +143,7 @@ class DFC2022(NonGeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new DFC2022 dataset instance.
|
||||
|
@ -229,7 +228,7 @@ class DFC2022(NonGeoDataset):
|
|||
|
||||
return files
|
||||
|
||||
def _load_image(self, path: str, shape: Optional[Sequence[int]] = None) -> Tensor:
|
||||
def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
|
@ -296,7 +295,7 @@ class DFC2022(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, cast
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import fiona
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -255,7 +255,7 @@ class EnviroAtlas(GeoDataset):
|
|||
root: str = "data",
|
||||
splits: Sequence[str] = ["pittsburgh_pa-2010_1m-train"],
|
||||
layers: Sequence[str] = ["naip", "prior"],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
prior_as_input: bool = False,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
|
@ -445,7 +445,7 @@ class EnviroAtlas(GeoDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -68,10 +68,10 @@ class Esri2020(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -138,7 +138,7 @@ class Esri2020(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -82,7 +82,7 @@ class ETCI2021(NonGeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -255,7 +255,7 @@ class ETCI2021(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -83,10 +83,10 @@ class EUDEM(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -140,7 +140,7 @@ class EUDEM(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import csv
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import fiona
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -88,11 +88,11 @@ class EuroCrops(VectorDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS = CRS.from_epsg(4326),
|
||||
res: float = 0.00001,
|
||||
classes: Optional[list[str]] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
classes: list[str] | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -166,7 +166,7 @@ class EuroCrops(VectorDataset):
|
|||
self.base_url + fname, self.paths, md5=md5 if self.checksum else None
|
||||
)
|
||||
|
||||
def _load_class_map(self, classes: Optional[list[str]]) -> None:
|
||||
def _load_class_map(self, classes: list[str] | None) -> None:
|
||||
"""Load map from HCAT class codes to class indices.
|
||||
|
||||
If classes is provided, then we simply use those codes. Otherwise, we load
|
||||
|
@ -221,7 +221,7 @@ class EuroCrops(VectorDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
"""EuroSAT dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional, cast
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -106,7 +106,7 @@ class EuroSAT(NonGeoClassificationDataset):
|
|||
root: str = "data",
|
||||
split: str = "train",
|
||||
bands: Sequence[str] = BAND_SETS["all"],
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -247,7 +247,7 @@ class EuroSAT(NonGeoClassificationDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Optional, cast
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
from xml.etree.ElementTree import Element, parse
|
||||
|
||||
import matplotlib.patches as patches
|
||||
|
@ -230,7 +231,7 @@ class FAIR1M(NonGeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -382,7 +383,7 @@ class FAIR1M(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
"""FireRisk dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional, cast
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -68,7 +69,7 @@ class FireRisk(NonGeoClassificationDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -136,7 +137,7 @@ class FireRisk(NonGeoClassificationDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import matplotlib.patches as patches
|
||||
|
@ -110,7 +111,7 @@ class ForestDamage(NonGeoDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -259,7 +260,7 @@ class ForestDamage(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -10,8 +10,8 @@ import os
|
|||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import fiona
|
||||
import fiona.transform
|
||||
|
@ -83,7 +83,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
|||
dataset = landsat7 | landsat8
|
||||
"""
|
||||
|
||||
paths: Union[str, Iterable[str]]
|
||||
paths: str | Iterable[str]
|
||||
_crs = CRS.from_epsg(4326)
|
||||
_res = 0.0
|
||||
|
||||
|
@ -108,7 +108,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
|||
__add__ = None # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
|
||||
self, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None
|
||||
) -> None:
|
||||
"""Initialize a new GeoDataset instance.
|
||||
|
||||
|
@ -190,9 +190,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
|||
# NOTE: This hack should be removed once the following issue is fixed:
|
||||
# https://github.com/Toblerity/rtree/issues/87
|
||||
|
||||
def __getstate__(
|
||||
self,
|
||||
) -> tuple[dict[str, Any], list[tuple[Any, Any, Optional[Any]]]]:
|
||||
def __getstate__(self) -> tuple[dict[str, Any], list[tuple[Any, Any, Any | None]]]:
|
||||
"""Define how instances are pickled.
|
||||
|
||||
Returns:
|
||||
|
@ -388,11 +386,11 @@ class RasterDataset(GeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
bands: Optional[Sequence[str]] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
bands: Sequence[str] | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a new RasterDataset instance.
|
||||
|
@ -539,7 +537,7 @@ class RasterDataset(GeoDataset):
|
|||
self,
|
||||
filepaths: Sequence[str],
|
||||
query: BoundingBox,
|
||||
band_indexes: Optional[Sequence[int]] = None,
|
||||
band_indexes: Sequence[int] | None = None,
|
||||
) -> Tensor:
|
||||
"""Load and merge one or more files.
|
||||
|
||||
|
@ -612,11 +610,11 @@ class VectorDataset(GeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float = 0.0001,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
label_name: Optional[str] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
label_name: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a new VectorDataset instance.
|
||||
|
||||
|
@ -809,9 +807,9 @@ class NonGeoClassificationDataset(NonGeoDataset, ImageFolder): # type: ignore[m
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
loader: Optional[Callable[[str], Any]] = pil_loader,
|
||||
is_valid_file: Optional[Callable[[str], bool]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
loader: Callable[[str], Any] | None = pil_loader,
|
||||
is_valid_file: Callable[[str], bool] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a new NonGeoClassificationDataset instance.
|
||||
|
||||
|
@ -909,7 +907,7 @@ class IntersectionDataset(GeoDataset):
|
|||
collate_fn: Callable[
|
||||
[Sequence[dict[str, Any]]], dict[str, Any]
|
||||
] = concat_samples,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a new IntersectionDataset instance.
|
||||
|
||||
|
@ -1067,7 +1065,7 @@ class UnionDataset(GeoDataset):
|
|||
collate_fn: Callable[
|
||||
[Sequence[dict[str, Any]]], dict[str, Any]
|
||||
] = merge_samples,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a new UnionDataset instance.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -89,7 +89,7 @@ class GID15(NonGeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -234,7 +234,7 @@ class GID15(NonGeoDataset):
|
|||
md5=self.md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
|
||||
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
@ -119,11 +119,11 @@ class GlobBiomass(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
measurement: str = "agb",
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -225,7 +225,7 @@ class GlobBiomass(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Optional, cast, overload
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast, overload
|
||||
|
||||
import fiona
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -148,7 +149,7 @@ class IDTReeS(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "train",
|
||||
task: str = "task1",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -333,7 +334,7 @@ class IDTReeS(NonGeoDataset):
|
|||
|
||||
def _load(
|
||||
self, root: str
|
||||
) -> tuple[list[str], Optional[dict[int, dict[str, Any]]], Any]:
|
||||
) -> tuple[list[str], dict[int, dict[str, Any]] | None, Any]:
|
||||
"""Load files, geometries, and labels.
|
||||
|
||||
Args:
|
||||
|
@ -419,8 +420,8 @@ class IDTReeS(NonGeoDataset):
|
|||
image_size: tuple[int, int],
|
||||
min_size: int,
|
||||
boxes: Tensor,
|
||||
labels: Optional[Tensor],
|
||||
) -> tuple[Tensor, Optional[Tensor]]:
|
||||
labels: Tensor | None,
|
||||
) -> tuple[Tensor, Tensor | None]:
|
||||
"""Clip boxes to image size and filter boxes with sides less than ``min_size``.
|
||||
|
||||
Args:
|
||||
|
@ -477,7 +478,7 @@ class IDTReeS(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
hsi_indices: tuple[int, int, int] = (0, 1, 2),
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
import glob
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -64,7 +65,7 @@ class InriaAerialImageLabeling(NonGeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new InriaAerialImageLabeling Dataset instance.
|
||||
|
@ -200,7 +201,7 @@ class InriaAerialImageLabeling(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -97,11 +97,11 @@ class L7Irish(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = CRS.from_epsg(3857),
|
||||
res: Optional[float] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = CRS.from_epsg(3857),
|
||||
res: float | None = None,
|
||||
bands: Sequence[str] = all_bands,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -221,7 +221,7 @@ class L7Irish(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -96,11 +96,11 @@ class L8Biome(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]],
|
||||
crs: Optional[CRS] = CRS.from_epsg(3857),
|
||||
res: Optional[float] = None,
|
||||
paths: str | Iterable[str],
|
||||
crs: CRS | None = CRS.from_epsg(3857),
|
||||
res: float | None = None,
|
||||
bands: Sequence[str] = all_bands,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -217,7 +217,7 @@ class L8Biome(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -6,8 +6,9 @@ import abc
|
|||
import glob
|
||||
import hashlib
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -152,7 +153,7 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
@ -209,9 +210,9 @@ class LandCoverAIGeo(LandCoverAIBase, RasterDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -299,7 +300,7 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
"""Landsat datasets."""
|
||||
|
||||
import abc
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -59,11 +59,11 @@ class Landsat(RasterDataset, abc.ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
bands: Optional[Sequence[str]] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
bands: Sequence[str] | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
@ -94,7 +94,7 @@ class Landsat(RasterDataset, abc.ABC):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import abc
|
||||
import glob
|
||||
import os
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -29,14 +29,14 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC):
|
|||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
splits: Union[list[str], dict[str, dict[str, str]]]
|
||||
splits: list[str] | dict[str, dict[str, str]]
|
||||
directories = ["A", "B", "label"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -135,7 +135,7 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -93,7 +93,7 @@ class LoveDA(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "train",
|
||||
scene: list[str] = ["urban", "rural"],
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -256,7 +256,7 @@ class LoveDA(NonGeoDataset):
|
|||
md5=self.md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
|
||||
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import os
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -111,7 +111,7 @@ class MapInWild(NonGeoDataset):
|
|||
root: str = "data",
|
||||
modality: list[str] = ["mask", "esa_wc", "viirs", "s2_summer"],
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -223,7 +223,7 @@ class MapInWild(NonGeoDataset):
|
|||
tensor = torch.from_numpy(array).float()
|
||||
return tensor
|
||||
|
||||
def _verify(self, url: str, md5: Optional[str] = None) -> None:
|
||||
def _verify(self, url: str, md5: str | None = None) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -258,7 +258,7 @@ class MapInWild(NonGeoDataset):
|
|||
if not url.endswith(".csv"):
|
||||
self._extract(url)
|
||||
|
||||
def _download(self, url: str, md5: Optional[str]) -> None:
|
||||
def _download(self, url: str, md5: str | None) -> None:
|
||||
"""Downloads a modality.
|
||||
|
||||
Args:
|
||||
|
@ -330,7 +330,7 @@ class MapInWild(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
"""Million-AID dataset."""
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Optional, cast
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -191,7 +192,7 @@ class MillionAID(NonGeoDataset):
|
|||
root: str = "data",
|
||||
task: str = "multi-class",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new MillionAID dataset instance.
|
||||
|
@ -332,7 +333,7 @@ class MillionAID(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""National Agriculture Imagery Program (NAIP) dataset."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -52,7 +52,7 @@ class NAIP(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"""NASA Marine Debris dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -66,9 +66,9 @@ class NASAMarineDebris(NonGeoDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
|
@ -224,7 +224,7 @@ class NASAMarineDebris(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
"""Northeastern China Crop Map Dataset."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
@ -82,11 +82,11 @@ class NCCM(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
years: list[int] = [2019],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -170,7 +170,7 @@ class NCCM(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
@ -107,12 +107,12 @@ class NLCD(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
years: list[int] = [2019],
|
||||
classes: list[int] = list(cmap.keys()),
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -230,7 +230,7 @@ class NLCD(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@ import glob
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, cast
|
||||
|
||||
import fiona
|
||||
import fiona.transform
|
||||
|
@ -206,10 +206,10 @@ class OpenBuildings(VectorDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float = 0.0001,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
@ -413,7 +413,7 @@ class OpenBuildings(VectorDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,7 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -103,7 +102,7 @@ class OSCD(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "train",
|
||||
bands: Sequence[str] = all_bands,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -164,7 +163,7 @@ class OSCD(NonGeoDataset):
|
|||
"""
|
||||
return len(self.files)
|
||||
|
||||
def _load_files(self) -> list[dict[str, Union[str, Sequence[str]]]]:
|
||||
def _load_files(self) -> list[dict[str, str | Sequence[str]]]:
|
||||
regions = []
|
||||
labels_root = os.path.join(
|
||||
self.root,
|
||||
|
@ -284,7 +283,7 @@ class OSCD(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
alpha: float = 0.5,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
|
|
@ -4,8 +4,7 @@
|
|||
"""PASTIS dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import fiona
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -132,7 +131,7 @@ class PASTIS(NonGeoDataset):
|
|||
folds: Sequence[int] = (1, 2, 3, 4, 5),
|
||||
bands: str = "s2",
|
||||
mode: str = "semantic",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -347,7 +346,7 @@ class PASTIS(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
"""PatternNet dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional, cast
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -84,7 +85,7 @@ class PatternNet(NonGeoClassificationDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -145,7 +146,7 @@ class PatternNet(NonGeoClassificationDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"""Potsdam dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -123,7 +123,7 @@ class Potsdam2D(NonGeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new Potsdam dataset instance.
|
||||
|
@ -240,7 +240,7 @@ class Potsdam2D(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
alpha: float = 0.5,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""PRISMA datasets."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -81,7 +81,7 @@ class PRISMA(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.patches as patches
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -69,7 +69,7 @@ class ReforesTree(NonGeoDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -209,7 +209,7 @@ class ReforesTree(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
"""RESISC45 dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional, cast
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -112,7 +113,7 @@ class RESISC45(NonGeoClassificationDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -193,7 +194,7 @@ class RESISC45(NonGeoClassificationDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,8 +4,7 @@
|
|||
"""Rwanda Field Boundary Competition dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -91,9 +90,9 @@ class RwandaFieldBoundary(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "train",
|
||||
bands: Sequence[str] = all_bands,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new RwandaFieldBoundary instance.
|
||||
|
@ -258,7 +257,7 @@ class RwandaFieldBoundary(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
time_step: int = 0,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
import os
|
||||
import random
|
||||
from collections.abc import Callable, Collection, Iterable
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -219,7 +218,7 @@ class SeasoNet(NonGeoDataset):
|
|||
bands: Iterable[str] = all_bands,
|
||||
grids: Iterable[int] = [1, 2],
|
||||
concat_seasons: int = 1,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -405,7 +404,7 @@ class SeasoNet(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
show_legend: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import os
|
||||
import random
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -79,7 +79,7 @@ class SeasonalContrastS2(NonGeoDataset):
|
|||
version: str = "100k",
|
||||
seasons: int = 1,
|
||||
bands: list[str] = rgb_bands,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -231,7 +231,7 @@ class SeasonalContrastS2(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,8 +4,7 @@
|
|||
"""SEN12MS dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -173,7 +172,7 @@ class SEN12MS(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "train",
|
||||
bands: Sequence[str] = BAND_SETS["all"],
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SEN12MS dataset instance.
|
||||
|
@ -320,7 +319,7 @@ class SEN12MS(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
"""Sentinel datasets."""
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
@ -141,11 +141,11 @@ class Sentinel1(Sentinel):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, list[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
paths: str | list[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float = 10,
|
||||
bands: Sequence[str] = ["VV", "VH"],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
@ -194,7 +194,7 @@ To create a dataset containing both, use:
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
@ -297,11 +297,11 @@ class Sentinel2(Sentinel):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float = 10,
|
||||
bands: Optional[Sequence[str]] = None,
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
bands: Sequence[str] | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
@ -333,7 +333,7 @@ class Sentinel2(Sentinel):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
"""SKy Images and Photovoltaic Power Dataset (SKIPP'D)."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -74,7 +75,7 @@ class SKIPPD(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "trainval",
|
||||
task: str = "nowcast",
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -133,7 +134,7 @@ class SKIPPD(NonGeoDataset):
|
|||
|
||||
return num_datapoints
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Union[str, Tensor]]:
|
||||
def __getitem__(self, index: int) -> dict[str, str | Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -142,7 +143,7 @@ class SKIPPD(NonGeoDataset):
|
|||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
sample: dict[str, Union[str, Tensor]] = {"image": self._load_image(index)}
|
||||
sample: dict[str, str | Tensor] = {"image": self._load_image(index)}
|
||||
sample.update(self._load_features(index))
|
||||
|
||||
if self.transforms is not None:
|
||||
|
@ -176,7 +177,7 @@ class SKIPPD(NonGeoDataset):
|
|||
tensor = torch.from_numpy(arr).to(torch.float32)
|
||||
return tensor
|
||||
|
||||
def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]:
|
||||
def _load_features(self, index: int) -> dict[str, str | Tensor]:
|
||||
"""Load label.
|
||||
|
||||
Args:
|
||||
|
@ -195,7 +196,7 @@ class SKIPPD(NonGeoDataset):
|
|||
path = os.path.join(self.root, f"times_{self.split}_{self.task}.npy")
|
||||
datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat)
|
||||
|
||||
features: dict[str, Union[str, Tensor]] = {
|
||||
features: dict[str, str | Tensor] = {
|
||||
"label": torch.tensor(label, dtype=torch.float32),
|
||||
"date": datestring,
|
||||
}
|
||||
|
@ -241,7 +242,7 @@ class SKIPPD(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
"""So2Sat dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional, cast
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -196,7 +196,7 @@ class So2Sat(NonGeoDataset):
|
|||
version: str = "2",
|
||||
split: str = "train",
|
||||
bands: Sequence[str] = BAND_SETS["all"],
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new So2Sat dataset instance.
|
||||
|
@ -340,7 +340,7 @@ class So2Sat(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
@ -98,11 +98,11 @@ class SouthAfricaCropType(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
classes: list[int] = list(cmap.keys()),
|
||||
bands: list[str] = all_bands,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a new South Africa Crop Type dataset instance.
|
||||
|
||||
|
@ -224,7 +224,7 @@ class SouthAfricaCropType(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
"""South America Soybean Dataset."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
@ -71,11 +71,11 @@ class SouthAmericaSoybean(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: Union[str, Iterable[str]] = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
paths: str | Iterable[str] = "data",
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
years: list[int] = [2021],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -133,7 +133,7 @@ class SouthAmericaSoybean(RasterDataset):
|
|||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -9,7 +9,8 @@ import glob
|
|||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import fiona
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -81,9 +82,9 @@ class SpaceNet(NonGeoDataset, abc.ABC):
|
|||
root: str,
|
||||
image: str,
|
||||
collections: list[str] = [],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet Dataset instance.
|
||||
|
@ -274,7 +275,7 @@ class SpaceNet(NonGeoDataset, abc.ABC):
|
|||
|
||||
return to_be_downloaded
|
||||
|
||||
def _download(self, collections: list[str], api_key: Optional[str] = None) -> None:
|
||||
def _download(self, collections: list[str], api_key: str | None = None) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Args:
|
||||
|
@ -299,7 +300,7 @@ class SpaceNet(NonGeoDataset, abc.ABC):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
@ -398,9 +399,9 @@ class SpaceNet1(SpaceNet):
|
|||
self,
|
||||
root: str = "data",
|
||||
image: str = "rgb",
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 1 Dataset instance.
|
||||
|
@ -514,9 +515,9 @@ class SpaceNet2(SpaceNet):
|
|||
root: str = "data",
|
||||
image: str = "PS-RGB",
|
||||
collections: list[str] = [],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 2 Dataset instance.
|
||||
|
@ -633,11 +634,11 @@ class SpaceNet3(SpaceNet):
|
|||
self,
|
||||
root: str = "data",
|
||||
image: str = "PS-RGB",
|
||||
speed_mask: Optional[bool] = False,
|
||||
speed_mask: bool | None = False,
|
||||
collections: list[str] = [],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 3 Dataset instance.
|
||||
|
@ -733,7 +734,7 @@ class SpaceNet3(SpaceNet):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
@ -884,9 +885,9 @@ class SpaceNet4(SpaceNet):
|
|||
root: str = "data",
|
||||
image: str = "PS-RGBNIR",
|
||||
angles: list[str] = [],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 4 Dataset instance.
|
||||
|
@ -1051,11 +1052,11 @@ class SpaceNet5(SpaceNet3):
|
|||
self,
|
||||
root: str = "data",
|
||||
image: str = "PS-RGB",
|
||||
speed_mask: Optional[bool] = False,
|
||||
speed_mask: bool | None = False,
|
||||
collections: list[str] = [],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 5 Dataset instance.
|
||||
|
@ -1183,9 +1184,9 @@ class SpaceNet6(SpaceNet):
|
|||
self,
|
||||
root: str = "data",
|
||||
image: str = "PS-RGB",
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 6 Dataset instance.
|
||||
|
||||
|
@ -1212,7 +1213,7 @@ class SpaceNet6(SpaceNet):
|
|||
|
||||
self.files = self._load_files(os.path.join(root, self.dataset_id))
|
||||
|
||||
def __download(self, api_key: Optional[str] = None) -> None:
|
||||
def __download(self, api_key: str | None = None) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Args:
|
||||
|
@ -1281,9 +1282,9 @@ class SpaceNet7(SpaceNet):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 7 Dataset instance.
|
||||
|
|
|
@ -7,7 +7,7 @@ from collections.abc import Sequence
|
|||
from copy import deepcopy
|
||||
from itertools import accumulate
|
||||
from math import floor, isclose
|
||||
from typing import Optional, Union, cast
|
||||
from typing import cast
|
||||
|
||||
from rtree.index import Index, Property
|
||||
from torch import Generator, default_generator, randint, randperm
|
||||
|
@ -50,7 +50,7 @@ def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[in
|
|||
def random_bbox_assignment(
|
||||
dataset: GeoDataset,
|
||||
lengths: Sequence[float],
|
||||
generator: Optional[Generator] = default_generator,
|
||||
generator: Generator | None = default_generator,
|
||||
) -> list[GeoDataset]:
|
||||
"""Split a GeoDataset randomly assigning its index's BoundingBoxes.
|
||||
|
||||
|
@ -104,7 +104,7 @@ def random_bbox_assignment(
|
|||
def random_bbox_splitting(
|
||||
dataset: GeoDataset,
|
||||
fractions: Sequence[float],
|
||||
generator: Optional[Generator] = default_generator,
|
||||
generator: Generator | None = default_generator,
|
||||
) -> list[GeoDataset]:
|
||||
"""Split a GeoDataset randomly splitting its index's BoundingBoxes.
|
||||
|
||||
|
@ -172,7 +172,7 @@ def random_grid_cell_assignment(
|
|||
dataset: GeoDataset,
|
||||
fractions: Sequence[float],
|
||||
grid_size: int = 6,
|
||||
generator: Optional[Generator] = default_generator,
|
||||
generator: Generator | None = default_generator,
|
||||
) -> list[GeoDataset]:
|
||||
"""Overlays a grid over a GeoDataset and randomly assigns cells to new GeoDatasets.
|
||||
|
||||
|
@ -289,7 +289,7 @@ def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> list[GeoDatas
|
|||
|
||||
|
||||
def time_series_split(
|
||||
dataset: GeoDataset, lengths: Sequence[Union[float, tuple[float, float]]]
|
||||
dataset: GeoDataset, lengths: Sequence[float | tuple[float, float]]
|
||||
) -> list[GeoDataset]:
|
||||
"""Split a GeoDataset on its time dimension to create non-overlapping GeoDatasets.
|
||||
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
import glob
|
||||
import os
|
||||
import random
|
||||
from typing import Callable, Optional, TypedDict
|
||||
from collections.abc import Callable
|
||||
from typing import TypedDict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -163,7 +164,7 @@ class SSL4EOL(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "oli_sr",
|
||||
seasons: int = 1,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -285,7 +286,7 @@ class SSL4EOL(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
@ -405,7 +406,7 @@ class SSL4EOS12(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "s2c",
|
||||
seasons: int = 1,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SSL4EOS12 instance.
|
||||
|
@ -500,7 +501,7 @@ class SSL4EOS12(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -110,8 +110,8 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
sensor: str = "oli_sr",
|
||||
product: str = "cdl",
|
||||
split: str = "train",
|
||||
classes: Optional[list[int]] = None,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
classes: list[int] | None = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -327,7 +327,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
"""SustainBench Crop Yield dataset."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -60,7 +61,7 @@ class SustainBenchCropYield(NonGeoDataset):
|
|||
root: str = "data",
|
||||
split: str = "train",
|
||||
countries: list[str] = ["usa"],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -195,7 +196,7 @@ class SustainBenchCropYield(NonGeoDataset):
|
|||
sample: dict[str, Any],
|
||||
band_idx: int = 0,
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
|
||||
"""UC Merced dataset."""
|
||||
import os
|
||||
from typing import Callable, Optional, cast
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -85,7 +86,7 @@ class UCMerced(NonGeoClassificationDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -190,7 +191,7 @@ class UCMerced(NonGeoClassificationDataset):
|
|||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче