Ruff: enable ruff-specific rules (#2218)

* Ruff: enable ruff-specific rules

* Static class variables

* String colormap must be list

* String colormap must be list
This commit is contained in:
Adam J. Stewart 2024-08-19 15:07:21 +02:00 коммит произвёл GitHub
Родитель c26512e39d
Коммит 067ae1af75
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
98 изменённых файлов: 668 добавлений и 595 удалений

Просмотреть файл

@ -19,7 +19,7 @@ import pytorch_sphinx_theme
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('..'))
import torchgeo # noqa: E402
import torchgeo
# -- Project information -----------------------------------------------------

Просмотреть файл

@ -369,8 +369,8 @@
" date_format = '%Y%m%dT%H%M%S'\n",
" is_image = True\n",
" separate_files = True\n",
" all_bands = ['B02', 'B03', 'B04', 'B08']\n",
" rgb_bands = ['B04', 'B03', 'B02']"
" all_bands = ('B02', 'B03', 'B04', 'B08')\n",
" rgb_bands = ('B04', 'B03', 'B02')"
]
},
{
@ -432,8 +432,8 @@
" date_format = '%Y%m%dT%H%M%S'\n",
" is_image = True\n",
" separate_files = True\n",
" all_bands = ['B02', 'B03', 'B04', 'B08']\n",
" rgb_bands = ['B04', 'B03', 'B02']\n",
" all_bands = ('B02', 'B03', 'B04', 'B08')\n",
" rgb_bands = ('B04', 'B03', 'B02')\n",
"\n",
" def plot(self, sample):\n",
" # Find the correct band index order\n",

Просмотреть файл

@ -125,7 +125,7 @@ def filter_collection(
if filtered.size().getInfo() == 0:
raise ee.EEException(
f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' # noqa: E501
f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.'
)
return filtered

Просмотреть файл

@ -47,7 +47,7 @@ from tqdm import tqdm
def get_world_cities(
download_root: str = 'world_cities', size: int = 10000
) -> pd.DataFrame:
url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' # noqa: E501
url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip'
filename = 'worldcities.csv'
download_and_extract_archive(url, download_root)
cols = ['city', 'lat', 'lng', 'population']

Просмотреть файл

@ -268,7 +268,7 @@ quote-style = "single"
skip-magic-trailing-comma = true
[tool.ruff.lint]
extend-select = ["ANN", "D", "I", "NPY201", "UP"]
extend-select = ["ANN", "D", "I", "NPY201", "RUF", "UP"]
ignore = ["ANN101", "ANN102", "ANN401"]
[tool.ruff.lint.per-file-ignores]

Просмотреть файл

@ -19,36 +19,36 @@ np.random.seed(0)
train_set = [
{
'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', # noqa: E501
'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', # noqa: E501
'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif',
'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif',
'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif',
},
{
'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', # noqa: E501
'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', # noqa: E501
'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif',
'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif',
'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif',
},
]
unlabeled_set = [
{
'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', # noqa: E501
'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif',
'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif',
},
{
'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', # noqa: E501
'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif',
'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif',
},
]
val_set = [
{
'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', # noqa: E501
'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif',
'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif',
},
{
'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', # noqa: E501
'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif',
'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif',
},
]

Просмотреть файл

@ -112,7 +112,7 @@ for season in seasons:
# Compute checksums
with open(archive, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'{season}: {repr(md5)}')
print(f'{season}: {md5!r}')
# Write meta.csv
with open('meta.csv', 'w') as f:
@ -121,7 +121,7 @@ with open('meta.csv', 'w') as f:
# Compute checksums
with open('meta.csv', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'meta.csv: {repr(md5)}')
print(f'meta.csv: {md5!r}')
os.makedirs('splits', exist_ok=True)
@ -138,4 +138,4 @@ shutil.make_archive('splits', 'zip', '.', 'splits')
# Compute checksums
with open('splits.zip', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'splits: {repr(md5)}')
print(f'splits: {md5!r}')

Просмотреть файл

@ -83,5 +83,5 @@ class TestEuroCrops:
dataset[query]
def test_integrity_error(self, dataset: EuroCrops) -> None:
dataset.zenodo_files = [('AA.zip', 'invalid')]
dataset.zenodo_files = (('AA.zip', 'invalid'),)
assert not dataset._check_integrity()

Просмотреть файл

@ -72,7 +72,7 @@ class CustomVectorDataset(VectorDataset):
class CustomSentinelDataset(Sentinel2):
all_bands: list[str] = []
all_bands: tuple[str, ...] = ()
separate_files = False
@ -356,7 +356,7 @@ class TestRasterDataset:
def test_no_all_bands(self) -> None:
root = os.path.join('tests', 'data', 'sentinel2')
bands = ['B04', 'B03', 'B02']
bands = ('B04', 'B03', 'B02')
transforms = nn.Identity()
cache = True
msg = (

Просмотреть файл

@ -73,7 +73,7 @@ class TestBYOLTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
@pytest.fixture
def weights(self) -> WeightsEnum:

Просмотреть файл

@ -103,13 +103,13 @@ class TestClassificationTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass
@ -259,13 +259,13 @@ class TestMultiLabelClassificationTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Просмотреть файл

@ -97,13 +97,13 @@ class TestObjectDetectionTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Просмотреть файл

@ -27,12 +27,12 @@ class TestClassificationTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Просмотреть файл

@ -63,7 +63,7 @@ class TestMoCoTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'):

Просмотреть файл

@ -84,13 +84,13 @@ class TestRegressionTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass
@ -237,13 +237,13 @@ class TestPixelwiseRegressionTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Просмотреть файл

@ -108,13 +108,13 @@ class TestSemanticSegmentationTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
try:
main(['test'] + args)
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict'] + args)
main(['predict', *args])
except MisconfigurationException:
pass

Просмотреть файл

@ -63,7 +63,7 @@ class TestSimCLRTask:
'1',
]
main(['fit'] + args)
main(['fit', *args])
def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'):

Просмотреть файл

@ -37,7 +37,7 @@ class SeasonalContrastS2DataModule(NonGeoDataModule):
seasons = kwargs.get('seasons', 1)
# Normalization only available for RGB dataset, defined here:
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
if bands == SeasonalContrastS2.rgb_bands:
_min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129])

Просмотреть файл

@ -3,7 +3,7 @@
"""So2Sat datamodule."""
from typing import Any
from typing import Any, ClassVar
import torch
from torch import Generator, Tensor
@ -21,7 +21,7 @@ class So2SatDataModule(NonGeoDataModule):
"train" set and use the "test" set as the test set.
"""
means_per_version: dict[str, Tensor] = {
means_per_version: ClassVar[dict[str, Tensor]] = {
'2': torch.tensor(
[
-0.00003591224260,
@ -91,7 +91,7 @@ class So2SatDataModule(NonGeoDataModule):
}
means_per_version['3_culture_10'] = means_per_version['2']
stds_per_version: dict[str, Tensor] = {
stds_per_version: ClassVar[dict[str, Tensor]] = {
'2': torch.tensor(
[
0.17555201,

Просмотреть файл

@ -45,7 +45,7 @@ class SSL4EOS12DataModule(NonGeoDataModule):
.. versionadded:: 0.5
"""
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
mean = torch.tensor(0)
std = torch.tensor(10000)

Просмотреть файл

@ -63,14 +63,14 @@ class ADVANCE(NonGeoDataset):
* `scipy <https://pypi.org/project/scipy/>`_ to load the audio files to tensors
"""
urls = [
urls = (
'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1',
'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1',
]
filenames = ['ADVANCE_vision.zip', 'ADVANCE_sound.zip']
md5s = ['a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31']
directories = ['vision', 'sound']
classes = [
)
filenames = ('ADVANCE_vision.zip', 'ADVANCE_sound.zip')
md5s = ('a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31')
directories = ('vision', 'sound')
classes: tuple[str, ...] = (
'airport',
'beach',
'bridge',
@ -84,7 +84,7 @@ class ADVANCE(NonGeoDataset):
'sparse shrub land',
'sports land',
'train station',
]
)
def __init__(
self,
@ -119,7 +119,7 @@ class ADVANCE(NonGeoDataset):
raise DatasetNotFoundError(self)
self.files = self._load_files(self.root)
self.classes = sorted({f['cls'] for f in self.files})
self.classes = tuple(sorted({f['cls'] for f in self.files}))
self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)}
def __getitem__(self, index: int) -> dict[str, Tensor]:

Просмотреть файл

@ -46,7 +46,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
is_image = False
url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' # noqa: E501
url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326'
base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson'

Просмотреть файл

@ -7,7 +7,7 @@ import os
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import torch
@ -90,8 +90,8 @@ class AgriFieldNet(RasterDataset):
_(?P<band>B[0-9A-Z]{2})_10m
"""
rgb_bands = ['B04', 'B03', 'B02']
all_bands = [
rgb_bands = ('B04', 'B03', 'B02')
all_bands = (
'B01',
'B02',
'B03',
@ -104,9 +104,9 @@ class AgriFieldNet(RasterDataset):
'B09',
'B11',
'B12',
]
)
cmap = {
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255),
1: (255, 211, 0, 255),
2: (255, 37, 37, 255),

Просмотреть файл

