Ruff: enable ruff-specific rules (#2218)

* Ruff: enable ruff-specific rules

* Static class variables

* String colormap must be list

* String colormap must be list
This commit is contained in:
Adam J. Stewart 2024-08-19 15:07:21 +02:00 коммит произвёл GitHub
Родитель c26512e39d
Коммит 067ae1af75
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
98 изменённых файлов: 668 добавлений и 595 удалений

Просмотреть файл

@ -19,7 +19,7 @@ import pytorch_sphinx_theme
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('..')) sys.path.insert(0, os.path.abspath('..'))
import torchgeo # noqa: E402 import torchgeo
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------

Просмотреть файл

@ -369,8 +369,8 @@
" date_format = '%Y%m%dT%H%M%S'\n", " date_format = '%Y%m%dT%H%M%S'\n",
" is_image = True\n", " is_image = True\n",
" separate_files = True\n", " separate_files = True\n",
" all_bands = ['B02', 'B03', 'B04', 'B08']\n", " all_bands = ('B02', 'B03', 'B04', 'B08')\n",
" rgb_bands = ['B04', 'B03', 'B02']" " rgb_bands = ('B04', 'B03', 'B02')"
] ]
}, },
{ {
@ -432,8 +432,8 @@
" date_format = '%Y%m%dT%H%M%S'\n", " date_format = '%Y%m%dT%H%M%S'\n",
" is_image = True\n", " is_image = True\n",
" separate_files = True\n", " separate_files = True\n",
" all_bands = ['B02', 'B03', 'B04', 'B08']\n", " all_bands = ('B02', 'B03', 'B04', 'B08')\n",
" rgb_bands = ['B04', 'B03', 'B02']\n", " rgb_bands = ('B04', 'B03', 'B02')\n",
"\n", "\n",
" def plot(self, sample):\n", " def plot(self, sample):\n",
" # Find the correct band index order\n", " # Find the correct band index order\n",

Просмотреть файл

@ -125,7 +125,7 @@ def filter_collection(
if filtered.size().getInfo() == 0: if filtered.size().getInfo() == 0:
raise ee.EEException( raise ee.EEException(
f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' # noqa: E501 f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.'
) )
return filtered return filtered

Просмотреть файл

@ -47,7 +47,7 @@ from tqdm import tqdm
def get_world_cities( def get_world_cities(
download_root: str = 'world_cities', size: int = 10000 download_root: str = 'world_cities', size: int = 10000
) -> pd.DataFrame: ) -> pd.DataFrame:
url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' # noqa: E501 url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip'
filename = 'worldcities.csv' filename = 'worldcities.csv'
download_and_extract_archive(url, download_root) download_and_extract_archive(url, download_root)
cols = ['city', 'lat', 'lng', 'population'] cols = ['city', 'lat', 'lng', 'population']

Просмотреть файл

@ -268,7 +268,7 @@ quote-style = "single"
skip-magic-trailing-comma = true skip-magic-trailing-comma = true
[tool.ruff.lint] [tool.ruff.lint]
extend-select = ["ANN", "D", "I", "NPY201", "UP"] extend-select = ["ANN", "D", "I", "NPY201", "RUF", "UP"]
ignore = ["ANN101", "ANN102", "ANN401"] ignore = ["ANN101", "ANN102", "ANN401"]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]

Просмотреть файл

@ -19,36 +19,36 @@ np.random.seed(0)
train_set = [ train_set = [
{ {
'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', # noqa: E501 'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif',
'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif',
'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', # noqa: E501 'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif',
}, },
{ {
'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', # noqa: E501 'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif',
'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif',
'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', # noqa: E501 'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif',
}, },
] ]
unlabeled_set = [ unlabeled_set = [
{ {
'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', # noqa: E501 'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif',
'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif',
}, },
{ {
'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', # noqa: E501 'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif',
'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif',
}, },
] ]
val_set = [ val_set = [
{ {
'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', # noqa: E501 'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif',
'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif',
}, },
{ {
'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', # noqa: E501 'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif',
'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif',
}, },
] ]

Просмотреть файл

@ -112,7 +112,7 @@ for season in seasons:
# Compute checksums # Compute checksums
with open(archive, 'rb') as f: with open(archive, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest() md5 = hashlib.md5(f.read()).hexdigest()
print(f'{season}: {repr(md5)}') print(f'{season}: {md5!r}')
# Write meta.csv # Write meta.csv
with open('meta.csv', 'w') as f: with open('meta.csv', 'w') as f:
@ -121,7 +121,7 @@ with open('meta.csv', 'w') as f:
# Compute checksums # Compute checksums
with open('meta.csv', 'rb') as f: with open('meta.csv', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest() md5 = hashlib.md5(f.read()).hexdigest()
print(f'meta.csv: {repr(md5)}') print(f'meta.csv: {md5!r}')
os.makedirs('splits', exist_ok=True) os.makedirs('splits', exist_ok=True)
@ -138,4 +138,4 @@ shutil.make_archive('splits', 'zip', '.', 'splits')
# Compute checksums # Compute checksums
with open('splits.zip', 'rb') as f: with open('splits.zip', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest() md5 = hashlib.md5(f.read()).hexdigest()
print(f'splits: {repr(md5)}') print(f'splits: {md5!r}')

Просмотреть файл

@ -83,5 +83,5 @@ class TestEuroCrops:
dataset[query] dataset[query]
def test_integrity_error(self, dataset: EuroCrops) -> None: def test_integrity_error(self, dataset: EuroCrops) -> None:
dataset.zenodo_files = [('AA.zip', 'invalid')] dataset.zenodo_files = (('AA.zip', 'invalid'),)
assert not dataset._check_integrity() assert not dataset._check_integrity()

Просмотреть файл

@ -72,7 +72,7 @@ class CustomVectorDataset(VectorDataset):
class CustomSentinelDataset(Sentinel2): class CustomSentinelDataset(Sentinel2):
all_bands: list[str] = [] all_bands: tuple[str, ...] = ()
separate_files = False separate_files = False
@ -356,7 +356,7 @@ class TestRasterDataset:
def test_no_all_bands(self) -> None: def test_no_all_bands(self) -> None:
root = os.path.join('tests', 'data', 'sentinel2') root = os.path.join('tests', 'data', 'sentinel2')
bands = ['B04', 'B03', 'B02'] bands = ('B04', 'B03', 'B02')
transforms = nn.Identity() transforms = nn.Identity()
cache = True cache = True
msg = ( msg = (

Просмотреть файл

@ -73,7 +73,7 @@ class TestBYOLTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
@pytest.fixture @pytest.fixture
def weights(self) -> WeightsEnum: def weights(self) -> WeightsEnum:

Просмотреть файл

@ -103,13 +103,13 @@ class TestClassificationTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
try: try:
main(['test'] + args) main(['test', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass
try: try:
main(['predict'] + args) main(['predict', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass
@ -259,13 +259,13 @@ class TestMultiLabelClassificationTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
try: try:
main(['test'] + args) main(['test', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass
try: try:
main(['predict'] + args) main(['predict', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass

Просмотреть файл

@ -97,13 +97,13 @@ class TestObjectDetectionTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
try: try:
main(['test'] + args) main(['test', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass
try: try:
main(['predict'] + args) main(['predict', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass

Просмотреть файл

@ -27,12 +27,12 @@ class TestClassificationTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
try: try:
main(['test'] + args) main(['test', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass
try: try:
main(['predict'] + args) main(['predict', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass

Просмотреть файл

@ -63,7 +63,7 @@ class TestMoCoTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
def test_version_warnings(self) -> None: def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'): with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'):

Просмотреть файл

@ -84,13 +84,13 @@ class TestRegressionTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
try: try:
main(['test'] + args) main(['test', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass
try: try:
main(['predict'] + args) main(['predict', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass
@ -237,13 +237,13 @@ class TestPixelwiseRegressionTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
try: try:
main(['test'] + args) main(['test', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass
try: try:
main(['predict'] + args) main(['predict', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass

Просмотреть файл

@ -108,13 +108,13 @@ class TestSemanticSegmentationTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
try: try:
main(['test'] + args) main(['test', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass
try: try:
main(['predict'] + args) main(['predict', *args])
except MisconfigurationException: except MisconfigurationException:
pass pass

Просмотреть файл

@ -63,7 +63,7 @@ class TestSimCLRTask:
'1', '1',
] ]
main(['fit'] + args) main(['fit', *args])
def test_version_warnings(self) -> None: def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'): with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'):

Просмотреть файл

@ -37,7 +37,7 @@ class SeasonalContrastS2DataModule(NonGeoDataModule):
seasons = kwargs.get('seasons', 1) seasons = kwargs.get('seasons', 1)
# Normalization only available for RGB dataset, defined here: # Normalization only available for RGB dataset, defined here:
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501 # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
if bands == SeasonalContrastS2.rgb_bands: if bands == SeasonalContrastS2.rgb_bands:
_min = torch.tensor([3, 2, 0]) _min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129]) _max = torch.tensor([88, 103, 129])

Просмотреть файл

@ -3,7 +3,7 @@
"""So2Sat datamodule.""" """So2Sat datamodule."""
from typing import Any from typing import Any, ClassVar
import torch import torch
from torch import Generator, Tensor from torch import Generator, Tensor
@ -21,7 +21,7 @@ class So2SatDataModule(NonGeoDataModule):
"train" set and use the "test" set as the test set. "train" set and use the "test" set as the test set.
""" """
means_per_version: dict[str, Tensor] = { means_per_version: ClassVar[dict[str, Tensor]] = {
'2': torch.tensor( '2': torch.tensor(
[ [
-0.00003591224260, -0.00003591224260,
@ -91,7 +91,7 @@ class So2SatDataModule(NonGeoDataModule):
} }
means_per_version['3_culture_10'] = means_per_version['2'] means_per_version['3_culture_10'] = means_per_version['2']
stds_per_version: dict[str, Tensor] = { stds_per_version: ClassVar[dict[str, Tensor]] = {
'2': torch.tensor( '2': torch.tensor(
[ [
0.17555201, 0.17555201,

Просмотреть файл

@ -45,7 +45,7 @@ class SSL4EOS12DataModule(NonGeoDataModule):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
mean = torch.tensor(0) mean = torch.tensor(0)
std = torch.tensor(10000) std = torch.tensor(10000)

Просмотреть файл

@ -63,14 +63,14 @@ class ADVANCE(NonGeoDataset):
* `scipy <https://pypi.org/project/scipy/>`_ to load the audio files to tensors * `scipy <https://pypi.org/project/scipy/>`_ to load the audio files to tensors
""" """
urls = [ urls = (
'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1', 'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1',
'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1', 'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1',
] )
filenames = ['ADVANCE_vision.zip', 'ADVANCE_sound.zip'] filenames = ('ADVANCE_vision.zip', 'ADVANCE_sound.zip')
md5s = ['a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31'] md5s = ('a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31')
directories = ['vision', 'sound'] directories = ('vision', 'sound')
classes = [ classes: tuple[str, ...] = (
'airport', 'airport',
'beach', 'beach',
'bridge', 'bridge',
@ -84,7 +84,7 @@ class ADVANCE(NonGeoDataset):
'sparse shrub land', 'sparse shrub land',
'sports land', 'sports land',
'train station', 'train station',
] )
def __init__( def __init__(
self, self,
@ -119,7 +119,7 @@ class ADVANCE(NonGeoDataset):
raise DatasetNotFoundError(self) raise DatasetNotFoundError(self)
self.files = self._load_files(self.root) self.files = self._load_files(self.root)
self.classes = sorted({f['cls'] for f in self.files}) self.classes = tuple(sorted({f['cls'] for f in self.files}))
self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)}
def __getitem__(self, index: int) -> dict[str, Tensor]: def __getitem__(self, index: int) -> dict[str, Tensor]:

Просмотреть файл

@ -46,7 +46,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
is_image = False is_image = False
url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' # noqa: E501 url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326'
base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson' base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson'

Просмотреть файл

@ -7,7 +7,7 @@ import os
import pathlib import pathlib
import re import re
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
@ -90,8 +90,8 @@ class AgriFieldNet(RasterDataset):
_(?P<band>B[0-9A-Z]{2})_10m _(?P<band>B[0-9A-Z]{2})_10m
""" """
rgb_bands = ['B04', 'B03', 'B02'] rgb_bands = ('B04', 'B03', 'B02')
all_bands = [ all_bands = (
'B01', 'B01',
'B02', 'B02',
'B03', 'B03',
@ -104,9 +104,9 @@ class AgriFieldNet(RasterDataset):
'B09', 'B09',
'B11', 'B11',
'B12', 'B12',
] )
cmap = { cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255), 0: (0, 0, 0, 255),
1: (255, 211, 0, 255), 1: (255, 211, 0, 255),
2: (255, 37, 37, 255), 2: (255, 37, 37, 255),

Просмотреть файл

@ -40,8 +40,8 @@ class Airphen(RasterDataset):
# Each camera measures a custom set of spectral bands chosen at purchase time. # Each camera measures a custom set of spectral bands chosen at purchase time.
# Hiphen offers 8 bands to choose from, sorted from short to long wavelength. # Hiphen offers 8 bands to choose from, sorted from short to long wavelength.
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8'] all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8')
rgb_bands = ['B4', 'B3', 'B1'] rgb_bands = ('B4', 'B3', 'B1')
def plot( def plot(
self, self,

Просмотреть файл

@ -147,7 +147,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
) )
rgb_bands = ('B04', 'B03', 'B02') rgb_bands = ('B04', 'B03', 'B02')
classes = [ classes = (
'No data', 'No data',
'Well-managed planatation', 'Well-managed planatation',
'Poorly-managed planatation', 'Poorly-managed planatation',
@ -155,7 +155,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
'Residential', 'Residential',
'Background', 'Background',
'Uncertain', 'Uncertain',
] )
# Same for all tiles # Same for all tiles
tile_height = 1186 tile_height = 1186
@ -199,11 +199,13 @@ class BeninSmallHolderCashews(NonGeoDataset):
# Calculate the indices that we will use over all tiles # Calculate the indices that we will use over all tiles
self.chips_metadata = [] self.chips_metadata = []
for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ for y in [
self.tile_height - self.chip_size *list(range(0, self.tile_height - self.chip_size, stride)),
self.tile_height - self.chip_size,
]: ]:
for x in list(range(0, self.tile_width - self.chip_size, stride)) + [ for x in [
self.tile_width - self.chip_size *list(range(0, self.tile_width - self.chip_size, stride)),
self.tile_width - self.chip_size,
]: ]:
self.chips_metadata.append((y, x)) self.chips_metadata.append((y, x))

Просмотреть файл

@ -7,6 +7,7 @@ import glob
import json import json
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -124,9 +125,9 @@ class BigEarthNet(NonGeoDataset):
* https://doi.org/10.1109/IGARSS.2019.8900532 * https://doi.org/10.1109/IGARSS.2019.8900532
""" # noqa: E501 """
class_sets = { class_sets: ClassVar[dict[int, list[str]]] = {
19: [ 19: [
'Urban fabric', 'Urban fabric',
'Industrial or commercial units', 'Industrial or commercial units',
@ -197,7 +198,7 @@ class BigEarthNet(NonGeoDataset):
], ],
} }
label_converter = { label_converter: ClassVar[dict[int, int]] = {
0: 0, 0: 0,
1: 0, 1: 0,
2: 1, 2: 1,
@ -232,24 +233,24 @@ class BigEarthNet(NonGeoDataset):
42: 18, 42: 18,
} }
splits_metadata = { splits_metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': { 'train': {
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false', # noqa: E501 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false',
'filename': 'bigearthnet-train.csv', 'filename': 'bigearthnet-train.csv',
'md5': '623e501b38ab7b12fe44f0083c00986d', 'md5': '623e501b38ab7b12fe44f0083c00986d',
}, },
'val': { 'val': {
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false', # noqa: E501 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false',
'filename': 'bigearthnet-val.csv', 'filename': 'bigearthnet-val.csv',
'md5': '22efe8ed9cbd71fa10742ff7df2b7978', 'md5': '22efe8ed9cbd71fa10742ff7df2b7978',
}, },
'test': { 'test': {
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false', # noqa: E501 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false',
'filename': 'bigearthnet-test.csv', 'filename': 'bigearthnet-test.csv',
'md5': '697fb90677e30571b9ac7699b7e5b432', 'md5': '697fb90677e30571b9ac7699b7e5b432',
}, },
} }
metadata = { metadata: ClassVar[dict[str, dict[str, str]]] = {
's1': { 's1': {
'url': 'https://zenodo.org/records/12687186/files/BigEarthNet-S1-v1.0.tar.gz', 'url': 'https://zenodo.org/records/12687186/files/BigEarthNet-S1-v1.0.tar.gz',
'md5': '94ced73440dea8c7b9645ee738c5a172', 'md5': '94ced73440dea8c7b9645ee738c5a172',

Просмотреть файл

@ -50,7 +50,7 @@ class BioMassters(NonGeoDataset):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
valid_splits = ['train', 'test'] valid_splits = ('train', 'test')
valid_sensors = ('S1', 'S2') valid_sensors = ('S1', 'S2')
metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv' metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv'

Просмотреть файл

@ -30,7 +30,7 @@ class CanadianBuildingFootprints(VectorDataset):
# https://github.com/microsoft/CanadianBuildingFootprints/issues/11 # https://github.com/microsoft/CanadianBuildingFootprints/issues/11
url = 'https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/' url = 'https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/'
provinces_territories = [ provinces_territories = (
'Alberta', 'Alberta',
'BritishColumbia', 'BritishColumbia',
'Manitoba', 'Manitoba',
@ -44,8 +44,8 @@ class CanadianBuildingFootprints(VectorDataset):
'Quebec', 'Quebec',
'Saskatchewan', 'Saskatchewan',
'YukonTerritory', 'YukonTerritory',
] )
md5s = [ md5s = (
'8b4190424e57bb0902bd8ecb95a9235b', '8b4190424e57bb0902bd8ecb95a9235b',
'fea05d6eb0006710729c675de63db839', 'fea05d6eb0006710729c675de63db839',
'adf11187362624d68f9c69aaa693c46f', 'adf11187362624d68f9c69aaa693c46f',
@ -59,7 +59,7 @@ class CanadianBuildingFootprints(VectorDataset):
'9ff4417ae00354d39a0cf193c8df592c', '9ff4417ae00354d39a0cf193c8df592c',
'a51078d8e60082c7d3a3818240da6dd5', 'a51078d8e60082c7d3a3818240da6dd5',
'c11f3bd914ecabd7cac2cb2871ec0261', 'c11f3bd914ecabd7cac2cb2871ec0261',
] )
def __init__( def __init__(
self, self,

Просмотреть файл

@ -6,7 +6,7 @@
import os import os
import pathlib import pathlib
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any from typing import Any, ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
@ -38,7 +38,7 @@ class CDL(RasterDataset):
If you use this dataset in your research, please cite it using the following format: If you use this dataset in your research, please cite it using the following format:
* https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0 * https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0
""" # noqa: E501 """
filename_glob = '*_30m_cdls.tif' filename_glob = '*_30m_cdls.tif'
filename_regex = r""" filename_regex = r"""
@ -49,8 +49,8 @@ class CDL(RasterDataset):
date_format = '%Y' date_format = '%Y'
is_image = False is_image = False
url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip' # noqa: E501 url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip'
md5s = { md5s: ClassVar[dict[int, str]] = {
2023: '8c7685d6278d50c554f934b16a6076b7', 2023: '8c7685d6278d50c554f934b16a6076b7',
2022: '754cf50670cdfee511937554785de3e6', 2022: '754cf50670cdfee511937554785de3e6',
2021: '27606eab08fe975aa138baad3e5dfcd8', 2021: '27606eab08fe975aa138baad3e5dfcd8',
@ -69,7 +69,7 @@ class CDL(RasterDataset):
2008: '0610f2f17ab60a9fbb3baeb7543993a4', 2008: '0610f2f17ab60a9fbb3baeb7543993a4',
} }
cmap = { cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255), 0: (0, 0, 0, 255),
1: (255, 211, 0, 255), 1: (255, 211, 0, 255),
2: (255, 37, 37, 255), 2: (255, 37, 37, 255),

Просмотреть файл

@ -4,7 +4,8 @@
"""ChaBuD dataset.""" """ChaBuD dataset."""
import os import os
from collections.abc import Callable from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -53,7 +54,7 @@ class ChaBuD(NonGeoDataset):
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
all_bands = [ all_bands = (
'B01', 'B01',
'B02', 'B02',
'B03', 'B03',
@ -66,10 +67,10 @@ class ChaBuD(NonGeoDataset):
'B09', 'B09',
'B11', 'B11',
'B12', 'B12',
] )
rgb_bands = ['B04', 'B03', 'B02'] rgb_bands = ('B04', 'B03', 'B02')
folds = {'train': [1, 2, 3, 4], 'val': [0]} folds: ClassVar[dict[str, list[int]]] = {'train': [1, 2, 3, 4], 'val': [0]}
url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5' # noqa: E501 url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5'
filename = 'train_eval.hdf5' filename = 'train_eval.hdf5'
md5 = '15d78fb825f9a81dad600db828d22c08' md5 = '15d78fb825f9a81dad600db828d22c08'
@ -77,7 +78,7 @@ class ChaBuD(NonGeoDataset):
self, self,
root: Path = 'data', root: Path = 'data',
split: str = 'train', split: str = 'train',
bands: list[str] = all_bands, bands: Sequence[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False, download: bool = False,
checksum: bool = False, checksum: bool = False,

Просмотреть файл

@ -9,7 +9,7 @@ import pathlib
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast from typing import Any, ClassVar, cast
import fiona import fiona
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -39,7 +39,7 @@ class Chesapeake(RasterDataset, ABC):
The Chesapeake Bay Land Use and Land Cover Database (LULC) facilitates The Chesapeake Bay Land Use and Land Cover Database (LULC) facilitates
characterization of the landscape and land change for and between discrete time characterization of the landscape and land change for and between discrete time
periods. The database was developed by the University of Vermonts Spatial Analysis periods. The database was developed by the University of Vermont's Spatial Analysis
Laboratory in cooperation with Chesapeake Conservancy (CC) and U.S. Geological Laboratory in cooperation with Chesapeake Conservancy (CC) and U.S. Geological
Survey (USGS) as part of a 6-year Cooperative Agreement between Chesapeake Survey (USGS) as part of a 6-year Cooperative Agreement between Chesapeake
Conservancy and the U.S. Environmental Protection Agency (EPA) and a separate Conservancy and the U.S. Environmental Protection Agency (EPA) and a separate
@ -83,7 +83,7 @@ class Chesapeake(RasterDataset, ABC):
"""State abbreviation.""" """State abbreviation."""
return self.__class__.__name__[-2:].lower() return self.__class__.__name__[-2:].lower()
cmap = { cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
11: (0, 92, 230, 255), 11: (0, 92, 230, 255),
12: (0, 92, 230, 255), 12: (0, 92, 230, 255),
13: (0, 92, 230, 255), 13: (0, 92, 230, 255),
@ -255,7 +255,7 @@ class Chesapeake(RasterDataset, ABC):
class ChesapeakeDC(Chesapeake): class ChesapeakeDC(Chesapeake):
"""This subset of the dataset contains data only for Washington, D.C.""" """This subset of the dataset contains data only for Washington, D.C."""
md5s = { md5s: ClassVar[dict[int, str]] = {
2013: '9f1df21afbb9d5c0fcf33af7f6750a7f', 2013: '9f1df21afbb9d5c0fcf33af7f6750a7f',
2017: 'c45e4af2950e1c93ecd47b61af296d9b', 2017: 'c45e4af2950e1c93ecd47b61af296d9b',
} }
@ -264,7 +264,7 @@ class ChesapeakeDC(Chesapeake):
class ChesapeakeDE(Chesapeake): class ChesapeakeDE(Chesapeake):
"""This subset of the dataset contains data only for Delaware.""" """This subset of the dataset contains data only for Delaware."""
md5s = { md5s: ClassVar[dict[int, str]] = {
2013: '5850d96d897babba85610658aeb5951a', 2013: '5850d96d897babba85610658aeb5951a',
2018: 'ee94c8efeae423d898677104117bdebc', 2018: 'ee94c8efeae423d898677104117bdebc',
} }
@ -273,7 +273,7 @@ class ChesapeakeDE(Chesapeake):
class ChesapeakeMD(Chesapeake): class ChesapeakeMD(Chesapeake):
"""This subset of the dataset contains data only for Maryland.""" """This subset of the dataset contains data only for Maryland."""
md5s = { md5s: ClassVar[dict[int, str]] = {
2013: '9c3ca5040668d15284c1bd64b7d6c7a0', 2013: '9c3ca5040668d15284c1bd64b7d6c7a0',
2018: '0647530edf8bec6e60f82760dcc7db9c', 2018: '0647530edf8bec6e60f82760dcc7db9c',
} }
@ -282,7 +282,7 @@ class ChesapeakeMD(Chesapeake):
class ChesapeakeNY(Chesapeake): class ChesapeakeNY(Chesapeake):
"""This subset of the dataset contains data only for New York.""" """This subset of the dataset contains data only for New York."""
md5s = { md5s: ClassVar[dict[int, str]] = {
2013: '38a29b721610ba661a7f8b6ec71a48b7', 2013: '38a29b721610ba661a7f8b6ec71a48b7',
2017: '4c1b1a50fd9368cd7b8b12c4d80c63f3', 2017: '4c1b1a50fd9368cd7b8b12c4d80c63f3',
} }
@ -291,7 +291,7 @@ class ChesapeakeNY(Chesapeake):
class ChesapeakePA(Chesapeake): class ChesapeakePA(Chesapeake):
"""This subset of the dataset contains data only for Pennsylvania.""" """This subset of the dataset contains data only for Pennsylvania."""
md5s = { md5s: ClassVar[dict[int, str]] = {
2013: '86febd603a120a49ef7d23ef486152a3', 2013: '86febd603a120a49ef7d23ef486152a3',
2017: 'b11d92e4471e8cb887c790d488a338c1', 2017: 'b11d92e4471e8cb887c790d488a338c1',
} }
@ -300,7 +300,7 @@ class ChesapeakePA(Chesapeake):
class ChesapeakeVA(Chesapeake): class ChesapeakeVA(Chesapeake):
"""This subset of the dataset contains data only for Virginia.""" """This subset of the dataset contains data only for Virginia."""
md5s = { md5s: ClassVar[dict[int, str]] = {
2014: '49c9700c71854eebd00de24d8488eb7c', 2014: '49c9700c71854eebd00de24d8488eb7c',
2018: '51731c8b5632978bfd1df869ea10db5b', 2018: '51731c8b5632978bfd1df869ea10db5b',
} }
@ -309,7 +309,7 @@ class ChesapeakeVA(Chesapeake):
class ChesapeakeWV(Chesapeake): class ChesapeakeWV(Chesapeake):
"""This subset of the dataset contains data only for West Virginia.""" """This subset of the dataset contains data only for West Virginia."""
md5s = { md5s: ClassVar[dict[int, str]] = {
2014: '32fea42fae147bd58a83e3ea6cccfb94', 2014: '32fea42fae147bd58a83e3ea6cccfb94',
2018: '80f25dcba72e39685ab33215c5d97292', 2018: '80f25dcba72e39685ab33215c5d97292',
} }
@ -337,16 +337,16 @@ class ChesapeakeCVPR(GeoDataset):
* https://doi.org/10.1109/cvpr.2019.01301 * https://doi.org/10.1109/cvpr.2019.01301
""" """
subdatasets = ['base', 'prior_extension'] subdatasets = ('base', 'prior_extension')
urls = { urls: ClassVar[dict[str, str]] = {
'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip', # noqa: E501 'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip',
'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1', # noqa: E501 'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1',
} }
filenames = { filenames: ClassVar[dict[str, str]] = {
'base': 'cvpr_chesapeake_landcover.zip', 'base': 'cvpr_chesapeake_landcover.zip',
'prior_extension': 'cvpr_chesapeake_landcover_prior_extension.zip', 'prior_extension': 'cvpr_chesapeake_landcover_prior_extension.zip',
} }
md5s = { md5s: ClassVar[dict[str, str]] = {
'base': '1225ccbb9590e9396875f221e5031514', 'base': '1225ccbb9590e9396875f221e5031514',
'prior_extension': '402f41d07823c8faf7ea6960d7c4e17a', 'prior_extension': '402f41d07823c8faf7ea6960d7c4e17a',
} }
@ -354,7 +354,7 @@ class ChesapeakeCVPR(GeoDataset):
crs = CRS.from_epsg(3857) crs = CRS.from_epsg(3857)
res = 1 res = 1
lc_cmap = { lc_cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0), 0: (0, 0, 0, 0),
1: (0, 197, 255, 255), 1: (0, 197, 255, 255),
2: (38, 115, 0, 255), 2: (38, 115, 0, 255),
@ -374,7 +374,7 @@ class ChesapeakeCVPR(GeoDataset):
] ]
) )
valid_layers = [ valid_layers = (
'naip-new', 'naip-new',
'naip-old', 'naip-old',
'landsat-leaf-on', 'landsat-leaf-on',
@ -383,8 +383,8 @@ class ChesapeakeCVPR(GeoDataset):
'lc', 'lc',
'buildings', 'buildings',
'prior_from_cooccurrences_101_31_no_osm_no_buildings', 'prior_from_cooccurrences_101_31_no_osm_no_buildings',
] )
states = ['de', 'md', 'va', 'wv', 'pa', 'ny'] states = ('de', 'md', 'va', 'wv', 'pa', 'ny')
splits = ( splits = (
[f'{state}-train' for state in states] [f'{state}-train' for state in states]
+ [f'{state}-val' for state in states] + [f'{state}-val' for state in states]
@ -392,7 +392,7 @@ class ChesapeakeCVPR(GeoDataset):
) )
# these are used to check the integrity of the dataset # these are used to check the integrity of the dataset
_files = [ _files = (
'de_1m_2013_extended-debuffered-test_tiles', 'de_1m_2013_extended-debuffered-test_tiles',
'de_1m_2013_extended-debuffered-train_tiles', 'de_1m_2013_extended-debuffered-train_tiles',
'de_1m_2013_extended-debuffered-val_tiles', 'de_1m_2013_extended-debuffered-val_tiles',
@ -412,18 +412,18 @@ class ChesapeakeCVPR(GeoDataset):
'wv_1m_2014_extended-debuffered-train_tiles', 'wv_1m_2014_extended-debuffered-train_tiles',
'wv_1m_2014_extended-debuffered-val_tiles', 'wv_1m_2014_extended-debuffered-val_tiles',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif', # noqa: E501 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif', # noqa: E501 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif',
'spatial_index.geojson', 'spatial_index.geojson',
] )
p_src_crs = pyproj.CRS('epsg:3857') p_src_crs = pyproj.CRS('epsg:3857')
p_transformers = { p_transformers: ClassVar[dict[str, CRS]] = {
'epsg:26917': pyproj.Transformer.from_crs( 'epsg:26917': pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
).transform, ).transform,
@ -511,7 +511,7 @@ class ChesapeakeCVPR(GeoDataset):
'lc': row['properties']['lc'], 'lc': row['properties']['lc'],
'nlcd': row['properties']['nlcd'], 'nlcd': row['properties']['nlcd'],
'buildings': row['properties']['buildings'], 'buildings': row['properties']['buildings'],
'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn, # noqa: E501 'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn,
}, },
) )

Просмотреть файл

@ -5,6 +5,7 @@
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -55,9 +56,9 @@ class CloudCoverDetection(NonGeoDataset):
""" """
url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_cloud_cover_detection_challenge_v1/final' url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_cloud_cover_detection_challenge_v1/final'
all_bands = ['B02', 'B03', 'B04', 'B08'] all_bands = ('B02', 'B03', 'B04', 'B08')
rgb_bands = ['B04', 'B03', 'B02'] rgb_bands = ('B04', 'B03', 'B02')
splits = {'train': 'public', 'test': 'private'} splits: ClassVar[dict[str, str]] = {'train': 'public', 'test': 'private'}
def __init__( def __init__(
self, self,

Просмотреть файл

@ -42,7 +42,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip' zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip'
md5 = '3e7f9f23bf971c25e828b36e6c5496e3' md5 = '3e7f9f23bf971c25e828b36e6c5496e3'
all_countries = [ all_countries = (
'AndamanAndNicobar', 'AndamanAndNicobar',
'Angola', 'Angola',
'Anguilla', 'Anguilla',
@ -164,9 +164,9 @@ class CMSGlobalMangroveCanopy(RasterDataset):
'VirginIslandsUs', 'VirginIslandsUs',
'WallisAndFutuna', 'WallisAndFutuna',
'Yemen', 'Yemen',
] )
measurements = ['agb', 'hba95', 'hmax95'] measurements = ('agb', 'hba95', 'hmax95')
def __init__( def __init__(
self, self,

Просмотреть файл

@ -50,12 +50,12 @@ class COWC(NonGeoDataset, abc.ABC):
@property @property
@abc.abstractmethod @abc.abstractmethod
def filenames(self) -> list[str]: def filenames(self) -> tuple[str, ...]:
"""List of files to download.""" """List of files to download."""
@property @property
@abc.abstractmethod @abc.abstractmethod
def md5s(self) -> list[str]: def md5s(self) -> tuple[str, ...]:
"""List of MD5 checksums of files to download.""" """List of MD5 checksums of files to download."""
@property @property
@ -239,7 +239,7 @@ class COWCCounting(COWC):
base_url = ( base_url = (
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/' 'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/'
) )
filenames = [ filenames = (
'COWC_train_list_64_class.txt.bz2', 'COWC_train_list_64_class.txt.bz2',
'COWC_test_list_64_class.txt.bz2', 'COWC_test_list_64_class.txt.bz2',
'COWC_Counting_Toronto_ISPRS.tbz', 'COWC_Counting_Toronto_ISPRS.tbz',
@ -248,8 +248,8 @@ class COWCCounting(COWC):
'COWC_Counting_Vaihingen_ISPRS.tbz', 'COWC_Counting_Vaihingen_ISPRS.tbz',
'COWC_Counting_Columbus_CSUAV_AFRL.tbz', 'COWC_Counting_Columbus_CSUAV_AFRL.tbz',
'COWC_Counting_Utah_AGRC.tbz', 'COWC_Counting_Utah_AGRC.tbz',
] )
md5s = [ md5s = (
'187543d20fa6d591b8da51136e8ef8fb', '187543d20fa6d591b8da51136e8ef8fb',
'930cfd6e160a7b36db03146282178807', '930cfd6e160a7b36db03146282178807',
'bc2613196dfa93e66d324ae43e7c1fdb', 'bc2613196dfa93e66d324ae43e7c1fdb',
@ -258,7 +258,7 @@ class COWCCounting(COWC):
'4009c1e420566390746f5b4db02afdb9', '4009c1e420566390746f5b4db02afdb9',
'daf8033c4e8ceebbf2c3cac3fabb8b10', 'daf8033c4e8ceebbf2c3cac3fabb8b10',
'777ec107ed2a3d54597a739ce74f95ad', '777ec107ed2a3d54597a739ce74f95ad',
] )
filename = 'COWC_{}_list_64_class.txt' filename = 'COWC_{}_list_64_class.txt'
@ -268,7 +268,7 @@ class COWCDetection(COWC):
base_url = ( base_url = (
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/' 'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/'
) )
filenames = [ filenames = (
'COWC_train_list_detection.txt.bz2', 'COWC_train_list_detection.txt.bz2',
'COWC_test_list_detection.txt.bz2', 'COWC_test_list_detection.txt.bz2',
'COWC_Detection_Toronto_ISPRS.tbz', 'COWC_Detection_Toronto_ISPRS.tbz',
@ -277,8 +277,8 @@ class COWCDetection(COWC):
'COWC_Detection_Vaihingen_ISPRS.tbz', 'COWC_Detection_Vaihingen_ISPRS.tbz',
'COWC_Detection_Columbus_CSUAV_AFRL.tbz', 'COWC_Detection_Columbus_CSUAV_AFRL.tbz',
'COWC_Detection_Utah_AGRC.tbz', 'COWC_Detection_Utah_AGRC.tbz',
] )
md5s = [ md5s = (
'c954a5a3dac08c220b10cfbeec83893c', 'c954a5a3dac08c220b10cfbeec83893c',
'c6c2d0a78f12a2ad88b286b724a57c1a', 'c6c2d0a78f12a2ad88b286b724a57c1a',
'11af24f43b198b0f13c8e94814008a48', '11af24f43b198b0f13c8e94814008a48',
@ -287,7 +287,7 @@ class COWCDetection(COWC):
'23945d5b22455450a938382ccc2a8b27', '23945d5b22455450a938382ccc2a8b27',
'f40522dc97bea41b10117d4a5b946a6f', 'f40522dc97bea41b10117d4a5b946a6f',
'195da7c9443a939a468c9f232fd86ee3', '195da7c9443a939a468c9f232fd86ee3',
] )
filename = 'COWC_{}_list_detection.txt' filename = 'COWC_{}_list_detection.txt'

Просмотреть файл

@ -7,6 +7,7 @@ import glob
import json import json
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -55,7 +56,7 @@ class CropHarvest(NonGeoDataset):
""" """
# https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py # https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py
all_bands = [ all_bands = (
'VV', 'VV',
'VH', 'VH',
'B2', 'B2',
@ -74,12 +75,12 @@ class CropHarvest(NonGeoDataset):
'elevation', 'elevation',
'slope', 'slope',
'NDVI', 'NDVI',
] )
rgb_bands = ['B4', 'B3', 'B2'] rgb_bands = ('B4', 'B3', 'B2')
features_url = 'https://zenodo.org/records/7257688/files/features.tar.gz?download=1' features_url = 'https://zenodo.org/records/7257688/files/features.tar.gz?download=1'
labels_url = 'https://zenodo.org/records/7257688/files/labels.geojson?download=1' labels_url = 'https://zenodo.org/records/7257688/files/labels.geojson?download=1'
file_dict = { file_dict: ClassVar[dict[str, dict[str, str]]] = {
'features': { 'features': {
'url': features_url, 'url': features_url,
'filename': 'features.tar.gz', 'filename': 'features.tar.gz',

Просмотреть файл

@ -65,8 +65,8 @@ class CV4AKenyaCropType(NonGeoDataset):
""" """
url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge' url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge'
tiles = list(map(str, range(4))) tiles = tuple(map(str, range(4)))
dates = [ dates = (
'20190606', '20190606',
'20190701', '20190701',
'20190706', '20190706',
@ -80,7 +80,7 @@ class CV4AKenyaCropType(NonGeoDataset):
'20190924', '20190924',
'20191004', '20191004',
'20191103', '20191103',
] )
all_bands = ( all_bands = (
'B01', 'B01',
'B02', 'B02',
@ -96,7 +96,7 @@ class CV4AKenyaCropType(NonGeoDataset):
'B12', 'B12',
'CLD', 'CLD',
) )
rgb_bands = ['B04', 'B03', 'B02'] rgb_bands = ('B04', 'B03', 'B02')
# Same for all tiles # Same for all tiles
tile_height = 3035 tile_height = 3035
@ -141,11 +141,13 @@ class CV4AKenyaCropType(NonGeoDataset):
# Calculate the indices that we will use over all tiles # Calculate the indices that we will use over all tiles
self.chips_metadata = [] self.chips_metadata = []
for tile_index in range(len(self.tiles)): for tile_index in range(len(self.tiles)):
for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ for y in [
self.tile_height - self.chip_size *list(range(0, self.tile_height - self.chip_size, stride)),
self.tile_height - self.chip_size,
]: ]:
for x in list(range(0, self.tile_width - self.chip_size, stride)) + [ for x in [
self.tile_width - self.chip_size *list(range(0, self.tile_width - self.chip_size, stride)),
self.tile_width - self.chip_size,
]: ]:
self.chips_metadata.append((tile_index, y, x)) self.chips_metadata.append((tile_index, y, x))

Просмотреть файл

@ -74,13 +74,13 @@ class DeepGlobeLandCover(NonGeoDataset):
$ unzip deepglobe2018-landcover-segmentation-traindataset.zip $ unzip deepglobe2018-landcover-segmentation-traindataset.zip
.. versionadded:: 0.3 .. versionadded:: 0.3
""" # noqa: E501 """
filename = 'data.zip' filename = 'data.zip'
data_root = 'data' data_root = 'data'
md5 = 'f32684b0b2bf6f8d604cd359a399c061' md5 = 'f32684b0b2bf6f8d604cd359a399c061'
splits = ['train', 'test'] splits = ('train', 'test')
classes = [ classes = (
'Urban land', 'Urban land',
'Agriculture land', 'Agriculture land',
'Rangeland', 'Rangeland',
@ -88,8 +88,8 @@ class DeepGlobeLandCover(NonGeoDataset):
'Water', 'Water',
'Barren land', 'Barren land',
'Unknown', 'Unknown',
] )
colormap = [ colormap = (
(0, 255, 255), (0, 255, 255),
(255, 255, 0), (255, 255, 0),
(255, 0, 255), (255, 0, 255),
@ -97,7 +97,7 @@ class DeepGlobeLandCover(NonGeoDataset):
(0, 0, 255), (0, 0, 255),
(255, 255, 255), (255, 255, 255),
(0, 0, 0), (0, 0, 0),
] )
def __init__( def __init__(
self, self,
@ -246,12 +246,15 @@ class DeepGlobeLandCover(NonGeoDataset):
""" """
ncols = 1 ncols = 1
image1 = draw_semantic_segmentation_masks( image1 = draw_semantic_segmentation_masks(
sample['image'], sample['mask'], alpha=alpha, colors=self.colormap sample['image'], sample['mask'], alpha=alpha, colors=list(self.colormap)
) )
if 'prediction' in sample: if 'prediction' in sample:
ncols += 1 ncols += 1
image2 = draw_semantic_segmentation_masks( image2 = draw_semantic_segmentation_masks(
sample['image'], sample['prediction'], alpha=alpha, colors=self.colormap sample['image'],
sample['prediction'],
alpha=alpha,
colors=list(self.colormap),
) )
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))

Просмотреть файл

@ -6,6 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -75,9 +76,9 @@ class DFC2022(NonGeoDataset):
* https://doi.org/10.1007/s10994-020-05943-y * https://doi.org/10.1007/s10994-020-05943-y
.. versionadded:: 0.3 .. versionadded:: 0.3
""" # noqa: E501 """
classes = [ classes = (
'No information', 'No information',
'Urban fabric', 'Urban fabric',
'Industrial, commercial, public, military, private and transport units', 'Industrial, commercial, public, military, private and transport units',
@ -94,8 +95,8 @@ class DFC2022(NonGeoDataset):
'Wetlands', 'Wetlands',
'Water', 'Water',
'Clouds and Shadows', 'Clouds and Shadows',
] )
colormap = [ colormap = (
'#231F20', '#231F20',
'#DB5F57', '#DB5F57',
'#DB9757', '#DB9757',
@ -112,8 +113,8 @@ class DFC2022(NonGeoDataset):
'#579BDB', '#579BDB',
'#0062FF', '#0062FF',
'#231F20', '#231F20',
] )
metadata = { metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': { 'train': {
'filename': 'labeled_train.zip', 'filename': 'labeled_train.zip',
'md5': '2e87d6a218e466dd0566797d7298c7a9', 'md5': '2e87d6a218e466dd0566797d7298c7a9',

Просмотреть файл

@ -6,7 +6,7 @@
import os import os
import sys import sys
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, cast from typing import Any, ClassVar, cast
import fiona import fiona
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -54,9 +54,9 @@ class EnviroAtlas(GeoDataset):
crs = CRS.from_epsg(3857) crs = CRS.from_epsg(3857)
res = 1 res = 1
valid_prior_layers = ['prior', 'prior_no_osm_no_buildings'] valid_prior_layers = ('prior', 'prior_no_osm_no_buildings')
valid_layers = [ valid_layers = (
'naip', 'naip',
'nlcd', 'nlcd',
'roads', 'roads',
@ -65,14 +65,15 @@ class EnviroAtlas(GeoDataset):
'waterbodies', 'waterbodies',
'buildings', 'buildings',
'lc', 'lc',
] + valid_prior_layers *valid_prior_layers,
)
cities = [ cities = (
'pittsburgh_pa-2010_1m', 'pittsburgh_pa-2010_1m',
'durham_nc-2012_1m', 'durham_nc-2012_1m',
'austin_tx-2012_1m', 'austin_tx-2012_1m',
'phoenix_az-2010_1m', 'phoenix_az-2010_1m',
] )
splits = ( splits = (
[f'{state}-train' for state in cities[:1]] [f'{state}-train' for state in cities[:1]]
+ [f'{state}-val' for state in cities[:1]] + [f'{state}-val' for state in cities[:1]]
@ -81,7 +82,7 @@ class EnviroAtlas(GeoDataset):
) )
# these are used to check the integrity of the dataset # these are used to check the integrity of the dataset
_files = [ _files = (
'austin_tx-2012_1m-test_tiles-debuffered', 'austin_tx-2012_1m-test_tiles-debuffered',
'austin_tx-2012_1m-val5_tiles-debuffered', 'austin_tx-2012_1m-val5_tiles-debuffered',
'durham_nc-2012_1m-test_tiles-debuffered', 'durham_nc-2012_1m-test_tiles-debuffered',
@ -100,13 +101,13 @@ class EnviroAtlas(GeoDataset):
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif', 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif', 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif', 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif', # noqa: E501 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif',
'spatial_index.geojson', 'spatial_index.geojson',
] )
p_src_crs = pyproj.CRS('epsg:3857') p_src_crs = pyproj.CRS('epsg:3857')
p_transformers = { p_transformers: ClassVar[dict[str, CRS]] = {
'epsg:26917': pyproj.Transformer.from_crs( 'epsg:26917': pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
).transform, ).transform,
@ -222,7 +223,7 @@ class EnviroAtlas(GeoDataset):
dtype=np.uint8, dtype=np.uint8,
) )
highres_classes = [ highres_classes = (
'Unclassified', 'Unclassified',
'Water', 'Water',
'Impervious Surface', 'Impervious Surface',
@ -234,7 +235,7 @@ class EnviroAtlas(GeoDataset):
'Orchards', 'Orchards',
'Woody Wetlands', 'Woody Wetlands',
'Emergent Wetlands', 'Emergent Wetlands',
] )
highres_cmap = ListedColormap( highres_cmap = ListedColormap(
[ [
[1.00000000, 1.00000000, 1.00000000], [1.00000000, 1.00000000, 1.00000000],

Просмотреть файл

@ -6,6 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -56,9 +57,9 @@ class ETCI2021(NonGeoDataset):
the ETCI competition. the ETCI competition.
""" """
bands = ['VV', 'VH'] bands = ('VV', 'VH')
masks = ['flood', 'water_body'] masks = ('flood', 'water_body')
metadata = { metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': { 'train': {
'filename': 'train.zip', 'filename': 'train.zip',
'md5': '1e95792fe0f6e3c9000abdeab2a8ab0f', 'md5': '1e95792fe0f6e3c9000abdeab2a8ab0f',

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os import os
import pathlib import pathlib
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any from typing import Any, ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.figure import Figure from matplotlib.figure import Figure
@ -53,7 +53,7 @@ class EUDEM(RasterDataset):
zipfile_glob = 'eu_dem_v11_*[A-Z0-9].zip' zipfile_glob = 'eu_dem_v11_*[A-Z0-9].zip'
filename_regex = '(?P<name>[eudem_v11]{10})_(?P<id>[A-Z0-9]{6})' filename_regex = '(?P<name>[eudem_v11]{10})_(?P<id>[A-Z0-9]{6})'
md5s = { md5s: ClassVar[dict[str, str]] = {
'eu_dem_v11_E00N20.zip': '96edc7e11bc299b994e848050d6be591', 'eu_dem_v11_E00N20.zip': '96edc7e11bc299b994e848050d6be591',
'eu_dem_v11_E10N00.zip': 'e14be147ac83eddf655f4833d55c1571', 'eu_dem_v11_E10N00.zip': 'e14be147ac83eddf655f4833d55c1571',
'eu_dem_v11_E10N10.zip': '2eb5187e4d827245b33768404529c709', 'eu_dem_v11_E10N10.zip': '2eb5187e4d827245b33768404529c709',

Просмотреть файл

@ -61,7 +61,7 @@ class EuroCrops(VectorDataset):
date_format = '%Y' date_format = '%Y'
# Filename and md5 of files in this dataset on zenodo. # Filename and md5 of files in this dataset on zenodo.
zenodo_files = [ zenodo_files: tuple[tuple[str, str], ...] = (
('AT_2021.zip', '490241df2e3d62812e572049fc0c36c5'), ('AT_2021.zip', '490241df2e3d62812e572049fc0c36c5'),
('BE_VLG_2021.zip', 'ac4b9e12ad39b1cba47fdff1a786c2d7'), ('BE_VLG_2021.zip', 'ac4b9e12ad39b1cba47fdff1a786c2d7'),
('DE_LS_2021.zip', '6d94e663a3ff7988b32cb36ea24a724f'), ('DE_LS_2021.zip', '6d94e663a3ff7988b32cb36ea24a724f'),
@ -81,7 +81,7 @@ class EuroCrops(VectorDataset):
# Year is unknown for Romania portion (ny = no year). # Year is unknown for Romania portion (ny = no year).
# We skip since it is inconsistent with the rest of the data. # We skip since it is inconsistent with the rest of the data.
# ("RO_ny.zip", "648e1504097765b4b7f825decc838882"), # ("RO_ny.zip", "648e1504097765b4b7f825decc838882"),
] )
def __init__( def __init__(
self, self,

Просмотреть файл

@ -5,7 +5,7 @@
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import cast from typing import ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset):
* https://ieeexplore.ieee.org/document/8519248 * https://ieeexplore.ieee.org/document/8519248
""" """
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip' # noqa: E501 url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip'
filename = 'EuroSATallBands.zip' filename = 'EuroSATallBands.zip'
md5 = '5ac12b3b2557aa56e1826e981e8e200e' md5 = '5ac12b3b2557aa56e1826e981e8e200e'
@ -63,13 +63,13 @@ class EuroSAT(NonGeoClassificationDataset):
'ds', 'images', 'remote_sensing', 'otherDatasets', 'sentinel_2', 'tif' 'ds', 'images', 'remote_sensing', 'otherDatasets', 'sentinel_2', 'tif'
) )
splits = ['train', 'val', 'test'] splits = ('train', 'val', 'test')
split_urls = { split_urls: ClassVar[dict[str, str]] = {
'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt', # noqa: E501 'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt',
'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt', # noqa: E501 'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt',
'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt', # noqa: E501 'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt',
} }
split_md5s = { split_md5s: ClassVar[dict[str, str]] = {
'train': '908f142e73d6acdf3f482c5e80d851b1', 'train': '908f142e73d6acdf3f482c5e80d851b1',
'val': '95de90f2aa998f70a3b2416bfe0687b4', 'val': '95de90f2aa998f70a3b2416bfe0687b4',
'test': '7ae5ab94471417b6e315763121e67c5f', 'test': '7ae5ab94471417b6e315763121e67c5f',
@ -93,7 +93,10 @@ class EuroSAT(NonGeoClassificationDataset):
rgb_bands = ('B04', 'B03', 'B02') rgb_bands = ('B04', 'B03', 'B02')
BAND_SETS = {'all': all_band_names, 'rgb': rgb_bands} BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
'all': all_band_names,
'rgb': rgb_bands,
}
def __init__( def __init__(
self, self,
@ -302,12 +305,12 @@ class EuroSATSpatial(EuroSAT):
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
split_urls = { split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt', 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt',
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt', 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt',
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt', 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt',
} }
split_md5s = { split_md5s: ClassVar[dict[str, str]] = {
'train': '7be3254be39f23ce4d4d144290c93292', 'train': '7be3254be39f23ce4d4d144290c93292',
'val': 'acf392290050bb3df790dc8fc0ebf193', 'val': 'acf392290050bb3df790dc8fc0ebf193',
'test': '5ec1733f9c16116bf0aa2d921fc613ef', 'test': '5ec1733f9c16116bf0aa2d921fc613ef',
@ -325,16 +328,16 @@ class EuroSAT100(EuroSAT):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip' # noqa: E501 url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip'
filename = 'EuroSAT100.zip' filename = 'EuroSAT100.zip'
md5 = 'c21c649ba747e86eda813407ef17d596' md5 = 'c21c649ba747e86eda813407ef17d596'
split_urls = { split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt', # noqa: E501 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt',
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt', # noqa: E501 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt',
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt', # noqa: E501 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt',
} }
split_md5s = { split_md5s: ClassVar[dict[str, str]] = {
'train': '033d0c23e3a75e3fa79618b0e35fe1c7', 'train': '033d0c23e3a75e3fa79618b0e35fe1c7',
'val': '3e3f8b3c344182b8d126c4cc88f3f215', 'val': '3e3f8b3c344182b8d126c4cc88f3f215',
'test': 'f908f151b950f270ad18e61153579794', 'test': 'f908f151b950f270ad18e61153579794',

Просмотреть файл

@ -6,7 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import Any, cast from typing import Any, ClassVar, cast
from xml.etree.ElementTree import Element, parse from xml.etree.ElementTree import Element, parse
import matplotlib.patches as patches import matplotlib.patches as patches
@ -119,7 +119,7 @@ class FAIR1M(NonGeoDataset):
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
classes = { classes: ClassVar[dict[str, dict[str, Any]]] = {
'Passenger Ship': {'id': 0, 'category': 'Ship'}, 'Passenger Ship': {'id': 0, 'category': 'Ship'},
'Motorboat': {'id': 1, 'category': 'Ship'}, 'Motorboat': {'id': 1, 'category': 'Ship'},
'Fishing Boat': {'id': 2, 'category': 'Ship'}, 'Fishing Boat': {'id': 2, 'category': 'Ship'},
@ -159,12 +159,12 @@ class FAIR1M(NonGeoDataset):
'Bridge': {'id': 36, 'category': 'Road'}, 'Bridge': {'id': 36, 'category': 'Road'},
} }
filename_glob = { filename_glob: ClassVar[dict[str, str]] = {
'train': os.path.join('train', '**', 'images', '*.tif'), 'train': os.path.join('train', '**', 'images', '*.tif'),
'val': os.path.join('validation', 'images', '*.tif'), 'val': os.path.join('validation', 'images', '*.tif'),
'test': os.path.join('test', 'images', '*.tif'), 'test': os.path.join('test', 'images', '*.tif'),
} }
directories = { directories: ClassVar[dict[str, tuple[str, ...]]] = {
'train': ( 'train': (
os.path.join('train', 'part1', 'images'), os.path.join('train', 'part1', 'images'),
os.path.join('train', 'part1', 'labelXml'), os.path.join('train', 'part1', 'labelXml'),
@ -175,9 +175,9 @@ class FAIR1M(NonGeoDataset):
os.path.join('validation', 'images'), os.path.join('validation', 'images'),
os.path.join('validation', 'labelXml'), os.path.join('validation', 'labelXml'),
), ),
'test': (os.path.join('test', 'images')), 'test': (os.path.join('test', 'images'),),
} }
paths = { paths: ClassVar[dict[str, tuple[str, ...]]] = {
'train': ( 'train': (
os.path.join('train', 'part1', 'images.zip'), os.path.join('train', 'part1', 'images.zip'),
os.path.join('train', 'part1', 'labelXml.zip'), os.path.join('train', 'part1', 'labelXml.zip'),
@ -194,7 +194,7 @@ class FAIR1M(NonGeoDataset):
os.path.join('test', 'images2.zip'), os.path.join('test', 'images2.zip'),
), ),
} }
urls = { urls: ClassVar[dict[str, tuple[str, ...]]] = {
'train': ( 'train': (
'https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf', 'https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf',
'https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u', 'https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u',
@ -211,7 +211,7 @@ class FAIR1M(NonGeoDataset):
'https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0', 'https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0',
), ),
} }
md5s = { md5s: ClassVar[dict[str, tuple[str, ...]]] = {
'train': ( 'train': (
'a460fe6b1b5b276bf856ce9ac72d6568', 'a460fe6b1b5b276bf856ce9ac72d6568',
'80f833ff355f91445c92a0c0c1fa7414', '80f833ff355f91445c92a0c0c1fa7414',

Просмотреть файл

@ -55,8 +55,8 @@ class FireRisk(NonGeoClassificationDataset):
md5 = 'a77b9a100d51167992ae8c51d26198a6' md5 = 'a77b9a100d51167992ae8c51d26198a6'
filename = 'FireRisk.zip' filename = 'FireRisk.zip'
directory = 'FireRisk' directory = 'FireRisk'
splits = ['train', 'val'] splits = ('train', 'val')
classes = [ classes = (
'High', 'High',
'Low', 'Low',
'Moderate', 'Moderate',
@ -64,7 +64,7 @@ class FireRisk(NonGeoClassificationDataset):
'Very_High', 'Very_High',
'Very_Low', 'Very_Low',
'Water', 'Water',
] )
def __init__( def __init__(
self, self,

Просмотреть файл

@ -96,11 +96,8 @@ class ForestDamage(NonGeoDataset):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
classes = ['other', 'H', 'LD', 'HD'] classes = ('other', 'H', 'LD', 'HD')
url = ( url = 'https://lilablobssc.blob.core.windows.net/larch-casebearer/Data_Set_Larch_Casebearer.zip'
'https://lilablobssc.blob.core.windows.net/larch-casebearer/'
'Data_Set_Larch_Casebearer.zip'
)
data_dir = 'Data_Set_Larch_Casebearer' data_dir = 'Data_Set_Larch_Casebearer'
md5 = '907815bcc739bff89496fac8f8ce63d7' md5 = '907815bcc739bff89496fac8f8ce63d7'

Просмотреть файл

@ -13,7 +13,7 @@ import re
import sys import sys
import warnings import warnings
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast from typing import Any, ClassVar, cast
import fiona import fiona
import fiona.transform import fiona.transform
@ -370,13 +370,13 @@ class RasterDataset(GeoDataset):
separate_files = False separate_files = False
#: Names of all available bands in the dataset #: Names of all available bands in the dataset
all_bands: list[str] = [] all_bands: tuple[str, ...] = ()
#: Names of RGB bands in the dataset, used for plotting #: Names of RGB bands in the dataset, used for plotting
rgb_bands: list[str] = [] rgb_bands: tuple[str, ...] = ()
#: Color map for the dataset, used for plotting #: Color map for the dataset, used for plotting
cmap: dict[int, tuple[int, int, int, int]] = {} cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {}
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
@ -458,7 +458,7 @@ class RasterDataset(GeoDataset):
# See if file has a color map # See if file has a color map
if len(self.cmap) == 0: if len(self.cmap) == 0:
try: try:
self.cmap = src.colormap(1) self.cmap = src.colormap(1) # type: ignore[misc]
except ValueError: except ValueError:
pass pass

Просмотреть файл

@ -66,8 +66,8 @@ class GID15(NonGeoDataset):
md5 = '615682bf659c3ed981826c6122c10c83' md5 = '615682bf659c3ed981826c6122c10c83'
filename = 'gid-15.zip' filename = 'gid-15.zip'
directory = 'GID' directory = 'GID'
splits = ['train', 'val', 'test'] splits = ('train', 'val', 'test')
classes = [ classes = (
'background', 'background',
'industrial_land', 'industrial_land',
'urban_residential', 'urban_residential',
@ -84,7 +84,7 @@ class GID15(NonGeoDataset):
'river', 'river',
'lake', 'lake',
'pond', 'pond',
] )
def __init__( def __init__(
self, self,

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os import os
import pathlib import pathlib
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any, cast from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
@ -73,9 +73,9 @@ class GlobBiomass(RasterDataset):
is_image = False is_image = False
dtype = torch.float32 # pixelwise regression dtype = torch.float32 # pixelwise regression
measurements = ['agb', 'gsv'] measurements = ('agb', 'gsv')
md5s = { md5s: ClassVar[dict[str, str]] = {
'N00E020_agb.zip': 'bd83a3a4c143885d1962bde549413be6', 'N00E020_agb.zip': 'bd83a3a4c143885d1962bde549413be6',
'N00E020_gsv.zip': 'da5ddb88e369df2d781a0c6be008ae79', 'N00E020_gsv.zip': 'da5ddb88e369df2d781a0c6be008ae79',
'N00E060_agb.zip': '85eaca95b939086cc528e396b75bd097', 'N00E060_agb.zip': '85eaca95b939086cc528e396b75bd097',

Просмотреть файл

@ -6,7 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import Any, cast, overload from typing import Any, ClassVar, cast, overload
import fiona import fiona
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -100,7 +100,7 @@ class IDTReeS(NonGeoDataset):
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
classes = { classes: ClassVar[dict[str, str]] = {
'ACPE': 'Acer pensylvanicum L.', 'ACPE': 'Acer pensylvanicum L.',
'ACRU': 'Acer rubrum L.', 'ACRU': 'Acer rubrum L.',
'ACSA3': 'Acer saccharum Marshall', 'ACSA3': 'Acer saccharum Marshall',
@ -135,19 +135,22 @@ class IDTReeS(NonGeoDataset):
'ROPS': 'Robinia pseudoacacia L.', 'ROPS': 'Robinia pseudoacacia L.',
'TSCA': 'Tsuga canadensis (L.) Carriere', 'TSCA': 'Tsuga canadensis (L.) Carriere',
} }
metadata = { metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': { 'train': {
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1', # noqa: E501 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1',
'md5': '5ddfa76240b4bb6b4a7861d1d31c299c', 'md5': '5ddfa76240b4bb6b4a7861d1d31c299c',
'filename': 'IDTREES_competition_train_v2.zip', 'filename': 'IDTREES_competition_train_v2.zip',
}, },
'test': { 'test': {
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1', # noqa: E501 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1',
'md5': 'b108931c84a70f2a38a8234290131c9b', 'md5': 'b108931c84a70f2a38a8234290131c9b',
'filename': 'IDTREES_competition_test_v2.zip', 'filename': 'IDTREES_competition_test_v2.zip',
}, },
} }
directories = {'train': ['train'], 'test': ['task1', 'task2']} directories: ClassVar[dict[str, list[str]]] = {
'train': ['train'],
'test': ['task1', 'task2'],
}
image_size = (200, 200) image_size = (200, 200)
def __init__( def __init__(

Просмотреть файл

@ -6,7 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any from typing import Any, ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.figure import Figure from matplotlib.figure import Figure
@ -40,9 +40,9 @@ class IOBench(IntersectionDataset):
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz' # noqa: E501 url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz'
md5s = { md5s: ClassVar[dict[str, str]] = {
'original': 'e3a908a0fd1c05c1af2f4c65724d59b3', 'original': 'e3a908a0fd1c05c1af2f4c65724d59b3',
'raw': 'e9603990441007ce7bba73bb8ba7d217', 'raw': 'e9603990441007ce7bba73bb8ba7d217',
'preprocessed': '9801f1240b238cb17525c865e413d1fd', 'preprocessed': '9801f1240b238cb17525c865e413d1fd',
@ -54,7 +54,7 @@ class IOBench(IntersectionDataset):
split: str = 'preprocessed', split: str = 'preprocessed',
crs: CRS | None = None, crs: CRS | None = None,
res: float | None = None, res: float | None = None,
bands: Sequence[str] | None = Landsat9.default_bands + ['SR_QA_AEROSOL'], bands: Sequence[str] | None = [*Landsat9.default_bands, 'SR_QA_AEROSOL'],
classes: list[int] = [0], classes: list[int] = [0],
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True, cache: bool = True,

Просмотреть файл

@ -8,7 +8,7 @@ import os
import pathlib import pathlib
import re import re
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
@ -43,8 +43,8 @@ class L7IrishImage(RasterDataset):
""" """
date_format = '%Y%m%d' date_format = '%Y%m%d'
is_image = True is_image = True
rgb_bands = ['B30', 'B20', 'B10'] rgb_bands = ('B30', 'B20', 'B10')
all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80'] all_bands = ('B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80')
class L7IrishMask(RasterDataset): class L7IrishMask(RasterDataset):
@ -59,7 +59,7 @@ class L7IrishMask(RasterDataset):
_newmask2015\.TIF$ _newmask2015\.TIF$
""" """
is_image = False is_image = False
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud')
ordinal_map = torch.zeros(256, dtype=torch.long) ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1 ordinal_map[64] = 1
ordinal_map[128] = 2 ordinal_map[128] = 2
@ -158,11 +158,11 @@ class L7Irish(IntersectionDataset):
* https://www.sciencebase.gov/catalog/item/573ccf18e4b0dae0d5e4b109 * https://www.sciencebase.gov/catalog/item/573ccf18e4b0dae0d5e4b109
.. versionadded:: 0.5 .. versionadded:: 0.5
""" # noqa: E501 """
url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz' # noqa: E501 url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz'
md5s = { md5s: ClassVar[dict[str, str]] = {
'austral': '0a34770b992a62abeb88819feb192436', 'austral': '0a34770b992a62abeb88819feb192436',
'boreal': 'b7cfdd689a3c2fd2a8d572e1c10ed082', 'boreal': 'b7cfdd689a3c2fd2a8d572e1c10ed082',
'mid_latitude_north': 'c40abe5ad2487f8ab021cfb954982faa', 'mid_latitude_north': 'c40abe5ad2487f8ab021cfb954982faa',

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os import os
import pathlib import pathlib
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from typing import Any from typing import Any, ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
@ -36,8 +36,8 @@ class L8BiomeImage(RasterDataset):
""" """
date_format = '%Y%j' date_format = '%Y%j'
is_image = True is_image = True
rgb_bands = ['B4', 'B3', 'B2'] rgb_bands = ('B4', 'B3', 'B2')
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11'] all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11')
class L8BiomeMask(RasterDataset): class L8BiomeMask(RasterDataset):
@ -57,7 +57,7 @@ class L8BiomeMask(RasterDataset):
""" """
date_format = '%Y%j' date_format = '%Y%j'
is_image = False is_image = False
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud')
ordinal_map = torch.zeros(256, dtype=torch.long) ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1 ordinal_map[64] = 1
ordinal_map[128] = 2 ordinal_map[128] = 2
@ -116,11 +116,11 @@ class L8Biome(IntersectionDataset):
* https://doi.org/10.1016/j.rse.2017.03.026 * https://doi.org/10.1016/j.rse.2017.03.026
.. versionadded:: 0.5 .. versionadded:: 0.5
""" # noqa: E501 """
url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz' # noqa: E501 url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz'
md5s = { md5s: ClassVar[dict[str, str]] = {
'barren': '0eb691822d03dabd4f5ea8aadd0b41c3', 'barren': '0eb691822d03dabd4f5ea8aadd0b41c3',
'forest': '4a5645596f6bb8cea44677f746ec676e', 'forest': '4a5645596f6bb8cea44677f746ec676e',
'grass_crops': 'a69ed5d6cb227c5783f026b9303cdd3c', 'grass_crops': 'a69ed5d6cb227c5783f026b9303cdd3c',

Просмотреть файл

@ -9,7 +9,7 @@ import hashlib
import os import os
from collections.abc import Callable from collections.abc import Callable
from functools import lru_cache from functools import lru_cache
from typing import Any, cast from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -64,8 +64,8 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
url = 'https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip' url = 'https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip'
filename = 'landcover.ai.v1.zip' filename = 'landcover.ai.v1.zip'
md5 = '3268c89070e8734b4e91d531c0617e03' md5 = '3268c89070e8734b4e91d531c0617e03'
classes = ['Background', 'Building', 'Woodland', 'Water', 'Road'] classes = ('Background', 'Building', 'Woodland', 'Water', 'Road')
cmap = { cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0), 0: (0, 0, 0, 0),
1: (97, 74, 74, 255), 1: (97, 74, 74, 255),
2: (38, 115, 0, 255), 2: (38, 115, 0, 255),

Просмотреть файл

@ -33,7 +33,7 @@ class Landsat(RasterDataset, abc.ABC):
* `Surface Temperature <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-temperature>`_ * `Surface Temperature <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-temperature>`_
* `Surface Reflectance <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-reflectance>`_ * `Surface Reflectance <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-reflectance>`_
* `U.S. Analysis Ready Data <https://www.usgs.gov/landsat-missions/landsat-collection-2-us-analysis-ready-data>`_ * `U.S. Analysis Ready Data <https://www.usgs.gov/landsat-missions/landsat-collection-2-us-analysis-ready-data>`_
""" # noqa: E501 """
# https://www.usgs.gov/landsat-missions/landsat-collection-2 # https://www.usgs.gov/landsat-missions/landsat-collection-2
filename_regex = r""" filename_regex = r"""
@ -55,7 +55,7 @@ class Landsat(RasterDataset, abc.ABC):
@property @property
@abc.abstractmethod @abc.abstractmethod
def default_bands(self) -> list[str]: def default_bands(self) -> tuple[str, ...]:
"""Bands to load by default.""" """Bands to load by default."""
def __init__( def __init__(
@ -145,8 +145,8 @@ class Landsat1(Landsat):
filename_glob = 'LM01_*_{}.*' filename_glob = 'LM01_*_{}.*'
default_bands = ['B4', 'B5', 'B6', 'B7'] default_bands = ('B4', 'B5', 'B6', 'B7')
rgb_bands = ['B6', 'B5', 'B4'] rgb_bands = ('B6', 'B5', 'B4')
class Landsat2(Landsat1): class Landsat2(Landsat1):
@ -166,8 +166,8 @@ class Landsat4MSS(Landsat):
filename_glob = 'LM04_*_{}.*' filename_glob = 'LM04_*_{}.*'
default_bands = ['B1', 'B2', 'B3', 'B4'] default_bands = ('B1', 'B2', 'B3', 'B4')
rgb_bands = ['B3', 'B2', 'B1'] rgb_bands = ('B3', 'B2', 'B1')
class Landsat4TM(Landsat): class Landsat4TM(Landsat):
@ -175,8 +175,8 @@ class Landsat4TM(Landsat):
filename_glob = 'LT04_*_{}.*' filename_glob = 'LT04_*_{}.*'
default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1'] rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')
class Landsat5MSS(Landsat4MSS): class Landsat5MSS(Landsat4MSS):
@ -196,8 +196,8 @@ class Landsat7(Landsat):
filename_glob = 'LE07_*_{}.*' filename_glob = 'LE07_*_{}.*'
default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1'] rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')
class Landsat8(Landsat): class Landsat8(Landsat):
@ -205,11 +205,11 @@ class Landsat8(Landsat):
filename_glob = 'LC08_*_{}.*' filename_glob = 'LC08_*_{}.*'
default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
rgb_bands = ['SR_B4', 'SR_B3', 'SR_B2'] rgb_bands = ('SR_B4', 'SR_B3', 'SR_B2')
class Landsat9(Landsat8): class Landsat9(Landsat8):
"""Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2).""" # noqa: E501 """Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2)."""
filename_glob = 'LC09_*_{}.*' filename_glob = 'LC09_*_{}.*'

Просмотреть файл

@ -7,6 +7,7 @@ import abc
import glob import glob
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -26,8 +27,8 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC):
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
splits: list[str] | dict[str, dict[str, str]] splits: ClassVar[tuple[str, ...] | dict[str, dict[str, str]]]
directories = ['A', 'B', 'label'] directories = ('A', 'B', 'label')
def __init__( def __init__(
self, self,
@ -237,7 +238,7 @@ class LEVIRCD(LEVIRCDBase):
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
splits = { splits: ClassVar[dict[str, dict[str, str]]] = {
'train': { 'train': {
'url': 'https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-', 'url': 'https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-',
'filename': 'train.zip', 'filename': 'train.zip',
@ -336,7 +337,7 @@ class LEVIRCDPlus(LEVIRCDBase):
md5 = '1adf156f628aa32fb2e8fe6cada16c04' md5 = '1adf156f628aa32fb2e8fe6cada16c04'
filename = 'LEVIR-CD+.zip' filename = 'LEVIR-CD+.zip'
directory = 'LEVIR-CD+' directory = 'LEVIR-CD+'
splits = ['train', 'test'] splits = ('train', 'test')
def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: def _load_files(self, root: Path, split: str) -> list[dict[str, str]]:
"""Return the paths of the files in the dataset. """Return the paths of the files in the dataset.

Просмотреть файл

@ -5,7 +5,8 @@
import glob import glob
import os import os
from collections.abc import Callable from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -57,10 +58,10 @@ class LoveDA(NonGeoDataset):
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
scenes = ['urban', 'rural'] scenes = ('urban', 'rural')
splits = ['train', 'val', 'test'] splits = ('train', 'val', 'test')
info_dict = { info_dict: ClassVar[dict[str, dict[str, str]]] = {
'train': { 'train': {
'url': 'https://zenodo.org/record/5706578/files/Train.zip?download=1', 'url': 'https://zenodo.org/record/5706578/files/Train.zip?download=1',
'filename': 'Train.zip', 'filename': 'Train.zip',
@ -78,7 +79,7 @@ class LoveDA(NonGeoDataset):
}, },
} }
classes = [ classes = (
'background', 'background',
'building', 'building',
'road', 'road',
@ -87,13 +88,13 @@ class LoveDA(NonGeoDataset):
'forest', 'forest',
'agriculture', 'agriculture',
'no-data', 'no-data',
] )
def __init__( def __init__(
self, self,
root: Path = 'data', root: Path = 'data',
split: str = 'train', split: str = 'train',
scene: list[str] = ['urban', 'rural'], scene: Sequence[str] = ['urban', 'rural'],
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False, download: bool = False,
checksum: bool = False, checksum: bool = False,

Просмотреть файл

@ -7,6 +7,7 @@ import os
import shutil import shutil
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -36,7 +37,7 @@ class MapInWild(NonGeoDataset):
different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season
Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging
Radiometer Suite NightTime Day/Night band. The dataset consists of 8144 Radiometer Suite NightTime Day/Night band. The dataset consists of 8144
images with the shape of 1920 × 1920 pixels. The images are weakly annotated images with the shape of 1920 x 1920 pixels. The images are weakly annotated
from the World Database of Protected Areas (WDPA). from the World Database of Protected Areas (WDPA).
Dataset features: Dataset features:
@ -54,9 +55,9 @@ class MapInWild(NonGeoDataset):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/' # noqa: E501 url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/'
modality_urls = { modality_urls: ClassVar[dict[str, set[str]]] = {
'esa_wc': {'esa_wc/ESA_WC.zip'}, 'esa_wc': {'esa_wc/ESA_WC.zip'},
'viirs': {'viirs/VIIRS.zip'}, 'viirs': {'viirs/VIIRS.zip'},
'mask': {'mask/mask.zip'}, 'mask': {'mask/mask.zip'},
@ -72,7 +73,7 @@ class MapInWild(NonGeoDataset):
'split_IDs': {'split_IDs/split_IDs.csv'}, 'split_IDs': {'split_IDs/split_IDs.csv'},
} }
md5s = { md5s: ClassVar[dict[str, str]] = {
'ESA_WC.zip': '72b2ee578fe10f0df85bdb7f19311c92', 'ESA_WC.zip': '72b2ee578fe10f0df85bdb7f19311c92',
'VIIRS.zip': '4eff014bae127fe536f8a5f17d89ecb4', 'VIIRS.zip': '4eff014bae127fe536f8a5f17d89ecb4',
'mask.zip': '87c83a23a73998ad60d448d240b66225', 'mask.zip': '87c83a23a73998ad60d448d240b66225',
@ -91,9 +92,12 @@ class MapInWild(NonGeoDataset):
'split_IDs.csv': 'cb5c6c073702acee23544e1e6fe5856f', 'split_IDs.csv': 'cb5c6c073702acee23544e1e6fe5856f',
} }
mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)} mask_cmap: ClassVar[dict[int, tuple[int, int, int]]] = {
1: (0, 153, 0),
0: (255, 255, 255),
}
wc_cmap = { wc_cmap: ClassVar[dict[int, tuple[int, int, int]]] = {
10: (0, 160, 0), 10: (0, 160, 0),
20: (150, 100, 0), 20: (150, 100, 0),
30: (255, 180, 0), 30: (255, 180, 0),

Просмотреть файл

@ -6,7 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import Any, cast from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -48,7 +48,7 @@ class MillionAID(NonGeoDataset):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
multi_label_categories = [ multi_label_categories = (
'agriculture_land', 'agriculture_land',
'airport_area', 'airport_area',
'apartment', 'apartment',
@ -122,9 +122,9 @@ class MillionAID(NonGeoDataset):
'wind_turbine', 'wind_turbine',
'woodland', 'woodland',
'works', 'works',
] )
multi_class_categories = [ multi_class_categories = (
'apartment', 'apartment',
'apron', 'apron',
'bare_land', 'bare_land',
@ -176,17 +176,17 @@ class MillionAID(NonGeoDataset):
'wastewater_plant', 'wastewater_plant',
'wind_turbine', 'wind_turbine',
'works', 'works',
] )
md5s = { md5s: ClassVar[dict[str, str]] = {
'train': '1b40503cafa9b0601653ca36cd788852', 'train': '1b40503cafa9b0601653ca36cd788852',
'test': '51a63ee3eeb1351889eacff349a983d8', 'test': '51a63ee3eeb1351889eacff349a983d8',
} }
filenames = {'train': 'train.zip', 'test': 'test.zip'} filenames: ClassVar[dict[str, str]] = {'train': 'train.zip', 'test': 'test.zip'}
tasks = ['multi-class', 'multi-label'] tasks = ('multi-class', 'multi-label')
splits = ['train', 'test'] splits = ('train', 'test')
def __init__( def __init__(
self, self,

Просмотреть файл

@ -45,8 +45,8 @@ class NAIP(RasterDataset):
""" """
# Plotting # Plotting
all_bands = ['R', 'G', 'B', 'NIR'] all_bands = ('R', 'G', 'B', 'NIR')
rgb_bands = ['R', 'G', 'B'] rgb_bands = ('R', 'G', 'B')
def plot( def plot(
self, self,

Просмотреть файл

@ -4,7 +4,7 @@
"""Northeastern China Crop Map Dataset.""" """Northeastern China Crop Map Dataset."""
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any from typing import Any, ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
@ -57,23 +57,23 @@ class NCCM(RasterDataset):
date_format = '%Y' date_format = '%Y'
is_image = False is_image = False
urls = { urls: ClassVar[dict[int, str]] = {
2019: 'https://figshare.com/ndownloader/files/25070540', 2019: 'https://figshare.com/ndownloader/files/25070540',
2018: 'https://figshare.com/ndownloader/files/25070624', 2018: 'https://figshare.com/ndownloader/files/25070624',
2017: 'https://figshare.com/ndownloader/files/25070582', 2017: 'https://figshare.com/ndownloader/files/25070582',
} }
md5s = { md5s: ClassVar[dict[int, str]] = {
2019: '0d062bbd42e483fdc8239d22dba7020f', 2019: '0d062bbd42e483fdc8239d22dba7020f',
2018: 'b3bb4894478d10786aa798fb11693ec1', 2018: 'b3bb4894478d10786aa798fb11693ec1',
2017: 'd047fbe4a85341fa6248fd7e0badab6c', 2017: 'd047fbe4a85341fa6248fd7e0badab6c',
} }
fnames = { fnames: ClassVar[dict[int, str]] = {
2019: 'CDL2019_clip.tif', 2019: 'CDL2019_clip.tif',
2018: 'CDL2018_clip1.tif', 2018: 'CDL2018_clip1.tif',
2017: 'CDL2017_clip.tif', 2017: 'CDL2017_clip.tif',
} }
cmap = { cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 255, 0, 255), 0: (0, 255, 0, 255),
1: (255, 0, 0, 255), 1: (255, 0, 0, 255),
2: (255, 255, 0, 255), 2: (255, 255, 0, 255),

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os import os
import pathlib import pathlib
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any from typing import Any, ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
@ -67,7 +67,7 @@ class NLCD(RasterDataset):
* 2019: https://doi.org/10.5066/P9KZCM54 * 2019: https://doi.org/10.5066/P9KZCM54
.. versionadded:: 0.5 .. versionadded:: 0.5
""" # noqa: E501 """
filename_glob = 'nlcd_*_land_cover_l48_*.img' filename_glob = 'nlcd_*_land_cover_l48_*.img'
filename_regex = ( filename_regex = (
@ -79,7 +79,7 @@ class NLCD(RasterDataset):
url = 'https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip' url = 'https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip'
md5s = { md5s: ClassVar[dict[int, str]] = {
2001: '538166a4d783204764e3df3b221fc4cd', 2001: '538166a4d783204764e3df3b221fc4cd',
2006: '67454e7874a00294adb9442374d0c309', 2006: '67454e7874a00294adb9442374d0c309',
2011: 'ea524c835d173658eeb6fa3c8e6b917b', 2011: 'ea524c835d173658eeb6fa3c8e6b917b',
@ -87,7 +87,7 @@ class NLCD(RasterDataset):
2019: '82851c3f8105763b01c83b4a9e6f3961', 2019: '82851c3f8105763b01c83b4a9e6f3961',
} }
cmap = { cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0), 0: (0, 0, 0, 0),
11: (70, 107, 159, 255), 11: (70, 107, 159, 255),
12: (209, 222, 248, 255), 12: (209, 222, 248, 255),

Просмотреть файл

@ -9,7 +9,7 @@ import os
import pathlib import pathlib
import sys import sys
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any, cast from typing import Any, ClassVar, cast
import fiona import fiona
import fiona.transform import fiona.transform
@ -61,7 +61,7 @@ class OpenBuildings(VectorDataset):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
md5s = { md5s: ClassVar[dict[str, str]] = {
'025_buildings.csv.gz': '41db2572bfd08628d01475a2ee1a2f17', '025_buildings.csv.gz': '41db2572bfd08628d01475a2ee1a2f17',
'04f_buildings.csv.gz': '3232c1c6d45c1543260b77e5689fc8b1', '04f_buildings.csv.gz': '3232c1c6d45c1543260b77e5689fc8b1',
'05b_buildings.csv.gz': '4fc57c63bbbf9a21a3902da7adc3a670', '05b_buildings.csv.gz': '4fc57c63bbbf9a21a3902da7adc3a670',

Просмотреть файл

@ -6,6 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -50,7 +51,7 @@ class OSCD(NonGeoDataset):
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
urls = { urls: ClassVar[dict[str, str]] = {
'Onera Satellite Change Detection dataset - Images.zip': ( 'Onera Satellite Change Detection dataset - Images.zip': (
'https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download' 'https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download'
), ),
@ -61,7 +62,7 @@ class OSCD(NonGeoDataset):
'https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download' 'https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download'
), ),
} }
md5s = { md5s: ClassVar[dict[str, str]] = {
'Onera Satellite Change Detection dataset - Images.zip': ( 'Onera Satellite Change Detection dataset - Images.zip': (
'c50d4a2941da64e03a47ac4dec63d915' 'c50d4a2941da64e03a47ac4dec63d915'
), ),
@ -75,9 +76,9 @@ class OSCD(NonGeoDataset):
zipfile_glob = '*Onera*.zip' zipfile_glob = '*Onera*.zip'
filename_glob = '*Onera*' filename_glob = '*Onera*'
splits = ['train', 'test'] splits = ('train', 'test')
colormap = ['blue'] colormap = ('blue',)
all_bands = ( all_bands = (
'B01', 'B01',
@ -319,7 +320,7 @@ class OSCD(NonGeoDataset):
torch.from_numpy(rgb_img), torch.from_numpy(rgb_img),
sample['mask'], sample['mask'],
alpha=alpha, alpha=alpha,
colors=self.colormap, colors=list(self.colormap),
) )
return array return array

Просмотреть файл

@ -5,6 +5,7 @@
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import ClassVar
import fiona import fiona
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -70,7 +71,7 @@ class PASTIS(NonGeoDataset):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
classes = [ classes = (
'background', # all non-agricultural land 'background', # all non-agricultural land
'meadow', 'meadow',
'soft_winter_wheat', 'soft_winter_wheat',
@ -91,8 +92,8 @@ class PASTIS(NonGeoDataset):
'mixed_cereal', 'mixed_cereal',
'sorghum', 'sorghum',
'void_label', # for parcels mostly outside their patch 'void_label', # for parcels mostly outside their patch
] )
cmap = { cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255), 0: (0, 0, 0, 255),
1: (174, 199, 232, 255), 1: (174, 199, 232, 255),
2: (255, 127, 14, 255), 2: (255, 127, 14, 255),
@ -118,7 +119,7 @@ class PASTIS(NonGeoDataset):
filename = 'PASTIS-R.zip' filename = 'PASTIS-R.zip'
url = 'https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1' url = 'https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1'
md5 = '4887513d6c2d2b07fa935d325bd53e09' md5 = '4887513d6c2d2b07fa935d325bd53e09'
prefix = { prefix: ClassVar[dict[str, str]] = {
's2': os.path.join('DATA_S2', 'S2_'), 's2': os.path.join('DATA_S2', 'S2_'),
's1a': os.path.join('DATA_S1A', 'S1A_'), 's1a': os.path.join('DATA_S1A', 'S1A_'),
's1d': os.path.join('DATA_S1D', 'S1D_'), 's1d': os.path.join('DATA_S1D', 'S1D_'),
@ -232,7 +233,7 @@ class PASTIS(NonGeoDataset):
Returns: Returns:
the target mask the target mask
""" """
# See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501 # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201
# even though the mask file is 3 bands, we just select the first band # even though the mask file is 3 bands, we just select the first band
array = np.load(self.files[index]['semantic'])[0].astype(np.uint8) array = np.load(self.files[index]['semantic'])[0].astype(np.uint8)
tensor = torch.from_numpy(array).long() tensor = torch.from_numpy(array).long()

Просмотреть файл

@ -5,6 +5,7 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -54,12 +55,12 @@ class Potsdam2D(NonGeoDataset):
* https://doi.org/10.5194/isprsannals-I-3-293-2012 * https://doi.org/10.5194/isprsannals-I-3-293-2012
.. versionadded:: 0.2 .. versionadded:: 0.2
""" # noqa: E501 """
filenames = ['4_Ortho_RGBIR.zip', '5_Labels_all.zip'] filenames = ('4_Ortho_RGBIR.zip', '5_Labels_all.zip')
md5s = ['c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db'] md5s = ('c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db')
image_root = '4_Ortho_RGBIR' image_root = '4_Ortho_RGBIR'
splits = { splits: ClassVar[dict[str, list[str]]] = {
'train': [ 'train': [
'top_potsdam_2_10', 'top_potsdam_2_10',
'top_potsdam_2_11', 'top_potsdam_2_11',
@ -103,22 +104,22 @@ class Potsdam2D(NonGeoDataset):
'top_potsdam_7_13', 'top_potsdam_7_13',
], ],
} }
classes = [ classes = (
'Clutter/background', 'Clutter/background',
'Impervious surfaces', 'Impervious surfaces',
'Building', 'Building',
'Low Vegetation', 'Low Vegetation',
'Tree', 'Tree',
'Car', 'Car',
] )
colormap = [ colormap = (
(255, 0, 0), (255, 0, 0),
(255, 255, 255), (255, 255, 255),
(0, 0, 255), (0, 0, 255),
(0, 255, 255), (0, 255, 255),
(0, 255, 0), (0, 255, 0),
(255, 255, 0), (255, 255, 0),
] )
def __init__( def __init__(
self, self,
@ -257,7 +258,7 @@ class Potsdam2D(NonGeoDataset):
""" """
ncols = 1 ncols = 1
image1 = draw_semantic_segmentation_masks( image1 = draw_semantic_segmentation_masks(
sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap)
) )
if 'prediction' in sample: if 'prediction' in sample:
ncols += 1 ncols += 1
@ -265,7 +266,7 @@ class Potsdam2D(NonGeoDataset):
sample['image'][:3], sample['image'][:3],
sample['prediction'], sample['prediction'],
alpha=alpha, alpha=alpha,
colors=self.colormap, colors=list(self.colormap),
) )
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))

Просмотреть файл

@ -5,7 +5,7 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import Any, cast from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -61,8 +61,12 @@ class QuakeSet(NonGeoDataset):
filename = 'earthquakes.h5' filename = 'earthquakes.h5'
url = 'https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5' url = 'https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5'
md5 = '76fc7c76b7ca56f4844d852e175e1560' md5 = '76fc7c76b7ca56f4844d852e175e1560'
splits = {'train': 'train', 'val': 'validation', 'test': 'test'} splits: ClassVar[dict[str, str]] = {
classes = ['unaffected_area', 'earthquake_affected_area'] 'train': 'train',
'val': 'validation',
'test': 'test',
}
classes = ('unaffected_area', 'earthquake_affected_area')
def __init__( def __init__(
self, self,

Просмотреть файл

@ -56,7 +56,7 @@ class ReforesTree(NonGeoDataset):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
classes = ['other', 'banana', 'cacao', 'citrus', 'fruit', 'timber'] classes = ('other', 'banana', 'cacao', 'citrus', 'fruit', 'timber')
url = 'https://zenodo.org/record/6813783/files/reforesTree.zip?download=1' url = 'https://zenodo.org/record/6813783/files/reforesTree.zip?download=1'
md5 = 'f6a4a1d8207aeaa5fbab7b21b683a302' md5 = 'f6a4a1d8207aeaa5fbab7b21b683a302'

Просмотреть файл

@ -5,7 +5,7 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import cast from typing import ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -98,13 +98,13 @@ class RESISC45(NonGeoClassificationDataset):
filename = 'NWPU-RESISC45.zip' filename = 'NWPU-RESISC45.zip'
directory = 'NWPU-RESISC45' directory = 'NWPU-RESISC45'
splits = ['train', 'val', 'test'] splits = ('train', 'val', 'test')
split_urls = { split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-train.txt', 'train': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-train.txt',
'val': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-val.txt', 'val': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-val.txt',
'test': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-test.txt', 'test': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-test.txt',
} }
split_md5s = { split_md5s: ClassVar[dict[str, str]] = {
'train': 'b5a4c05a37de15e4ca886696a85c403e', 'train': 'b5a4c05a37de15e4ca886696a85c403e',
'val': 'a0770cee4c5ca20b8c32bbd61e114805', 'val': 'a0770cee4c5ca20b8c32bbd61e114805',
'test': '3dda9e4988b47eb1de9f07993653eb08', 'test': '3dda9e4988b47eb1de9f07993653eb08',

Просмотреть файл

@ -6,6 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -55,11 +56,11 @@ class RwandaFieldBoundary(NonGeoDataset):
url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition' url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition'
splits = {'train': 57, 'test': 13} splits: ClassVar[dict[str, int]] = {'train': 57, 'test': 13}
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
all_bands = ('B01', 'B02', 'B03', 'B04') all_bands = ('B01', 'B02', 'B03', 'B04')
rgb_bands = ('B03', 'B02', 'B01') rgb_bands = ('B03', 'B02', 'B01')
classes = ['No field-boundary', 'Field-boundary'] classes = ('No field-boundary', 'Field-boundary')
def __init__( def __init__(
self, self,

Просмотреть файл

@ -6,6 +6,7 @@
import os import os
import random import random
from collections.abc import Callable, Collection, Iterable from collections.abc import Callable, Collection, Iterable
from typing import ClassVar
import matplotlib.patches as mpatches import matplotlib.patches as mpatches
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -85,51 +86,51 @@ class SeasoNet(NonGeoDataset):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
metadata = [ metadata = (
{ {
'name': 'spring', 'name': 'spring',
'ext': '.zip', 'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip', # noqa: E501 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip',
'md5': 'de4cdba7b6196aff624073991b187561', 'md5': 'de4cdba7b6196aff624073991b187561',
}, },
{ {
'name': 'summer', 'name': 'summer',
'ext': '.zip', 'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip', # noqa: E501 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip',
'md5': '6a54d4e134d27ae4eb03f180ee100550', 'md5': '6a54d4e134d27ae4eb03f180ee100550',
}, },
{ {
'name': 'fall', 'name': 'fall',
'ext': '.zip', 'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip', # noqa: E501 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip',
'md5': '5f94920fe41a63c6bfbab7295f7d6b95', 'md5': '5f94920fe41a63c6bfbab7295f7d6b95',
}, },
{ {
'name': 'winter', 'name': 'winter',
'ext': '.zip', 'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip', # noqa: E501 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip',
'md5': 'dc5e3e09e52ab5c72421b1e3186c9a48', 'md5': 'dc5e3e09e52ab5c72421b1e3186c9a48',
}, },
{ {
'name': 'snow', 'name': 'snow',
'ext': '.zip', 'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip', # noqa: E501 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip',
'md5': 'e1b300994143f99ebb03f51d6ab1cbe6', 'md5': 'e1b300994143f99ebb03f51d6ab1cbe6',
}, },
{ {
'name': 'splits', 'name': 'splits',
'ext': '.zip', 'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip', # noqa: E501 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip',
'md5': 'e4ec4a18bc4efc828f0944a7cf4d5fed', 'md5': 'e4ec4a18bc4efc828f0944a7cf4d5fed',
}, },
{ {
'name': 'meta.csv', 'name': 'meta.csv',
'ext': '', 'ext': '',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv', # noqa: E501 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv',
'md5': '43ea07974936a6bf47d989c32e16afe7', 'md5': '43ea07974936a6bf47d989c32e16afe7',
}, },
] )
classes = [ classes = (
'Continuous urban fabric', 'Continuous urban fabric',
'Discontinuous urban fabric', 'Discontinuous urban fabric',
'Industrial or commercial units', 'Industrial or commercial units',
@ -163,12 +164,17 @@ class SeasoNet(NonGeoDataset):
'Coastal lagoons', 'Coastal lagoons',
'Estuaries', 'Estuaries',
'Sea and ocean', 'Sea and ocean',
] )
all_seasons = {'Spring', 'Summer', 'Fall', 'Winter', 'Snow'} all_seasons = frozenset({'Spring', 'Summer', 'Fall', 'Winter', 'Snow'})
all_bands = ('10m_RGB', '10m_IR', '20m', '60m') all_bands = ('10m_RGB', '10m_IR', '20m', '60m')
band_nums = {'10m_RGB': 3, '10m_IR': 1, '20m': 6, '60m': 2} band_nums: ClassVar[dict[str, int]] = {
splits = ['train', 'val', 'test'] '10m_RGB': 3,
cmap = { '10m_IR': 1,
'20m': 6,
'60m': 2,
}
splits = ('train', 'val', 'test')
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (230, 000, 77, 255), 0: (230, 000, 77, 255),
1: (255, 000, 000, 255), 1: (255, 000, 000, 255),
2: (204, 77, 242, 255), 2: (204, 77, 242, 255),
@ -331,7 +337,7 @@ class SeasoNet(NonGeoDataset):
for band in self.bands: for band in self.bands:
with rasterio.open(f'{path}_{band}.tif') as f: with rasterio.open(f'{path}_{band}.tif') as f:
array = f.read( array = f.read(
out_shape=[f.count] + list(self.image_size), out_shape=[f.count, *list(self.image_size)],
out_dtype='int32', out_dtype='int32',
resampling=Resampling.bilinear, resampling=Resampling.bilinear,
) )

Просмотреть файл

@ -5,7 +5,8 @@
import os import os
import random import random
from collections.abc import Callable from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -37,7 +38,7 @@ class SeasonalContrastS2(NonGeoDataset):
* https://arxiv.org/pdf/2103.16607.pdf * https://arxiv.org/pdf/2103.16607.pdf
""" """
all_bands = [ all_bands = (
'B1', 'B1',
'B2', 'B2',
'B3', 'B3',
@ -50,10 +51,10 @@ class SeasonalContrastS2(NonGeoDataset):
'B9', 'B9',
'B11', 'B11',
'B12', 'B12',
] )
rgb_bands = ['B4', 'B3', 'B2'] rgb_bands = ('B4', 'B3', 'B2')
metadata = { metadata: ClassVar[dict[str, dict[str, str]]] = {
'100k': { '100k': {
'url': 'https://zenodo.org/record/4728033/files/seco_100k.zip?download=1', 'url': 'https://zenodo.org/record/4728033/files/seco_100k.zip?download=1',
'md5': 'ebf2d5e03adc6e657f9a69a20ad863e0', 'md5': 'ebf2d5e03adc6e657f9a69a20ad863e0',
@ -73,7 +74,7 @@ class SeasonalContrastS2(NonGeoDataset):
root: Path = 'data', root: Path = 'data',
version: str = '100k', version: str = '100k',
seasons: int = 1, seasons: int = 1,
bands: list[str] = rgb_bands, bands: Sequence[str] = rgb_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False, download: bool = False,
checksum: bool = False, checksum: bool = False,

Просмотреть файл

@ -5,6 +5,7 @@
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -63,9 +64,9 @@ class SEN12MS(NonGeoDataset):
or manually downloaded from https://dataserv.ub.tum.de/s/m1474000 or manually downloaded from https://dataserv.ub.tum.de/s/m1474000
and https://github.com/schmitt-muc/SEN12MS/tree/master/splits. and https://github.com/schmitt-muc/SEN12MS/tree/master/splits.
This download will likely take several hours. This download will likely take several hours.
""" # noqa: E501 """
BAND_SETS: dict[str, tuple[str, ...]] = { BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
'all': ( 'all': (
'VV', 'VV',
'VH', 'VH',
@ -120,9 +121,9 @@ class SEN12MS(NonGeoDataset):
'B12', 'B12',
) )
rgb_bands = ['B04', 'B03', 'B02'] rgb_bands = ('B04', 'B03', 'B02')
filenames = [ filenames = (
'ROIs1158_spring_lc.tar.gz', 'ROIs1158_spring_lc.tar.gz',
'ROIs1158_spring_s1.tar.gz', 'ROIs1158_spring_s1.tar.gz',
'ROIs1158_spring_s2.tar.gz', 'ROIs1158_spring_s2.tar.gz',
@ -137,16 +138,16 @@ class SEN12MS(NonGeoDataset):
'ROIs2017_winter_s2.tar.gz', 'ROIs2017_winter_s2.tar.gz',
'train_list.txt', 'train_list.txt',
'test_list.txt', 'test_list.txt',
] )
light_filenames = [ light_filenames = (
'ROIs1158_spring', 'ROIs1158_spring',
'ROIs1868_summer', 'ROIs1868_summer',
'ROIs1970_fall', 'ROIs1970_fall',
'ROIs2017_winter', 'ROIs2017_winter',
'train_list.txt', 'train_list.txt',
'test_list.txt', 'test_list.txt',
] )
md5s = [ md5s = (
'6e2e8fa8b8cba77ddab49fd20ff5c37b', '6e2e8fa8b8cba77ddab49fd20ff5c37b',
'fba019bb27a08c1db96b31f718c34d79', 'fba019bb27a08c1db96b31f718c34d79',
'd58af2c15a16f376eb3308dc9b685af2', 'd58af2c15a16f376eb3308dc9b685af2',
@ -161,7 +162,7 @@ class SEN12MS(NonGeoDataset):
'3807545661288dcca312c9c538537b63', '3807545661288dcca312c9c538537b63',
'0a68d4e1eb24f128fccdb930000b2546', '0a68d4e1eb24f128fccdb930000b2546',
'c7faad064001e646445c4c634169484d', 'c7faad064001e646445c4c634169484d',
] )
def __init__( def __init__(
self, self,

Просмотреть файл

@ -137,7 +137,7 @@ class Sentinel1(Sentinel):
\. \.
""" """
date_format = '%Y%m%dT%H%M%S' date_format = '%Y%m%dT%H%M%S'
all_bands = ['HH', 'HV', 'VV', 'VH'] all_bands = ('HH', 'HV', 'VV', 'VH')
separate_files = True separate_files = True
def __init__( def __init__(
@ -277,7 +277,7 @@ class Sentinel2(Sentinel):
date_format = '%Y%m%dT%H%M%S' date_format = '%Y%m%dT%H%M%S'
# https://gisgeography.com/sentinel-2-bands-combinations/ # https://gisgeography.com/sentinel-2-bands-combinations/
all_bands = [ all_bands: tuple[str, ...] = (
'B01', 'B01',
'B02', 'B02',
'B03', 'B03',
@ -291,8 +291,8 @@ class Sentinel2(Sentinel):
'B10', 'B10',
'B11', 'B11',
'B12', 'B12',
] )
rgb_bands = ['B04', 'B03', 'B02'] rgb_bands = ('B04', 'B03', 'B02')
separate_files = True separate_files = True

Просмотреть файл

@ -5,7 +5,7 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any, ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -62,8 +62,8 @@ class SKIPPD(NonGeoDataset):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}' # noqa: E501 url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}'
md5 = { md5: ClassVar[dict[str, str]] = {
'forecast': 'f4f3509ddcc83a55c433be9db2e51077', 'forecast': 'f4f3509ddcc83a55c433be9db2e51077',
'nowcast': '0000761d403e45bb5f86c21d3c69aa80', 'nowcast': '0000761d403e45bb5f86c21d3c69aa80',
} }
@ -71,9 +71,9 @@ class SKIPPD(NonGeoDataset):
data_file_name = '2017_2019_images_pv_processed_{}.hdf5' data_file_name = '2017_2019_images_pv_processed_{}.hdf5'
zipfile_name = '2017_2019_images_pv_processed_{}.zip' zipfile_name = '2017_2019_images_pv_processed_{}.zip'
valid_splits = ['trainval', 'test'] valid_splits = ('trainval', 'test')
valid_tasks = ['nowcast', 'forecast'] valid_tasks = ('nowcast', 'forecast')
dateformat = '%m/%d/%Y, %H:%M:%S' dateformat = '%m/%d/%Y, %H:%M:%S'

Просмотреть файл

@ -5,7 +5,7 @@
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import cast from typing import ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -103,10 +103,10 @@ class So2Sat(NonGeoDataset):
This dataset requires the following additional library to be installed: This dataset requires the following additional library to be installed:
* `<https://pypi.org/project/h5py/>`_ to load the dataset * `<https://pypi.org/project/h5py/>`_ to load the dataset
""" # noqa: E501 """
versions = ['2', '3_random', '3_block', '3_culture_10'] versions = ('2', '3_random', '3_block', '3_culture_10')
filenames_by_version = { filenames_by_version: ClassVar[dict[str, dict[str, str]]] = {
'2': { '2': {
'train': 'training.h5', 'train': 'training.h5',
'validation': 'validation.h5', 'validation': 'validation.h5',
@ -119,7 +119,7 @@ class So2Sat(NonGeoDataset):
'test': 'culture_10/testing.h5', 'test': 'culture_10/testing.h5',
}, },
} }
md5s_by_version = { md5s_by_version: ClassVar[dict[str, dict[str, str]]] = {
'2': { '2': {
'train': '702bc6a9368ebff4542d791e53469244', 'train': '702bc6a9368ebff4542d791e53469244',
'validation': '71cfa6795de3e22207229d06d6f8775d', 'validation': '71cfa6795de3e22207229d06d6f8775d',
@ -139,7 +139,7 @@ class So2Sat(NonGeoDataset):
}, },
} }
classes = [ classes = (
'Compact high rise', 'Compact high rise',
'Compact mid rise', 'Compact mid rise',
'Compact low rise', 'Compact low rise',
@ -157,7 +157,7 @@ class So2Sat(NonGeoDataset):
'Bare rock or paved', 'Bare rock or paved',
'Bare soil or sand', 'Bare soil or sand',
'Water', 'Water',
] )
all_s1_band_names = ( all_s1_band_names = (
'S1_B1', 'S1_B1',
@ -183,9 +183,9 @@ class So2Sat(NonGeoDataset):
) )
all_band_names = all_s1_band_names + all_s2_band_names all_band_names = all_s1_band_names + all_s2_band_names
rgb_bands = ['S2_B04', 'S2_B03', 'S2_B02'] rgb_bands = ('S2_B04', 'S2_B03', 'S2_B02')
BAND_SETS = { BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
'all': all_band_names, 'all': all_band_names,
's1': all_s1_band_names, 's1': all_s1_band_names,
's2': all_s2_band_names, 's2': all_s2_band_names,

Просмотреть файл

@ -6,8 +6,8 @@
import os import os
import pathlib import pathlib
import re import re
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
@ -79,9 +79,9 @@ class SouthAfricaCropType(RasterDataset):
_10m _10m
""" """
date_format = '%Y_%m_%d' date_format = '%Y_%m_%d'
rgb_bands = ['B04', 'B03', 'B02'] rgb_bands = ('B04', 'B03', 'B02')
s1_bands = ['VH', 'VV'] s1_bands = ('VH', 'VV')
s2_bands = [ s2_bands = (
'B01', 'B01',
'B02', 'B02',
'B03', 'B03',
@ -94,9 +94,9 @@ class SouthAfricaCropType(RasterDataset):
'B09', 'B09',
'B11', 'B11',
'B12', 'B12',
] )
all_bands: list[str] = s1_bands + s2_bands all_bands = s1_bands + s2_bands
cmap = { cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255), 0: (0, 0, 0, 255),
1: (255, 211, 0, 255), 1: (255, 211, 0, 255),
2: (255, 37, 37, 255), 2: (255, 37, 37, 255),
@ -113,8 +113,8 @@ class SouthAfricaCropType(RasterDataset):
self, self,
paths: Path | Iterable[Path] = 'data', paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None, crs: CRS | None = None,
classes: list[int] = list(cmap.keys()), classes: Sequence[int] = list(cmap.keys()),
bands: list[str] = s2_bands, bands: Sequence[str] = s2_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:

Просмотреть файл

@ -5,7 +5,7 @@
import pathlib import pathlib
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any from typing import Any, ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.figure import Figure from matplotlib.figure import Figure
@ -47,7 +47,7 @@ class SouthAmericaSoybean(RasterDataset):
is_image = False is_image = False
url = 'https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif' url = 'https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif'
md5s = { md5s: ClassVar[dict[int, str]] = {
2021: 'edff3ada13a1a9910d1fe844d28ae4f', 2021: 'edff3ada13a1a9910d1fe844d28ae4f',
2020: '0709dec807f576c9707c8c7e183db31', 2020: '0709dec807f576c9707c8c7e183db31',
2019: '441836493bbcd5e123cff579a58f5a4f', 2019: '441836493bbcd5e123cff579a58f5a4f',

Просмотреть файл

@ -8,7 +8,7 @@ import os
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any, ClassVar
import fiona import fiona
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -55,9 +55,9 @@ class SpaceNet(NonGeoDataset, ABC):
image_glob = '*.tif' image_glob = '*.tif'
mask_glob = '*.geojson' mask_glob = '*.geojson'
file_regex = r'_img(\d+)\.' file_regex = r'_img(\d+)\.'
chip_size: dict[str, tuple[int, int]] = {} chip_size: ClassVar[dict[str, tuple[int, int]]] = {}
cities = { cities: ClassVar[dict[int, str]] = {
1: 'Rio', 1: 'Rio',
2: 'Vegas', 2: 'Vegas',
3: 'Paris', 3: 'Paris',
@ -98,7 +98,7 @@ class SpaceNet(NonGeoDataset, ABC):
@property @property
@abstractmethod @abstractmethod
def valid_masks(self) -> list[str]: def valid_masks(self) -> tuple[str, ...]:
"""List of valid masks.""" """List of valid masks."""
def __init__( def __init__(
@ -426,7 +426,7 @@ class SpaceNet1(SpaceNet):
directory_glob = '{product}' directory_glob = '{product}'
dataset_id = 'SN1_buildings' dataset_id = 'SN1_buildings'
tarballs = { tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
1: [ 1: [
'SN1_buildings_train_AOI_1_Rio_3band.tar.gz', 'SN1_buildings_train_AOI_1_Rio_3band.tar.gz',
@ -441,7 +441,7 @@ class SpaceNet1(SpaceNet):
] ]
}, },
} }
md5s = { md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
1: [ 1: [
'279e334a2120ecac70439ea246174516', '279e334a2120ecac70439ea246174516',
@ -453,10 +453,16 @@ class SpaceNet1(SpaceNet):
1: ['18283d78b21c239bc1831f3bf1d2c996', '732b3a40603b76e80aac84e002e2b3e8'] 1: ['18283d78b21c239bc1831f3bf1d2c996', '732b3a40603b76e80aac84e002e2b3e8']
}, },
} }
valid_aois = {'train': [1], 'test': [1]} valid_aois: ClassVar[dict[str, list[int]]] = {'train': [1], 'test': [1]}
valid_images = {'train': ['3band', '8band'], 'test': ['3band', '8band']} valid_images: ClassVar[dict[str, list[str]]] = {
valid_masks = ['geojson'] 'train': ['3band', '8band'],
chip_size = {'3band': (406, 439), '8band': (102, 110)} 'test': ['3band', '8band'],
}
valid_masks = ('geojson',)
chip_size: ClassVar[dict[str, tuple[int, int]]] = {
'3band': (406, 439),
'8band': (102, 110),
}
class SpaceNet2(SpaceNet): class SpaceNet2(SpaceNet):
@ -522,7 +528,7 @@ class SpaceNet2(SpaceNet):
""" """
dataset_id = 'SN2_buildings' dataset_id = 'SN2_buildings'
tarballs = { tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
2: ['SN2_buildings_train_AOI_2_Vegas.tar.gz'], 2: ['SN2_buildings_train_AOI_2_Vegas.tar.gz'],
3: ['SN2_buildings_train_AOI_3_Paris.tar.gz'], 3: ['SN2_buildings_train_AOI_3_Paris.tar.gz'],
@ -536,7 +542,7 @@ class SpaceNet2(SpaceNet):
5: ['AOI_5_Khartoum_Test_public.tar.gz'], 5: ['AOI_5_Khartoum_Test_public.tar.gz'],
}, },
} }
md5s = { md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
2: ['307da318bc43aaf9481828f92eda9126'], 2: ['307da318bc43aaf9481828f92eda9126'],
3: ['4db469e3e4e7bf025368ad730aec0888'], 3: ['4db469e3e4e7bf025368ad730aec0888'],
@ -550,13 +556,16 @@ class SpaceNet2(SpaceNet):
5: ['037d7be10530f0dd1c43d4ef79f3236e'], 5: ['037d7be10530f0dd1c43d4ef79f3236e'],
}, },
} }
valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]} valid_aois: ClassVar[dict[str, list[int]]] = {
valid_images = { 'train': [2, 3, 4, 5],
'test': [2, 3, 4, 5],
}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'], 'train': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'], 'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
} }
valid_masks = [os.path.join('geojson', 'buildings')] valid_masks = (os.path.join('geojson', 'buildings'),)
chip_size = {'MUL': (163, 163)} chip_size: ClassVar[dict[str, tuple[int, int]]] = {'MUL': (163, 163)}
class SpaceNet3(SpaceNet): class SpaceNet3(SpaceNet):
@ -624,7 +633,7 @@ class SpaceNet3(SpaceNet):
""" """
dataset_id = 'SN3_roads' dataset_id = 'SN3_roads'
tarballs = { tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
2: [ 2: [
'SN3_roads_train_AOI_2_Vegas.tar.gz', 'SN3_roads_train_AOI_2_Vegas.tar.gz',
@ -650,7 +659,7 @@ class SpaceNet3(SpaceNet):
5: ['SN3_roads_test_public_AOI_5_Khartoum.tar.gz'], 5: ['SN3_roads_test_public_AOI_5_Khartoum.tar.gz'],
}, },
} }
md5s = { md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
2: ['06317255b5e0c6df2643efd8a50f22ae', '4acf7846ed8121db1319345cfe9fdca9'], 2: ['06317255b5e0c6df2643efd8a50f22ae', '4acf7846ed8121db1319345cfe9fdca9'],
3: ['c13baf88ee10fe47870c303223cabf82', 'abc8199d4c522d3a14328f4f514702ad'], 3: ['c13baf88ee10fe47870c303223cabf82', 'abc8199d4c522d3a14328f4f514702ad'],
@ -664,12 +673,15 @@ class SpaceNet3(SpaceNet):
5: ['f367c79fa0fc1d38e63a0fdd065ed957'], 5: ['f367c79fa0fc1d38e63a0fdd065ed957'],
}, },
} }
valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]} valid_aois: ClassVar[dict[str, list[int]]] = {
valid_images = { 'train': [2, 3, 4, 5],
'test': [2, 3, 4, 5],
}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MS', 'PS-MS', 'PAN', 'PS-RGB'], 'train': ['MS', 'PS-MS', 'PAN', 'PS-RGB'],
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'], 'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
} }
valid_masks = ['geojson_roads', 'geojson_roads_speed'] valid_masks: tuple[str, ...] = ('geojson_roads', 'geojson_roads_speed')
class SpaceNet4(SpaceNet): class SpaceNet4(SpaceNet):
@ -708,7 +720,7 @@ class SpaceNet4(SpaceNet):
directory_glob = os.path.join('**', '{product}') directory_glob = os.path.join('**', '{product}')
file_regex = r'_(\d+_\d+)\.' file_regex = r'_(\d+_\d+)\.'
dataset_id = 'SN4_buildings' dataset_id = 'SN4_buildings'
tarballs = { tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
6: [ 6: [
'Atlanta_nadir7_catid_1030010003D22F00.tar.gz', 'Atlanta_nadir7_catid_1030010003D22F00.tar.gz',
@ -743,7 +755,7 @@ class SpaceNet4(SpaceNet):
}, },
'test': {6: ['SN4_buildings_AOI_6_Atlanta_test_public.tar.gz']}, 'test': {6: ['SN4_buildings_AOI_6_Atlanta_test_public.tar.gz']},
} }
md5s = { md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
6: [ 6: [
'd41ab6ec087b07e1e046c55d1fa5754b', 'd41ab6ec087b07e1e046c55d1fa5754b',
@ -778,12 +790,12 @@ class SpaceNet4(SpaceNet):
}, },
'test': {6: ['0ec3874bfc19aed63b33ac47b039aace']}, 'test': {6: ['0ec3874bfc19aed63b33ac47b039aace']},
} }
valid_aois = {'train': [6], 'test': [6]} valid_aois: ClassVar[dict[str, list[int]]] = {'train': [6], 'test': [6]}
valid_images = { valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MS', 'PAN', 'Pan-Sharpen'], 'train': ['MS', 'PAN', 'Pan-Sharpen'],
'test': ['MS', 'PAN', 'Pan-Sharpen'], 'test': ['MS', 'PAN', 'Pan-Sharpen'],
} }
valid_masks = [os.path.join('geojson', 'spacenet-buildings')] valid_masks = (os.path.join('geojson', 'spacenet-buildings'),)
class SpaceNet5(SpaceNet3): class SpaceNet5(SpaceNet3):
@ -850,26 +862,26 @@ class SpaceNet5(SpaceNet3):
file_regex = r'_chip(\d+)\.' file_regex = r'_chip(\d+)\.'
dataset_id = 'SN5_roads' dataset_id = 'SN5_roads'
tarballs = { tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
7: ['SN5_roads_train_AOI_7_Moscow.tar.gz'], 7: ['SN5_roads_train_AOI_7_Moscow.tar.gz'],
8: ['SN5_roads_train_AOI_8_Mumbai.tar.gz'], 8: ['SN5_roads_train_AOI_8_Mumbai.tar.gz'],
}, },
'test': {9: ['SN5_roads_test_public_AOI_9_San_Juan.tar.gz']}, 'test': {9: ['SN5_roads_test_public_AOI_9_San_Juan.tar.gz']},
} }
md5s = { md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
7: ['03082d01081a6d8df2bc5a9645148d2a'], 7: ['03082d01081a6d8df2bc5a9645148d2a'],
8: ['1ee20ba781da6cb7696eef9a95a5bdcc'], 8: ['1ee20ba781da6cb7696eef9a95a5bdcc'],
}, },
'test': {9: ['fc45afef219dfd3a20f2d4fc597f6882']}, 'test': {9: ['fc45afef219dfd3a20f2d4fc597f6882']},
} }
valid_aois = {'train': [7, 8], 'test': [9]} valid_aois: ClassVar[dict[str, list[int]]] = {'train': [7, 8], 'test': [9]}
valid_images = { valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MS', 'PAN', 'PS-MS', 'PS-RGB'], 'train': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
'test': ['MS', 'PAN', 'PS-MS', 'PS-RGB'], 'test': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
} }
valid_masks = ['geojson_roads_speed'] valid_masks = ('geojson_roads_speed',)
class SpaceNet6(SpaceNet): class SpaceNet6(SpaceNet):
@ -937,20 +949,20 @@ class SpaceNet6(SpaceNet):
file_regex = r'_tile_(\d+)\.' file_regex = r'_tile_(\d+)\.'
dataset_id = 'SN6_buildings' dataset_id = 'SN6_buildings'
tarballs = { tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {11: ['SN6_buildings_AOI_11_Rotterdam_train.tar.gz']}, 'train': {11: ['SN6_buildings_AOI_11_Rotterdam_train.tar.gz']},
'test': {11: ['SN6_buildings_AOI_11_Rotterdam_test_public.tar.gz']}, 'test': {11: ['SN6_buildings_AOI_11_Rotterdam_test_public.tar.gz']},
} }
md5s = { md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {11: ['10ca26d2287716e3b6ef0cf0ad9f946e']}, 'train': {11: ['10ca26d2287716e3b6ef0cf0ad9f946e']},
'test': {11: ['a07823a5e536feeb8bb6b6f0cb43cf05']}, 'test': {11: ['a07823a5e536feeb8bb6b6f0cb43cf05']},
} }
valid_aois = {'train': [11], 'test': [11]} valid_aois: ClassVar[dict[str, list[int]]] = {'train': [11], 'test': [11]}
valid_images = { valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['PAN', 'PS-RGB', 'PS-RGBNIR', 'RGBNIR', 'SAR-Intensity'], 'train': ['PAN', 'PS-RGB', 'PS-RGBNIR', 'RGBNIR', 'SAR-Intensity'],
'test': ['SAR-Intensity'], 'test': ['SAR-Intensity'],
} }
valid_masks = ['geojson_buildings'] valid_masks = ('geojson_buildings',)
class SpaceNet7(SpaceNet): class SpaceNet7(SpaceNet):
@ -958,7 +970,7 @@ class SpaceNet7(SpaceNet):
`SpaceNet 7 <https://spacenet.ai/sn7-challenge/>`_ is a dataset which `SpaceNet 7 <https://spacenet.ai/sn7-challenge/>`_ is a dataset which
consist of medium resolution (4.0m) satellite imagery mosaics acquired from consist of medium resolution (4.0m) satellite imagery mosaics acquired from
Planet Labs Dove constellation between 2017 and 2020. It includes 24 Planet Labs' Dove constellation between 2017 and 2020. It includes ≈ 24
images (one per month) covering > 100 unique geographies, and comprises > images (one per month) covering > 100 unique geographies, and comprises >
40,000 km2 of imagery and exhaustive polygon labels of building footprints 40,000 km2 of imagery and exhaustive polygon labels of building footprints
therein, totaling over 11M individual annotations. therein, totaling over 11M individual annotations.
@ -993,18 +1005,24 @@ class SpaceNet7(SpaceNet):
mask_glob = '*_Buildings.geojson' mask_glob = '*_Buildings.geojson'
file_regex = r'global_monthly_(\d+.*\d+)' file_regex = r'global_monthly_(\d+.*\d+)'
dataset_id = 'SN7_buildings' dataset_id = 'SN7_buildings'
tarballs = { tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {0: ['SN7_buildings_train.tar.gz']}, 'train': {0: ['SN7_buildings_train.tar.gz']},
'test': {0: ['SN7_buildings_test_public.tar.gz']}, 'test': {0: ['SN7_buildings_test_public.tar.gz']},
} }
md5s = { md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {0: ['6eda13b9c28f6f5cdf00a7e8e218c1b1']}, 'train': {0: ['6eda13b9c28f6f5cdf00a7e8e218c1b1']},
'test': {0: ['b3bde95a0f8f32f3bfeba49464b9bc97']}, 'test': {0: ['b3bde95a0f8f32f3bfeba49464b9bc97']},
} }
valid_aois = {'train': [0], 'test': [0]} valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]}
valid_images = {'train': ['images', 'images_masked'], 'test': ['images_masked']} valid_images: ClassVar[dict[str, list[str]]] = {
valid_masks = ['labels', 'labels_match', 'labels_match_pix'] 'train': ['images', 'images_masked'],
chip_size = {'images': (1024, 1024), 'images_masked': (1024, 1024)} 'test': ['images_masked'],
}
valid_masks = ('labels', 'labels_match', 'labels_match_pix')
chip_size: ClassVar[dict[str, tuple[int, int]]] = {
'images': (1024, 1024),
'images_masked': (1024, 1024),
}
class SpaceNet8(SpaceNet): class SpaceNet8(SpaceNet):
@ -1024,7 +1042,7 @@ class SpaceNet8(SpaceNet):
directory_glob = '{product}' directory_glob = '{product}'
file_regex = r'(\d+_\d+_\d+)\.' file_regex = r'(\d+_\d+_\d+)\.'
dataset_id = 'SN8_floods' dataset_id = 'SN8_floods'
tarballs = { tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
0: [ 0: [
'Germany_Training_Public.tar.gz', 'Germany_Training_Public.tar.gz',
@ -1033,16 +1051,19 @@ class SpaceNet8(SpaceNet):
}, },
'test': {0: ['Louisiana-West_Test_Public.tar.gz']}, 'test': {0: ['Louisiana-West_Test_Public.tar.gz']},
} }
md5s = { md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': { 'train': {
0: ['81383a9050b93e8f70c8557d4568e8a2', 'fa40ae3cf6ac212c90073bf93d70bd95'] 0: ['81383a9050b93e8f70c8557d4568e8a2', 'fa40ae3cf6ac212c90073bf93d70bd95']
}, },
'test': {0: ['d41d8cd98f00b204e9800998ecf8427e']}, 'test': {0: ['d41d8cd98f00b204e9800998ecf8427e']},
} }
valid_aois = {'train': [0], 'test': [0]} valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]}
valid_images = { valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['PRE-event', 'POST-event'], 'train': ['PRE-event', 'POST-event'],
'test': ['PRE-event', 'POST-event'], 'test': ['PRE-event', 'POST-event'],
} }
valid_masks = ['annotations'] valid_masks = ('annotations',)
chip_size = {'PRE-event': (1300, 1300), 'POST-event': (1300, 1300)} chip_size: ClassVar[dict[str, tuple[int, int]]] = {
'PRE-event': (1300, 1300),
'POST-event': (1300, 1300),
}

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os import os
import random import random
from collections.abc import Callable from collections.abc import Callable
from typing import TypedDict from typing import ClassVar, TypedDict
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -93,13 +93,13 @@ class SSL4EOL(NonGeoDataset):
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html * https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
.. versionadded:: 0.5 .. versionadded:: 0.5
""" # noqa: E501 """
class _Metadata(TypedDict): class _Metadata(TypedDict):
num_bands: int num_bands: int
rgb_bands: list[int] rgb_bands: list[int]
metadata: dict[str, _Metadata] = { metadata: ClassVar[dict[str, _Metadata]] = {
'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]}, 'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]},
'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]}, 'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]},
'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]}, 'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]},
@ -107,8 +107,8 @@ class SSL4EOL(NonGeoDataset):
'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]}, 'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]},
} }
url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}' # noqa: E501 url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}'
checksums = { checksums: ClassVar[dict[str, dict[str, str]]] = {
'tm_toa': { 'tm_toa': {
'aa': '553795b8d73aa253445b1e67c5b81f11', 'aa': '553795b8d73aa253445b1e67c5b81f11',
'ab': 'e9e0739b5171b37d16086cb89ab370e8', 'ab': 'e9e0739b5171b37d16086cb89ab370e8',
@ -357,7 +357,7 @@ class SSL4EOS12(NonGeoDataset):
md5: str md5: str
bands: list[str] bands: list[str]
metadata: dict[str, _Metadata] = { metadata: ClassVar[dict[str, _Metadata]] = {
's1': { 's1': {
'filename': 's1.tar.gz', 'filename': 's1.tar.gz',
'md5': '51ee23b33eb0a2f920bda25225072f3a', 'md5': '51ee23b33eb0a2f920bda25225072f3a',

Просмотреть файл

@ -6,6 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -46,16 +47,16 @@ class SSL4EOLBenchmark(NonGeoDataset):
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html * https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
.. versionadded:: 0.5 .. versionadded:: 0.5
""" # noqa: E501 """
url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz' # noqa: E501 url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz'
valid_sensors = ['tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr'] valid_sensors = ('tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr')
valid_products = ['cdl', 'nlcd'] valid_products = ('cdl', 'nlcd')
valid_splits = ['train', 'val', 'test'] valid_splits = ('train', 'val', 'test')
image_root = 'ssl4eo_l_{}_benchmark' image_root = 'ssl4eo_l_{}_benchmark'
img_md5s = { img_md5s: ClassVar[dict[str, str]] = {
'tm_toa': '8e3c5bcd56d3780a442f1332013b8d15', 'tm_toa': '8e3c5bcd56d3780a442f1332013b8d15',
'etm_toa': '1b051c7fe4d61c581b341370c9e76f1f', 'etm_toa': '1b051c7fe4d61c581b341370c9e76f1f',
'etm_sr': '34a24fa89a801654f8d01e054662c8cd', 'etm_sr': '34a24fa89a801654f8d01e054662c8cd',
@ -63,14 +64,14 @@ class SSL4EOLBenchmark(NonGeoDataset):
'oli_sr': '0700cd15cc2366fe68c2f8c02fa09a15', 'oli_sr': '0700cd15cc2366fe68c2f8c02fa09a15',
} }
mask_dir_dict = { mask_dir_dict: ClassVar[dict[str, str]] = {
'tm_toa': 'ssl4eo_l_tm_{}', 'tm_toa': 'ssl4eo_l_tm_{}',
'etm_toa': 'ssl4eo_l_etm_{}', 'etm_toa': 'ssl4eo_l_etm_{}',
'etm_sr': 'ssl4eo_l_etm_{}', 'etm_sr': 'ssl4eo_l_etm_{}',
'oli_tirs_toa': 'ssl4eo_l_oli_{}', 'oli_tirs_toa': 'ssl4eo_l_oli_{}',
'oli_sr': 'ssl4eo_l_oli_{}', 'oli_sr': 'ssl4eo_l_oli_{}',
} }
mask_md5s = { mask_md5s: ClassVar[dict[str, dict[str, str]]] = {
'tm': { 'tm': {
'cdl': '3d676770ffb56c7e222a7192a652a846', 'cdl': '3d676770ffb56c7e222a7192a652a846',
'nlcd': '261149d7614fcfdcb3be368eefa825c7', 'nlcd': '261149d7614fcfdcb3be368eefa825c7',
@ -85,7 +86,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
}, },
} }
year_dict = { year_dict: ClassVar[dict[str, int]] = {
'tm_toa': 2011, 'tm_toa': 2011,
'etm_toa': 2019, 'etm_toa': 2019,
'etm_sr': 2019, 'etm_sr': 2019,
@ -93,7 +94,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
'oli_sr': 2019, 'oli_sr': 2019,
} }
rgb_indices = { rgb_indices: ClassVar[dict[str, list[int]]] = {
'tm_toa': [2, 1, 0], 'tm_toa': [2, 1, 0],
'etm_toa': [2, 1, 0], 'etm_toa': [2, 1, 0],
'etm_sr': [2, 1, 0], 'etm_sr': [2, 1, 0],
@ -101,9 +102,12 @@ class SSL4EOLBenchmark(NonGeoDataset):
'oli_sr': [3, 2, 1], 'oli_sr': [3, 2, 1],
} }
split_percentages = [0.7, 0.15, 0.15] split_percentages = (0.7, 0.15, 0.15)
cmaps = {'nlcd': NLCD.cmap, 'cdl': CDL.cmap} cmaps: ClassVar[dict[str, dict[int, tuple[int, int, int, int]]]] = {
'nlcd': NLCD.cmap,
'cdl': CDL.cmap,
}
def __init__( def __init__(
self, self,

Просмотреть файл

@ -45,17 +45,17 @@ class SustainBenchCropYield(NonGeoDataset):
* https://doi.org/10.1609/aaai.v31i1.11172 * https://doi.org/10.1609/aaai.v31i1.11172
.. versionadded:: 0.5 .. versionadded:: 0.5
""" # noqa: E501 """
valid_countries = ['usa', 'brazil', 'argentina'] valid_countries = ('usa', 'brazil', 'argentina')
md5 = '362bad07b51a1264172b8376b39d1fc9' md5 = '362bad07b51a1264172b8376b39d1fc9'
url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link' # noqa: E501 url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link'
dir = 'soybeans' dir = 'soybeans'
valid_splits = ['train', 'dev', 'test'] valid_splits = ('train', 'dev', 'test')
def __init__( def __init__(
self, self,

Просмотреть файл

@ -5,7 +5,7 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import cast from typing import ClassVar, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -66,19 +66,19 @@ class UCMerced(NonGeoClassificationDataset):
* https://dl.acm.org/doi/10.1145/1869790.1869829 * https://dl.acm.org/doi/10.1145/1869790.1869829
""" """
url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip' # noqa: E501 url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip'
filename = 'UCMerced_LandUse.zip' filename = 'UCMerced_LandUse.zip'
md5 = '5b7ec56793786b6dc8a908e8854ac0e4' md5 = '5b7ec56793786b6dc8a908e8854ac0e4'
base_dir = os.path.join('UCMerced_LandUse', 'Images') base_dir = os.path.join('UCMerced_LandUse', 'Images')
splits = ['train', 'val', 'test'] splits = ('train', 'val', 'test')
split_urls = { split_urls: ClassVar[dict[str, str]] = {
'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt', # noqa: E501 'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt',
'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt', # noqa: E501 'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt',
'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt', # noqa: E501 'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt',
} }
split_md5s = { split_md5s: ClassVar[dict[str, str]] = {
'train': 'f2fb12eb2210cfb53f93f063a35ff374', 'train': 'f2fb12eb2210cfb53f93f063a35ff374',
'val': '11ecabfc52782e5ea6a9c7c0d263aca0', 'val': '11ecabfc52782e5ea6a9c7c0d263aca0',
'test': '046aff88472d8fc07c4678d03749e28d', 'test': '046aff88472d8fc07c4678d03749e28d',

Просмотреть файл

@ -6,6 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -49,12 +50,12 @@ class USAVars(NonGeoDataset):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}' # noqa: E501 data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}'
dirname = 'uar' dirname = 'uar'
md5 = '677e89fd20e5dd0fe4d29b61827c2456' md5 = '677e89fd20e5dd0fe4d29b61827c2456'
label_urls = { label_urls: ClassVar[dict[str, str]] = {
'housing': data_url.format('housing.csv'), 'housing': data_url.format('housing.csv'),
'income': data_url.format('income.csv'), 'income': data_url.format('income.csv'),
'roads': data_url.format('roads.csv'), 'roads': data_url.format('roads.csv'),
@ -64,7 +65,7 @@ class USAVars(NonGeoDataset):
'treecover': data_url.format('treecover.csv'), 'treecover': data_url.format('treecover.csv'),
} }
split_metadata = { split_metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': { 'train': {
'url': data_url.format('train_split.txt'), 'url': data_url.format('train_split.txt'),
'filename': 'train_split.txt', 'filename': 'train_split.txt',
@ -82,7 +83,7 @@ class USAVars(NonGeoDataset):
}, },
} }
ALL_LABELS = ['treecover', 'elevation', 'population'] ALL_LABELS = ('treecover', 'elevation', 'population')
def __init__( def __init__(
self, self,

Просмотреть файл

@ -86,11 +86,11 @@ class BoundingBox:
# https://github.com/PyCQA/pydocstyle/issues/525 # https://github.com/PyCQA/pydocstyle/issues/525
@overload @overload
def __getitem__(self, key: int) -> float: # noqa: D105 def __getitem__(self, key: int) -> float:
pass pass
@overload @overload
def __getitem__(self, key: slice) -> list[float]: # noqa: D105 def __getitem__(self, key: slice) -> list[float]:
pass pass
def __getitem__(self, key: int | slice) -> float | list[float]: def __getitem__(self, key: int | slice) -> float | list[float]:
@ -289,7 +289,7 @@ class Executable:
The completed process. The completed process.
""" """
kwargs['check'] = True kwargs['check'] = True
return subprocess.run((self.name,) + args, **kwargs) return subprocess.run((self.name, *args), **kwargs)
def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]: def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]:
@ -547,7 +547,7 @@ def draw_semantic_segmentation_masks(
def rgb_to_mask( def rgb_to_mask(
rgb: np.typing.NDArray[np.uint8], colors: list[tuple[int, int, int]] rgb: np.typing.NDArray[np.uint8], colors: Sequence[tuple[int, int, int]]
) -> np.typing.NDArray[np.uint8]: ) -> np.typing.NDArray[np.uint8]:
"""Converts an RGB colormap mask to a integer mask. """Converts an RGB colormap mask to a integer mask.

Просмотреть файл

@ -5,6 +5,7 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -55,15 +56,15 @@ class Vaihingen2D(NonGeoDataset):
* https://doi.org/10.5194/isprsannals-I-3-293-2012 * https://doi.org/10.5194/isprsannals-I-3-293-2012
.. versionadded:: 0.2 .. versionadded:: 0.2
""" # noqa: E501 """
filenames = [ filenames = (
'ISPRS_semantic_labeling_Vaihingen.zip', 'ISPRS_semantic_labeling_Vaihingen.zip',
'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip', 'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip',
] )
md5s = ['462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277'] md5s = ('462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277')
image_root = 'top' image_root = 'top'
splits = { splits: ClassVar[dict[str, list[str]]] = {
'train': [ 'train': [
'top_mosaic_09cm_area1.tif', 'top_mosaic_09cm_area1.tif',
'top_mosaic_09cm_area11.tif', 'top_mosaic_09cm_area11.tif',
@ -102,22 +103,22 @@ class Vaihingen2D(NonGeoDataset):
'top_mosaic_09cm_area29.tif', 'top_mosaic_09cm_area29.tif',
], ],
} }
classes = [ classes = (
'Clutter/background', 'Clutter/background',
'Impervious surfaces', 'Impervious surfaces',
'Building', 'Building',
'Low Vegetation', 'Low Vegetation',
'Tree', 'Tree',
'Car', 'Car',
] )
colormap = [ colormap = (
(255, 0, 0), (255, 0, 0),
(255, 255, 255), (255, 255, 255),
(0, 0, 255), (0, 0, 255),
(0, 255, 255), (0, 255, 255),
(0, 255, 0), (0, 255, 0),
(255, 255, 0), (255, 255, 0),
] )
def __init__( def __init__(
self, self,
@ -258,7 +259,7 @@ class Vaihingen2D(NonGeoDataset):
""" """
ncols = 1 ncols = 1
image1 = draw_semantic_segmentation_masks( image1 = draw_semantic_segmentation_masks(
sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap)
) )
if 'prediction' in sample: if 'prediction' in sample:
ncols += 1 ncols += 1
@ -266,7 +267,7 @@ class Vaihingen2D(NonGeoDataset):
sample['image'][:3], sample['image'][:3],
sample['prediction'], sample['prediction'],
alpha=alpha, alpha=alpha,
colors=self.colormap, colors=list(self.colormap),
) )
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))

Просмотреть файл

@ -5,7 +5,7 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any, ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -158,18 +158,18 @@ class VHR10(NonGeoDataset):
``annotations.json`` file for the "positive" image set ``annotations.json`` file for the "positive" image set
""" """
image_meta = { image_meta: ClassVar[dict[str, str]] = {
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/NWPU%20VHR-10%20dataset.zip', 'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/NWPU%20VHR-10%20dataset.zip',
'filename': 'NWPU VHR-10 dataset.zip', 'filename': 'NWPU VHR-10 dataset.zip',
'md5': '6add6751469c12dd8c8d6223064c6c4d', 'md5': '6add6751469c12dd8c8d6223064c6c4d',
} }
target_meta = { target_meta: ClassVar[dict[str, str]] = {
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/annotations.json', 'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/annotations.json',
'filename': 'annotations.json', 'filename': 'annotations.json',
'md5': '7c76ec50c17a61bb0514050d20f22c08', 'md5': '7c76ec50c17a61bb0514050d20f22c08',
} }
categories = [ categories = (
'background', 'background',
'airplane', 'airplane',
'ships', 'ships',
@ -181,7 +181,7 @@ class VHR10(NonGeoDataset):
'harbor', 'harbor',
'bridge', 'bridge',
'vehicle', 'vehicle',
] )
def __init__( def __init__(
self, self,

Просмотреть файл

@ -6,7 +6,7 @@
import glob import glob
import json import json
import os import os
from collections.abc import Callable from collections.abc import Callable, Iterable
from typing import Any from typing import Any
import pandas as pd import pandas as pd
@ -53,7 +53,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
label_name = 'percent(t)' label_name = 'percent(t)'
all_variable_names = [ all_variable_names = (
# "date", # "date",
'slope(t)', 'slope(t)',
'elevation(t)', 'elevation(t)',
@ -193,12 +193,12 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
'vh_vv(t-3)', 'vh_vv(t-3)',
'lat', 'lat',
'lon', 'lon',
] )
def __init__( def __init__(
self, self,
root: Path = 'data', root: Path = 'data',
input_features: list[str] = all_variable_names, input_features: Iterable[str] = all_variable_names,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
@ -273,7 +273,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
data_rows.append(data_dict) data_rows.append(data_dict)
df = pd.DataFrame(data_rows) df = pd.DataFrame(data_rows)
df = df[self.input_features + [self.label_name]] df = df[[*self.input_features, self.label_name]]
return df return df
def _verify(self) -> None: def _verify(self) -> None:

Просмотреть файл

@ -6,6 +6,7 @@
import glob import glob
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -54,7 +55,7 @@ class XView2(NonGeoDataset):
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
metadata = { metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': { 'train': {
'filename': 'train_images_labels_targets.tar.gz', 'filename': 'train_images_labels_targets.tar.gz',
'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16',
@ -66,8 +67,8 @@ class XView2(NonGeoDataset):
'directory': 'test', 'directory': 'test',
}, },
} }
classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed'] classes = ('background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed')
colormap = ['green', 'blue', 'orange', 'red'] colormap = ('green', 'blue', 'orange', 'red')
def __init__( def __init__(
self, self,
@ -242,10 +243,16 @@ class XView2(NonGeoDataset):
""" """
ncols = 2 ncols = 2
image1 = draw_semantic_segmentation_masks( image1 = draw_semantic_segmentation_masks(
sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap sample['image'][0],
sample['mask'][0],
alpha=alpha,
colors=list(self.colormap),
) )
image2 = draw_semantic_segmentation_masks( image2 = draw_semantic_segmentation_masks(
sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap sample['image'][1],
sample['mask'][1],
alpha=alpha,
colors=list(self.colormap),
) )
if 'prediction' in sample: # NOTE: this assumes predictions are made for post if 'prediction' in sample: # NOTE: this assumes predictions are made for post
ncols += 1 ncols += 1
@ -253,7 +260,7 @@ class XView2(NonGeoDataset):
sample['image'][1], sample['image'][1],
sample['prediction'], sample['prediction'],
alpha=alpha, alpha=alpha,
colors=self.colormap, colors=list(self.colormap),
) )
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))

Просмотреть файл

@ -52,15 +52,15 @@ class ZueriCrop(NonGeoDataset):
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset * `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
""" """
urls = [ urls = (
'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download', 'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download',
'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv', # noqa: E501 'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv',
] )
md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b'] md5s = ('1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b')
filenames = ['ZueriCrop.hdf5', 'labels.csv'] filenames = ('ZueriCrop.hdf5', 'labels.csv')
band_names = ('NIR', 'B03', 'B02', 'B04', 'B05', 'B06', 'B07', 'B11', 'B12') band_names = ('NIR', 'B03', 'B02', 'B04', 'B05', 'B06', 'B07', 'B11', 'B12')
rgb_bands = ['B04', 'B03', 'B02'] rgb_bands = ('B04', 'B03', 'B02')
def __init__( def __init__(
self, self,

Просмотреть файл

@ -8,7 +8,7 @@ import os
from lightning.pytorch.cli import ArgsType, LightningCLI from lightning.pytorch.cli import ArgsType, LightningCLI
# Allows classes to be referenced using only the class name # Allows classes to be referenced using only the class name
import torchgeo.datamodules # noqa: F401 import torchgeo.datamodules
import torchgeo.trainers # noqa: F401 import torchgeo.trainers # noqa: F401
from torchgeo.datamodules import BaseDataModule from torchgeo.datamodules import BaseDataModule
from torchgeo.trainers import BaseTask from torchgeo.trainers import BaseTask

Просмотреть файл

@ -8,7 +8,7 @@ See the following references for design details:
* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/ * https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
* https://pytorch.org/vision/stable/models.html * https://pytorch.org/vision/stable/models.html
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py * https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
""" # noqa: E501 """
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any

Просмотреть файл

@ -384,7 +384,7 @@ class DOFABase16_Weights(WeightsEnum): # type: ignore[misc]
""" """
DOFA_MAE = Weights( DOFA_MAE = Weights(
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth', # noqa: E501 url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth',
transforms=_dofa_transforms, transforms=_dofa_transforms,
meta={ meta={
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',
@ -403,7 +403,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc]
""" """
DOFA_MAE = Weights( DOFA_MAE = Weights(
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth', # noqa: E501 url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth',
transforms=_dofa_transforms, transforms=_dofa_transforms,
meta={ meta={
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',

Просмотреть файл

@ -140,7 +140,7 @@ class RCF(Module):
a numpy array of size (N, C, H, W) containing the normalized patches a numpy array of size (N, C, H, W) containing the normalized patches
.. versionadded:: 0.5 .. versionadded:: 0.5
""" # noqa: E501 """
n_patches = patches.shape[0] n_patches = patches.shape[0]
orig_shape = patches.shape orig_shape = patches.shape
patches = patches.reshape(patches.shape[0], -1) patches = patches.reshape(patches.shape[0], -1)

Просмотреть файл

@ -11,8 +11,8 @@ import torch
from timm.models import ResNet from timm.models import ResNet
from torchvision.models._api import Weights, WeightsEnum from torchvision.models._api import Weights, WeightsEnum
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
# Normalization either by 10K or channel-wise with band statistics # Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = K.AugmentationSequential( _zhu_xlab_transforms = K.AugmentationSequential(
K.Resize(256), K.Resize(256),
@ -22,7 +22,7 @@ _zhu_xlab_transforms = K.AugmentationSequential(
) )
# Normalization only available for RGB dataset, defined here: # Normalization only available for RGB dataset, defined here:
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501 # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
_min = torch.tensor([3, 2, 0]) _min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129]) _max = torch.tensor([88, 103, 129])
_mean = torch.tensor([0.485, 0.456, 0.406]) _mean = torch.tensor([0.485, 0.456, 0.406])
@ -37,7 +37,7 @@ _seco_transforms = K.AugmentationSequential(
) )
# Normalization only available for RGB dataset, defined here: # Normalization only available for RGB dataset, defined here:
# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501 # https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287
_mean = torch.tensor([0.485, 0.456, 0.406]) _mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225]) _std = torch.tensor([0.229, 0.224, 0.225])
_gassl_transforms = K.AugmentationSequential( _gassl_transforms = K.AugmentationSequential(
@ -47,7 +47,7 @@ _gassl_transforms = K.AugmentationSequential(
data_keys=None, data_keys=None,
) )
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501 # https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43
_ssl4eo_l_transforms = K.AugmentationSequential( _ssl4eo_l_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.CenterCrop((224, 224)), K.CenterCrop((224, 224)),
@ -70,7 +70,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
""" """
LANDSAT_TM_TOA_MOCO = Weights( LANDSAT_TM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -83,7 +83,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_TM_TOA_SIMCLR = Weights( LANDSAT_TM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -96,7 +96,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_TOA_MOCO = Weights( LANDSAT_ETM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -109,7 +109,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_TOA_SIMCLR = Weights( LANDSAT_ETM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -122,7 +122,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_SR_MOCO = Weights( LANDSAT_ETM_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -135,7 +135,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_SR_SIMCLR = Weights( LANDSAT_ETM_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -148,7 +148,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_TIRS_TOA_MOCO = Weights( LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -161,7 +161,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -174,7 +174,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_SR_MOCO = Weights( LANDSAT_OLI_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -187,7 +187,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_SR_SIMCLR = Weights( LANDSAT_OLI_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -200,7 +200,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_ALL_MOCO = Weights( SENTINEL2_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth', # noqa: E501 url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth',
transforms=_zhu_xlab_transforms, transforms=_zhu_xlab_transforms,
meta={ meta={
'dataset': 'SSL4EO-S12', 'dataset': 'SSL4EO-S12',
@ -213,7 +213,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_RGB_MOCO = Weights( SENTINEL2_RGB_MOCO = Weights(
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth', # noqa: E501 url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth',
transforms=_zhu_xlab_transforms, transforms=_zhu_xlab_transforms,
meta={ meta={
'dataset': 'SSL4EO-S12', 'dataset': 'SSL4EO-S12',
@ -226,7 +226,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_RGB_SECO = Weights( SENTINEL2_RGB_SECO = Weights(
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth', # noqa: E501 url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth',
transforms=_seco_transforms, transforms=_seco_transforms,
meta={ meta={
'dataset': 'SeCo Dataset', 'dataset': 'SeCo Dataset',
@ -249,7 +249,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
""" """
FMOW_RGB_GASSL = Weights( FMOW_RGB_GASSL = Weights(
url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth', # noqa: E501 url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth',
transforms=_gassl_transforms, transforms=_gassl_transforms,
meta={ meta={
'dataset': 'fMoW Dataset', 'dataset': 'fMoW Dataset',
@ -262,7 +262,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_TM_TOA_MOCO = Weights( LANDSAT_TM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -275,7 +275,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_TM_TOA_SIMCLR = Weights( LANDSAT_TM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -288,7 +288,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_TOA_MOCO = Weights( LANDSAT_ETM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -301,7 +301,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_TOA_SIMCLR = Weights( LANDSAT_ETM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -314,7 +314,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_SR_MOCO = Weights( LANDSAT_ETM_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -327,7 +327,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_SR_SIMCLR = Weights( LANDSAT_ETM_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -340,7 +340,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_TIRS_TOA_MOCO = Weights( LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -353,7 +353,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -366,7 +366,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_SR_MOCO = Weights( LANDSAT_OLI_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -379,7 +379,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_SR_SIMCLR = Weights( LANDSAT_OLI_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -392,7 +392,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL1_ALL_MOCO = Weights( SENTINEL1_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth', # noqa: E501 url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth',
transforms=_zhu_xlab_transforms, transforms=_zhu_xlab_transforms,
meta={ meta={
'dataset': 'SSL4EO-S12', 'dataset': 'SSL4EO-S12',
@ -405,7 +405,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_ALL_DINO = Weights( SENTINEL2_ALL_DINO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth', # noqa: E501 url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth',
transforms=_zhu_xlab_transforms, transforms=_zhu_xlab_transforms,
meta={ meta={
'dataset': 'SSL4EO-S12', 'dataset': 'SSL4EO-S12',
@ -418,7 +418,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_ALL_MOCO = Weights( SENTINEL2_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth', # noqa: E501 url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth',
transforms=_zhu_xlab_transforms, transforms=_zhu_xlab_transforms,
meta={ meta={
'dataset': 'SSL4EO-S12', 'dataset': 'SSL4EO-S12',
@ -431,7 +431,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_RGB_MOCO = Weights( SENTINEL2_RGB_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', # noqa: E501 url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth',
transforms=_zhu_xlab_transforms, transforms=_zhu_xlab_transforms,
meta={ meta={
'dataset': 'SSL4EO-S12', 'dataset': 'SSL4EO-S12',
@ -444,7 +444,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_RGB_SECO = Weights( SENTINEL2_RGB_SECO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth', # noqa: E501 url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth',
transforms=_seco_transforms, transforms=_seco_transforms,
meta={ meta={
'dataset': 'SeCo Dataset', 'dataset': 'SeCo Dataset',

Просмотреть файл

@ -12,20 +12,20 @@ from kornia.contrib import Lambda
from torchvision.models import SwinTransformer from torchvision.models import SwinTransformer
from torchvision.models._api import Weights, WeightsEnum from torchvision.models._api import Weights, WeightsEnum
# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501 # https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). # Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501 # See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501 # Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255.
_satlas_transforms = K.AugmentationSequential( _satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None
) )
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). # Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501 # See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). # noqa: E501 # Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1).
_std = torch.tensor( _std = torch.tensor(
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0] [255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
) # noqa: E501 )
_mean = torch.zeros_like(_std) _mean = torch.zeros_like(_std)
_sentinel2_ms_satlas_transforms = K.AugmentationSequential( _sentinel2_ms_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=_mean, std=_std), K.Normalize(mean=_mean, std=_std),
@ -33,7 +33,7 @@ _sentinel2_ms_satlas_transforms = K.AugmentationSequential(
data_keys=None, data_keys=None,
) )
# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501 # Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1).
_landsat_satlas_transforms = K.AugmentationSequential( _landsat_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)), K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))), K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
@ -56,7 +56,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
""" """
NAIP_RGB_SI_SATLAS = Weights( NAIP_RGB_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth', # noqa: E501 url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth',
transforms=_satlas_transforms, transforms=_satlas_transforms,
meta={ meta={
'dataset': 'Satlas', 'dataset': 'Satlas',
@ -68,7 +68,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_RGB_SI_SATLAS = Weights( SENTINEL2_RGB_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth', # noqa: E501 url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth',
transforms=_satlas_transforms, transforms=_satlas_transforms,
meta={ meta={
'dataset': 'Satlas', 'dataset': 'Satlas',
@ -80,7 +80,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_MS_SI_SATLAS = Weights( SENTINEL2_MS_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth', # noqa: E501 url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth',
transforms=_sentinel2_ms_satlas_transforms, transforms=_sentinel2_ms_satlas_transforms,
meta={ meta={
'dataset': 'Satlas', 'dataset': 'Satlas',
@ -93,7 +93,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL1_SI_SATLAS = Weights( SENTINEL1_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth', # noqa: E501 url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth',
transforms=_satlas_transforms, transforms=_satlas_transforms,
meta={ meta={
'dataset': 'Satlas', 'dataset': 'Satlas',
@ -106,7 +106,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_SI_SATLAS = Weights( LANDSAT_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth', # noqa: E501 url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth',
transforms=_landsat_satlas_transforms, transforms=_landsat_satlas_transforms,
meta={ meta={
'dataset': 'Satlas', 'dataset': 'Satlas',
@ -126,7 +126,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
'B09', 'B09',
'B10', 'B10',
'B11', 'B11',
], # noqa: E501 ],
}, },
) )

Просмотреть файл

@ -11,8 +11,8 @@ import torch
from timm.models.vision_transformer import VisionTransformer from timm.models.vision_transformer import VisionTransformer
from torchvision.models._api import Weights, WeightsEnum from torchvision.models._api import Weights, WeightsEnum
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
# Normalization either by 10K or channel-wise with band statistics # Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = K.AugmentationSequential( _zhu_xlab_transforms = K.AugmentationSequential(
K.Resize(256), K.Resize(256),
@ -21,7 +21,7 @@ _zhu_xlab_transforms = K.AugmentationSequential(
data_keys=None, data_keys=None,
) )
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501 # https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43
_ssl4eo_l_transforms = K.AugmentationSequential( _ssl4eo_l_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.CenterCrop((224, 224)), K.CenterCrop((224, 224)),
@ -44,7 +44,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
""" """
LANDSAT_TM_TOA_MOCO = Weights( LANDSAT_TM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -57,7 +57,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_TM_TOA_SIMCLR = Weights( LANDSAT_TM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -70,7 +70,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_TOA_MOCO = Weights( LANDSAT_ETM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -83,7 +83,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_TOA_SIMCLR = Weights( LANDSAT_ETM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -96,7 +96,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_SR_MOCO = Weights( LANDSAT_ETM_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -109,7 +109,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_ETM_SR_SIMCLR = Weights( LANDSAT_ETM_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -122,7 +122,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_TIRS_TOA_MOCO = Weights( LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -135,7 +135,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -148,7 +148,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_SR_MOCO = Weights( LANDSAT_OLI_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -161,7 +161,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
LANDSAT_OLI_SR_SIMCLR = Weights( LANDSAT_OLI_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth', # noqa: E501 url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth',
transforms=_ssl4eo_l_transforms, transforms=_ssl4eo_l_transforms,
meta={ meta={
'dataset': 'SSL4EO-L', 'dataset': 'SSL4EO-L',
@ -174,7 +174,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_ALL_DINO = Weights( SENTINEL2_ALL_DINO = Weights(
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth', # noqa: E501 url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth',
transforms=_zhu_xlab_transforms, transforms=_zhu_xlab_transforms,
meta={ meta={
'dataset': 'SSL4EO-S12', 'dataset': 'SSL4EO-S12',
@ -187,7 +187,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
) )
SENTINEL2_ALL_MOCO = Weights( SENTINEL2_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth', # noqa: E501 url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth',
transforms=_zhu_xlab_transforms, transforms=_zhu_xlab_transforms,
meta={ meta={
'dataset': 'SSL4EO-S12', 'dataset': 'SSL4EO-S12',

Просмотреть файл

@ -53,7 +53,7 @@ class AugmentationSequential(Module):
else: else:
keys.append(key) keys.append(key)
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] # noqa: E501 self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type]
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Perform augmentations and update data dict. """Perform augmentations and update data dict.