@ -40,8 +40,8 @@ class Airphen(RasterDataset):
# Each camera measures a custom set of spectral bands chosen at purchase time.
# Hiphen offers 8 bands to choose from, sorted from short to long wavelength.
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8']
rgb_bands = ['B4', 'B3', 'B1']
all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8')
rgb_bands = ('B4', 'B3', 'B1')
def plot(
self,

Просмотреть файл

@ -147,7 +147,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
)
rgb_bands = ('B04', 'B03', 'B02')
classes = [
classes = (
'No data',
'Well-managed planatation',
'Poorly-managed planatation',
@ -155,7 +155,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
'Residential',
'Background',
'Uncertain',
]
)
# Same for all tiles
tile_height = 1186
@ -199,11 +199,13 @@ class BeninSmallHolderCashews(NonGeoDataset):
# Calculate the indices that we will use over all tiles
self.chips_metadata = []
for y in list(range(0, self.tile_height - self.chip_size, stride)) + [
self.tile_height - self.chip_size
for y in [
*list(range(0, self.tile_height - self.chip_size, stride)),
self.tile_height - self.chip_size,
]:
for x in list(range(0, self.tile_width - self.chip_size, stride)) + [
self.tile_width - self.chip_size
for x in [
*list(range(0, self.tile_width - self.chip_size, stride)),
self.tile_width - self.chip_size,
]:
self.chips_metadata.append((y, x))

Просмотреть файл

@ -7,6 +7,7 @@ import glob
import json
import os
from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -124,9 +125,9 @@ class BigEarthNet(NonGeoDataset):
* https://doi.org/10.1109/IGARSS.2019.8900532
""" # noqa: E501
"""
class_sets = {
class_sets: ClassVar[dict[int, list[str]]] = {
19: [
'Urban fabric',
'Industrial or commercial units',
@ -197,7 +198,7 @@ class BigEarthNet(NonGeoDataset):
],
}
label_converter = {
label_converter: ClassVar[dict[int, int]] = {
0: 0,
1: 0,
2: 1,
@ -232,24 +233,24 @@ class BigEarthNet(NonGeoDataset):
42: 18,
}
splits_metadata = {
splits_metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false', # noqa: E501
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false',
'filename': 'bigearthnet-train.csv',
'md5': '623e501b38ab7b12fe44f0083c00986d',
},
'val': {
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false', # noqa: E501
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false',
'filename': 'bigearthnet-val.csv',
'md5': '22efe8ed9cbd71fa10742ff7df2b7978',
},
'test': {
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false', # noqa: E501
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false',
'filename': 'bigearthnet-test.csv',
'md5': '697fb90677e30571b9ac7699b7e5b432',
},
}
metadata = {
metadata: ClassVar[dict[str, dict[str, str]]] = {
's1': {
'url': 'https://zenodo.org/records/12687186/files/BigEarthNet-S1-v1.0.tar.gz',
'md5': '94ced73440dea8c7b9645ee738c5a172',

Просмотреть файл

@ -50,7 +50,7 @@ class BioMassters(NonGeoDataset):
.. versionadded:: 0.5
"""
valid_splits = ['train', 'test']
valid_splits = ('train', 'test')
valid_sensors = ('S1', 'S2')
metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv'

Просмотреть файл

@ -30,7 +30,7 @@ class CanadianBuildingFootprints(VectorDataset):
# https://github.com/microsoft/CanadianBuildingFootprints/issues/11
url = 'https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/'
provinces_territories = [
provinces_territories = (
'Alberta',
'BritishColumbia',
'Manitoba',
@ -44,8 +44,8 @@ class CanadianBuildingFootprints(VectorDataset):
'Quebec',
'Saskatchewan',
'YukonTerritory',
]
md5s = [
)
md5s = (
'8b4190424e57bb0902bd8ecb95a9235b',
'fea05d6eb0006710729c675de63db839',
'adf11187362624d68f9c69aaa693c46f',
@ -59,7 +59,7 @@ class CanadianBuildingFootprints(VectorDataset):
'9ff4417ae00354d39a0cf193c8df592c',
'a51078d8e60082c7d3a3818240da6dd5',
'c11f3bd914ecabd7cac2cb2871ec0261',
]
)
def __init__(
self,

Просмотреть файл

@ -6,7 +6,7 @@
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any
from typing import Any, ClassVar
import matplotlib.pyplot as plt
import torch
@ -38,7 +38,7 @@ class CDL(RasterDataset):
If you use this dataset in your research, please cite it using the following format:
* https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0
""" # noqa: E501
"""
filename_glob = '*_30m_cdls.tif'
filename_regex = r"""
@ -49,8 +49,8 @@ class CDL(RasterDataset):
date_format = '%Y'
is_image = False
url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip' # noqa: E501
md5s = {
url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip'
md5s: ClassVar[dict[int, str]] = {
2023: '8c7685d6278d50c554f934b16a6076b7',
2022: '754cf50670cdfee511937554785de3e6',
2021: '27606eab08fe975aa138baad3e5dfcd8',
@ -69,7 +69,7 @@ class CDL(RasterDataset):
2008: '0610f2f17ab60a9fbb3baeb7543993a4',
}
cmap = {
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255),
1: (255, 211, 0, 255),
2: (255, 37, 37, 255),

Просмотреть файл

@ -4,7 +4,8 @@
"""ChaBuD dataset."""
import os
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -53,7 +54,7 @@ class ChaBuD(NonGeoDataset):
.. versionadded:: 0.6
"""
all_bands = [
all_bands = (
'B01',
'B02',
'B03',
@ -66,10 +67,10 @@ class ChaBuD(NonGeoDataset):
'B09',
'B11',
'B12',
]
rgb_bands = ['B04', 'B03', 'B02']
folds = {'train': [1, 2, 3, 4], 'val': [0]}
url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5' # noqa: E501
)
rgb_bands = ('B04', 'B03', 'B02')
folds: ClassVar[dict[str, list[int]]] = {'train': [1, 2, 3, 4], 'val': [0]}
url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5'
filename = 'train_eval.hdf5'
md5 = '15d78fb825f9a81dad600db828d22c08'
@ -77,7 +78,7 @@ class ChaBuD(NonGeoDataset):
self,
root: Path = 'data',
split: str = 'train',
bands: list[str] = all_bands,
bands: Sequence[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,

Просмотреть файл

@ -9,7 +9,7 @@ import pathlib
import sys
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any, ClassVar, cast
import fiona
import matplotlib.pyplot as plt
@ -39,7 +39,7 @@ class Chesapeake(RasterDataset, ABC):
The Chesapeake Bay Land Use and Land Cover Database (LULC) facilitates
characterization of the landscape and land change for and between discrete time
periods. The database was developed by the University of Vermonts Spatial Analysis
periods. The database was developed by the University of Vermont's Spatial Analysis
Laboratory in cooperation with Chesapeake Conservancy (CC) and U.S. Geological
Survey (USGS) as part of a 6-year Cooperative Agreement between Chesapeake
Conservancy and the U.S. Environmental Protection Agency (EPA) and a separate
@ -83,7 +83,7 @@ class Chesapeake(RasterDataset, ABC):
"""State abbreviation."""
return self.__class__.__name__[-2:].lower()
cmap = {
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
11: (0, 92, 230, 255),
12: (0, 92, 230, 255),
13: (0, 92, 230, 255),
@ -255,7 +255,7 @@ class Chesapeake(RasterDataset, ABC):
class ChesapeakeDC(Chesapeake):
"""This subset of the dataset contains data only for Washington, D.C."""
md5s = {
md5s: ClassVar[dict[int, str]] = {
2013: '9f1df21afbb9d5c0fcf33af7f6750a7f',
2017: 'c45e4af2950e1c93ecd47b61af296d9b',
}
@ -264,7 +264,7 @@ class ChesapeakeDC(Chesapeake):
class ChesapeakeDE(Chesapeake):
"""This subset of the dataset contains data only for Delaware."""
md5s = {
md5s: ClassVar[dict[int, str]] = {
2013: '5850d96d897babba85610658aeb5951a',
2018: 'ee94c8efeae423d898677104117bdebc',
}
@ -273,7 +273,7 @@ class ChesapeakeDE(Chesapeake):
class ChesapeakeMD(Chesapeake):
"""This subset of the dataset contains data only for Maryland."""
md5s = {
md5s: ClassVar[dict[int, str]] = {
2013: '9c3ca5040668d15284c1bd64b7d6c7a0',
2018: '0647530edf8bec6e60f82760dcc7db9c',
}
@ -282,7 +282,7 @@ class ChesapeakeMD(Chesapeake):
class ChesapeakeNY(Chesapeake):
"""This subset of the dataset contains data only for New York."""
md5s = {
md5s: ClassVar[dict[int, str]] = {
2013: '38a29b721610ba661a7f8b6ec71a48b7',
2017: '4c1b1a50fd9368cd7b8b12c4d80c63f3',
}
@ -291,7 +291,7 @@ class ChesapeakeNY(Chesapeake):
class ChesapeakePA(Chesapeake):
"""This subset of the dataset contains data only for Pennsylvania."""
md5s = {
md5s: ClassVar[dict[int, str]] = {
2013: '86febd603a120a49ef7d23ef486152a3',
2017: 'b11d92e4471e8cb887c790d488a338c1',
}
@ -300,7 +300,7 @@ class ChesapeakePA(Chesapeake):
class ChesapeakeVA(Chesapeake):
"""This subset of the dataset contains data only for Virginia."""
md5s = {
md5s: ClassVar[dict[int, str]] = {
2014: '49c9700c71854eebd00de24d8488eb7c',
2018: '51731c8b5632978bfd1df869ea10db5b',
}
@ -309,7 +309,7 @@ class ChesapeakeVA(Chesapeake):
class ChesapeakeWV(Chesapeake):
"""This subset of the dataset contains data only for West Virginia."""
md5s = {
md5s: ClassVar[dict[int, str]] = {
2014: '32fea42fae147bd58a83e3ea6cccfb94',
2018: '80f25dcba72e39685ab33215c5d97292',
}
@ -337,16 +337,16 @@ class ChesapeakeCVPR(GeoDataset):
* https://doi.org/10.1109/cvpr.2019.01301
"""
subdatasets = ['base', 'prior_extension']
urls = {
'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip', # noqa: E501
'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1', # noqa: E501
subdatasets = ('base', 'prior_extension')
urls: ClassVar[dict[str, str]] = {
'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip',
'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1',
}
filenames = {
filenames: ClassVar[dict[str, str]] = {
'base': 'cvpr_chesapeake_landcover.zip',
'prior_extension': 'cvpr_chesapeake_landcover_prior_extension.zip',
}
md5s = {
md5s: ClassVar[dict[str, str]] = {
'base': '1225ccbb9590e9396875f221e5031514',
'prior_extension': '402f41d07823c8faf7ea6960d7c4e17a',
}
@ -354,7 +354,7 @@ class ChesapeakeCVPR(GeoDataset):
crs = CRS.from_epsg(3857)
res = 1
lc_cmap = {
lc_cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0),
1: (0, 197, 255, 255),
2: (38, 115, 0, 255),
@ -374,7 +374,7 @@ class ChesapeakeCVPR(GeoDataset):
]
)
valid_layers = [
valid_layers = (
'naip-new',
'naip-old',
'landsat-leaf-on',
@ -383,8 +383,8 @@ class ChesapeakeCVPR(GeoDataset):
'lc',
'buildings',
'prior_from_cooccurrences_101_31_no_osm_no_buildings',
]
states = ['de', 'md', 'va', 'wv', 'pa', 'ny']
)
states = ('de', 'md', 'va', 'wv', 'pa', 'ny')
splits = (
[f'{state}-train' for state in states]
+ [f'{state}-val' for state in states]
@ -392,7 +392,7 @@ class ChesapeakeCVPR(GeoDataset):
)
# these are used to check the integrity of the dataset
_files = [
_files = (
'de_1m_2013_extended-debuffered-test_tiles',
'de_1m_2013_extended-debuffered-train_tiles',
'de_1m_2013_extended-debuffered-val_tiles',
@ -412,18 +412,18 @@ class ChesapeakeCVPR(GeoDataset):
'wv_1m_2014_extended-debuffered-train_tiles',
'wv_1m_2014_extended-debuffered-val_tiles',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif', # noqa: E501
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif', # noqa: E501
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif',
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif',
'spatial_index.geojson',
]
)
p_src_crs = pyproj.CRS('epsg:3857')
p_transformers = {
p_transformers: ClassVar[dict[str, CRS]] = {
'epsg:26917': pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
).transform,
@ -511,7 +511,7 @@ class ChesapeakeCVPR(GeoDataset):
'lc': row['properties']['lc'],
'nlcd': row['properties']['nlcd'],
'buildings': row['properties']['buildings'],
'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn, # noqa: E501
'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn,
},
)

Просмотреть файл

@ -5,6 +5,7 @@
import os
from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -55,9 +56,9 @@ class CloudCoverDetection(NonGeoDataset):
"""
url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_cloud_cover_detection_challenge_v1/final'
all_bands = ['B02', 'B03', 'B04', 'B08']
rgb_bands = ['B04', 'B03', 'B02']
splits = {'train': 'public', 'test': 'private'}
all_bands = ('B02', 'B03', 'B04', 'B08')
rgb_bands = ('B04', 'B03', 'B02')
splits: ClassVar[dict[str, str]] = {'train': 'public', 'test': 'private'}
def __init__(
self,

Просмотреть файл

@ -42,7 +42,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip'
md5 = '3e7f9f23bf971c25e828b36e6c5496e3'
all_countries = [
all_countries = (
'AndamanAndNicobar',
'Angola',
'Anguilla',
@ -164,9 +164,9 @@ class CMSGlobalMangroveCanopy(RasterDataset):
'VirginIslandsUs',
'WallisAndFutuna',
'Yemen',
]
)
measurements = ['agb', 'hba95', 'hmax95']
measurements = ('agb', 'hba95', 'hmax95')
def __init__(
self,

Просмотреть файл

@ -50,12 +50,12 @@ class COWC(NonGeoDataset, abc.ABC):
@property
@abc.abstractmethod
def filenames(self) -> list[str]:
def filenames(self) -> tuple[str, ...]:
"""List of files to download."""
@property
@abc.abstractmethod
def md5s(self) -> list[str]:
def md5s(self) -> tuple[str, ...]:
"""List of MD5 checksums of files to download."""
@property
@ -239,7 +239,7 @@ class COWCCounting(COWC):
base_url = (
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/'
)
filenames = [
filenames = (
'COWC_train_list_64_class.txt.bz2',
'COWC_test_list_64_class.txt.bz2',
'COWC_Counting_Toronto_ISPRS.tbz',
@ -248,8 +248,8 @@ class COWCCounting(COWC):
'COWC_Counting_Vaihingen_ISPRS.tbz',
'COWC_Counting_Columbus_CSUAV_AFRL.tbz',
'COWC_Counting_Utah_AGRC.tbz',
]
md5s = [
)
md5s = (
'187543d20fa6d591b8da51136e8ef8fb',
'930cfd6e160a7b36db03146282178807',
'bc2613196dfa93e66d324ae43e7c1fdb',
@ -258,7 +258,7 @@ class COWCCounting(COWC):
'4009c1e420566390746f5b4db02afdb9',
'daf8033c4e8ceebbf2c3cac3fabb8b10',
'777ec107ed2a3d54597a739ce74f95ad',
]
)
filename = 'COWC_{}_list_64_class.txt'
@ -268,7 +268,7 @@ class COWCDetection(COWC):
base_url = (
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/'
)
filenames = [
filenames = (
'COWC_train_list_detection.txt.bz2',
'COWC_test_list_detection.txt.bz2',
'COWC_Detection_Toronto_ISPRS.tbz',
@ -277,8 +277,8 @@ class COWCDetection(COWC):
'COWC_Detection_Vaihingen_ISPRS.tbz',
'COWC_Detection_Columbus_CSUAV_AFRL.tbz',
'COWC_Detection_Utah_AGRC.tbz',
]
md5s = [
)
md5s = (
'c954a5a3dac08c220b10cfbeec83893c',
'c6c2d0a78f12a2ad88b286b724a57c1a',
'11af24f43b198b0f13c8e94814008a48',
@ -287,7 +287,7 @@ class COWCDetection(COWC):
'23945d5b22455450a938382ccc2a8b27',
'f40522dc97bea41b10117d4a5b946a6f',
'195da7c9443a939a468c9f232fd86ee3',
]
)
filename = 'COWC_{}_list_detection.txt'

Просмотреть файл

@ -7,6 +7,7 @@ import glob
import json
import os
from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -55,7 +56,7 @@ class CropHarvest(NonGeoDataset):
"""
# https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py
all_bands = [
all_bands = (
'VV',
'VH',
'B2',
@ -74,12 +75,12 @@ class CropHarvest(NonGeoDataset):
'elevation',
'slope',
'NDVI',
]
rgb_bands = ['B4', 'B3', 'B2']
)
rgb_bands = ('B4', 'B3', 'B2')
features_url = 'https://zenodo.org/records/7257688/files/features.tar.gz?download=1'
labels_url = 'https://zenodo.org/records/7257688/files/labels.geojson?download=1'
file_dict = {
file_dict: ClassVar[dict[str, dict[str, str]]] = {
'features': {
'url': features_url,
'filename': 'features.tar.gz',

Просмотреть файл

@ -65,8 +65,8 @@ class CV4AKenyaCropType(NonGeoDataset):
"""
url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge'
tiles = list(map(str, range(4)))
dates = [
tiles = tuple(map(str, range(4)))
dates = (
'20190606',
'20190701',
'20190706',
@ -80,7 +80,7 @@ class CV4AKenyaCropType(NonGeoDataset):
'20190924',
'20191004',
'20191103',
]
)
all_bands = (
'B01',
'B02',
@ -96,7 +96,7 @@ class CV4AKenyaCropType(NonGeoDataset):
'B12',
'CLD',
)
rgb_bands = ['B04', 'B03', 'B02']
rgb_bands = ('B04', 'B03', 'B02')
# Same for all tiles
tile_height = 3035
@ -141,11 +141,13 @@ class CV4AKenyaCropType(NonGeoDataset):
# Calculate the indices that we will use over all tiles
self.chips_metadata = []
for tile_index in range(len(self.tiles)):
for y in list(range(0, self.tile_height - self.chip_size, stride)) + [
self.tile_height - self.chip_size
for y in [
*list(range(0, self.tile_height - self.chip_size, stride)),
self.tile_height - self.chip_size,
]:
for x in list(range(0, self.tile_width - self.chip_size, stride)) + [
self.tile_width - self.chip_size
for x in [
*list(range(0, self.tile_width - self.chip_size, stride)),
self.tile_width - self.chip_size,
]:
self.chips_metadata.append((tile_index, y, x))

Просмотреть файл

@ -74,13 +74,13 @@ class DeepGlobeLandCover(NonGeoDataset):
$ unzip deepglobe2018-landcover-segmentation-traindataset.zip
.. versionadded:: 0.3
""" # noqa: E501
"""
filename = 'data.zip'
data_root = 'data'
md5 = 'f32684b0b2bf6f8d604cd359a399c061'
splits = ['train', 'test']
classes = [
splits = ('train', 'test')
classes = (
'Urban land',
'Agriculture land',
'Rangeland',
@ -88,8 +88,8 @@ class DeepGlobeLandCover(NonGeoDataset):
'Water',
'Barren land',
'Unknown',
]
colormap = [
)
colormap = (
(0, 255, 255),
(255, 255, 0),
(255, 0, 255),
@ -97,7 +97,7 @@ class DeepGlobeLandCover(NonGeoDataset):
(0, 0, 255),
(255, 255, 255),
(0, 0, 0),
]
)
def __init__(
self,
@ -246,12 +246,15 @@ class DeepGlobeLandCover(NonGeoDataset):
"""
ncols = 1
image1 = draw_semantic_segmentation_masks(
sample['image'], sample['mask'], alpha=alpha, colors=self.colormap
sample['image'], sample['mask'], alpha=alpha, colors=list(self.colormap)
)
if 'prediction' in sample:
ncols += 1
image2 = draw_semantic_segmentation_masks(
sample['image'], sample['prediction'], alpha=alpha, colors=self.colormap
sample['image'],
sample['prediction'],
alpha=alpha,
colors=list(self.colormap),
)
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))

Просмотреть файл

@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -75,9 +76,9 @@ class DFC2022(NonGeoDataset):
* https://doi.org/10.1007/s10994-020-05943-y
.. versionadded:: 0.3
""" # noqa: E501
"""
classes = [
classes = (
'No information',
'Urban fabric',
'Industrial, commercial, public, military, private and transport units',
@ -94,8 +95,8 @@ class DFC2022(NonGeoDataset):
'Wetlands',
'Water',
'Clouds and Shadows',
]
colormap = [
)
colormap = (
'#231F20',
'#DB5F57',
'#DB9757',
@ -112,8 +113,8 @@ class DFC2022(NonGeoDataset):
'#579BDB',
'#0062FF',
'#231F20',
]
metadata = {
)
metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'filename': 'labeled_train.zip',
'md5': '2e87d6a218e466dd0566797d7298c7a9',

Просмотреть файл

@ -6,7 +6,7 @@
import os
import sys
from collections.abc import Callable, Sequence
from typing import Any, cast
from typing import Any, ClassVar, cast
import fiona
import matplotlib.pyplot as plt
@ -54,9 +54,9 @@ class EnviroAtlas(GeoDataset):
crs = CRS.from_epsg(3857)
res = 1
valid_prior_layers = ['prior', 'prior_no_osm_no_buildings']
valid_prior_layers = ('prior', 'prior_no_osm_no_buildings')
valid_layers = [
valid_layers = (
'naip',
'nlcd',
'roads',
@ -65,14 +65,15 @@ class EnviroAtlas(GeoDataset):
'waterbodies',
'buildings',
'lc',
] + valid_prior_layers
*valid_prior_layers,
)
cities = [
cities = (
'pittsburgh_pa-2010_1m',
'durham_nc-2012_1m',
'austin_tx-2012_1m',
'phoenix_az-2010_1m',
]
)
splits = (
[f'{state}-train' for state in cities[:1]]
+ [f'{state}-val' for state in cities[:1]]
@ -81,7 +82,7 @@ class EnviroAtlas(GeoDataset):
)
# these are used to check the integrity of the dataset
_files = [
_files = (
'austin_tx-2012_1m-test_tiles-debuffered',
'austin_tx-2012_1m-val5_tiles-debuffered',
'durham_nc-2012_1m-test_tiles-debuffered',
@ -100,13 +101,13 @@ class EnviroAtlas(GeoDataset):
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif', # noqa: E501
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif',
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif',
'spatial_index.geojson',
]
)
p_src_crs = pyproj.CRS('epsg:3857')
p_transformers = {
p_transformers: ClassVar[dict[str, CRS]] = {
'epsg:26917': pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
).transform,
@ -222,7 +223,7 @@ class EnviroAtlas(GeoDataset):
dtype=np.uint8,
)
highres_classes = [
highres_classes = (
'Unclassified',
'Water',
'Impervious Surface',
@ -234,7 +235,7 @@ class EnviroAtlas(GeoDataset):
'Orchards',
'Woody Wetlands',
'Emergent Wetlands',
]
)
highres_cmap = ListedColormap(
[
[1.00000000, 1.00000000, 1.00000000],

Просмотреть файл

@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -56,9 +57,9 @@ class ETCI2021(NonGeoDataset):
the ETCI competition.
"""
bands = ['VV', 'VH']
masks = ['flood', 'water_body']
metadata = {
bands = ('VV', 'VH')
masks = ('flood', 'water_body')
metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'filename': 'train.zip',
'md5': '1e95792fe0f6e3c9000abdeab2a8ab0f',

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any
from typing import Any, ClassVar
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -53,7 +53,7 @@ class EUDEM(RasterDataset):
zipfile_glob = 'eu_dem_v11_*[A-Z0-9].zip'
filename_regex = '(?P<name>[eudem_v11]{10})_(?P<id>[A-Z0-9]{6})'
md5s = {
md5s: ClassVar[dict[str, str]] = {
'eu_dem_v11_E00N20.zip': '96edc7e11bc299b994e848050d6be591',
'eu_dem_v11_E10N00.zip': 'e14be147ac83eddf655f4833d55c1571',
'eu_dem_v11_E10N10.zip': '2eb5187e4d827245b33768404529c709',

Просмотреть файл

@ -61,7 +61,7 @@ class EuroCrops(VectorDataset):
date_format = '%Y'
# Filename and md5 of files in this dataset on zenodo.
zenodo_files = [
zenodo_files: tuple[tuple[str, str], ...] = (
('AT_2021.zip', '490241df2e3d62812e572049fc0c36c5'),
('BE_VLG_2021.zip', 'ac4b9e12ad39b1cba47fdff1a786c2d7'),
('DE_LS_2021.zip', '6d94e663a3ff7988b32cb36ea24a724f'),
@ -81,7 +81,7 @@ class EuroCrops(VectorDataset):
# Year is unknown for Romania portion (ny = no year).
# We skip since it is inconsistent with the rest of the data.
# ("RO_ny.zip", "648e1504097765b4b7f825decc838882"),
]
)
def __init__(
self,

Просмотреть файл

@ -5,7 +5,7 @@
import os
from collections.abc import Callable, Sequence
from typing import cast
from typing import ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset):
* https://ieeexplore.ieee.org/document/8519248
"""
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip' # noqa: E501
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip'
filename = 'EuroSATallBands.zip'
md5 = '5ac12b3b2557aa56e1826e981e8e200e'
@ -63,13 +63,13 @@ class EuroSAT(NonGeoClassificationDataset):
'ds', 'images', 'remote_sensing', 'otherDatasets', 'sentinel_2', 'tif'
)
splits = ['train', 'val', 'test']
split_urls = {
'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt', # noqa: E501
'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt', # noqa: E501
'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt', # noqa: E501
splits = ('train', 'val', 'test')
split_urls: ClassVar[dict[str, str]] = {
'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt',
'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt',
'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt',
}
split_md5s = {
split_md5s: ClassVar[dict[str, str]] = {
'train': '908f142e73d6acdf3f482c5e80d851b1',
'val': '95de90f2aa998f70a3b2416bfe0687b4',
'test': '7ae5ab94471417b6e315763121e67c5f',
@ -93,7 +93,10 @@ class EuroSAT(NonGeoClassificationDataset):
rgb_bands = ('B04', 'B03', 'B02')
BAND_SETS = {'all': all_band_names, 'rgb': rgb_bands}
BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
'all': all_band_names,
'rgb': rgb_bands,
}
def __init__(
self,
@ -302,12 +305,12 @@ class EuroSATSpatial(EuroSAT):
.. versionadded:: 0.6
"""
split_urls = {
split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt',
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt',
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt',
}
split_md5s = {
split_md5s: ClassVar[dict[str, str]] = {
'train': '7be3254be39f23ce4d4d144290c93292',
'val': 'acf392290050bb3df790dc8fc0ebf193',
'test': '5ec1733f9c16116bf0aa2d921fc613ef',
@ -325,16 +328,16 @@ class EuroSAT100(EuroSAT):
.. versionadded:: 0.5
"""
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip' # noqa: E501
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip'
filename = 'EuroSAT100.zip'
md5 = 'c21c649ba747e86eda813407ef17d596'
split_urls = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt', # noqa: E501
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt', # noqa: E501
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt', # noqa: E501
split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt',
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt',
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt',
}
split_md5s = {
split_md5s: ClassVar[dict[str, str]] = {
'train': '033d0c23e3a75e3fa79618b0e35fe1c7',
'val': '3e3f8b3c344182b8d126c4cc88f3f215',
'test': 'f908f151b950f270ad18e61153579794',

Просмотреть файл

@ -6,7 +6,7 @@
import glob
import os
from collections.abc import Callable
from typing import Any, cast
from typing import Any, ClassVar, cast
from xml.etree.ElementTree import Element, parse
import matplotlib.patches as patches
@ -119,7 +119,7 @@ class FAIR1M(NonGeoDataset):
.. versionadded:: 0.2
"""
classes = {
classes: ClassVar[dict[str, dict[str, Any]]] = {
'Passenger Ship': {'id': 0, 'category': 'Ship'},
'Motorboat': {'id': 1, 'category': 'Ship'},
'Fishing Boat': {'id': 2, 'category': 'Ship'},
@ -159,12 +159,12 @@ class FAIR1M(NonGeoDataset):
'Bridge': {'id': 36, 'category': 'Road'},
}
filename_glob = {
filename_glob: ClassVar[dict[str, str]] = {
'train': os.path.join('train', '**', 'images', '*.tif'),
'val': os.path.join('validation', 'images', '*.tif'),
'test': os.path.join('test', 'images', '*.tif'),
}
directories = {
directories: ClassVar[dict[str, tuple[str, ...]]] = {
'train': (
os.path.join('train', 'part1', 'images'),
os.path.join('train', 'part1', 'labelXml'),
@ -175,9 +175,9 @@ class FAIR1M(NonGeoDataset):
os.path.join('validation', 'images'),
os.path.join('validation', 'labelXml'),
),
'test': (os.path.join('test', 'images')),
'test': (os.path.join('test', 'images'),),
}
paths = {
paths: ClassVar[dict[str, tuple[str, ...]]] = {
'train': (
os.path.join('train', 'part1', 'images.zip'),
os.path.join('train', 'part1', 'labelXml.zip'),
@ -194,7 +194,7 @@ class FAIR1M(NonGeoDataset):
os.path.join('test', 'images2.zip'),
),
}
urls = {
urls: ClassVar[dict[str, tuple[str, ...]]] = {
'train': (
'https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf',
'https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u',
@ -211,7 +211,7 @@ class FAIR1M(NonGeoDataset):
'https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0',
),
}
md5s = {
md5s: ClassVar[dict[str, tuple[str, ...]]] = {
'train': (
'a460fe6b1b5b276bf856ce9ac72d6568',
'80f833ff355f91445c92a0c0c1fa7414',

Просмотреть файл

@ -55,8 +55,8 @@ class FireRisk(NonGeoClassificationDataset):
md5 = 'a77b9a100d51167992ae8c51d26198a6'
filename = 'FireRisk.zip'
directory = 'FireRisk'
splits = ['train', 'val']
classes = [
splits = ('train', 'val')
classes = (
'High',
'Low',
'Moderate',
@ -64,7 +64,7 @@ class FireRisk(NonGeoClassificationDataset):
'Very_High',
'Very_Low',
'Water',
]
)
def __init__(
self,

Просмотреть файл

@ -96,11 +96,8 @@ class ForestDamage(NonGeoDataset):
.. versionadded:: 0.3
"""
classes = ['other', 'H', 'LD', 'HD']
url = (
'https://lilablobssc.blob.core.windows.net/larch-casebearer/'
'Data_Set_Larch_Casebearer.zip'
)
classes = ('other', 'H', 'LD', 'HD')
url = 'https://lilablobssc.blob.core.windows.net/larch-casebearer/Data_Set_Larch_Casebearer.zip'
data_dir = 'Data_Set_Larch_Casebearer'
md5 = '907815bcc739bff89496fac8f8ce63d7'

Просмотреть файл

@ -13,7 +13,7 @@ import re
import sys
import warnings
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any, ClassVar, cast
import fiona
import fiona.transform
@ -370,13 +370,13 @@ class RasterDataset(GeoDataset):
separate_files = False
#: Names of all available bands in the dataset
all_bands: list[str] = []
all_bands: tuple[str, ...] = ()
#: Names of RGB bands in the dataset, used for plotting
rgb_bands: list[str] = []
rgb_bands: tuple[str, ...] = ()
#: Color map for the dataset, used for plotting
cmap: dict[int, tuple[int, int, int, int]] = {}
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {}
@property
def dtype(self) -> torch.dtype:
@ -458,7 +458,7 @@ class RasterDataset(GeoDataset):
# See if file has a color map
if len(self.cmap) == 0:
try:
self.cmap = src.colormap(1)
self.cmap = src.colormap(1) # type: ignore[misc]
except ValueError:
pass

Просмотреть файл

@ -66,8 +66,8 @@ class GID15(NonGeoDataset):
md5 = '615682bf659c3ed981826c6122c10c83'
filename = 'gid-15.zip'
directory = 'GID'
splits = ['train', 'val', 'test']
classes = [
splits = ('train', 'val', 'test')
classes = (
'background',
'industrial_land',
'urban_residential',
@ -84,7 +84,7 @@ class GID15(NonGeoDataset):
'river',
'lake',
'pond',
]
)
def __init__(
self,

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, cast
from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import torch
@ -73,9 +73,9 @@ class GlobBiomass(RasterDataset):
is_image = False
dtype = torch.float32 # pixelwise regression
measurements = ['agb', 'gsv']
measurements = ('agb', 'gsv')
md5s = {
md5s: ClassVar[dict[str, str]] = {
'N00E020_agb.zip': 'bd83a3a4c143885d1962bde549413be6',
'N00E020_gsv.zip': 'da5ddb88e369df2d781a0c6be008ae79',
'N00E060_agb.zip': '85eaca95b939086cc528e396b75bd097',

Просмотреть файл

@ -6,7 +6,7 @@
import glob
import os
from collections.abc import Callable
from typing import Any, cast, overload
from typing import Any, ClassVar, cast, overload
import fiona
import matplotlib.pyplot as plt
@ -100,7 +100,7 @@ class IDTReeS(NonGeoDataset):
.. versionadded:: 0.2
"""
classes = {
classes: ClassVar[dict[str, str]] = {
'ACPE': 'Acer pensylvanicum L.',
'ACRU': 'Acer rubrum L.',
'ACSA3': 'Acer saccharum Marshall',
@ -135,19 +135,22 @@ class IDTReeS(NonGeoDataset):
'ROPS': 'Robinia pseudoacacia L.',
'TSCA': 'Tsuga canadensis (L.) Carriere',
}
metadata = {
metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1', # noqa: E501
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1',
'md5': '5ddfa76240b4bb6b4a7861d1d31c299c',
'filename': 'IDTREES_competition_train_v2.zip',
},
'test': {
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1', # noqa: E501
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1',
'md5': 'b108931c84a70f2a38a8234290131c9b',
'filename': 'IDTREES_competition_test_v2.zip',
},
}
directories = {'train': ['train'], 'test': ['task1', 'task2']}
directories: ClassVar[dict[str, list[str]]] = {
'train': ['train'],
'test': ['task1', 'task2'],
}
image_size = (200, 200)
def __init__(

Просмотреть файл

@ -6,7 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, ClassVar
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -40,9 +40,9 @@ class IOBench(IntersectionDataset):
.. versionadded:: 0.6
"""
url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz' # noqa: E501
url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz'
md5s = {
md5s: ClassVar[dict[str, str]] = {
'original': 'e3a908a0fd1c05c1af2f4c65724d59b3',
'raw': 'e9603990441007ce7bba73bb8ba7d217',
'preprocessed': '9801f1240b238cb17525c865e413d1fd',
@ -54,7 +54,7 @@ class IOBench(IntersectionDataset):
split: str = 'preprocessed',
crs: CRS | None = None,
res: float | None = None,
bands: Sequence[str] | None = Landsat9.default_bands + ['SR_QA_AEROSOL'],
bands: Sequence[str] | None = [*Landsat9.default_bands, 'SR_QA_AEROSOL'],
classes: list[int] = [0],
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,

Просмотреть файл

@ -8,7 +8,7 @@ import os
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import torch
@ -43,8 +43,8 @@ class L7IrishImage(RasterDataset):
"""
date_format = '%Y%m%d'
is_image = True
rgb_bands = ['B30', 'B20', 'B10']
all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80']
rgb_bands = ('B30', 'B20', 'B10')
all_bands = ('B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80')
class L7IrishMask(RasterDataset):
@ -59,7 +59,7 @@ class L7IrishMask(RasterDataset):
_newmask2015\.TIF$
"""
is_image = False
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud')
ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1
ordinal_map[128] = 2
@ -158,11 +158,11 @@ class L7Irish(IntersectionDataset):
* https://www.sciencebase.gov/catalog/item/573ccf18e4b0dae0d5e4b109
.. versionadded:: 0.5
""" # noqa: E501
"""
url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz' # noqa: E501
url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz'
md5s = {
md5s: ClassVar[dict[str, str]] = {
'austral': '0a34770b992a62abeb88819feb192436',
'boreal': 'b7cfdd689a3c2fd2a8d572e1c10ed082',
'mid_latitude_north': 'c40abe5ad2487f8ab021cfb954982faa',

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os
import pathlib
from collections.abc import Callable, Iterable, Sequence
from typing import Any
from typing import Any, ClassVar
import matplotlib.pyplot as plt
import torch
@ -36,8 +36,8 @@ class L8BiomeImage(RasterDataset):
"""
date_format = '%Y%j'
is_image = True
rgb_bands = ['B4', 'B3', 'B2']
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']
rgb_bands = ('B4', 'B3', 'B2')
all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11')
class L8BiomeMask(RasterDataset):
@ -57,7 +57,7 @@ class L8BiomeMask(RasterDataset):
"""
date_format = '%Y%j'
is_image = False
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud')
ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1
ordinal_map[128] = 2
@ -116,11 +116,11 @@ class L8Biome(IntersectionDataset):
* https://doi.org/10.1016/j.rse.2017.03.026
.. versionadded:: 0.5
""" # noqa: E501
"""
url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz' # noqa: E501
url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz'
md5s = {
md5s: ClassVar[dict[str, str]] = {
'barren': '0eb691822d03dabd4f5ea8aadd0b41c3',
'forest': '4a5645596f6bb8cea44677f746ec676e',
'grass_crops': 'a69ed5d6cb227c5783f026b9303cdd3c',

Просмотреть файл

@ -9,7 +9,7 @@ import hashlib
import os
from collections.abc import Callable
from functools import lru_cache
from typing import Any, cast
from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@ -64,8 +64,8 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
url = 'https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip'
filename = 'landcover.ai.v1.zip'
md5 = '3268c89070e8734b4e91d531c0617e03'
classes = ['Background', 'Building', 'Woodland', 'Water', 'Road']
cmap = {
classes = ('Background', 'Building', 'Woodland', 'Water', 'Road')
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0),
1: (97, 74, 74, 255),
2: (38, 115, 0, 255),

Просмотреть файл

@ -33,7 +33,7 @@ class Landsat(RasterDataset, abc.ABC):
* `Surface Temperature <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-temperature>`_
* `Surface Reflectance <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-reflectance>`_
* `U.S. Analysis Ready Data <https://www.usgs.gov/landsat-missions/landsat-collection-2-us-analysis-ready-data>`_
""" # noqa: E501
"""
# https://www.usgs.gov/landsat-missions/landsat-collection-2
filename_regex = r"""
@ -55,7 +55,7 @@ class Landsat(RasterDataset, abc.ABC):
@property
@abc.abstractmethod
def default_bands(self) -> list[str]:
def default_bands(self) -> tuple[str, ...]:
"""Bands to load by default."""
def __init__(
@ -145,8 +145,8 @@ class Landsat1(Landsat):
filename_glob = 'LM01_*_{}.*'
default_bands = ['B4', 'B5', 'B6', 'B7']
rgb_bands = ['B6', 'B5', 'B4']
default_bands = ('B4', 'B5', 'B6', 'B7')
rgb_bands = ('B6', 'B5', 'B4')
class Landsat2(Landsat1):
@ -166,8 +166,8 @@ class Landsat4MSS(Landsat):
filename_glob = 'LM04_*_{}.*'
default_bands = ['B1', 'B2', 'B3', 'B4']
rgb_bands = ['B3', 'B2', 'B1']
default_bands = ('B1', 'B2', 'B3', 'B4')
rgb_bands = ('B3', 'B2', 'B1')
class Landsat4TM(Landsat):
@ -175,8 +175,8 @@ class Landsat4TM(Landsat):
filename_glob = 'LT04_*_{}.*'
default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1']
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')
class Landsat5MSS(Landsat4MSS):
@ -196,8 +196,8 @@ class Landsat7(Landsat):
filename_glob = 'LE07_*_{}.*'
default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1']
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')
class Landsat8(Landsat):
@ -205,11 +205,11 @@ class Landsat8(Landsat):
filename_glob = 'LC08_*_{}.*'
default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
rgb_bands = ['SR_B4', 'SR_B3', 'SR_B2']
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
rgb_bands = ('SR_B4', 'SR_B3', 'SR_B2')
class Landsat9(Landsat8):
"""Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2).""" # noqa: E501
"""Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2)."""
filename_glob = 'LC09_*_{}.*'

Просмотреть файл

@ -7,6 +7,7 @@ import abc
import glob
import os
from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -26,8 +27,8 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC):
.. versionadded:: 0.6
"""
splits: list[str] | dict[str, dict[str, str]]
directories = ['A', 'B', 'label']
splits: ClassVar[tuple[str, ...] | dict[str, dict[str, str]]]
directories = ('A', 'B', 'label')
def __init__(
self,
@ -237,7 +238,7 @@ class LEVIRCD(LEVIRCDBase):
.. versionadded:: 0.6
"""
splits = {
splits: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'url': 'https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-',
'filename': 'train.zip',
@ -336,7 +337,7 @@ class LEVIRCDPlus(LEVIRCDBase):
md5 = '1adf156f628aa32fb2e8fe6cada16c04'
filename = 'LEVIR-CD+.zip'
directory = 'LEVIR-CD+'
splits = ['train', 'test']
splits = ('train', 'test')
def _load_files(self, root: Path, split: str) -> list[dict[str, str]]:
"""Return the paths of the files in the dataset.

Просмотреть файл

@ -5,7 +5,8 @@
import glob
import os
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -57,10 +58,10 @@ class LoveDA(NonGeoDataset):
.. versionadded:: 0.2
"""
scenes = ['urban', 'rural']
splits = ['train', 'val', 'test']
scenes = ('urban', 'rural')
splits = ('train', 'val', 'test')
info_dict = {
info_dict: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'url': 'https://zenodo.org/record/5706578/files/Train.zip?download=1',
'filename': 'Train.zip',
@ -78,7 +79,7 @@ class LoveDA(NonGeoDataset):
},
}
classes = [
classes = (
'background',
'building',
'road',
@ -87,13 +88,13 @@ class LoveDA(NonGeoDataset):
'forest',
'agriculture',
'no-data',
]
)
def __init__(
self,
root: Path = 'data',
split: str = 'train',
scene: list[str] = ['urban', 'rural'],
scene: Sequence[str] = ['urban', 'rural'],
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,

Просмотреть файл

@ -7,6 +7,7 @@ import os
import shutil
from collections import defaultdict
from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -36,7 +37,7 @@ class MapInWild(NonGeoDataset):
different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season
Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging
Radiometer Suite NightTime Day/Night band. The dataset consists of 8144
images with the shape of 1920 × 1920 pixels. The images are weakly annotated
images with the shape of 1920 x 1920 pixels. The images are weakly annotated
from the World Database of Protected Areas (WDPA).
Dataset features:
@ -54,9 +55,9 @@ class MapInWild(NonGeoDataset):
.. versionadded:: 0.5
"""
url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/' # noqa: E501
url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/'
modality_urls = {
modality_urls: ClassVar[dict[str, set[str]]] = {
'esa_wc': {'esa_wc/ESA_WC.zip'},
'viirs': {'viirs/VIIRS.zip'},
'mask': {'mask/mask.zip'},
@ -72,7 +73,7 @@ class MapInWild(NonGeoDataset):
'split_IDs': {'split_IDs/split_IDs.csv'},
}
md5s = {
md5s: ClassVar[dict[str, str]] = {
'ESA_WC.zip': '72b2ee578fe10f0df85bdb7f19311c92',
'VIIRS.zip': '4eff014bae127fe536f8a5f17d89ecb4',
'mask.zip': '87c83a23a73998ad60d448d240b66225',
@ -91,9 +92,12 @@ class MapInWild(NonGeoDataset):
'split_IDs.csv': 'cb5c6c073702acee23544e1e6fe5856f',
}
mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)}
mask_cmap: ClassVar[dict[int, tuple[int, int, int]]] = {
1: (0, 153, 0),
0: (255, 255, 255),
}
wc_cmap = {
wc_cmap: ClassVar[dict[int, tuple[int, int, int]]] = {
10: (0, 160, 0),
20: (150, 100, 0),
30: (255, 180, 0),

Просмотреть файл

@ -6,7 +6,7 @@
import glob
import os
from collections.abc import Callable
from typing import Any, cast
from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@ -48,7 +48,7 @@ class MillionAID(NonGeoDataset):
.. versionadded:: 0.3
"""
multi_label_categories = [
multi_label_categories = (
'agriculture_land',
'airport_area',
'apartment',
@ -122,9 +122,9 @@ class MillionAID(NonGeoDataset):
'wind_turbine',
'woodland',
'works',
]
)
multi_class_categories = [
multi_class_categories = (
'apartment',
'apron',
'bare_land',
@ -176,17 +176,17 @@ class MillionAID(NonGeoDataset):
'wastewater_plant',
'wind_turbine',
'works',
]
)
md5s = {
md5s: ClassVar[dict[str, str]] = {
'train': '1b40503cafa9b0601653ca36cd788852',
'test': '51a63ee3eeb1351889eacff349a983d8',
}
filenames = {'train': 'train.zip', 'test': 'test.zip'}
filenames: ClassVar[dict[str, str]] = {'train': 'train.zip', 'test': 'test.zip'}
tasks = ['multi-class', 'multi-label']
splits = ['train', 'test']
tasks = ('multi-class', 'multi-label')
splits = ('train', 'test')
def __init__(
self,

Просмотреть файл

@ -45,8 +45,8 @@ class NAIP(RasterDataset):
"""
# Plotting
all_bands = ['R', 'G', 'B', 'NIR']
rgb_bands = ['R', 'G', 'B']
all_bands = ('R', 'G', 'B', 'NIR')
rgb_bands = ('R', 'G', 'B')
def plot(
self,

Просмотреть файл

@ -4,7 +4,7 @@
"""Northeastern China Crop Map Dataset."""
from collections.abc import Callable, Iterable
from typing import Any
from typing import Any, ClassVar
import matplotlib.pyplot as plt
import torch
@ -57,23 +57,23 @@ class NCCM(RasterDataset):
date_format = '%Y'
is_image = False
urls = {
urls: ClassVar[dict[int, str]] = {
2019: 'https://figshare.com/ndownloader/files/25070540',
2018: 'https://figshare.com/ndownloader/files/25070624',
2017: 'https://figshare.com/ndownloader/files/25070582',
}
md5s = {
md5s: ClassVar[dict[int, str]] = {
2019: '0d062bbd42e483fdc8239d22dba7020f',
2018: 'b3bb4894478d10786aa798fb11693ec1',
2017: 'd047fbe4a85341fa6248fd7e0badab6c',
}
fnames = {
fnames: ClassVar[dict[int, str]] = {
2019: 'CDL2019_clip.tif',
2018: 'CDL2018_clip1.tif',
2017: 'CDL2017_clip.tif',
}
cmap = {
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 255, 0, 255),
1: (255, 0, 0, 255),
2: (255, 255, 0, 255),

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any
from typing import Any, ClassVar
import matplotlib.pyplot as plt
import torch
@ -67,7 +67,7 @@ class NLCD(RasterDataset):
* 2019: https://doi.org/10.5066/P9KZCM54
.. versionadded:: 0.5
""" # noqa: E501
"""
filename_glob = 'nlcd_*_land_cover_l48_*.img'
filename_regex = (
@ -79,7 +79,7 @@ class NLCD(RasterDataset):
url = 'https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip'
md5s = {
md5s: ClassVar[dict[int, str]] = {
2001: '538166a4d783204764e3df3b221fc4cd',
2006: '67454e7874a00294adb9442374d0c309',
2011: 'ea524c835d173658eeb6fa3c8e6b917b',
@ -87,7 +87,7 @@ class NLCD(RasterDataset):
2019: '82851c3f8105763b01c83b4a9e6f3961',
}
cmap = {
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0),
11: (70, 107, 159, 255),
12: (209, 222, 248, 255),

Просмотреть файл

@ -9,7 +9,7 @@ import os
import pathlib
import sys
from collections.abc import Callable, Iterable
from typing import Any, cast
from typing import Any, ClassVar, cast
import fiona
import fiona.transform
@ -61,7 +61,7 @@ class OpenBuildings(VectorDataset):
.. versionadded:: 0.3
"""
md5s = {
md5s: ClassVar[dict[str, str]] = {
'025_buildings.csv.gz': '41db2572bfd08628d01475a2ee1a2f17',
'04f_buildings.csv.gz': '3232c1c6d45c1543260b77e5689fc8b1',
'05b_buildings.csv.gz': '4fc57c63bbbf9a21a3902da7adc3a670',

Просмотреть файл

@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -50,7 +51,7 @@ class OSCD(NonGeoDataset):
.. versionadded:: 0.2
"""
urls = {
urls: ClassVar[dict[str, str]] = {
'Onera Satellite Change Detection dataset - Images.zip': (
'https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download'
),
@ -61,7 +62,7 @@ class OSCD(NonGeoDataset):
'https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download'
),
}
md5s = {
md5s: ClassVar[dict[str, str]] = {
'Onera Satellite Change Detection dataset - Images.zip': (
'c50d4a2941da64e03a47ac4dec63d915'
),
@ -75,9 +76,9 @@ class OSCD(NonGeoDataset):
zipfile_glob = '*Onera*.zip'
filename_glob = '*Onera*'
splits = ['train', 'test']
splits = ('train', 'test')
colormap = ['blue']
colormap = ('blue',)
all_bands = (
'B01',
@ -319,7 +320,7 @@ class OSCD(NonGeoDataset):
torch.from_numpy(rgb_img),
sample['mask'],
alpha=alpha,
colors=self.colormap,
colors=list(self.colormap),
)
return array

Просмотреть файл

@ -5,6 +5,7 @@
import os
from collections.abc import Callable, Sequence
from typing import ClassVar
import fiona
import matplotlib.pyplot as plt
@ -70,7 +71,7 @@ class PASTIS(NonGeoDataset):
.. versionadded:: 0.5
"""
classes = [
classes = (
'background', # all non-agricultural land
'meadow',
'soft_winter_wheat',
@ -91,8 +92,8 @@ class PASTIS(NonGeoDataset):
'mixed_cereal',
'sorghum',
'void_label', # for parcels mostly outside their patch
]
cmap = {
)
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255),
1: (174, 199, 232, 255),
2: (255, 127, 14, 255),
@ -118,7 +119,7 @@ class PASTIS(NonGeoDataset):
filename = 'PASTIS-R.zip'
url = 'https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1'
md5 = '4887513d6c2d2b07fa935d325bd53e09'
prefix = {
prefix: ClassVar[dict[str, str]] = {
's2': os.path.join('DATA_S2', 'S2_'),
's1a': os.path.join('DATA_S1A', 'S1A_'),
's1d': os.path.join('DATA_S1D', 'S1D_'),
@ -232,7 +233,7 @@ class PASTIS(NonGeoDataset):
Returns:
the target mask
"""
# See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501
# See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201
# even though the mask file is 3 bands, we just select the first band
array = np.load(self.files[index]['semantic'])[0].astype(np.uint8)
tensor = torch.from_numpy(array).long()

Просмотреть файл

@ -5,6 +5,7 @@
import os
from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -54,12 +55,12 @@ class Potsdam2D(NonGeoDataset):
* https://doi.org/10.5194/isprsannals-I-3-293-2012
.. versionadded:: 0.2
""" # noqa: E501
"""
filenames = ['4_Ortho_RGBIR.zip', '5_Labels_all.zip']
md5s = ['c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db']
filenames = ('4_Ortho_RGBIR.zip', '5_Labels_all.zip')
md5s = ('c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db')
image_root = '4_Ortho_RGBIR'
splits = {
splits: ClassVar[dict[str, list[str]]] = {
'train': [
'top_potsdam_2_10',
'top_potsdam_2_11',
@ -103,22 +104,22 @@ class Potsdam2D(NonGeoDataset):
'top_potsdam_7_13',
],
}
classes = [
classes = (
'Clutter/background',
'Impervious surfaces',
'Building',
'Low Vegetation',
'Tree',
'Car',
]
colormap = [
)
colormap = (
(255, 0, 0),
(255, 255, 255),
(0, 0, 255),
(0, 255, 255),
(0, 255, 0),
(255, 255, 0),
]
)
def __init__(
self,
@ -257,7 +258,7 @@ class Potsdam2D(NonGeoDataset):
"""
ncols = 1
image1 = draw_semantic_segmentation_masks(
sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap
sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap)
)
if 'prediction' in sample:
ncols += 1
@ -265,7 +266,7 @@ class Potsdam2D(NonGeoDataset):
sample['image'][:3],
sample['prediction'],
alpha=alpha,
colors=self.colormap,
colors=list(self.colormap),
)
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))

Просмотреть файл

@ -5,7 +5,7 @@
import os
from collections.abc import Callable
from typing import Any, cast
from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@ -61,8 +61,12 @@ class QuakeSet(NonGeoDataset):
filename = 'earthquakes.h5'
url = 'https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5'
md5 = '76fc7c76b7ca56f4844d852e175e1560'
splits = {'train': 'train', 'val': 'validation', 'test': 'test'}
classes = ['unaffected_area', 'earthquake_affected_area']
splits: ClassVar[dict[str, str]] = {
'train': 'train',
'val': 'validation',
'test': 'test',
}
classes = ('unaffected_area', 'earthquake_affected_area')
def __init__(
self,

Просмотреть файл

@ -56,7 +56,7 @@ class ReforesTree(NonGeoDataset):
.. versionadded:: 0.3
"""
classes = ['other', 'banana', 'cacao', 'citrus', 'fruit', 'timber']
classes = ('other', 'banana', 'cacao', 'citrus', 'fruit', 'timber')
url = 'https://zenodo.org/record/6813783/files/reforesTree.zip?download=1'
md5 = 'f6a4a1d8207aeaa5fbab7b21b683a302'

Просмотреть файл

@ -5,7 +5,7 @@
import os
from collections.abc import Callable
from typing import cast
from typing import ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@ -98,13 +98,13 @@ class RESISC45(NonGeoClassificationDataset):
filename = 'NWPU-RESISC45.zip'
directory = 'NWPU-RESISC45'
splits = ['train', 'val', 'test']
split_urls = {
splits = ('train', 'val', 'test')
split_urls: ClassVar[dict[str, str]] = {
'train': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-train.txt',
'val': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-val.txt',
'test': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-test.txt',
}
split_md5s = {
split_md5s: ClassVar[dict[str, str]] = {
'train': 'b5a4c05a37de15e4ca886696a85c403e',
'val': 'a0770cee4c5ca20b8c32bbd61e114805',
'test': '3dda9e4988b47eb1de9f07993653eb08',

Просмотреть файл

@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -55,11 +56,11 @@ class RwandaFieldBoundary(NonGeoDataset):
url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition'
splits = {'train': 57, 'test': 13}
splits: ClassVar[dict[str, int]] = {'train': 57, 'test': 13}
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
all_bands = ('B01', 'B02', 'B03', 'B04')
rgb_bands = ('B03', 'B02', 'B01')
classes = ['No field-boundary', 'Field-boundary']
classes = ('No field-boundary', 'Field-boundary')
def __init__(
self,

Просмотреть файл

@ -6,6 +6,7 @@
import os
import random
from collections.abc import Callable, Collection, Iterable
from typing import ClassVar
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
@ -85,51 +86,51 @@ class SeasoNet(NonGeoDataset):
.. versionadded:: 0.5
"""
metadata = [
metadata = (
{
'name': 'spring',
'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip', # noqa: E501
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip',
'md5': 'de4cdba7b6196aff624073991b187561',
},
{
'name': 'summer',
'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip', # noqa: E501
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip',
'md5': '6a54d4e134d27ae4eb03f180ee100550',
},
{
'name': 'fall',
'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip', # noqa: E501
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip',
'md5': '5f94920fe41a63c6bfbab7295f7d6b95',
},
{
'name': 'winter',
'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip', # noqa: E501
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip',
'md5': 'dc5e3e09e52ab5c72421b1e3186c9a48',
},
{
'name': 'snow',
'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip', # noqa: E501
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip',
'md5': 'e1b300994143f99ebb03f51d6ab1cbe6',
},
{
'name': 'splits',
'ext': '.zip',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip', # noqa: E501
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip',
'md5': 'e4ec4a18bc4efc828f0944a7cf4d5fed',
},
{
'name': 'meta.csv',
'ext': '',
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv', # noqa: E501
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv',
'md5': '43ea07974936a6bf47d989c32e16afe7',
},
]
classes = [
)
classes = (
'Continuous urban fabric',
'Discontinuous urban fabric',
'Industrial or commercial units',
@ -163,12 +164,17 @@ class SeasoNet(NonGeoDataset):
'Coastal lagoons',
'Estuaries',
'Sea and ocean',
]
all_seasons = {'Spring', 'Summer', 'Fall', 'Winter', 'Snow'}
)
all_seasons = frozenset({'Spring', 'Summer', 'Fall', 'Winter', 'Snow'})
all_bands = ('10m_RGB', '10m_IR', '20m', '60m')
band_nums = {'10m_RGB': 3, '10m_IR': 1, '20m': 6, '60m': 2}
splits = ['train', 'val', 'test']
cmap = {
band_nums: ClassVar[dict[str, int]] = {
'10m_RGB': 3,
'10m_IR': 1,
'20m': 6,
'60m': 2,
}
splits = ('train', 'val', 'test')
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (230, 000, 77, 255),
1: (255, 000, 000, 255),
2: (204, 77, 242, 255),
@ -331,7 +337,7 @@ class SeasoNet(NonGeoDataset):
for band in self.bands:
with rasterio.open(f'{path}_{band}.tif') as f:
array = f.read(
out_shape=[f.count] + list(self.image_size),
out_shape=[f.count, *list(self.image_size)],
out_dtype='int32',
resampling=Resampling.bilinear,
)

Просмотреть файл

@ -5,7 +5,8 @@
import os
import random
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -37,7 +38,7 @@ class SeasonalContrastS2(NonGeoDataset):
* https://arxiv.org/pdf/2103.16607.pdf
"""
all_bands = [
all_bands = (
'B1',
'B2',
'B3',
@ -50,10 +51,10 @@ class SeasonalContrastS2(NonGeoDataset):
'B9',
'B11',
'B12',
]
rgb_bands = ['B4', 'B3', 'B2']
)
rgb_bands = ('B4', 'B3', 'B2')
metadata = {
metadata: ClassVar[dict[str, dict[str, str]]] = {
'100k': {
'url': 'https://zenodo.org/record/4728033/files/seco_100k.zip?download=1',
'md5': 'ebf2d5e03adc6e657f9a69a20ad863e0',
@ -73,7 +74,7 @@ class SeasonalContrastS2(NonGeoDataset):
root: Path = 'data',
version: str = '100k',
seasons: int = 1,
bands: list[str] = rgb_bands,
bands: Sequence[str] = rgb_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,

Просмотреть файл

@ -5,6 +5,7 @@
import os
from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -63,9 +64,9 @@ class SEN12MS(NonGeoDataset):
or manually downloaded from https://dataserv.ub.tum.de/s/m1474000
and https://github.com/schmitt-muc/SEN12MS/tree/master/splits.
This download will likely take several hours.
""" # noqa: E501
"""
BAND_SETS: dict[str, tuple[str, ...]] = {
BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
'all': (
'VV',
'VH',
@ -120,9 +121,9 @@ class SEN12MS(NonGeoDataset):
'B12',
)
rgb_bands = ['B04', 'B03', 'B02']
rgb_bands = ('B04', 'B03', 'B02')
filenames = [
filenames = (
'ROIs1158_spring_lc.tar.gz',
'ROIs1158_spring_s1.tar.gz',
'ROIs1158_spring_s2.tar.gz',
@ -137,16 +138,16 @@ class SEN12MS(NonGeoDataset):
'ROIs2017_winter_s2.tar.gz',
'train_list.txt',
'test_list.txt',
]
light_filenames = [
)
light_filenames = (
'ROIs1158_spring',
'ROIs1868_summer',
'ROIs1970_fall',
'ROIs2017_winter',
'train_list.txt',
'test_list.txt',
]
md5s = [
)
md5s = (
'6e2e8fa8b8cba77ddab49fd20ff5c37b',
'fba019bb27a08c1db96b31f718c34d79',
'd58af2c15a16f376eb3308dc9b685af2',
@ -161,7 +162,7 @@ class SEN12MS(NonGeoDataset):
'3807545661288dcca312c9c538537b63',
'0a68d4e1eb24f128fccdb930000b2546',
'c7faad064001e646445c4c634169484d',
]
)
def __init__(
self,

Просмотреть файл

@ -137,7 +137,7 @@ class Sentinel1(Sentinel):
\.
"""
date_format = '%Y%m%dT%H%M%S'
all_bands = ['HH', 'HV', 'VV', 'VH']
all_bands = ('HH', 'HV', 'VV', 'VH')
separate_files = True
def __init__(
@ -277,7 +277,7 @@ class Sentinel2(Sentinel):
date_format = '%Y%m%dT%H%M%S'
# https://gisgeography.com/sentinel-2-bands-combinations/
all_bands = [
all_bands: tuple[str, ...] = (
'B01',
'B02',
'B03',
@ -291,8 +291,8 @@ class Sentinel2(Sentinel):
'B10',
'B11',
'B12',
]
rgb_bands = ['B04', 'B03', 'B02']
)
rgb_bands = ('B04', 'B03', 'B02')
separate_files = True

Просмотреть файл

@ -5,7 +5,7 @@
import os
from collections.abc import Callable
from typing import Any
from typing import Any, ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -62,8 +62,8 @@ class SKIPPD(NonGeoDataset):
.. versionadded:: 0.5
"""
url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}' # noqa: E501
md5 = {
url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}'
md5: ClassVar[dict[str, str]] = {
'forecast': 'f4f3509ddcc83a55c433be9db2e51077',
'nowcast': '0000761d403e45bb5f86c21d3c69aa80',
}
@ -71,9 +71,9 @@ class SKIPPD(NonGeoDataset):
data_file_name = '2017_2019_images_pv_processed_{}.hdf5'
zipfile_name = '2017_2019_images_pv_processed_{}.zip'
valid_splits = ['trainval', 'test']
valid_splits = ('trainval', 'test')
valid_tasks = ['nowcast', 'forecast']
valid_tasks = ('nowcast', 'forecast')
dateformat = '%m/%d/%Y, %H:%M:%S'

Просмотреть файл

@ -5,7 +5,7 @@
import os
from collections.abc import Callable, Sequence
from typing import cast
from typing import ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@ -103,10 +103,10 @@ class So2Sat(NonGeoDataset):
This dataset requires the following additional library to be installed:
* `<https://pypi.org/project/h5py/>`_ to load the dataset
""" # noqa: E501
"""
versions = ['2', '3_random', '3_block', '3_culture_10']
filenames_by_version = {
versions = ('2', '3_random', '3_block', '3_culture_10')
filenames_by_version: ClassVar[dict[str, dict[str, str]]] = {
'2': {
'train': 'training.h5',
'validation': 'validation.h5',
@ -119,7 +119,7 @@ class So2Sat(NonGeoDataset):
'test': 'culture_10/testing.h5',
},
}
md5s_by_version = {
md5s_by_version: ClassVar[dict[str, dict[str, str]]] = {
'2': {
'train': '702bc6a9368ebff4542d791e53469244',
'validation': '71cfa6795de3e22207229d06d6f8775d',
@ -139,7 +139,7 @@ class So2Sat(NonGeoDataset):
},
}
classes = [
classes = (
'Compact high rise',
'Compact mid rise',
'Compact low rise',
@ -157,7 +157,7 @@ class So2Sat(NonGeoDataset):
'Bare rock or paved',
'Bare soil or sand',
'Water',
]
)
all_s1_band_names = (
'S1_B1',
@ -183,9 +183,9 @@ class So2Sat(NonGeoDataset):
)
all_band_names = all_s1_band_names + all_s2_band_names
rgb_bands = ['S2_B04', 'S2_B03', 'S2_B02']
rgb_bands = ('S2_B04', 'S2_B03', 'S2_B02')
BAND_SETS = {
BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
'all': all_band_names,
's1': all_s1_band_names,
's2': all_s2_band_names,

Просмотреть файл

@ -6,8 +6,8 @@
import os
import pathlib
import re
from collections.abc import Callable, Iterable
from typing import Any, cast
from collections.abc import Callable, Iterable, Sequence
from typing import Any, ClassVar, cast
import matplotlib.pyplot as plt
import torch
@ -79,9 +79,9 @@ class SouthAfricaCropType(RasterDataset):
_10m
"""
date_format = '%Y_%m_%d'
rgb_bands = ['B04', 'B03', 'B02']
s1_bands = ['VH', 'VV']
s2_bands = [
rgb_bands = ('B04', 'B03', 'B02')
s1_bands = ('VH', 'VV')
s2_bands = (
'B01',
'B02',
'B03',
@ -94,9 +94,9 @@ class SouthAfricaCropType(RasterDataset):
'B09',
'B11',
'B12',
]
all_bands: list[str] = s1_bands + s2_bands
cmap = {
)
all_bands = s1_bands + s2_bands
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 255),
1: (255, 211, 0, 255),
2: (255, 37, 37, 255),
@ -113,8 +113,8 @@ class SouthAfricaCropType(RasterDataset):
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
classes: list[int] = list(cmap.keys()),
bands: list[str] = s2_bands,
classes: Sequence[int] = list(cmap.keys()),
bands: Sequence[str] = s2_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
) -> None:

Просмотреть файл

@ -5,7 +5,7 @@
import pathlib
from collections.abc import Callable, Iterable
from typing import Any
from typing import Any, ClassVar
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -47,7 +47,7 @@ class SouthAmericaSoybean(RasterDataset):
is_image = False
url = 'https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif'
md5s = {
md5s: ClassVar[dict[int, str]] = {
2021: 'edff3ada13a1a9910d1fe844d28ae4f',
2020: '0709dec807f576c9707c8c7e183db31',
2019: '441836493bbcd5e123cff579a58f5a4f',

Просмотреть файл

@ -8,7 +8,7 @@ import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
from typing import Any, ClassVar
import fiona
import matplotlib.pyplot as plt
@ -55,9 +55,9 @@ class SpaceNet(NonGeoDataset, ABC):
image_glob = '*.tif'
mask_glob = '*.geojson'
file_regex = r'_img(\d+)\.'
chip_size: dict[str, tuple[int, int]] = {}
chip_size: ClassVar[dict[str, tuple[int, int]]] = {}
cities = {
cities: ClassVar[dict[int, str]] = {
1: 'Rio',
2: 'Vegas',
3: 'Paris',
@ -98,7 +98,7 @@ class SpaceNet(NonGeoDataset, ABC):
@property
@abstractmethod
def valid_masks(self) -> list[str]:
def valid_masks(self) -> tuple[str, ...]:
"""List of valid masks."""
def __init__(
@ -426,7 +426,7 @@ class SpaceNet1(SpaceNet):
directory_glob = '{product}'
dataset_id = 'SN1_buildings'
tarballs = {
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
1: [
'SN1_buildings_train_AOI_1_Rio_3band.tar.gz',
@ -441,7 +441,7 @@ class SpaceNet1(SpaceNet):
]
},
}
md5s = {
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
1: [
'279e334a2120ecac70439ea246174516',
@ -453,10 +453,16 @@ class SpaceNet1(SpaceNet):
1: ['18283d78b21c239bc1831f3bf1d2c996', '732b3a40603b76e80aac84e002e2b3e8']
},
}
valid_aois = {'train': [1], 'test': [1]}
valid_images = {'train': ['3band', '8band'], 'test': ['3band', '8band']}
valid_masks = ['geojson']
chip_size = {'3band': (406, 439), '8band': (102, 110)}
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [1], 'test': [1]}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['3band', '8band'],
'test': ['3band', '8band'],
}
valid_masks = ('geojson',)
chip_size: ClassVar[dict[str, tuple[int, int]]] = {
'3band': (406, 439),
'8band': (102, 110),
}
class SpaceNet2(SpaceNet):
@ -522,7 +528,7 @@ class SpaceNet2(SpaceNet):
"""
dataset_id = 'SN2_buildings'
tarballs = {
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
2: ['SN2_buildings_train_AOI_2_Vegas.tar.gz'],
3: ['SN2_buildings_train_AOI_3_Paris.tar.gz'],
@ -536,7 +542,7 @@ class SpaceNet2(SpaceNet):
5: ['AOI_5_Khartoum_Test_public.tar.gz'],
},
}
md5s = {
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
2: ['307da318bc43aaf9481828f92eda9126'],
3: ['4db469e3e4e7bf025368ad730aec0888'],
@ -550,13 +556,16 @@ class SpaceNet2(SpaceNet):
5: ['037d7be10530f0dd1c43d4ef79f3236e'],
},
}
valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]}
valid_images = {
valid_aois: ClassVar[dict[str, list[int]]] = {
'train': [2, 3, 4, 5],
'test': [2, 3, 4, 5],
}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
}
valid_masks = [os.path.join('geojson', 'buildings')]
chip_size = {'MUL': (163, 163)}
valid_masks = (os.path.join('geojson', 'buildings'),)
chip_size: ClassVar[dict[str, tuple[int, int]]] = {'MUL': (163, 163)}
class SpaceNet3(SpaceNet):
@ -624,7 +633,7 @@ class SpaceNet3(SpaceNet):
"""
dataset_id = 'SN3_roads'
tarballs = {
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
2: [
'SN3_roads_train_AOI_2_Vegas.tar.gz',
@ -650,7 +659,7 @@ class SpaceNet3(SpaceNet):
5: ['SN3_roads_test_public_AOI_5_Khartoum.tar.gz'],
},
}
md5s = {
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
2: ['06317255b5e0c6df2643efd8a50f22ae', '4acf7846ed8121db1319345cfe9fdca9'],
3: ['c13baf88ee10fe47870c303223cabf82', 'abc8199d4c522d3a14328f4f514702ad'],
@ -664,12 +673,15 @@ class SpaceNet3(SpaceNet):
5: ['f367c79fa0fc1d38e63a0fdd065ed957'],
},
}
valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]}
valid_images = {
valid_aois: ClassVar[dict[str, list[int]]] = {
'train': [2, 3, 4, 5],
'test': [2, 3, 4, 5],
}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MS', 'PS-MS', 'PAN', 'PS-RGB'],
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
}
valid_masks = ['geojson_roads', 'geojson_roads_speed']
valid_masks: tuple[str, ...] = ('geojson_roads', 'geojson_roads_speed')
class SpaceNet4(SpaceNet):
@ -708,7 +720,7 @@ class SpaceNet4(SpaceNet):
directory_glob = os.path.join('**', '{product}')
file_regex = r'_(\d+_\d+)\.'
dataset_id = 'SN4_buildings'
tarballs = {
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
6: [
'Atlanta_nadir7_catid_1030010003D22F00.tar.gz',
@ -743,7 +755,7 @@ class SpaceNet4(SpaceNet):
},
'test': {6: ['SN4_buildings_AOI_6_Atlanta_test_public.tar.gz']},
}
md5s = {
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
6: [
'd41ab6ec087b07e1e046c55d1fa5754b',
@ -778,12 +790,12 @@ class SpaceNet4(SpaceNet):
},
'test': {6: ['0ec3874bfc19aed63b33ac47b039aace']},
}
valid_aois = {'train': [6], 'test': [6]}
valid_images = {
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [6], 'test': [6]}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MS', 'PAN', 'Pan-Sharpen'],
'test': ['MS', 'PAN', 'Pan-Sharpen'],
}
valid_masks = [os.path.join('geojson', 'spacenet-buildings')]
valid_masks = (os.path.join('geojson', 'spacenet-buildings'),)
class SpaceNet5(SpaceNet3):
@ -850,26 +862,26 @@ class SpaceNet5(SpaceNet3):
file_regex = r'_chip(\d+)\.'
dataset_id = 'SN5_roads'
tarballs = {
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
7: ['SN5_roads_train_AOI_7_Moscow.tar.gz'],
8: ['SN5_roads_train_AOI_8_Mumbai.tar.gz'],
},
'test': {9: ['SN5_roads_test_public_AOI_9_San_Juan.tar.gz']},
}
md5s = {
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
7: ['03082d01081a6d8df2bc5a9645148d2a'],
8: ['1ee20ba781da6cb7696eef9a95a5bdcc'],
},
'test': {9: ['fc45afef219dfd3a20f2d4fc597f6882']},
}
valid_aois = {'train': [7, 8], 'test': [9]}
valid_images = {
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [7, 8], 'test': [9]}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
'test': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
}
valid_masks = ['geojson_roads_speed']
valid_masks = ('geojson_roads_speed',)
class SpaceNet6(SpaceNet):
@ -937,20 +949,20 @@ class SpaceNet6(SpaceNet):
file_regex = r'_tile_(\d+)\.'
dataset_id = 'SN6_buildings'
tarballs = {
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {11: ['SN6_buildings_AOI_11_Rotterdam_train.tar.gz']},
'test': {11: ['SN6_buildings_AOI_11_Rotterdam_test_public.tar.gz']},
}
md5s = {
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {11: ['10ca26d2287716e3b6ef0cf0ad9f946e']},
'test': {11: ['a07823a5e536feeb8bb6b6f0cb43cf05']},
}
valid_aois = {'train': [11], 'test': [11]}
valid_images = {
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [11], 'test': [11]}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['PAN', 'PS-RGB', 'PS-RGBNIR', 'RGBNIR', 'SAR-Intensity'],
'test': ['SAR-Intensity'],
}
valid_masks = ['geojson_buildings']
valid_masks = ('geojson_buildings',)
class SpaceNet7(SpaceNet):
@ -958,7 +970,7 @@ class SpaceNet7(SpaceNet):
`SpaceNet 7 <https://spacenet.ai/sn7-challenge/>`_ is a dataset which
consist of medium resolution (4.0m) satellite imagery mosaics acquired from
Planet Labs Dove constellation between 2017 and 2020. It includes 24
Planet Labs' Dove constellation between 2017 and 2020. It includes ≈ 24
images (one per month) covering > 100 unique geographies, and comprises >
40,000 km2 of imagery and exhaustive polygon labels of building footprints
therein, totaling over 11M individual annotations.
@ -993,18 +1005,24 @@ class SpaceNet7(SpaceNet):
mask_glob = '*_Buildings.geojson'
file_regex = r'global_monthly_(\d+.*\d+)'
dataset_id = 'SN7_buildings'
tarballs = {
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {0: ['SN7_buildings_train.tar.gz']},
'test': {0: ['SN7_buildings_test_public.tar.gz']},
}
md5s = {
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {0: ['6eda13b9c28f6f5cdf00a7e8e218c1b1']},
'test': {0: ['b3bde95a0f8f32f3bfeba49464b9bc97']},
}
valid_aois = {'train': [0], 'test': [0]}
valid_images = {'train': ['images', 'images_masked'], 'test': ['images_masked']}
valid_masks = ['labels', 'labels_match', 'labels_match_pix']
chip_size = {'images': (1024, 1024), 'images_masked': (1024, 1024)}
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['images', 'images_masked'],
'test': ['images_masked'],
}
valid_masks = ('labels', 'labels_match', 'labels_match_pix')
chip_size: ClassVar[dict[str, tuple[int, int]]] = {
'images': (1024, 1024),
'images_masked': (1024, 1024),
}
class SpaceNet8(SpaceNet):
@ -1024,7 +1042,7 @@ class SpaceNet8(SpaceNet):
directory_glob = '{product}'
file_regex = r'(\d+_\d+_\d+)\.'
dataset_id = 'SN8_floods'
tarballs = {
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
0: [
'Germany_Training_Public.tar.gz',
@ -1033,16 +1051,19 @@ class SpaceNet8(SpaceNet):
},
'test': {0: ['Louisiana-West_Test_Public.tar.gz']},
}
md5s = {
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
'train': {
0: ['81383a9050b93e8f70c8557d4568e8a2', 'fa40ae3cf6ac212c90073bf93d70bd95']
},
'test': {0: ['d41d8cd98f00b204e9800998ecf8427e']},
}
valid_aois = {'train': [0], 'test': [0]}
valid_images = {
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]}
valid_images: ClassVar[dict[str, list[str]]] = {
'train': ['PRE-event', 'POST-event'],
'test': ['PRE-event', 'POST-event'],
}
valid_masks = ['annotations']
chip_size = {'PRE-event': (1300, 1300), 'POST-event': (1300, 1300)}
valid_masks = ('annotations',)
chip_size: ClassVar[dict[str, tuple[int, int]]] = {
'PRE-event': (1300, 1300),
'POST-event': (1300, 1300),
}

Просмотреть файл

@ -7,7 +7,7 @@ import glob
import os
import random
from collections.abc import Callable
from typing import TypedDict
from typing import ClassVar, TypedDict
import matplotlib.pyplot as plt
import numpy as np
@ -93,13 +93,13 @@ class SSL4EOL(NonGeoDataset):
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
.. versionadded:: 0.5
""" # noqa: E501
"""
class _Metadata(TypedDict):
num_bands: int
rgb_bands: list[int]
metadata: dict[str, _Metadata] = {
metadata: ClassVar[dict[str, _Metadata]] = {
'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]},
'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]},
'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]},
@ -107,8 +107,8 @@ class SSL4EOL(NonGeoDataset):
'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]},
}
url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}' # noqa: E501
checksums = {
url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}'
checksums: ClassVar[dict[str, dict[str, str]]] = {
'tm_toa': {
'aa': '553795b8d73aa253445b1e67c5b81f11',
'ab': 'e9e0739b5171b37d16086cb89ab370e8',
@ -357,7 +357,7 @@ class SSL4EOS12(NonGeoDataset):
md5: str
bands: list[str]
metadata: dict[str, _Metadata] = {
metadata: ClassVar[dict[str, _Metadata]] = {
's1': {
'filename': 's1.tar.gz',
'md5': '51ee23b33eb0a2f920bda25225072f3a',

Просмотреть файл

@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -46,16 +47,16 @@ class SSL4EOLBenchmark(NonGeoDataset):
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
.. versionadded:: 0.5
""" # noqa: E501
"""
url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz' # noqa: E501
url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz'
valid_sensors = ['tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr']
valid_products = ['cdl', 'nlcd']
valid_splits = ['train', 'val', 'test']
valid_sensors = ('tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr')
valid_products = ('cdl', 'nlcd')
valid_splits = ('train', 'val', 'test')
image_root = 'ssl4eo_l_{}_benchmark'
img_md5s = {
img_md5s: ClassVar[dict[str, str]] = {
'tm_toa': '8e3c5bcd56d3780a442f1332013b8d15',
'etm_toa': '1b051c7fe4d61c581b341370c9e76f1f',
'etm_sr': '34a24fa89a801654f8d01e054662c8cd',
@ -63,14 +64,14 @@ class SSL4EOLBenchmark(NonGeoDataset):
'oli_sr': '0700cd15cc2366fe68c2f8c02fa09a15',
}
mask_dir_dict = {
mask_dir_dict: ClassVar[dict[str, str]] = {
'tm_toa': 'ssl4eo_l_tm_{}',
'etm_toa': 'ssl4eo_l_etm_{}',
'etm_sr': 'ssl4eo_l_etm_{}',
'oli_tirs_toa': 'ssl4eo_l_oli_{}',
'oli_sr': 'ssl4eo_l_oli_{}',
}
mask_md5s = {
mask_md5s: ClassVar[dict[str, dict[str, str]]] = {
'tm': {
'cdl': '3d676770ffb56c7e222a7192a652a846',
'nlcd': '261149d7614fcfdcb3be368eefa825c7',
@ -85,7 +86,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
},
}
year_dict = {
year_dict: ClassVar[dict[str, int]] = {
'tm_toa': 2011,
'etm_toa': 2019,
'etm_sr': 2019,
@ -93,7 +94,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
'oli_sr': 2019,
}
rgb_indices = {
rgb_indices: ClassVar[dict[str, list[int]]] = {
'tm_toa': [2, 1, 0],
'etm_toa': [2, 1, 0],
'etm_sr': [2, 1, 0],
@ -101,9 +102,12 @@ class SSL4EOLBenchmark(NonGeoDataset):
'oli_sr': [3, 2, 1],
}
split_percentages = [0.7, 0.15, 0.15]
split_percentages = (0.7, 0.15, 0.15)
cmaps = {'nlcd': NLCD.cmap, 'cdl': CDL.cmap}
cmaps: ClassVar[dict[str, dict[int, tuple[int, int, int, int]]]] = {
'nlcd': NLCD.cmap,
'cdl': CDL.cmap,
}
def __init__(
self,

Просмотреть файл

@ -45,17 +45,17 @@ class SustainBenchCropYield(NonGeoDataset):
* https://doi.org/10.1609/aaai.v31i1.11172
.. versionadded:: 0.5
""" # noqa: E501
"""
valid_countries = ['usa', 'brazil', 'argentina']
valid_countries = ('usa', 'brazil', 'argentina')
md5 = '362bad07b51a1264172b8376b39d1fc9'
url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link' # noqa: E501
url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link'
dir = 'soybeans'
valid_splits = ['train', 'dev', 'test']
valid_splits = ('train', 'dev', 'test')
def __init__(
self,

Просмотреть файл

@ -5,7 +5,7 @@
import os
from collections.abc import Callable
from typing import cast
from typing import ClassVar, cast
import matplotlib.pyplot as plt
import numpy as np
@ -66,19 +66,19 @@ class UCMerced(NonGeoClassificationDataset):
* https://dl.acm.org/doi/10.1145/1869790.1869829
"""
url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip' # noqa: E501
url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip'
filename = 'UCMerced_LandUse.zip'
md5 = '5b7ec56793786b6dc8a908e8854ac0e4'
base_dir = os.path.join('UCMerced_LandUse', 'Images')
splits = ['train', 'val', 'test']
split_urls = {
'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt', # noqa: E501
'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt', # noqa: E501
'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt', # noqa: E501
splits = ('train', 'val', 'test')
split_urls: ClassVar[dict[str, str]] = {
'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt',
'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt',
'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt',
}
split_md5s = {
split_md5s: ClassVar[dict[str, str]] = {
'train': 'f2fb12eb2210cfb53f93f063a35ff374',
'val': '11ecabfc52782e5ea6a9c7c0d263aca0',
'test': '046aff88472d8fc07c4678d03749e28d',

Просмотреть файл

@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable, Sequence
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -49,12 +50,12 @@ class USAVars(NonGeoDataset):
.. versionadded:: 0.3
"""
data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}' # noqa: E501
data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}'
dirname = 'uar'
md5 = '677e89fd20e5dd0fe4d29b61827c2456'
label_urls = {
label_urls: ClassVar[dict[str, str]] = {
'housing': data_url.format('housing.csv'),
'income': data_url.format('income.csv'),
'roads': data_url.format('roads.csv'),
@ -64,7 +65,7 @@ class USAVars(NonGeoDataset):
'treecover': data_url.format('treecover.csv'),
}
split_metadata = {
split_metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'url': data_url.format('train_split.txt'),
'filename': 'train_split.txt',
@ -82,7 +83,7 @@ class USAVars(NonGeoDataset):
},
}
ALL_LABELS = ['treecover', 'elevation', 'population']
ALL_LABELS = ('treecover', 'elevation', 'population')
def __init__(
self,

Просмотреть файл

@ -86,11 +86,11 @@ class BoundingBox:
# https://github.com/PyCQA/pydocstyle/issues/525
@overload
def __getitem__(self, key: int) -> float: # noqa: D105
def __getitem__(self, key: int) -> float:
pass
@overload
def __getitem__(self, key: slice) -> list[float]: # noqa: D105
def __getitem__(self, key: slice) -> list[float]:
pass
def __getitem__(self, key: int | slice) -> float | list[float]:
@ -289,7 +289,7 @@ class Executable:
The completed process.
"""
kwargs['check'] = True
return subprocess.run((self.name,) + args, **kwargs)
return subprocess.run((self.name, *args), **kwargs)
def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]:
@ -547,7 +547,7 @@ def draw_semantic_segmentation_masks(
def rgb_to_mask(
rgb: np.typing.NDArray[np.uint8], colors: list[tuple[int, int, int]]
rgb: np.typing.NDArray[np.uint8], colors: Sequence[tuple[int, int, int]]
) -> np.typing.NDArray[np.uint8]:
"""Converts an RGB colormap mask to a integer mask.

Просмотреть файл

@ -5,6 +5,7 @@
import os
from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -55,15 +56,15 @@ class Vaihingen2D(NonGeoDataset):
* https://doi.org/10.5194/isprsannals-I-3-293-2012
.. versionadded:: 0.2
""" # noqa: E501
"""
filenames = [
filenames = (
'ISPRS_semantic_labeling_Vaihingen.zip',
'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip',
]
md5s = ['462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277']
)
md5s = ('462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277')
image_root = 'top'
splits = {
splits: ClassVar[dict[str, list[str]]] = {
'train': [
'top_mosaic_09cm_area1.tif',
'top_mosaic_09cm_area11.tif',
@ -102,22 +103,22 @@ class Vaihingen2D(NonGeoDataset):
'top_mosaic_09cm_area29.tif',
],
}
classes = [
classes = (
'Clutter/background',
'Impervious surfaces',
'Building',
'Low Vegetation',
'Tree',
'Car',
]
colormap = [
)
colormap = (
(255, 0, 0),
(255, 255, 255),
(0, 0, 255),
(0, 255, 255),
(0, 255, 0),
(255, 255, 0),
]
)
def __init__(
self,
@ -258,7 +259,7 @@ class Vaihingen2D(NonGeoDataset):
"""
ncols = 1
image1 = draw_semantic_segmentation_masks(
sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap
sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap)
)
if 'prediction' in sample:
ncols += 1
@ -266,7 +267,7 @@ class Vaihingen2D(NonGeoDataset):
sample['image'][:3],
sample['prediction'],
alpha=alpha,
colors=self.colormap,
colors=list(self.colormap),
)
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))

Просмотреть файл

@ -5,7 +5,7 @@
import os
from collections.abc import Callable
from typing import Any
from typing import Any, ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -158,18 +158,18 @@ class VHR10(NonGeoDataset):
``annotations.json`` file for the "positive" image set
"""
image_meta = {
image_meta: ClassVar[dict[str, str]] = {
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/NWPU%20VHR-10%20dataset.zip',
'filename': 'NWPU VHR-10 dataset.zip',
'md5': '6add6751469c12dd8c8d6223064c6c4d',
}
target_meta = {
target_meta: ClassVar[dict[str, str]] = {
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/annotations.json',
'filename': 'annotations.json',
'md5': '7c76ec50c17a61bb0514050d20f22c08',
}
categories = [
categories = (
'background',
'airplane',
'ships',
@ -181,7 +181,7 @@ class VHR10(NonGeoDataset):
'harbor',
'bridge',
'vehicle',
]
)
def __init__(
self,

Просмотреть файл

@ -6,7 +6,7 @@
import glob
import json
import os
from collections.abc import Callable
from collections.abc import Callable, Iterable
from typing import Any
import pandas as pd
@ -53,7 +53,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
label_name = 'percent(t)'
all_variable_names = [
all_variable_names = (
# "date",
'slope(t)',
'elevation(t)',
@ -193,12 +193,12 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
'vh_vv(t-3)',
'lat',
'lon',
]
)
def __init__(
self,
root: Path = 'data',
input_features: list[str] = all_variable_names,
input_features: Iterable[str] = all_variable_names,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
) -> None:
@ -273,7 +273,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
data_rows.append(data_dict)
df = pd.DataFrame(data_rows)
df = df[self.input_features + [self.label_name]]
df = df[[*self.input_features, self.label_name]]
return df
def _verify(self) -> None:

Просмотреть файл

@ -6,6 +6,7 @@
import glob
import os
from collections.abc import Callable
from typing import ClassVar
import matplotlib.pyplot as plt
import numpy as np
@ -54,7 +55,7 @@ class XView2(NonGeoDataset):
.. versionadded:: 0.2
"""
metadata = {
metadata: ClassVar[dict[str, dict[str, str]]] = {
'train': {
'filename': 'train_images_labels_targets.tar.gz',
'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16',
@ -66,8 +67,8 @@ class XView2(NonGeoDataset):
'directory': 'test',
},
}
classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed']
colormap = ['green', 'blue', 'orange', 'red']
classes = ('background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed')
colormap = ('green', 'blue', 'orange', 'red')
def __init__(
self,
@ -242,10 +243,16 @@ class XView2(NonGeoDataset):
"""
ncols = 2
image1 = draw_semantic_segmentation_masks(
sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap
sample['image'][0],
sample['mask'][0],
alpha=alpha,
colors=list(self.colormap),
)
image2 = draw_semantic_segmentation_masks(
sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap
sample['image'][1],
sample['mask'][1],
alpha=alpha,
colors=list(self.colormap),
)
if 'prediction' in sample: # NOTE: this assumes predictions are made for post
ncols += 1
@ -253,7 +260,7 @@ class XView2(NonGeoDataset):
sample['image'][1],
sample['prediction'],
alpha=alpha,
colors=self.colormap,
colors=list(self.colormap),
)
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))

Просмотреть файл

@ -52,15 +52,15 @@ class ZueriCrop(NonGeoDataset):
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
"""
urls = [
urls = (
'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download',
'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv', # noqa: E501
]
md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b']
filenames = ['ZueriCrop.hdf5', 'labels.csv']
'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv',
)
md5s = ('1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b')
filenames = ('ZueriCrop.hdf5', 'labels.csv')
band_names = ('NIR', 'B03', 'B02', 'B04', 'B05', 'B06', 'B07', 'B11', 'B12')
rgb_bands = ['B04', 'B03', 'B02']
rgb_bands = ('B04', 'B03', 'B02')
def __init__(
self,

Просмотреть файл

@ -8,7 +8,7 @@ import os
from lightning.pytorch.cli import ArgsType, LightningCLI
# Allows classes to be referenced using only the class name
import torchgeo.datamodules # noqa: F401
import torchgeo.datamodules
import torchgeo.trainers # noqa: F401
from torchgeo.datamodules import BaseDataModule
from torchgeo.trainers import BaseTask

Просмотреть файл

@ -8,7 +8,7 @@ See the following references for design details:
* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
* https://pytorch.org/vision/stable/models.html
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
""" # noqa: E501
"""
from collections.abc import Callable
from typing import Any

Просмотреть файл

@ -384,7 +384,7 @@ class DOFABase16_Weights(WeightsEnum): # type: ignore[misc]
"""
DOFA_MAE = Weights(
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth', # noqa: E501
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth',
transforms=_dofa_transforms,
meta={
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',
@ -403,7 +403,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc]
"""
DOFA_MAE = Weights(
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth', # noqa: E501
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth',
transforms=_dofa_transforms,
meta={
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',

Просмотреть файл

@ -140,7 +140,7 @@ class RCF(Module):
a numpy array of size (N, C, H, W) containing the normalized patches
.. versionadded:: 0.5
""" # noqa: E501
"""
n_patches = patches.shape[0]
orig_shape = patches.shape
patches = patches.reshape(patches.shape[0], -1)

Просмотреть файл

@ -11,8 +11,8 @@ import torch
from timm.models import ResNet
from torchvision.models._api import Weights, WeightsEnum
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = K.AugmentationSequential(
K.Resize(256),
@ -22,7 +22,7 @@ _zhu_xlab_transforms = K.AugmentationSequential(
)
# Normalization only available for RGB dataset, defined here:
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
_min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129])
_mean = torch.tensor([0.485, 0.456, 0.406])
@ -37,7 +37,7 @@ _seco_transforms = K.AugmentationSequential(
)
# Normalization only available for RGB dataset, defined here:
# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501
# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287
_mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225])
_gassl_transforms = K.AugmentationSequential(
@ -47,7 +47,7 @@ _gassl_transforms = K.AugmentationSequential(
data_keys=None,
)
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43
_ssl4eo_l_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.CenterCrop((224, 224)),
@ -70,7 +70,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
"""
LANDSAT_TM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -83,7 +83,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_TM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -96,7 +96,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -109,7 +109,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -122,7 +122,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -135,7 +135,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -148,7 +148,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -161,7 +161,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -174,7 +174,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -187,7 +187,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -200,7 +200,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth', # noqa: E501
url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@ -213,7 +213,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_MOCO = Weights(
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth', # noqa: E501
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@ -226,7 +226,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_SECO = Weights(
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth', # noqa: E501
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth',
transforms=_seco_transforms,
meta={
'dataset': 'SeCo Dataset',
@ -249,7 +249,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
"""
FMOW_RGB_GASSL = Weights(
url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth', # noqa: E501
url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth',
transforms=_gassl_transforms,
meta={
'dataset': 'fMoW Dataset',
@ -262,7 +262,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_TM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -275,7 +275,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_TM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -288,7 +288,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -301,7 +301,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -314,7 +314,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -327,7 +327,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -340,7 +340,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -353,7 +353,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -366,7 +366,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -379,7 +379,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -392,7 +392,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL1_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth', # noqa: E501
url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@ -405,7 +405,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_DINO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth', # noqa: E501
url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@ -418,7 +418,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth', # noqa: E501
url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@ -431,7 +431,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', # noqa: E501
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@ -444,7 +444,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_SECO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth', # noqa: E501
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth',
transforms=_seco_transforms,
meta={
'dataset': 'SeCo Dataset',

Просмотреть файл

@ -12,20 +12,20 @@ from kornia.contrib import Lambda
from torchvision.models import SwinTransformer
from torchvision.models._api import Weights, WeightsEnum
# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501
# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255.
_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None
)
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501
# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). # noqa: E501
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1).
_std = torch.tensor(
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
) # noqa: E501
)
_mean = torch.zeros_like(_std)
_sentinel2_ms_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=_mean, std=_std),
@ -33,7 +33,7 @@ _sentinel2_ms_satlas_transforms = K.AugmentationSequential(
data_keys=None,
)
# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501
# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1).
_landsat_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
@ -56,7 +56,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
"""
NAIP_RGB_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth', # noqa: E501
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'Satlas',
@ -68,7 +68,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth', # noqa: E501
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'Satlas',
@ -80,7 +80,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_MS_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth', # noqa: E501
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth',
transforms=_sentinel2_ms_satlas_transforms,
meta={
'dataset': 'Satlas',
@ -93,7 +93,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL1_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth', # noqa: E501
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'Satlas',
@ -106,7 +106,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth', # noqa: E501
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth',
transforms=_landsat_satlas_transforms,
meta={
'dataset': 'Satlas',
@ -126,7 +126,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
'B09',
'B10',
'B11',
], # noqa: E501
],
},
)

Просмотреть файл

@ -11,8 +11,8 @@ import torch
from timm.models.vision_transformer import VisionTransformer
from torchvision.models._api import Weights, WeightsEnum
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = K.AugmentationSequential(
K.Resize(256),
@ -21,7 +21,7 @@ _zhu_xlab_transforms = K.AugmentationSequential(
data_keys=None,
)
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43
_ssl4eo_l_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.CenterCrop((224, 224)),
@ -44,7 +44,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
"""
LANDSAT_TM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -57,7 +57,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_TM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -70,7 +70,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -83,7 +83,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -96,7 +96,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -109,7 +109,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_ETM_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -122,7 +122,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -135,7 +135,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -148,7 +148,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -161,7 +161,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
LANDSAT_OLI_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth', # noqa: E501
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
@ -174,7 +174,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_DINO = Weights(
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth', # noqa: E501
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
@ -187,7 +187,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth', # noqa: E501
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',

Просмотреть файл

@ -53,7 +53,7 @@ class AugmentationSequential(Module):
else:
keys.append(key)
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] # noqa: E501
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type]
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Perform augmentations and update data dict.