зеркало из https://github.com/microsoft/torchgeo.git
Ruff: enable ruff-specific rules (#2218)
* Ruff: enable ruff-specific rules * Static class variables * String colormap must be list * String colormap must be list
This commit is contained in:
Родитель
c26512e39d
Коммит
067ae1af75
|
@ -19,7 +19,7 @@ import pytorch_sphinx_theme
|
||||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
sys.path.insert(0, os.path.abspath('..'))
|
sys.path.insert(0, os.path.abspath('..'))
|
||||||
|
|
||||||
import torchgeo # noqa: E402
|
import torchgeo
|
||||||
|
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -369,8 +369,8 @@
|
||||||
" date_format = '%Y%m%dT%H%M%S'\n",
|
" date_format = '%Y%m%dT%H%M%S'\n",
|
||||||
" is_image = True\n",
|
" is_image = True\n",
|
||||||
" separate_files = True\n",
|
" separate_files = True\n",
|
||||||
" all_bands = ['B02', 'B03', 'B04', 'B08']\n",
|
" all_bands = ('B02', 'B03', 'B04', 'B08')\n",
|
||||||
" rgb_bands = ['B04', 'B03', 'B02']"
|
" rgb_bands = ('B04', 'B03', 'B02')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -432,8 +432,8 @@
|
||||||
" date_format = '%Y%m%dT%H%M%S'\n",
|
" date_format = '%Y%m%dT%H%M%S'\n",
|
||||||
" is_image = True\n",
|
" is_image = True\n",
|
||||||
" separate_files = True\n",
|
" separate_files = True\n",
|
||||||
" all_bands = ['B02', 'B03', 'B04', 'B08']\n",
|
" all_bands = ('B02', 'B03', 'B04', 'B08')\n",
|
||||||
" rgb_bands = ['B04', 'B03', 'B02']\n",
|
" rgb_bands = ('B04', 'B03', 'B02')\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def plot(self, sample):\n",
|
" def plot(self, sample):\n",
|
||||||
" # Find the correct band index order\n",
|
" # Find the correct band index order\n",
|
||||||
|
|
|
@ -125,7 +125,7 @@ def filter_collection(
|
||||||
|
|
||||||
if filtered.size().getInfo() == 0:
|
if filtered.size().getInfo() == 0:
|
||||||
raise ee.EEException(
|
raise ee.EEException(
|
||||||
f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' # noqa: E501
|
f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.'
|
||||||
)
|
)
|
||||||
return filtered
|
return filtered
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ from tqdm import tqdm
|
||||||
def get_world_cities(
|
def get_world_cities(
|
||||||
download_root: str = 'world_cities', size: int = 10000
|
download_root: str = 'world_cities', size: int = 10000
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' # noqa: E501
|
url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip'
|
||||||
filename = 'worldcities.csv'
|
filename = 'worldcities.csv'
|
||||||
download_and_extract_archive(url, download_root)
|
download_and_extract_archive(url, download_root)
|
||||||
cols = ['city', 'lat', 'lng', 'population']
|
cols = ['city', 'lat', 'lng', 'population']
|
||||||
|
|
|
@ -268,7 +268,7 @@ quote-style = "single"
|
||||||
skip-magic-trailing-comma = true
|
skip-magic-trailing-comma = true
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
extend-select = ["ANN", "D", "I", "NPY201", "UP"]
|
extend-select = ["ANN", "D", "I", "NPY201", "RUF", "UP"]
|
||||||
ignore = ["ANN101", "ANN102", "ANN401"]
|
ignore = ["ANN101", "ANN102", "ANN401"]
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
|
|
@ -19,36 +19,36 @@ np.random.seed(0)
|
||||||
|
|
||||||
train_set = [
|
train_set = [
|
||||||
{
|
{
|
||||||
'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', # noqa: E501
|
'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif',
|
||||||
'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
|
'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif',
|
||||||
'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', # noqa: E501
|
'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', # noqa: E501
|
'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif',
|
||||||
'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
|
'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif',
|
||||||
'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', # noqa: E501
|
'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif',
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
unlabeled_set = [
|
unlabeled_set = [
|
||||||
{
|
{
|
||||||
'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', # noqa: E501
|
'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif',
|
||||||
'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
|
'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', # noqa: E501
|
'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif',
|
||||||
'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
|
'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif',
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
val_set = [
|
val_set = [
|
||||||
{
|
{
|
||||||
'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', # noqa: E501
|
'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif',
|
||||||
'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
|
'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', # noqa: E501
|
'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif',
|
||||||
'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', # noqa: E501
|
'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif',
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -112,7 +112,7 @@ for season in seasons:
|
||||||
# Compute checksums
|
# Compute checksums
|
||||||
with open(archive, 'rb') as f:
|
with open(archive, 'rb') as f:
|
||||||
md5 = hashlib.md5(f.read()).hexdigest()
|
md5 = hashlib.md5(f.read()).hexdigest()
|
||||||
print(f'{season}: {repr(md5)}')
|
print(f'{season}: {md5!r}')
|
||||||
|
|
||||||
# Write meta.csv
|
# Write meta.csv
|
||||||
with open('meta.csv', 'w') as f:
|
with open('meta.csv', 'w') as f:
|
||||||
|
@ -121,7 +121,7 @@ with open('meta.csv', 'w') as f:
|
||||||
# Compute checksums
|
# Compute checksums
|
||||||
with open('meta.csv', 'rb') as f:
|
with open('meta.csv', 'rb') as f:
|
||||||
md5 = hashlib.md5(f.read()).hexdigest()
|
md5 = hashlib.md5(f.read()).hexdigest()
|
||||||
print(f'meta.csv: {repr(md5)}')
|
print(f'meta.csv: {md5!r}')
|
||||||
|
|
||||||
os.makedirs('splits', exist_ok=True)
|
os.makedirs('splits', exist_ok=True)
|
||||||
|
|
||||||
|
@ -138,4 +138,4 @@ shutil.make_archive('splits', 'zip', '.', 'splits')
|
||||||
# Compute checksums
|
# Compute checksums
|
||||||
with open('splits.zip', 'rb') as f:
|
with open('splits.zip', 'rb') as f:
|
||||||
md5 = hashlib.md5(f.read()).hexdigest()
|
md5 = hashlib.md5(f.read()).hexdigest()
|
||||||
print(f'splits: {repr(md5)}')
|
print(f'splits: {md5!r}')
|
||||||
|
|
|
@ -83,5 +83,5 @@ class TestEuroCrops:
|
||||||
dataset[query]
|
dataset[query]
|
||||||
|
|
||||||
def test_integrity_error(self, dataset: EuroCrops) -> None:
|
def test_integrity_error(self, dataset: EuroCrops) -> None:
|
||||||
dataset.zenodo_files = [('AA.zip', 'invalid')]
|
dataset.zenodo_files = (('AA.zip', 'invalid'),)
|
||||||
assert not dataset._check_integrity()
|
assert not dataset._check_integrity()
|
||||||
|
|
|
@ -72,7 +72,7 @@ class CustomVectorDataset(VectorDataset):
|
||||||
|
|
||||||
|
|
||||||
class CustomSentinelDataset(Sentinel2):
|
class CustomSentinelDataset(Sentinel2):
|
||||||
all_bands: list[str] = []
|
all_bands: tuple[str, ...] = ()
|
||||||
separate_files = False
|
separate_files = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -356,7 +356,7 @@ class TestRasterDataset:
|
||||||
|
|
||||||
def test_no_all_bands(self) -> None:
|
def test_no_all_bands(self) -> None:
|
||||||
root = os.path.join('tests', 'data', 'sentinel2')
|
root = os.path.join('tests', 'data', 'sentinel2')
|
||||||
bands = ['B04', 'B03', 'B02']
|
bands = ('B04', 'B03', 'B02')
|
||||||
transforms = nn.Identity()
|
transforms = nn.Identity()
|
||||||
cache = True
|
cache = True
|
||||||
msg = (
|
msg = (
|
||||||
|
|
|
@ -73,7 +73,7 @@ class TestBYOLTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def weights(self) -> WeightsEnum:
|
def weights(self) -> WeightsEnum:
|
||||||
|
|
|
@ -103,13 +103,13 @@ class TestClassificationTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
try:
|
try:
|
||||||
main(['test'] + args)
|
main(['test', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
main(['predict'] + args)
|
main(['predict', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -259,13 +259,13 @@ class TestMultiLabelClassificationTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
try:
|
try:
|
||||||
main(['test'] + args)
|
main(['test', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
main(['predict'] + args)
|
main(['predict', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -97,13 +97,13 @@ class TestObjectDetectionTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
try:
|
try:
|
||||||
main(['test'] + args)
|
main(['test', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
main(['predict'] + args)
|
main(['predict', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -27,12 +27,12 @@ class TestClassificationTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
try:
|
try:
|
||||||
main(['test'] + args)
|
main(['test', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
main(['predict'] + args)
|
main(['predict', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -63,7 +63,7 @@ class TestMoCoTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
|
|
||||||
def test_version_warnings(self) -> None:
|
def test_version_warnings(self) -> None:
|
||||||
with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'):
|
with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'):
|
||||||
|
|
|
@ -84,13 +84,13 @@ class TestRegressionTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
try:
|
try:
|
||||||
main(['test'] + args)
|
main(['test', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
main(['predict'] + args)
|
main(['predict', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -237,13 +237,13 @@ class TestPixelwiseRegressionTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
try:
|
try:
|
||||||
main(['test'] + args)
|
main(['test', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
main(['predict'] + args)
|
main(['predict', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -108,13 +108,13 @@ class TestSemanticSegmentationTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
try:
|
try:
|
||||||
main(['test'] + args)
|
main(['test', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
main(['predict'] + args)
|
main(['predict', *args])
|
||||||
except MisconfigurationException:
|
except MisconfigurationException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ class TestSimCLRTask:
|
||||||
'1',
|
'1',
|
||||||
]
|
]
|
||||||
|
|
||||||
main(['fit'] + args)
|
main(['fit', *args])
|
||||||
|
|
||||||
def test_version_warnings(self) -> None:
|
def test_version_warnings(self) -> None:
|
||||||
with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'):
|
with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'):
|
||||||
|
|
|
@ -37,7 +37,7 @@ class SeasonalContrastS2DataModule(NonGeoDataModule):
|
||||||
seasons = kwargs.get('seasons', 1)
|
seasons = kwargs.get('seasons', 1)
|
||||||
|
|
||||||
# Normalization only available for RGB dataset, defined here:
|
# Normalization only available for RGB dataset, defined here:
|
||||||
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
|
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
|
||||||
if bands == SeasonalContrastS2.rgb_bands:
|
if bands == SeasonalContrastS2.rgb_bands:
|
||||||
_min = torch.tensor([3, 2, 0])
|
_min = torch.tensor([3, 2, 0])
|
||||||
_max = torch.tensor([88, 103, 129])
|
_max = torch.tensor([88, 103, 129])
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
"""So2Sat datamodule."""
|
"""So2Sat datamodule."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Generator, Tensor
|
from torch import Generator, Tensor
|
||||||
|
@ -21,7 +21,7 @@ class So2SatDataModule(NonGeoDataModule):
|
||||||
"train" set and use the "test" set as the test set.
|
"train" set and use the "test" set as the test set.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
means_per_version: dict[str, Tensor] = {
|
means_per_version: ClassVar[dict[str, Tensor]] = {
|
||||||
'2': torch.tensor(
|
'2': torch.tensor(
|
||||||
[
|
[
|
||||||
-0.00003591224260,
|
-0.00003591224260,
|
||||||
|
@ -91,7 +91,7 @@ class So2SatDataModule(NonGeoDataModule):
|
||||||
}
|
}
|
||||||
means_per_version['3_culture_10'] = means_per_version['2']
|
means_per_version['3_culture_10'] = means_per_version['2']
|
||||||
|
|
||||||
stds_per_version: dict[str, Tensor] = {
|
stds_per_version: ClassVar[dict[str, Tensor]] = {
|
||||||
'2': torch.tensor(
|
'2': torch.tensor(
|
||||||
[
|
[
|
||||||
0.17555201,
|
0.17555201,
|
||||||
|
|
|
@ -45,7 +45,7 @@ class SSL4EOS12DataModule(NonGeoDataModule):
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
|
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
|
||||||
mean = torch.tensor(0)
|
mean = torch.tensor(0)
|
||||||
std = torch.tensor(10000)
|
std = torch.tensor(10000)
|
||||||
|
|
||||||
|
|
|
@ -63,14 +63,14 @@ class ADVANCE(NonGeoDataset):
|
||||||
* `scipy <https://pypi.org/project/scipy/>`_ to load the audio files to tensors
|
* `scipy <https://pypi.org/project/scipy/>`_ to load the audio files to tensors
|
||||||
"""
|
"""
|
||||||
|
|
||||||
urls = [
|
urls = (
|
||||||
'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1',
|
'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1',
|
||||||
'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1',
|
'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1',
|
||||||
]
|
)
|
||||||
filenames = ['ADVANCE_vision.zip', 'ADVANCE_sound.zip']
|
filenames = ('ADVANCE_vision.zip', 'ADVANCE_sound.zip')
|
||||||
md5s = ['a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31']
|
md5s = ('a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31')
|
||||||
directories = ['vision', 'sound']
|
directories = ('vision', 'sound')
|
||||||
classes = [
|
classes: tuple[str, ...] = (
|
||||||
'airport',
|
'airport',
|
||||||
'beach',
|
'beach',
|
||||||
'bridge',
|
'bridge',
|
||||||
|
@ -84,7 +84,7 @@ class ADVANCE(NonGeoDataset):
|
||||||
'sparse shrub land',
|
'sparse shrub land',
|
||||||
'sports land',
|
'sports land',
|
||||||
'train station',
|
'train station',
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -119,7 +119,7 @@ class ADVANCE(NonGeoDataset):
|
||||||
raise DatasetNotFoundError(self)
|
raise DatasetNotFoundError(self)
|
||||||
|
|
||||||
self.files = self._load_files(self.root)
|
self.files = self._load_files(self.root)
|
||||||
self.classes = sorted({f['cls'] for f in self.files})
|
self.classes = tuple(sorted({f['cls'] for f in self.files}))
|
||||||
self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)}
|
self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)}
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||||
|
|
|
@ -46,7 +46,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
||||||
|
|
||||||
is_image = False
|
is_image = False
|
||||||
|
|
||||||
url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' # noqa: E501
|
url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326'
|
||||||
|
|
||||||
base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson'
|
base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson'
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
@ -90,8 +90,8 @@ class AgriFieldNet(RasterDataset):
|
||||||
_(?P<band>B[0-9A-Z]{2})_10m
|
_(?P<band>B[0-9A-Z]{2})_10m
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rgb_bands = ['B04', 'B03', 'B02']
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
all_bands = [
|
all_bands = (
|
||||||
'B01',
|
'B01',
|
||||||
'B02',
|
'B02',
|
||||||
'B03',
|
'B03',
|
||||||
|
@ -104,9 +104,9 @@ class AgriFieldNet(RasterDataset):
|
||||||
'B09',
|
'B09',
|
||||||
'B11',
|
'B11',
|
||||||
'B12',
|
'B12',
|
||||||
]
|
)
|
||||||
|
|
||||||
cmap = {
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
0: (0, 0, 0, 255),
|
0: (0, 0, 0, 255),
|
||||||
1: (255, 211, 0, 255),
|
1: (255, 211, 0, 255),
|
||||||
2: (255, 37, 37, 255),
|
2: (255, 37, 37, 255),
|
||||||
|
|
|
@ -40,8 +40,8 @@ class Airphen(RasterDataset):
|
||||||
|
|
||||||
# Each camera measures a custom set of spectral bands chosen at purchase time.
|
# Each camera measures a custom set of spectral bands chosen at purchase time.
|
||||||
# Hiphen offers 8 bands to choose from, sorted from short to long wavelength.
|
# Hiphen offers 8 bands to choose from, sorted from short to long wavelength.
|
||||||
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8']
|
all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8')
|
||||||
rgb_bands = ['B4', 'B3', 'B1']
|
rgb_bands = ('B4', 'B3', 'B1')
|
||||||
|
|
||||||
def plot(
|
def plot(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -147,7 +147,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
||||||
)
|
)
|
||||||
rgb_bands = ('B04', 'B03', 'B02')
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
|
|
||||||
classes = [
|
classes = (
|
||||||
'No data',
|
'No data',
|
||||||
'Well-managed planatation',
|
'Well-managed planatation',
|
||||||
'Poorly-managed planatation',
|
'Poorly-managed planatation',
|
||||||
|
@ -155,7 +155,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
||||||
'Residential',
|
'Residential',
|
||||||
'Background',
|
'Background',
|
||||||
'Uncertain',
|
'Uncertain',
|
||||||
]
|
)
|
||||||
|
|
||||||
# Same for all tiles
|
# Same for all tiles
|
||||||
tile_height = 1186
|
tile_height = 1186
|
||||||
|
@ -199,11 +199,13 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
||||||
|
|
||||||
# Calculate the indices that we will use over all tiles
|
# Calculate the indices that we will use over all tiles
|
||||||
self.chips_metadata = []
|
self.chips_metadata = []
|
||||||
for y in list(range(0, self.tile_height - self.chip_size, stride)) + [
|
for y in [
|
||||||
self.tile_height - self.chip_size
|
*list(range(0, self.tile_height - self.chip_size, stride)),
|
||||||
|
self.tile_height - self.chip_size,
|
||||||
]:
|
]:
|
||||||
for x in list(range(0, self.tile_width - self.chip_size, stride)) + [
|
for x in [
|
||||||
self.tile_width - self.chip_size
|
*list(range(0, self.tile_width - self.chip_size, stride)),
|
||||||
|
self.tile_width - self.chip_size,
|
||||||
]:
|
]:
|
||||||
self.chips_metadata.append((y, x))
|
self.chips_metadata.append((y, x))
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -124,9 +125,9 @@ class BigEarthNet(NonGeoDataset):
|
||||||
|
|
||||||
* https://doi.org/10.1109/IGARSS.2019.8900532
|
* https://doi.org/10.1109/IGARSS.2019.8900532
|
||||||
|
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
class_sets = {
|
class_sets: ClassVar[dict[int, list[str]]] = {
|
||||||
19: [
|
19: [
|
||||||
'Urban fabric',
|
'Urban fabric',
|
||||||
'Industrial or commercial units',
|
'Industrial or commercial units',
|
||||||
|
@ -197,7 +198,7 @@ class BigEarthNet(NonGeoDataset):
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
label_converter = {
|
label_converter: ClassVar[dict[int, int]] = {
|
||||||
0: 0,
|
0: 0,
|
||||||
1: 0,
|
1: 0,
|
||||||
2: 1,
|
2: 1,
|
||||||
|
@ -232,24 +233,24 @@ class BigEarthNet(NonGeoDataset):
|
||||||
42: 18,
|
42: 18,
|
||||||
}
|
}
|
||||||
|
|
||||||
splits_metadata = {
|
splits_metadata: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false', # noqa: E501
|
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false',
|
||||||
'filename': 'bigearthnet-train.csv',
|
'filename': 'bigearthnet-train.csv',
|
||||||
'md5': '623e501b38ab7b12fe44f0083c00986d',
|
'md5': '623e501b38ab7b12fe44f0083c00986d',
|
||||||
},
|
},
|
||||||
'val': {
|
'val': {
|
||||||
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false', # noqa: E501
|
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false',
|
||||||
'filename': 'bigearthnet-val.csv',
|
'filename': 'bigearthnet-val.csv',
|
||||||
'md5': '22efe8ed9cbd71fa10742ff7df2b7978',
|
'md5': '22efe8ed9cbd71fa10742ff7df2b7978',
|
||||||
},
|
},
|
||||||
'test': {
|
'test': {
|
||||||
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false', # noqa: E501
|
'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false',
|
||||||
'filename': 'bigearthnet-test.csv',
|
'filename': 'bigearthnet-test.csv',
|
||||||
'md5': '697fb90677e30571b9ac7699b7e5b432',
|
'md5': '697fb90677e30571b9ac7699b7e5b432',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
metadata = {
|
metadata: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
's1': {
|
's1': {
|
||||||
'url': 'https://zenodo.org/records/12687186/files/BigEarthNet-S1-v1.0.tar.gz',
|
'url': 'https://zenodo.org/records/12687186/files/BigEarthNet-S1-v1.0.tar.gz',
|
||||||
'md5': '94ced73440dea8c7b9645ee738c5a172',
|
'md5': '94ced73440dea8c7b9645ee738c5a172',
|
||||||
|
|
|
@ -50,7 +50,7 @@ class BioMassters(NonGeoDataset):
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
valid_splits = ['train', 'test']
|
valid_splits = ('train', 'test')
|
||||||
valid_sensors = ('S1', 'S2')
|
valid_sensors = ('S1', 'S2')
|
||||||
|
|
||||||
metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv'
|
metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv'
|
||||||
|
|
|
@ -30,7 +30,7 @@ class CanadianBuildingFootprints(VectorDataset):
|
||||||
# https://github.com/microsoft/CanadianBuildingFootprints/issues/11
|
# https://github.com/microsoft/CanadianBuildingFootprints/issues/11
|
||||||
|
|
||||||
url = 'https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/'
|
url = 'https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/'
|
||||||
provinces_territories = [
|
provinces_territories = (
|
||||||
'Alberta',
|
'Alberta',
|
||||||
'BritishColumbia',
|
'BritishColumbia',
|
||||||
'Manitoba',
|
'Manitoba',
|
||||||
|
@ -44,8 +44,8 @@ class CanadianBuildingFootprints(VectorDataset):
|
||||||
'Quebec',
|
'Quebec',
|
||||||
'Saskatchewan',
|
'Saskatchewan',
|
||||||
'YukonTerritory',
|
'YukonTerritory',
|
||||||
]
|
)
|
||||||
md5s = [
|
md5s = (
|
||||||
'8b4190424e57bb0902bd8ecb95a9235b',
|
'8b4190424e57bb0902bd8ecb95a9235b',
|
||||||
'fea05d6eb0006710729c675de63db839',
|
'fea05d6eb0006710729c675de63db839',
|
||||||
'adf11187362624d68f9c69aaa693c46f',
|
'adf11187362624d68f9c69aaa693c46f',
|
||||||
|
@ -59,7 +59,7 @@ class CanadianBuildingFootprints(VectorDataset):
|
||||||
'9ff4417ae00354d39a0cf193c8df592c',
|
'9ff4417ae00354d39a0cf193c8df592c',
|
||||||
'a51078d8e60082c7d3a3818240da6dd5',
|
'a51078d8e60082c7d3a3818240da6dd5',
|
||||||
'c11f3bd914ecabd7cac2cb2871ec0261',
|
'c11f3bd914ecabd7cac2cb2871ec0261',
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
@ -38,7 +38,7 @@ class CDL(RasterDataset):
|
||||||
If you use this dataset in your research, please cite it using the following format:
|
If you use this dataset in your research, please cite it using the following format:
|
||||||
|
|
||||||
* https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0
|
* https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
filename_glob = '*_30m_cdls.tif'
|
filename_glob = '*_30m_cdls.tif'
|
||||||
filename_regex = r"""
|
filename_regex = r"""
|
||||||
|
@ -49,8 +49,8 @@ class CDL(RasterDataset):
|
||||||
date_format = '%Y'
|
date_format = '%Y'
|
||||||
is_image = False
|
is_image = False
|
||||||
|
|
||||||
url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip' # noqa: E501
|
url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip'
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2023: '8c7685d6278d50c554f934b16a6076b7',
|
2023: '8c7685d6278d50c554f934b16a6076b7',
|
||||||
2022: '754cf50670cdfee511937554785de3e6',
|
2022: '754cf50670cdfee511937554785de3e6',
|
||||||
2021: '27606eab08fe975aa138baad3e5dfcd8',
|
2021: '27606eab08fe975aa138baad3e5dfcd8',
|
||||||
|
@ -69,7 +69,7 @@ class CDL(RasterDataset):
|
||||||
2008: '0610f2f17ab60a9fbb3baeb7543993a4',
|
2008: '0610f2f17ab60a9fbb3baeb7543993a4',
|
||||||
}
|
}
|
||||||
|
|
||||||
cmap = {
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
0: (0, 0, 0, 255),
|
0: (0, 0, 0, 255),
|
||||||
1: (255, 211, 0, 255),
|
1: (255, 211, 0, 255),
|
||||||
2: (255, 37, 37, 255),
|
2: (255, 37, 37, 255),
|
||||||
|
|
|
@ -4,7 +4,8 @@
|
||||||
"""ChaBuD dataset."""
|
"""ChaBuD dataset."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -53,7 +54,7 @@ class ChaBuD(NonGeoDataset):
|
||||||
.. versionadded:: 0.6
|
.. versionadded:: 0.6
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_bands = [
|
all_bands = (
|
||||||
'B01',
|
'B01',
|
||||||
'B02',
|
'B02',
|
||||||
'B03',
|
'B03',
|
||||||
|
@ -66,10 +67,10 @@ class ChaBuD(NonGeoDataset):
|
||||||
'B09',
|
'B09',
|
||||||
'B11',
|
'B11',
|
||||||
'B12',
|
'B12',
|
||||||
]
|
)
|
||||||
rgb_bands = ['B04', 'B03', 'B02']
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
folds = {'train': [1, 2, 3, 4], 'val': [0]}
|
folds: ClassVar[dict[str, list[int]]] = {'train': [1, 2, 3, 4], 'val': [0]}
|
||||||
url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5' # noqa: E501
|
url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5'
|
||||||
filename = 'train_eval.hdf5'
|
filename = 'train_eval.hdf5'
|
||||||
md5 = '15d78fb825f9a81dad600db828d22c08'
|
md5 = '15d78fb825f9a81dad600db828d22c08'
|
||||||
|
|
||||||
|
@ -77,7 +78,7 @@ class ChaBuD(NonGeoDataset):
|
||||||
self,
|
self,
|
||||||
root: Path = 'data',
|
root: Path = 'data',
|
||||||
split: str = 'train',
|
split: str = 'train',
|
||||||
bands: list[str] = all_bands,
|
bands: Sequence[str] = all_bands,
|
||||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
checksum: bool = False,
|
checksum: bool = False,
|
||||||
|
|
|
@ -9,7 +9,7 @@ import pathlib
|
||||||
import sys
|
import sys
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import fiona
|
import fiona
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -39,7 +39,7 @@ class Chesapeake(RasterDataset, ABC):
|
||||||
|
|
||||||
The Chesapeake Bay Land Use and Land Cover Database (LULC) facilitates
|
The Chesapeake Bay Land Use and Land Cover Database (LULC) facilitates
|
||||||
characterization of the landscape and land change for and between discrete time
|
characterization of the landscape and land change for and between discrete time
|
||||||
periods. The database was developed by the University of Vermont’s Spatial Analysis
|
periods. The database was developed by the University of Vermont's Spatial Analysis
|
||||||
Laboratory in cooperation with Chesapeake Conservancy (CC) and U.S. Geological
|
Laboratory in cooperation with Chesapeake Conservancy (CC) and U.S. Geological
|
||||||
Survey (USGS) as part of a 6-year Cooperative Agreement between Chesapeake
|
Survey (USGS) as part of a 6-year Cooperative Agreement between Chesapeake
|
||||||
Conservancy and the U.S. Environmental Protection Agency (EPA) and a separate
|
Conservancy and the U.S. Environmental Protection Agency (EPA) and a separate
|
||||||
|
@ -83,7 +83,7 @@ class Chesapeake(RasterDataset, ABC):
|
||||||
"""State abbreviation."""
|
"""State abbreviation."""
|
||||||
return self.__class__.__name__[-2:].lower()
|
return self.__class__.__name__[-2:].lower()
|
||||||
|
|
||||||
cmap = {
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
11: (0, 92, 230, 255),
|
11: (0, 92, 230, 255),
|
||||||
12: (0, 92, 230, 255),
|
12: (0, 92, 230, 255),
|
||||||
13: (0, 92, 230, 255),
|
13: (0, 92, 230, 255),
|
||||||
|
@ -255,7 +255,7 @@ class Chesapeake(RasterDataset, ABC):
|
||||||
class ChesapeakeDC(Chesapeake):
|
class ChesapeakeDC(Chesapeake):
|
||||||
"""This subset of the dataset contains data only for Washington, D.C."""
|
"""This subset of the dataset contains data only for Washington, D.C."""
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2013: '9f1df21afbb9d5c0fcf33af7f6750a7f',
|
2013: '9f1df21afbb9d5c0fcf33af7f6750a7f',
|
||||||
2017: 'c45e4af2950e1c93ecd47b61af296d9b',
|
2017: 'c45e4af2950e1c93ecd47b61af296d9b',
|
||||||
}
|
}
|
||||||
|
@ -264,7 +264,7 @@ class ChesapeakeDC(Chesapeake):
|
||||||
class ChesapeakeDE(Chesapeake):
|
class ChesapeakeDE(Chesapeake):
|
||||||
"""This subset of the dataset contains data only for Delaware."""
|
"""This subset of the dataset contains data only for Delaware."""
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2013: '5850d96d897babba85610658aeb5951a',
|
2013: '5850d96d897babba85610658aeb5951a',
|
||||||
2018: 'ee94c8efeae423d898677104117bdebc',
|
2018: 'ee94c8efeae423d898677104117bdebc',
|
||||||
}
|
}
|
||||||
|
@ -273,7 +273,7 @@ class ChesapeakeDE(Chesapeake):
|
||||||
class ChesapeakeMD(Chesapeake):
|
class ChesapeakeMD(Chesapeake):
|
||||||
"""This subset of the dataset contains data only for Maryland."""
|
"""This subset of the dataset contains data only for Maryland."""
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2013: '9c3ca5040668d15284c1bd64b7d6c7a0',
|
2013: '9c3ca5040668d15284c1bd64b7d6c7a0',
|
||||||
2018: '0647530edf8bec6e60f82760dcc7db9c',
|
2018: '0647530edf8bec6e60f82760dcc7db9c',
|
||||||
}
|
}
|
||||||
|
@ -282,7 +282,7 @@ class ChesapeakeMD(Chesapeake):
|
||||||
class ChesapeakeNY(Chesapeake):
|
class ChesapeakeNY(Chesapeake):
|
||||||
"""This subset of the dataset contains data only for New York."""
|
"""This subset of the dataset contains data only for New York."""
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2013: '38a29b721610ba661a7f8b6ec71a48b7',
|
2013: '38a29b721610ba661a7f8b6ec71a48b7',
|
||||||
2017: '4c1b1a50fd9368cd7b8b12c4d80c63f3',
|
2017: '4c1b1a50fd9368cd7b8b12c4d80c63f3',
|
||||||
}
|
}
|
||||||
|
@ -291,7 +291,7 @@ class ChesapeakeNY(Chesapeake):
|
||||||
class ChesapeakePA(Chesapeake):
|
class ChesapeakePA(Chesapeake):
|
||||||
"""This subset of the dataset contains data only for Pennsylvania."""
|
"""This subset of the dataset contains data only for Pennsylvania."""
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2013: '86febd603a120a49ef7d23ef486152a3',
|
2013: '86febd603a120a49ef7d23ef486152a3',
|
||||||
2017: 'b11d92e4471e8cb887c790d488a338c1',
|
2017: 'b11d92e4471e8cb887c790d488a338c1',
|
||||||
}
|
}
|
||||||
|
@ -300,7 +300,7 @@ class ChesapeakePA(Chesapeake):
|
||||||
class ChesapeakeVA(Chesapeake):
|
class ChesapeakeVA(Chesapeake):
|
||||||
"""This subset of the dataset contains data only for Virginia."""
|
"""This subset of the dataset contains data only for Virginia."""
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2014: '49c9700c71854eebd00de24d8488eb7c',
|
2014: '49c9700c71854eebd00de24d8488eb7c',
|
||||||
2018: '51731c8b5632978bfd1df869ea10db5b',
|
2018: '51731c8b5632978bfd1df869ea10db5b',
|
||||||
}
|
}
|
||||||
|
@ -309,7 +309,7 @@ class ChesapeakeVA(Chesapeake):
|
||||||
class ChesapeakeWV(Chesapeake):
|
class ChesapeakeWV(Chesapeake):
|
||||||
"""This subset of the dataset contains data only for West Virginia."""
|
"""This subset of the dataset contains data only for West Virginia."""
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2014: '32fea42fae147bd58a83e3ea6cccfb94',
|
2014: '32fea42fae147bd58a83e3ea6cccfb94',
|
||||||
2018: '80f25dcba72e39685ab33215c5d97292',
|
2018: '80f25dcba72e39685ab33215c5d97292',
|
||||||
}
|
}
|
||||||
|
@ -337,16 +337,16 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
* https://doi.org/10.1109/cvpr.2019.01301
|
* https://doi.org/10.1109/cvpr.2019.01301
|
||||||
"""
|
"""
|
||||||
|
|
||||||
subdatasets = ['base', 'prior_extension']
|
subdatasets = ('base', 'prior_extension')
|
||||||
urls = {
|
urls: ClassVar[dict[str, str]] = {
|
||||||
'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip', # noqa: E501
|
'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip',
|
||||||
'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1', # noqa: E501
|
'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1',
|
||||||
}
|
}
|
||||||
filenames = {
|
filenames: ClassVar[dict[str, str]] = {
|
||||||
'base': 'cvpr_chesapeake_landcover.zip',
|
'base': 'cvpr_chesapeake_landcover.zip',
|
||||||
'prior_extension': 'cvpr_chesapeake_landcover_prior_extension.zip',
|
'prior_extension': 'cvpr_chesapeake_landcover_prior_extension.zip',
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'base': '1225ccbb9590e9396875f221e5031514',
|
'base': '1225ccbb9590e9396875f221e5031514',
|
||||||
'prior_extension': '402f41d07823c8faf7ea6960d7c4e17a',
|
'prior_extension': '402f41d07823c8faf7ea6960d7c4e17a',
|
||||||
}
|
}
|
||||||
|
@ -354,7 +354,7 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
crs = CRS.from_epsg(3857)
|
crs = CRS.from_epsg(3857)
|
||||||
res = 1
|
res = 1
|
||||||
|
|
||||||
lc_cmap = {
|
lc_cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
0: (0, 0, 0, 0),
|
0: (0, 0, 0, 0),
|
||||||
1: (0, 197, 255, 255),
|
1: (0, 197, 255, 255),
|
||||||
2: (38, 115, 0, 255),
|
2: (38, 115, 0, 255),
|
||||||
|
@ -374,7 +374,7 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_layers = [
|
valid_layers = (
|
||||||
'naip-new',
|
'naip-new',
|
||||||
'naip-old',
|
'naip-old',
|
||||||
'landsat-leaf-on',
|
'landsat-leaf-on',
|
||||||
|
@ -383,8 +383,8 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
'lc',
|
'lc',
|
||||||
'buildings',
|
'buildings',
|
||||||
'prior_from_cooccurrences_101_31_no_osm_no_buildings',
|
'prior_from_cooccurrences_101_31_no_osm_no_buildings',
|
||||||
]
|
)
|
||||||
states = ['de', 'md', 'va', 'wv', 'pa', 'ny']
|
states = ('de', 'md', 'va', 'wv', 'pa', 'ny')
|
||||||
splits = (
|
splits = (
|
||||||
[f'{state}-train' for state in states]
|
[f'{state}-train' for state in states]
|
||||||
+ [f'{state}-val' for state in states]
|
+ [f'{state}-val' for state in states]
|
||||||
|
@ -392,7 +392,7 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
# these are used to check the integrity of the dataset
|
# these are used to check the integrity of the dataset
|
||||||
_files = [
|
_files = (
|
||||||
'de_1m_2013_extended-debuffered-test_tiles',
|
'de_1m_2013_extended-debuffered-test_tiles',
|
||||||
'de_1m_2013_extended-debuffered-train_tiles',
|
'de_1m_2013_extended-debuffered-train_tiles',
|
||||||
'de_1m_2013_extended-debuffered-val_tiles',
|
'de_1m_2013_extended-debuffered-val_tiles',
|
||||||
|
@ -412,18 +412,18 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
'wv_1m_2014_extended-debuffered-train_tiles',
|
'wv_1m_2014_extended-debuffered-train_tiles',
|
||||||
'wv_1m_2014_extended-debuffered-val_tiles',
|
'wv_1m_2014_extended-debuffered-val_tiles',
|
||||||
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif',
|
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif',
|
||||||
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif', # noqa: E501
|
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif',
|
||||||
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif', # noqa: E501
|
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif',
|
||||||
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif',
|
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif',
|
||||||
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif',
|
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif',
|
||||||
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif',
|
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif',
|
||||||
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif',
|
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif',
|
||||||
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501
|
'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif',
|
||||||
'spatial_index.geojson',
|
'spatial_index.geojson',
|
||||||
]
|
)
|
||||||
|
|
||||||
p_src_crs = pyproj.CRS('epsg:3857')
|
p_src_crs = pyproj.CRS('epsg:3857')
|
||||||
p_transformers = {
|
p_transformers: ClassVar[dict[str, CRS]] = {
|
||||||
'epsg:26917': pyproj.Transformer.from_crs(
|
'epsg:26917': pyproj.Transformer.from_crs(
|
||||||
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
|
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
|
||||||
).transform,
|
).transform,
|
||||||
|
@ -511,7 +511,7 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
'lc': row['properties']['lc'],
|
'lc': row['properties']['lc'],
|
||||||
'nlcd': row['properties']['nlcd'],
|
'nlcd': row['properties']['nlcd'],
|
||||||
'buildings': row['properties']['buildings'],
|
'buildings': row['properties']['buildings'],
|
||||||
'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn, # noqa: E501
|
'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -55,9 +56,9 @@ class CloudCoverDetection(NonGeoDataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_cloud_cover_detection_challenge_v1/final'
|
url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_cloud_cover_detection_challenge_v1/final'
|
||||||
all_bands = ['B02', 'B03', 'B04', 'B08']
|
all_bands = ('B02', 'B03', 'B04', 'B08')
|
||||||
rgb_bands = ['B04', 'B03', 'B02']
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
splits = {'train': 'public', 'test': 'private'}
|
splits: ClassVar[dict[str, str]] = {'train': 'public', 'test': 'private'}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -42,7 +42,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
||||||
zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip'
|
zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip'
|
||||||
md5 = '3e7f9f23bf971c25e828b36e6c5496e3'
|
md5 = '3e7f9f23bf971c25e828b36e6c5496e3'
|
||||||
|
|
||||||
all_countries = [
|
all_countries = (
|
||||||
'AndamanAndNicobar',
|
'AndamanAndNicobar',
|
||||||
'Angola',
|
'Angola',
|
||||||
'Anguilla',
|
'Anguilla',
|
||||||
|
@ -164,9 +164,9 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
||||||
'VirginIslandsUs',
|
'VirginIslandsUs',
|
||||||
'WallisAndFutuna',
|
'WallisAndFutuna',
|
||||||
'Yemen',
|
'Yemen',
|
||||||
]
|
)
|
||||||
|
|
||||||
measurements = ['agb', 'hba95', 'hmax95']
|
measurements = ('agb', 'hba95', 'hmax95')
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -50,12 +50,12 @@ class COWC(NonGeoDataset, abc.ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def filenames(self) -> list[str]:
|
def filenames(self) -> tuple[str, ...]:
|
||||||
"""List of files to download."""
|
"""List of files to download."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def md5s(self) -> list[str]:
|
def md5s(self) -> tuple[str, ...]:
|
||||||
"""List of MD5 checksums of files to download."""
|
"""List of MD5 checksums of files to download."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -239,7 +239,7 @@ class COWCCounting(COWC):
|
||||||
base_url = (
|
base_url = (
|
||||||
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/'
|
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/'
|
||||||
)
|
)
|
||||||
filenames = [
|
filenames = (
|
||||||
'COWC_train_list_64_class.txt.bz2',
|
'COWC_train_list_64_class.txt.bz2',
|
||||||
'COWC_test_list_64_class.txt.bz2',
|
'COWC_test_list_64_class.txt.bz2',
|
||||||
'COWC_Counting_Toronto_ISPRS.tbz',
|
'COWC_Counting_Toronto_ISPRS.tbz',
|
||||||
|
@ -248,8 +248,8 @@ class COWCCounting(COWC):
|
||||||
'COWC_Counting_Vaihingen_ISPRS.tbz',
|
'COWC_Counting_Vaihingen_ISPRS.tbz',
|
||||||
'COWC_Counting_Columbus_CSUAV_AFRL.tbz',
|
'COWC_Counting_Columbus_CSUAV_AFRL.tbz',
|
||||||
'COWC_Counting_Utah_AGRC.tbz',
|
'COWC_Counting_Utah_AGRC.tbz',
|
||||||
]
|
)
|
||||||
md5s = [
|
md5s = (
|
||||||
'187543d20fa6d591b8da51136e8ef8fb',
|
'187543d20fa6d591b8da51136e8ef8fb',
|
||||||
'930cfd6e160a7b36db03146282178807',
|
'930cfd6e160a7b36db03146282178807',
|
||||||
'bc2613196dfa93e66d324ae43e7c1fdb',
|
'bc2613196dfa93e66d324ae43e7c1fdb',
|
||||||
|
@ -258,7 +258,7 @@ class COWCCounting(COWC):
|
||||||
'4009c1e420566390746f5b4db02afdb9',
|
'4009c1e420566390746f5b4db02afdb9',
|
||||||
'daf8033c4e8ceebbf2c3cac3fabb8b10',
|
'daf8033c4e8ceebbf2c3cac3fabb8b10',
|
||||||
'777ec107ed2a3d54597a739ce74f95ad',
|
'777ec107ed2a3d54597a739ce74f95ad',
|
||||||
]
|
)
|
||||||
filename = 'COWC_{}_list_64_class.txt'
|
filename = 'COWC_{}_list_64_class.txt'
|
||||||
|
|
||||||
|
|
||||||
|
@ -268,7 +268,7 @@ class COWCDetection(COWC):
|
||||||
base_url = (
|
base_url = (
|
||||||
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/'
|
'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/'
|
||||||
)
|
)
|
||||||
filenames = [
|
filenames = (
|
||||||
'COWC_train_list_detection.txt.bz2',
|
'COWC_train_list_detection.txt.bz2',
|
||||||
'COWC_test_list_detection.txt.bz2',
|
'COWC_test_list_detection.txt.bz2',
|
||||||
'COWC_Detection_Toronto_ISPRS.tbz',
|
'COWC_Detection_Toronto_ISPRS.tbz',
|
||||||
|
@ -277,8 +277,8 @@ class COWCDetection(COWC):
|
||||||
'COWC_Detection_Vaihingen_ISPRS.tbz',
|
'COWC_Detection_Vaihingen_ISPRS.tbz',
|
||||||
'COWC_Detection_Columbus_CSUAV_AFRL.tbz',
|
'COWC_Detection_Columbus_CSUAV_AFRL.tbz',
|
||||||
'COWC_Detection_Utah_AGRC.tbz',
|
'COWC_Detection_Utah_AGRC.tbz',
|
||||||
]
|
)
|
||||||
md5s = [
|
md5s = (
|
||||||
'c954a5a3dac08c220b10cfbeec83893c',
|
'c954a5a3dac08c220b10cfbeec83893c',
|
||||||
'c6c2d0a78f12a2ad88b286b724a57c1a',
|
'c6c2d0a78f12a2ad88b286b724a57c1a',
|
||||||
'11af24f43b198b0f13c8e94814008a48',
|
'11af24f43b198b0f13c8e94814008a48',
|
||||||
|
@ -287,7 +287,7 @@ class COWCDetection(COWC):
|
||||||
'23945d5b22455450a938382ccc2a8b27',
|
'23945d5b22455450a938382ccc2a8b27',
|
||||||
'f40522dc97bea41b10117d4a5b946a6f',
|
'f40522dc97bea41b10117d4a5b946a6f',
|
||||||
'195da7c9443a939a468c9f232fd86ee3',
|
'195da7c9443a939a468c9f232fd86ee3',
|
||||||
]
|
)
|
||||||
filename = 'COWC_{}_list_detection.txt'
|
filename = 'COWC_{}_list_detection.txt'
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -55,7 +56,7 @@ class CropHarvest(NonGeoDataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py
|
# https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py
|
||||||
all_bands = [
|
all_bands = (
|
||||||
'VV',
|
'VV',
|
||||||
'VH',
|
'VH',
|
||||||
'B2',
|
'B2',
|
||||||
|
@ -74,12 +75,12 @@ class CropHarvest(NonGeoDataset):
|
||||||
'elevation',
|
'elevation',
|
||||||
'slope',
|
'slope',
|
||||||
'NDVI',
|
'NDVI',
|
||||||
]
|
)
|
||||||
rgb_bands = ['B4', 'B3', 'B2']
|
rgb_bands = ('B4', 'B3', 'B2')
|
||||||
|
|
||||||
features_url = 'https://zenodo.org/records/7257688/files/features.tar.gz?download=1'
|
features_url = 'https://zenodo.org/records/7257688/files/features.tar.gz?download=1'
|
||||||
labels_url = 'https://zenodo.org/records/7257688/files/labels.geojson?download=1'
|
labels_url = 'https://zenodo.org/records/7257688/files/labels.geojson?download=1'
|
||||||
file_dict = {
|
file_dict: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'features': {
|
'features': {
|
||||||
'url': features_url,
|
'url': features_url,
|
||||||
'filename': 'features.tar.gz',
|
'filename': 'features.tar.gz',
|
||||||
|
|
|
@ -65,8 +65,8 @@ class CV4AKenyaCropType(NonGeoDataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge'
|
url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge'
|
||||||
tiles = list(map(str, range(4)))
|
tiles = tuple(map(str, range(4)))
|
||||||
dates = [
|
dates = (
|
||||||
'20190606',
|
'20190606',
|
||||||
'20190701',
|
'20190701',
|
||||||
'20190706',
|
'20190706',
|
||||||
|
@ -80,7 +80,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
||||||
'20190924',
|
'20190924',
|
||||||
'20191004',
|
'20191004',
|
||||||
'20191103',
|
'20191103',
|
||||||
]
|
)
|
||||||
all_bands = (
|
all_bands = (
|
||||||
'B01',
|
'B01',
|
||||||
'B02',
|
'B02',
|
||||||
|
@ -96,7 +96,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
||||||
'B12',
|
'B12',
|
||||||
'CLD',
|
'CLD',
|
||||||
)
|
)
|
||||||
rgb_bands = ['B04', 'B03', 'B02']
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
|
|
||||||
# Same for all tiles
|
# Same for all tiles
|
||||||
tile_height = 3035
|
tile_height = 3035
|
||||||
|
@ -141,11 +141,13 @@ class CV4AKenyaCropType(NonGeoDataset):
|
||||||
# Calculate the indices that we will use over all tiles
|
# Calculate the indices that we will use over all tiles
|
||||||
self.chips_metadata = []
|
self.chips_metadata = []
|
||||||
for tile_index in range(len(self.tiles)):
|
for tile_index in range(len(self.tiles)):
|
||||||
for y in list(range(0, self.tile_height - self.chip_size, stride)) + [
|
for y in [
|
||||||
self.tile_height - self.chip_size
|
*list(range(0, self.tile_height - self.chip_size, stride)),
|
||||||
|
self.tile_height - self.chip_size,
|
||||||
]:
|
]:
|
||||||
for x in list(range(0, self.tile_width - self.chip_size, stride)) + [
|
for x in [
|
||||||
self.tile_width - self.chip_size
|
*list(range(0, self.tile_width - self.chip_size, stride)),
|
||||||
|
self.tile_width - self.chip_size,
|
||||||
]:
|
]:
|
||||||
self.chips_metadata.append((tile_index, y, x))
|
self.chips_metadata.append((tile_index, y, x))
|
||||||
|
|
||||||
|
|
|
@ -74,13 +74,13 @@ class DeepGlobeLandCover(NonGeoDataset):
|
||||||
$ unzip deepglobe2018-landcover-segmentation-traindataset.zip
|
$ unzip deepglobe2018-landcover-segmentation-traindataset.zip
|
||||||
|
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
filename = 'data.zip'
|
filename = 'data.zip'
|
||||||
data_root = 'data'
|
data_root = 'data'
|
||||||
md5 = 'f32684b0b2bf6f8d604cd359a399c061'
|
md5 = 'f32684b0b2bf6f8d604cd359a399c061'
|
||||||
splits = ['train', 'test']
|
splits = ('train', 'test')
|
||||||
classes = [
|
classes = (
|
||||||
'Urban land',
|
'Urban land',
|
||||||
'Agriculture land',
|
'Agriculture land',
|
||||||
'Rangeland',
|
'Rangeland',
|
||||||
|
@ -88,8 +88,8 @@ class DeepGlobeLandCover(NonGeoDataset):
|
||||||
'Water',
|
'Water',
|
||||||
'Barren land',
|
'Barren land',
|
||||||
'Unknown',
|
'Unknown',
|
||||||
]
|
)
|
||||||
colormap = [
|
colormap = (
|
||||||
(0, 255, 255),
|
(0, 255, 255),
|
||||||
(255, 255, 0),
|
(255, 255, 0),
|
||||||
(255, 0, 255),
|
(255, 0, 255),
|
||||||
|
@ -97,7 +97,7 @@ class DeepGlobeLandCover(NonGeoDataset):
|
||||||
(0, 0, 255),
|
(0, 0, 255),
|
||||||
(255, 255, 255),
|
(255, 255, 255),
|
||||||
(0, 0, 0),
|
(0, 0, 0),
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -246,12 +246,15 @@ class DeepGlobeLandCover(NonGeoDataset):
|
||||||
"""
|
"""
|
||||||
ncols = 1
|
ncols = 1
|
||||||
image1 = draw_semantic_segmentation_masks(
|
image1 = draw_semantic_segmentation_masks(
|
||||||
sample['image'], sample['mask'], alpha=alpha, colors=self.colormap
|
sample['image'], sample['mask'], alpha=alpha, colors=list(self.colormap)
|
||||||
)
|
)
|
||||||
if 'prediction' in sample:
|
if 'prediction' in sample:
|
||||||
ncols += 1
|
ncols += 1
|
||||||
image2 = draw_semantic_segmentation_masks(
|
image2 = draw_semantic_segmentation_masks(
|
||||||
sample['image'], sample['prediction'], alpha=alpha, colors=self.colormap
|
sample['image'],
|
||||||
|
sample['prediction'],
|
||||||
|
alpha=alpha,
|
||||||
|
colors=list(self.colormap),
|
||||||
)
|
)
|
||||||
|
|
||||||
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -75,9 +76,9 @@ class DFC2022(NonGeoDataset):
|
||||||
* https://doi.org/10.1007/s10994-020-05943-y
|
* https://doi.org/10.1007/s10994-020-05943-y
|
||||||
|
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
classes = [
|
classes = (
|
||||||
'No information',
|
'No information',
|
||||||
'Urban fabric',
|
'Urban fabric',
|
||||||
'Industrial, commercial, public, military, private and transport units',
|
'Industrial, commercial, public, military, private and transport units',
|
||||||
|
@ -94,8 +95,8 @@ class DFC2022(NonGeoDataset):
|
||||||
'Wetlands',
|
'Wetlands',
|
||||||
'Water',
|
'Water',
|
||||||
'Clouds and Shadows',
|
'Clouds and Shadows',
|
||||||
]
|
)
|
||||||
colormap = [
|
colormap = (
|
||||||
'#231F20',
|
'#231F20',
|
||||||
'#DB5F57',
|
'#DB5F57',
|
||||||
'#DB9757',
|
'#DB9757',
|
||||||
|
@ -112,8 +113,8 @@ class DFC2022(NonGeoDataset):
|
||||||
'#579BDB',
|
'#579BDB',
|
||||||
'#0062FF',
|
'#0062FF',
|
||||||
'#231F20',
|
'#231F20',
|
||||||
]
|
)
|
||||||
metadata = {
|
metadata: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
'filename': 'labeled_train.zip',
|
'filename': 'labeled_train.zip',
|
||||||
'md5': '2e87d6a218e466dd0566797d7298c7a9',
|
'md5': '2e87d6a218e466dd0566797d7298c7a9',
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import fiona
|
import fiona
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -54,9 +54,9 @@ class EnviroAtlas(GeoDataset):
|
||||||
crs = CRS.from_epsg(3857)
|
crs = CRS.from_epsg(3857)
|
||||||
res = 1
|
res = 1
|
||||||
|
|
||||||
valid_prior_layers = ['prior', 'prior_no_osm_no_buildings']
|
valid_prior_layers = ('prior', 'prior_no_osm_no_buildings')
|
||||||
|
|
||||||
valid_layers = [
|
valid_layers = (
|
||||||
'naip',
|
'naip',
|
||||||
'nlcd',
|
'nlcd',
|
||||||
'roads',
|
'roads',
|
||||||
|
@ -65,14 +65,15 @@ class EnviroAtlas(GeoDataset):
|
||||||
'waterbodies',
|
'waterbodies',
|
||||||
'buildings',
|
'buildings',
|
||||||
'lc',
|
'lc',
|
||||||
] + valid_prior_layers
|
*valid_prior_layers,
|
||||||
|
)
|
||||||
|
|
||||||
cities = [
|
cities = (
|
||||||
'pittsburgh_pa-2010_1m',
|
'pittsburgh_pa-2010_1m',
|
||||||
'durham_nc-2012_1m',
|
'durham_nc-2012_1m',
|
||||||
'austin_tx-2012_1m',
|
'austin_tx-2012_1m',
|
||||||
'phoenix_az-2010_1m',
|
'phoenix_az-2010_1m',
|
||||||
]
|
)
|
||||||
splits = (
|
splits = (
|
||||||
[f'{state}-train' for state in cities[:1]]
|
[f'{state}-train' for state in cities[:1]]
|
||||||
+ [f'{state}-val' for state in cities[:1]]
|
+ [f'{state}-val' for state in cities[:1]]
|
||||||
|
@ -81,7 +82,7 @@ class EnviroAtlas(GeoDataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
# these are used to check the integrity of the dataset
|
# these are used to check the integrity of the dataset
|
||||||
_files = [
|
_files = (
|
||||||
'austin_tx-2012_1m-test_tiles-debuffered',
|
'austin_tx-2012_1m-test_tiles-debuffered',
|
||||||
'austin_tx-2012_1m-val5_tiles-debuffered',
|
'austin_tx-2012_1m-val5_tiles-debuffered',
|
||||||
'durham_nc-2012_1m-test_tiles-debuffered',
|
'durham_nc-2012_1m-test_tiles-debuffered',
|
||||||
|
@ -100,13 +101,13 @@ class EnviroAtlas(GeoDataset):
|
||||||
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif',
|
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif',
|
||||||
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif',
|
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif',
|
||||||
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif',
|
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif',
|
||||||
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif', # noqa: E501
|
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif',
|
||||||
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501
|
'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif',
|
||||||
'spatial_index.geojson',
|
'spatial_index.geojson',
|
||||||
]
|
)
|
||||||
|
|
||||||
p_src_crs = pyproj.CRS('epsg:3857')
|
p_src_crs = pyproj.CRS('epsg:3857')
|
||||||
p_transformers = {
|
p_transformers: ClassVar[dict[str, CRS]] = {
|
||||||
'epsg:26917': pyproj.Transformer.from_crs(
|
'epsg:26917': pyproj.Transformer.from_crs(
|
||||||
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
|
p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True
|
||||||
).transform,
|
).transform,
|
||||||
|
@ -222,7 +223,7 @@ class EnviroAtlas(GeoDataset):
|
||||||
dtype=np.uint8,
|
dtype=np.uint8,
|
||||||
)
|
)
|
||||||
|
|
||||||
highres_classes = [
|
highres_classes = (
|
||||||
'Unclassified',
|
'Unclassified',
|
||||||
'Water',
|
'Water',
|
||||||
'Impervious Surface',
|
'Impervious Surface',
|
||||||
|
@ -234,7 +235,7 @@ class EnviroAtlas(GeoDataset):
|
||||||
'Orchards',
|
'Orchards',
|
||||||
'Woody Wetlands',
|
'Woody Wetlands',
|
||||||
'Emergent Wetlands',
|
'Emergent Wetlands',
|
||||||
]
|
)
|
||||||
highres_cmap = ListedColormap(
|
highres_cmap = ListedColormap(
|
||||||
[
|
[
|
||||||
[1.00000000, 1.00000000, 1.00000000],
|
[1.00000000, 1.00000000, 1.00000000],
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -56,9 +57,9 @@ class ETCI2021(NonGeoDataset):
|
||||||
the ETCI competition.
|
the ETCI competition.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bands = ['VV', 'VH']
|
bands = ('VV', 'VH')
|
||||||
masks = ['flood', 'water_body']
|
masks = ('flood', 'water_body')
|
||||||
metadata = {
|
metadata: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
'filename': 'train.zip',
|
'filename': 'train.zip',
|
||||||
'md5': '1e95792fe0f6e3c9000abdeab2a8ab0f',
|
'md5': '1e95792fe0f6e3c9000abdeab2a8ab0f',
|
||||||
|
|
|
@ -7,7 +7,7 @@ import glob
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
@ -53,7 +53,7 @@ class EUDEM(RasterDataset):
|
||||||
zipfile_glob = 'eu_dem_v11_*[A-Z0-9].zip'
|
zipfile_glob = 'eu_dem_v11_*[A-Z0-9].zip'
|
||||||
filename_regex = '(?P<name>[eudem_v11]{10})_(?P<id>[A-Z0-9]{6})'
|
filename_regex = '(?P<name>[eudem_v11]{10})_(?P<id>[A-Z0-9]{6})'
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'eu_dem_v11_E00N20.zip': '96edc7e11bc299b994e848050d6be591',
|
'eu_dem_v11_E00N20.zip': '96edc7e11bc299b994e848050d6be591',
|
||||||
'eu_dem_v11_E10N00.zip': 'e14be147ac83eddf655f4833d55c1571',
|
'eu_dem_v11_E10N00.zip': 'e14be147ac83eddf655f4833d55c1571',
|
||||||
'eu_dem_v11_E10N10.zip': '2eb5187e4d827245b33768404529c709',
|
'eu_dem_v11_E10N10.zip': '2eb5187e4d827245b33768404529c709',
|
||||||
|
|
|
@ -61,7 +61,7 @@ class EuroCrops(VectorDataset):
|
||||||
date_format = '%Y'
|
date_format = '%Y'
|
||||||
|
|
||||||
# Filename and md5 of files in this dataset on zenodo.
|
# Filename and md5 of files in this dataset on zenodo.
|
||||||
zenodo_files = [
|
zenodo_files: tuple[tuple[str, str], ...] = (
|
||||||
('AT_2021.zip', '490241df2e3d62812e572049fc0c36c5'),
|
('AT_2021.zip', '490241df2e3d62812e572049fc0c36c5'),
|
||||||
('BE_VLG_2021.zip', 'ac4b9e12ad39b1cba47fdff1a786c2d7'),
|
('BE_VLG_2021.zip', 'ac4b9e12ad39b1cba47fdff1a786c2d7'),
|
||||||
('DE_LS_2021.zip', '6d94e663a3ff7988b32cb36ea24a724f'),
|
('DE_LS_2021.zip', '6d94e663a3ff7988b32cb36ea24a724f'),
|
||||||
|
@ -81,7 +81,7 @@ class EuroCrops(VectorDataset):
|
||||||
# Year is unknown for Romania portion (ny = no year).
|
# Year is unknown for Romania portion (ny = no year).
|
||||||
# We skip since it is inconsistent with the rest of the data.
|
# We skip since it is inconsistent with the rest of the data.
|
||||||
# ("RO_ny.zip", "648e1504097765b4b7f825decc838882"),
|
# ("RO_ny.zip", "648e1504097765b4b7f825decc838882"),
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import cast
|
from typing import ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset):
|
||||||
* https://ieeexplore.ieee.org/document/8519248
|
* https://ieeexplore.ieee.org/document/8519248
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip' # noqa: E501
|
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip'
|
||||||
filename = 'EuroSATallBands.zip'
|
filename = 'EuroSATallBands.zip'
|
||||||
md5 = '5ac12b3b2557aa56e1826e981e8e200e'
|
md5 = '5ac12b3b2557aa56e1826e981e8e200e'
|
||||||
|
|
||||||
|
@ -63,13 +63,13 @@ class EuroSAT(NonGeoClassificationDataset):
|
||||||
'ds', 'images', 'remote_sensing', 'otherDatasets', 'sentinel_2', 'tif'
|
'ds', 'images', 'remote_sensing', 'otherDatasets', 'sentinel_2', 'tif'
|
||||||
)
|
)
|
||||||
|
|
||||||
splits = ['train', 'val', 'test']
|
splits = ('train', 'val', 'test')
|
||||||
split_urls = {
|
split_urls: ClassVar[dict[str, str]] = {
|
||||||
'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt', # noqa: E501
|
'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt',
|
||||||
'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt', # noqa: E501
|
'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt',
|
||||||
'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt', # noqa: E501
|
'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt',
|
||||||
}
|
}
|
||||||
split_md5s = {
|
split_md5s: ClassVar[dict[str, str]] = {
|
||||||
'train': '908f142e73d6acdf3f482c5e80d851b1',
|
'train': '908f142e73d6acdf3f482c5e80d851b1',
|
||||||
'val': '95de90f2aa998f70a3b2416bfe0687b4',
|
'val': '95de90f2aa998f70a3b2416bfe0687b4',
|
||||||
'test': '7ae5ab94471417b6e315763121e67c5f',
|
'test': '7ae5ab94471417b6e315763121e67c5f',
|
||||||
|
@ -93,7 +93,10 @@ class EuroSAT(NonGeoClassificationDataset):
|
||||||
|
|
||||||
rgb_bands = ('B04', 'B03', 'B02')
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
|
|
||||||
BAND_SETS = {'all': all_band_names, 'rgb': rgb_bands}
|
BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
|
||||||
|
'all': all_band_names,
|
||||||
|
'rgb': rgb_bands,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -302,12 +305,12 @@ class EuroSATSpatial(EuroSAT):
|
||||||
.. versionadded:: 0.6
|
.. versionadded:: 0.6
|
||||||
"""
|
"""
|
||||||
|
|
||||||
split_urls = {
|
split_urls: ClassVar[dict[str, str]] = {
|
||||||
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt',
|
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt',
|
||||||
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt',
|
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt',
|
||||||
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt',
|
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt',
|
||||||
}
|
}
|
||||||
split_md5s = {
|
split_md5s: ClassVar[dict[str, str]] = {
|
||||||
'train': '7be3254be39f23ce4d4d144290c93292',
|
'train': '7be3254be39f23ce4d4d144290c93292',
|
||||||
'val': 'acf392290050bb3df790dc8fc0ebf193',
|
'val': 'acf392290050bb3df790dc8fc0ebf193',
|
||||||
'test': '5ec1733f9c16116bf0aa2d921fc613ef',
|
'test': '5ec1733f9c16116bf0aa2d921fc613ef',
|
||||||
|
@ -325,16 +328,16 @@ class EuroSAT100(EuroSAT):
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip' # noqa: E501
|
url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip'
|
||||||
filename = 'EuroSAT100.zip'
|
filename = 'EuroSAT100.zip'
|
||||||
md5 = 'c21c649ba747e86eda813407ef17d596'
|
md5 = 'c21c649ba747e86eda813407ef17d596'
|
||||||
|
|
||||||
split_urls = {
|
split_urls: ClassVar[dict[str, str]] = {
|
||||||
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt', # noqa: E501
|
'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt',
|
||||||
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt', # noqa: E501
|
'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt',
|
||||||
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt', # noqa: E501
|
'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt',
|
||||||
}
|
}
|
||||||
split_md5s = {
|
split_md5s: ClassVar[dict[str, str]] = {
|
||||||
'train': '033d0c23e3a75e3fa79618b0e35fe1c7',
|
'train': '033d0c23e3a75e3fa79618b0e35fe1c7',
|
||||||
'val': '3e3f8b3c344182b8d126c4cc88f3f215',
|
'val': '3e3f8b3c344182b8d126c4cc88f3f215',
|
||||||
'test': 'f908f151b950f270ad18e61153579794',
|
'test': 'f908f151b950f270ad18e61153579794',
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
from xml.etree.ElementTree import Element, parse
|
from xml.etree.ElementTree import Element, parse
|
||||||
|
|
||||||
import matplotlib.patches as patches
|
import matplotlib.patches as patches
|
||||||
|
@ -119,7 +119,7 @@ class FAIR1M(NonGeoDataset):
|
||||||
.. versionadded:: 0.2
|
.. versionadded:: 0.2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
classes = {
|
classes: ClassVar[dict[str, dict[str, Any]]] = {
|
||||||
'Passenger Ship': {'id': 0, 'category': 'Ship'},
|
'Passenger Ship': {'id': 0, 'category': 'Ship'},
|
||||||
'Motorboat': {'id': 1, 'category': 'Ship'},
|
'Motorboat': {'id': 1, 'category': 'Ship'},
|
||||||
'Fishing Boat': {'id': 2, 'category': 'Ship'},
|
'Fishing Boat': {'id': 2, 'category': 'Ship'},
|
||||||
|
@ -159,12 +159,12 @@ class FAIR1M(NonGeoDataset):
|
||||||
'Bridge': {'id': 36, 'category': 'Road'},
|
'Bridge': {'id': 36, 'category': 'Road'},
|
||||||
}
|
}
|
||||||
|
|
||||||
filename_glob = {
|
filename_glob: ClassVar[dict[str, str]] = {
|
||||||
'train': os.path.join('train', '**', 'images', '*.tif'),
|
'train': os.path.join('train', '**', 'images', '*.tif'),
|
||||||
'val': os.path.join('validation', 'images', '*.tif'),
|
'val': os.path.join('validation', 'images', '*.tif'),
|
||||||
'test': os.path.join('test', 'images', '*.tif'),
|
'test': os.path.join('test', 'images', '*.tif'),
|
||||||
}
|
}
|
||||||
directories = {
|
directories: ClassVar[dict[str, tuple[str, ...]]] = {
|
||||||
'train': (
|
'train': (
|
||||||
os.path.join('train', 'part1', 'images'),
|
os.path.join('train', 'part1', 'images'),
|
||||||
os.path.join('train', 'part1', 'labelXml'),
|
os.path.join('train', 'part1', 'labelXml'),
|
||||||
|
@ -175,9 +175,9 @@ class FAIR1M(NonGeoDataset):
|
||||||
os.path.join('validation', 'images'),
|
os.path.join('validation', 'images'),
|
||||||
os.path.join('validation', 'labelXml'),
|
os.path.join('validation', 'labelXml'),
|
||||||
),
|
),
|
||||||
'test': (os.path.join('test', 'images')),
|
'test': (os.path.join('test', 'images'),),
|
||||||
}
|
}
|
||||||
paths = {
|
paths: ClassVar[dict[str, tuple[str, ...]]] = {
|
||||||
'train': (
|
'train': (
|
||||||
os.path.join('train', 'part1', 'images.zip'),
|
os.path.join('train', 'part1', 'images.zip'),
|
||||||
os.path.join('train', 'part1', 'labelXml.zip'),
|
os.path.join('train', 'part1', 'labelXml.zip'),
|
||||||
|
@ -194,7 +194,7 @@ class FAIR1M(NonGeoDataset):
|
||||||
os.path.join('test', 'images2.zip'),
|
os.path.join('test', 'images2.zip'),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
urls = {
|
urls: ClassVar[dict[str, tuple[str, ...]]] = {
|
||||||
'train': (
|
'train': (
|
||||||
'https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf',
|
'https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf',
|
||||||
'https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u',
|
'https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u',
|
||||||
|
@ -211,7 +211,7 @@ class FAIR1M(NonGeoDataset):
|
||||||
'https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0',
|
'https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0',
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, tuple[str, ...]]] = {
|
||||||
'train': (
|
'train': (
|
||||||
'a460fe6b1b5b276bf856ce9ac72d6568',
|
'a460fe6b1b5b276bf856ce9ac72d6568',
|
||||||
'80f833ff355f91445c92a0c0c1fa7414',
|
'80f833ff355f91445c92a0c0c1fa7414',
|
||||||
|
|
|
@ -55,8 +55,8 @@ class FireRisk(NonGeoClassificationDataset):
|
||||||
md5 = 'a77b9a100d51167992ae8c51d26198a6'
|
md5 = 'a77b9a100d51167992ae8c51d26198a6'
|
||||||
filename = 'FireRisk.zip'
|
filename = 'FireRisk.zip'
|
||||||
directory = 'FireRisk'
|
directory = 'FireRisk'
|
||||||
splits = ['train', 'val']
|
splits = ('train', 'val')
|
||||||
classes = [
|
classes = (
|
||||||
'High',
|
'High',
|
||||||
'Low',
|
'Low',
|
||||||
'Moderate',
|
'Moderate',
|
||||||
|
@ -64,7 +64,7 @@ class FireRisk(NonGeoClassificationDataset):
|
||||||
'Very_High',
|
'Very_High',
|
||||||
'Very_Low',
|
'Very_Low',
|
||||||
'Water',
|
'Water',
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -96,11 +96,8 @@ class ForestDamage(NonGeoDataset):
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
classes = ['other', 'H', 'LD', 'HD']
|
classes = ('other', 'H', 'LD', 'HD')
|
||||||
url = (
|
url = 'https://lilablobssc.blob.core.windows.net/larch-casebearer/Data_Set_Larch_Casebearer.zip'
|
||||||
'https://lilablobssc.blob.core.windows.net/larch-casebearer/'
|
|
||||||
'Data_Set_Larch_Casebearer.zip'
|
|
||||||
)
|
|
||||||
data_dir = 'Data_Set_Larch_Casebearer'
|
data_dir = 'Data_Set_Larch_Casebearer'
|
||||||
md5 = '907815bcc739bff89496fac8f8ce63d7'
|
md5 = '907815bcc739bff89496fac8f8ce63d7'
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import re
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import fiona
|
import fiona
|
||||||
import fiona.transform
|
import fiona.transform
|
||||||
|
@ -370,13 +370,13 @@ class RasterDataset(GeoDataset):
|
||||||
separate_files = False
|
separate_files = False
|
||||||
|
|
||||||
#: Names of all available bands in the dataset
|
#: Names of all available bands in the dataset
|
||||||
all_bands: list[str] = []
|
all_bands: tuple[str, ...] = ()
|
||||||
|
|
||||||
#: Names of RGB bands in the dataset, used for plotting
|
#: Names of RGB bands in the dataset, used for plotting
|
||||||
rgb_bands: list[str] = []
|
rgb_bands: tuple[str, ...] = ()
|
||||||
|
|
||||||
#: Color map for the dataset, used for plotting
|
#: Color map for the dataset, used for plotting
|
||||||
cmap: dict[int, tuple[int, int, int, int]] = {}
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
@ -458,7 +458,7 @@ class RasterDataset(GeoDataset):
|
||||||
# See if file has a color map
|
# See if file has a color map
|
||||||
if len(self.cmap) == 0:
|
if len(self.cmap) == 0:
|
||||||
try:
|
try:
|
||||||
self.cmap = src.colormap(1)
|
self.cmap = src.colormap(1) # type: ignore[misc]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -66,8 +66,8 @@ class GID15(NonGeoDataset):
|
||||||
md5 = '615682bf659c3ed981826c6122c10c83'
|
md5 = '615682bf659c3ed981826c6122c10c83'
|
||||||
filename = 'gid-15.zip'
|
filename = 'gid-15.zip'
|
||||||
directory = 'GID'
|
directory = 'GID'
|
||||||
splits = ['train', 'val', 'test']
|
splits = ('train', 'val', 'test')
|
||||||
classes = [
|
classes = (
|
||||||
'background',
|
'background',
|
||||||
'industrial_land',
|
'industrial_land',
|
||||||
'urban_residential',
|
'urban_residential',
|
||||||
|
@ -84,7 +84,7 @@ class GID15(NonGeoDataset):
|
||||||
'river',
|
'river',
|
||||||
'lake',
|
'lake',
|
||||||
'pond',
|
'pond',
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -7,7 +7,7 @@ import glob
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
@ -73,9 +73,9 @@ class GlobBiomass(RasterDataset):
|
||||||
is_image = False
|
is_image = False
|
||||||
dtype = torch.float32 # pixelwise regression
|
dtype = torch.float32 # pixelwise regression
|
||||||
|
|
||||||
measurements = ['agb', 'gsv']
|
measurements = ('agb', 'gsv')
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'N00E020_agb.zip': 'bd83a3a4c143885d1962bde549413be6',
|
'N00E020_agb.zip': 'bd83a3a4c143885d1962bde549413be6',
|
||||||
'N00E020_gsv.zip': 'da5ddb88e369df2d781a0c6be008ae79',
|
'N00E020_gsv.zip': 'da5ddb88e369df2d781a0c6be008ae79',
|
||||||
'N00E060_agb.zip': '85eaca95b939086cc528e396b75bd097',
|
'N00E060_agb.zip': '85eaca95b939086cc528e396b75bd097',
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, cast, overload
|
from typing import Any, ClassVar, cast, overload
|
||||||
|
|
||||||
import fiona
|
import fiona
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -100,7 +100,7 @@ class IDTReeS(NonGeoDataset):
|
||||||
.. versionadded:: 0.2
|
.. versionadded:: 0.2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
classes = {
|
classes: ClassVar[dict[str, str]] = {
|
||||||
'ACPE': 'Acer pensylvanicum L.',
|
'ACPE': 'Acer pensylvanicum L.',
|
||||||
'ACRU': 'Acer rubrum L.',
|
'ACRU': 'Acer rubrum L.',
|
||||||
'ACSA3': 'Acer saccharum Marshall',
|
'ACSA3': 'Acer saccharum Marshall',
|
||||||
|
@ -135,19 +135,22 @@ class IDTReeS(NonGeoDataset):
|
||||||
'ROPS': 'Robinia pseudoacacia L.',
|
'ROPS': 'Robinia pseudoacacia L.',
|
||||||
'TSCA': 'Tsuga canadensis (L.) Carriere',
|
'TSCA': 'Tsuga canadensis (L.) Carriere',
|
||||||
}
|
}
|
||||||
metadata = {
|
metadata: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1', # noqa: E501
|
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1',
|
||||||
'md5': '5ddfa76240b4bb6b4a7861d1d31c299c',
|
'md5': '5ddfa76240b4bb6b4a7861d1d31c299c',
|
||||||
'filename': 'IDTREES_competition_train_v2.zip',
|
'filename': 'IDTREES_competition_train_v2.zip',
|
||||||
},
|
},
|
||||||
'test': {
|
'test': {
|
||||||
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1', # noqa: E501
|
'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1',
|
||||||
'md5': 'b108931c84a70f2a38a8234290131c9b',
|
'md5': 'b108931c84a70f2a38a8234290131c9b',
|
||||||
'filename': 'IDTREES_competition_test_v2.zip',
|
'filename': 'IDTREES_competition_test_v2.zip',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
directories = {'train': ['train'], 'test': ['task1', 'task2']}
|
directories: ClassVar[dict[str, list[str]]] = {
|
||||||
|
'train': ['train'],
|
||||||
|
'test': ['task1', 'task2'],
|
||||||
|
}
|
||||||
image_size = (200, 200)
|
image_size = (200, 200)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
@ -40,9 +40,9 @@ class IOBench(IntersectionDataset):
|
||||||
.. versionadded:: 0.6
|
.. versionadded:: 0.6
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz' # noqa: E501
|
url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz'
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'original': 'e3a908a0fd1c05c1af2f4c65724d59b3',
|
'original': 'e3a908a0fd1c05c1af2f4c65724d59b3',
|
||||||
'raw': 'e9603990441007ce7bba73bb8ba7d217',
|
'raw': 'e9603990441007ce7bba73bb8ba7d217',
|
||||||
'preprocessed': '9801f1240b238cb17525c865e413d1fd',
|
'preprocessed': '9801f1240b238cb17525c865e413d1fd',
|
||||||
|
@ -54,7 +54,7 @@ class IOBench(IntersectionDataset):
|
||||||
split: str = 'preprocessed',
|
split: str = 'preprocessed',
|
||||||
crs: CRS | None = None,
|
crs: CRS | None = None,
|
||||||
res: float | None = None,
|
res: float | None = None,
|
||||||
bands: Sequence[str] | None = Landsat9.default_bands + ['SR_QA_AEROSOL'],
|
bands: Sequence[str] | None = [*Landsat9.default_bands, 'SR_QA_AEROSOL'],
|
||||||
classes: list[int] = [0],
|
classes: list[int] = [0],
|
||||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||||
cache: bool = True,
|
cache: bool = True,
|
||||||
|
|
|
@ -8,7 +8,7 @@ import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
@ -43,8 +43,8 @@ class L7IrishImage(RasterDataset):
|
||||||
"""
|
"""
|
||||||
date_format = '%Y%m%d'
|
date_format = '%Y%m%d'
|
||||||
is_image = True
|
is_image = True
|
||||||
rgb_bands = ['B30', 'B20', 'B10']
|
rgb_bands = ('B30', 'B20', 'B10')
|
||||||
all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80']
|
all_bands = ('B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80')
|
||||||
|
|
||||||
|
|
||||||
class L7IrishMask(RasterDataset):
|
class L7IrishMask(RasterDataset):
|
||||||
|
@ -59,7 +59,7 @@ class L7IrishMask(RasterDataset):
|
||||||
_newmask2015\.TIF$
|
_newmask2015\.TIF$
|
||||||
"""
|
"""
|
||||||
is_image = False
|
is_image = False
|
||||||
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
|
classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud')
|
||||||
ordinal_map = torch.zeros(256, dtype=torch.long)
|
ordinal_map = torch.zeros(256, dtype=torch.long)
|
||||||
ordinal_map[64] = 1
|
ordinal_map[64] = 1
|
||||||
ordinal_map[128] = 2
|
ordinal_map[128] = 2
|
||||||
|
@ -158,11 +158,11 @@ class L7Irish(IntersectionDataset):
|
||||||
* https://www.sciencebase.gov/catalog/item/573ccf18e4b0dae0d5e4b109
|
* https://www.sciencebase.gov/catalog/item/573ccf18e4b0dae0d5e4b109
|
||||||
|
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz' # noqa: E501
|
url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz'
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'austral': '0a34770b992a62abeb88819feb192436',
|
'austral': '0a34770b992a62abeb88819feb192436',
|
||||||
'boreal': 'b7cfdd689a3c2fd2a8d572e1c10ed082',
|
'boreal': 'b7cfdd689a3c2fd2a8d572e1c10ed082',
|
||||||
'mid_latitude_north': 'c40abe5ad2487f8ab021cfb954982faa',
|
'mid_latitude_north': 'c40abe5ad2487f8ab021cfb954982faa',
|
||||||
|
|
|
@ -7,7 +7,7 @@ import glob
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
@ -36,8 +36,8 @@ class L8BiomeImage(RasterDataset):
|
||||||
"""
|
"""
|
||||||
date_format = '%Y%j'
|
date_format = '%Y%j'
|
||||||
is_image = True
|
is_image = True
|
||||||
rgb_bands = ['B4', 'B3', 'B2']
|
rgb_bands = ('B4', 'B3', 'B2')
|
||||||
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']
|
all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11')
|
||||||
|
|
||||||
|
|
||||||
class L8BiomeMask(RasterDataset):
|
class L8BiomeMask(RasterDataset):
|
||||||
|
@ -57,7 +57,7 @@ class L8BiomeMask(RasterDataset):
|
||||||
"""
|
"""
|
||||||
date_format = '%Y%j'
|
date_format = '%Y%j'
|
||||||
is_image = False
|
is_image = False
|
||||||
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
|
classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud')
|
||||||
ordinal_map = torch.zeros(256, dtype=torch.long)
|
ordinal_map = torch.zeros(256, dtype=torch.long)
|
||||||
ordinal_map[64] = 1
|
ordinal_map[64] = 1
|
||||||
ordinal_map[128] = 2
|
ordinal_map[128] = 2
|
||||||
|
@ -116,11 +116,11 @@ class L8Biome(IntersectionDataset):
|
||||||
* https://doi.org/10.1016/j.rse.2017.03.026
|
* https://doi.org/10.1016/j.rse.2017.03.026
|
||||||
|
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz' # noqa: E501
|
url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz'
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'barren': '0eb691822d03dabd4f5ea8aadd0b41c3',
|
'barren': '0eb691822d03dabd4f5ea8aadd0b41c3',
|
||||||
'forest': '4a5645596f6bb8cea44677f746ec676e',
|
'forest': '4a5645596f6bb8cea44677f746ec676e',
|
||||||
'grass_crops': 'a69ed5d6cb227c5783f026b9303cdd3c',
|
'grass_crops': 'a69ed5d6cb227c5783f026b9303cdd3c',
|
||||||
|
|
|
@ -9,7 +9,7 @@ import hashlib
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -64,8 +64,8 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
|
||||||
url = 'https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip'
|
url = 'https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip'
|
||||||
filename = 'landcover.ai.v1.zip'
|
filename = 'landcover.ai.v1.zip'
|
||||||
md5 = '3268c89070e8734b4e91d531c0617e03'
|
md5 = '3268c89070e8734b4e91d531c0617e03'
|
||||||
classes = ['Background', 'Building', 'Woodland', 'Water', 'Road']
|
classes = ('Background', 'Building', 'Woodland', 'Water', 'Road')
|
||||||
cmap = {
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
0: (0, 0, 0, 0),
|
0: (0, 0, 0, 0),
|
||||||
1: (97, 74, 74, 255),
|
1: (97, 74, 74, 255),
|
||||||
2: (38, 115, 0, 255),
|
2: (38, 115, 0, 255),
|
||||||
|
|
|
@ -33,7 +33,7 @@ class Landsat(RasterDataset, abc.ABC):
|
||||||
* `Surface Temperature <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-temperature>`_
|
* `Surface Temperature <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-temperature>`_
|
||||||
* `Surface Reflectance <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-reflectance>`_
|
* `Surface Reflectance <https://www.usgs.gov/landsat-missions/landsat-collection-2-surface-reflectance>`_
|
||||||
* `U.S. Analysis Ready Data <https://www.usgs.gov/landsat-missions/landsat-collection-2-us-analysis-ready-data>`_
|
* `U.S. Analysis Ready Data <https://www.usgs.gov/landsat-missions/landsat-collection-2-us-analysis-ready-data>`_
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
# https://www.usgs.gov/landsat-missions/landsat-collection-2
|
# https://www.usgs.gov/landsat-missions/landsat-collection-2
|
||||||
filename_regex = r"""
|
filename_regex = r"""
|
||||||
|
@ -55,7 +55,7 @@ class Landsat(RasterDataset, abc.ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def default_bands(self) -> list[str]:
|
def default_bands(self) -> tuple[str, ...]:
|
||||||
"""Bands to load by default."""
|
"""Bands to load by default."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -145,8 +145,8 @@ class Landsat1(Landsat):
|
||||||
|
|
||||||
filename_glob = 'LM01_*_{}.*'
|
filename_glob = 'LM01_*_{}.*'
|
||||||
|
|
||||||
default_bands = ['B4', 'B5', 'B6', 'B7']
|
default_bands = ('B4', 'B5', 'B6', 'B7')
|
||||||
rgb_bands = ['B6', 'B5', 'B4']
|
rgb_bands = ('B6', 'B5', 'B4')
|
||||||
|
|
||||||
|
|
||||||
class Landsat2(Landsat1):
|
class Landsat2(Landsat1):
|
||||||
|
@ -166,8 +166,8 @@ class Landsat4MSS(Landsat):
|
||||||
|
|
||||||
filename_glob = 'LM04_*_{}.*'
|
filename_glob = 'LM04_*_{}.*'
|
||||||
|
|
||||||
default_bands = ['B1', 'B2', 'B3', 'B4']
|
default_bands = ('B1', 'B2', 'B3', 'B4')
|
||||||
rgb_bands = ['B3', 'B2', 'B1']
|
rgb_bands = ('B3', 'B2', 'B1')
|
||||||
|
|
||||||
|
|
||||||
class Landsat4TM(Landsat):
|
class Landsat4TM(Landsat):
|
||||||
|
@ -175,8 +175,8 @@ class Landsat4TM(Landsat):
|
||||||
|
|
||||||
filename_glob = 'LT04_*_{}.*'
|
filename_glob = 'LT04_*_{}.*'
|
||||||
|
|
||||||
default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
|
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
|
||||||
rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1']
|
rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')
|
||||||
|
|
||||||
|
|
||||||
class Landsat5MSS(Landsat4MSS):
|
class Landsat5MSS(Landsat4MSS):
|
||||||
|
@ -196,8 +196,8 @@ class Landsat7(Landsat):
|
||||||
|
|
||||||
filename_glob = 'LE07_*_{}.*'
|
filename_glob = 'LE07_*_{}.*'
|
||||||
|
|
||||||
default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
|
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
|
||||||
rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1']
|
rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1')
|
||||||
|
|
||||||
|
|
||||||
class Landsat8(Landsat):
|
class Landsat8(Landsat):
|
||||||
|
@ -205,11 +205,11 @@ class Landsat8(Landsat):
|
||||||
|
|
||||||
filename_glob = 'LC08_*_{}.*'
|
filename_glob = 'LC08_*_{}.*'
|
||||||
|
|
||||||
default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
|
default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7')
|
||||||
rgb_bands = ['SR_B4', 'SR_B3', 'SR_B2']
|
rgb_bands = ('SR_B4', 'SR_B3', 'SR_B2')
|
||||||
|
|
||||||
|
|
||||||
class Landsat9(Landsat8):
|
class Landsat9(Landsat8):
|
||||||
"""Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2).""" # noqa: E501
|
"""Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2)."""
|
||||||
|
|
||||||
filename_glob = 'LC09_*_{}.*'
|
filename_glob = 'LC09_*_{}.*'
|
||||||
|
|
|
@ -7,6 +7,7 @@ import abc
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -26,8 +27,8 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC):
|
||||||
.. versionadded:: 0.6
|
.. versionadded:: 0.6
|
||||||
"""
|
"""
|
||||||
|
|
||||||
splits: list[str] | dict[str, dict[str, str]]
|
splits: ClassVar[tuple[str, ...] | dict[str, dict[str, str]]]
|
||||||
directories = ['A', 'B', 'label']
|
directories = ('A', 'B', 'label')
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -237,7 +238,7 @@ class LEVIRCD(LEVIRCDBase):
|
||||||
.. versionadded:: 0.6
|
.. versionadded:: 0.6
|
||||||
"""
|
"""
|
||||||
|
|
||||||
splits = {
|
splits: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
'url': 'https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-',
|
'url': 'https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-',
|
||||||
'filename': 'train.zip',
|
'filename': 'train.zip',
|
||||||
|
@ -336,7 +337,7 @@ class LEVIRCDPlus(LEVIRCDBase):
|
||||||
md5 = '1adf156f628aa32fb2e8fe6cada16c04'
|
md5 = '1adf156f628aa32fb2e8fe6cada16c04'
|
||||||
filename = 'LEVIR-CD+.zip'
|
filename = 'LEVIR-CD+.zip'
|
||||||
directory = 'LEVIR-CD+'
|
directory = 'LEVIR-CD+'
|
||||||
splits = ['train', 'test']
|
splits = ('train', 'test')
|
||||||
|
|
||||||
def _load_files(self, root: Path, split: str) -> list[dict[str, str]]:
|
def _load_files(self, root: Path, split: str) -> list[dict[str, str]]:
|
||||||
"""Return the paths of the files in the dataset.
|
"""Return the paths of the files in the dataset.
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -57,10 +58,10 @@ class LoveDA(NonGeoDataset):
|
||||||
.. versionadded:: 0.2
|
.. versionadded:: 0.2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scenes = ['urban', 'rural']
|
scenes = ('urban', 'rural')
|
||||||
splits = ['train', 'val', 'test']
|
splits = ('train', 'val', 'test')
|
||||||
|
|
||||||
info_dict = {
|
info_dict: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
'url': 'https://zenodo.org/record/5706578/files/Train.zip?download=1',
|
'url': 'https://zenodo.org/record/5706578/files/Train.zip?download=1',
|
||||||
'filename': 'Train.zip',
|
'filename': 'Train.zip',
|
||||||
|
@ -78,7 +79,7 @@ class LoveDA(NonGeoDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
classes = [
|
classes = (
|
||||||
'background',
|
'background',
|
||||||
'building',
|
'building',
|
||||||
'road',
|
'road',
|
||||||
|
@ -87,13 +88,13 @@ class LoveDA(NonGeoDataset):
|
||||||
'forest',
|
'forest',
|
||||||
'agriculture',
|
'agriculture',
|
||||||
'no-data',
|
'no-data',
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: Path = 'data',
|
root: Path = 'data',
|
||||||
split: str = 'train',
|
split: str = 'train',
|
||||||
scene: list[str] = ['urban', 'rural'],
|
scene: Sequence[str] = ['urban', 'rural'],
|
||||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
checksum: bool = False,
|
checksum: bool = False,
|
||||||
|
|
|
@ -7,6 +7,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -36,7 +37,7 @@ class MapInWild(NonGeoDataset):
|
||||||
different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season
|
different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season
|
||||||
Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging
|
Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging
|
||||||
Radiometer Suite NightTime Day/Night band. The dataset consists of 8144
|
Radiometer Suite NightTime Day/Night band. The dataset consists of 8144
|
||||||
images with the shape of 1920 × 1920 pixels. The images are weakly annotated
|
images with the shape of 1920 x 1920 pixels. The images are weakly annotated
|
||||||
from the World Database of Protected Areas (WDPA).
|
from the World Database of Protected Areas (WDPA).
|
||||||
|
|
||||||
Dataset features:
|
Dataset features:
|
||||||
|
@ -54,9 +55,9 @@ class MapInWild(NonGeoDataset):
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/' # noqa: E501
|
url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/'
|
||||||
|
|
||||||
modality_urls = {
|
modality_urls: ClassVar[dict[str, set[str]]] = {
|
||||||
'esa_wc': {'esa_wc/ESA_WC.zip'},
|
'esa_wc': {'esa_wc/ESA_WC.zip'},
|
||||||
'viirs': {'viirs/VIIRS.zip'},
|
'viirs': {'viirs/VIIRS.zip'},
|
||||||
'mask': {'mask/mask.zip'},
|
'mask': {'mask/mask.zip'},
|
||||||
|
@ -72,7 +73,7 @@ class MapInWild(NonGeoDataset):
|
||||||
'split_IDs': {'split_IDs/split_IDs.csv'},
|
'split_IDs': {'split_IDs/split_IDs.csv'},
|
||||||
}
|
}
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'ESA_WC.zip': '72b2ee578fe10f0df85bdb7f19311c92',
|
'ESA_WC.zip': '72b2ee578fe10f0df85bdb7f19311c92',
|
||||||
'VIIRS.zip': '4eff014bae127fe536f8a5f17d89ecb4',
|
'VIIRS.zip': '4eff014bae127fe536f8a5f17d89ecb4',
|
||||||
'mask.zip': '87c83a23a73998ad60d448d240b66225',
|
'mask.zip': '87c83a23a73998ad60d448d240b66225',
|
||||||
|
@ -91,9 +92,12 @@ class MapInWild(NonGeoDataset):
|
||||||
'split_IDs.csv': 'cb5c6c073702acee23544e1e6fe5856f',
|
'split_IDs.csv': 'cb5c6c073702acee23544e1e6fe5856f',
|
||||||
}
|
}
|
||||||
|
|
||||||
mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)}
|
mask_cmap: ClassVar[dict[int, tuple[int, int, int]]] = {
|
||||||
|
1: (0, 153, 0),
|
||||||
|
0: (255, 255, 255),
|
||||||
|
}
|
||||||
|
|
||||||
wc_cmap = {
|
wc_cmap: ClassVar[dict[int, tuple[int, int, int]]] = {
|
||||||
10: (0, 160, 0),
|
10: (0, 160, 0),
|
||||||
20: (150, 100, 0),
|
20: (150, 100, 0),
|
||||||
30: (255, 180, 0),
|
30: (255, 180, 0),
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -48,7 +48,7 @@ class MillionAID(NonGeoDataset):
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
multi_label_categories = [
|
multi_label_categories = (
|
||||||
'agriculture_land',
|
'agriculture_land',
|
||||||
'airport_area',
|
'airport_area',
|
||||||
'apartment',
|
'apartment',
|
||||||
|
@ -122,9 +122,9 @@ class MillionAID(NonGeoDataset):
|
||||||
'wind_turbine',
|
'wind_turbine',
|
||||||
'woodland',
|
'woodland',
|
||||||
'works',
|
'works',
|
||||||
]
|
)
|
||||||
|
|
||||||
multi_class_categories = [
|
multi_class_categories = (
|
||||||
'apartment',
|
'apartment',
|
||||||
'apron',
|
'apron',
|
||||||
'bare_land',
|
'bare_land',
|
||||||
|
@ -176,17 +176,17 @@ class MillionAID(NonGeoDataset):
|
||||||
'wastewater_plant',
|
'wastewater_plant',
|
||||||
'wind_turbine',
|
'wind_turbine',
|
||||||
'works',
|
'works',
|
||||||
]
|
)
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'train': '1b40503cafa9b0601653ca36cd788852',
|
'train': '1b40503cafa9b0601653ca36cd788852',
|
||||||
'test': '51a63ee3eeb1351889eacff349a983d8',
|
'test': '51a63ee3eeb1351889eacff349a983d8',
|
||||||
}
|
}
|
||||||
|
|
||||||
filenames = {'train': 'train.zip', 'test': 'test.zip'}
|
filenames: ClassVar[dict[str, str]] = {'train': 'train.zip', 'test': 'test.zip'}
|
||||||
|
|
||||||
tasks = ['multi-class', 'multi-label']
|
tasks = ('multi-class', 'multi-label')
|
||||||
splits = ['train', 'test']
|
splits = ('train', 'test')
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -45,8 +45,8 @@ class NAIP(RasterDataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Plotting
|
# Plotting
|
||||||
all_bands = ['R', 'G', 'B', 'NIR']
|
all_bands = ('R', 'G', 'B', 'NIR')
|
||||||
rgb_bands = ['R', 'G', 'B']
|
rgb_bands = ('R', 'G', 'B')
|
||||||
|
|
||||||
def plot(
|
def plot(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
"""Northeastern China Crop Map Dataset."""
|
"""Northeastern China Crop Map Dataset."""
|
||||||
|
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
@ -57,23 +57,23 @@ class NCCM(RasterDataset):
|
||||||
|
|
||||||
date_format = '%Y'
|
date_format = '%Y'
|
||||||
is_image = False
|
is_image = False
|
||||||
urls = {
|
urls: ClassVar[dict[int, str]] = {
|
||||||
2019: 'https://figshare.com/ndownloader/files/25070540',
|
2019: 'https://figshare.com/ndownloader/files/25070540',
|
||||||
2018: 'https://figshare.com/ndownloader/files/25070624',
|
2018: 'https://figshare.com/ndownloader/files/25070624',
|
||||||
2017: 'https://figshare.com/ndownloader/files/25070582',
|
2017: 'https://figshare.com/ndownloader/files/25070582',
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2019: '0d062bbd42e483fdc8239d22dba7020f',
|
2019: '0d062bbd42e483fdc8239d22dba7020f',
|
||||||
2018: 'b3bb4894478d10786aa798fb11693ec1',
|
2018: 'b3bb4894478d10786aa798fb11693ec1',
|
||||||
2017: 'd047fbe4a85341fa6248fd7e0badab6c',
|
2017: 'd047fbe4a85341fa6248fd7e0badab6c',
|
||||||
}
|
}
|
||||||
fnames = {
|
fnames: ClassVar[dict[int, str]] = {
|
||||||
2019: 'CDL2019_clip.tif',
|
2019: 'CDL2019_clip.tif',
|
||||||
2018: 'CDL2018_clip1.tif',
|
2018: 'CDL2018_clip1.tif',
|
||||||
2017: 'CDL2017_clip.tif',
|
2017: 'CDL2017_clip.tif',
|
||||||
}
|
}
|
||||||
|
|
||||||
cmap = {
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
0: (0, 255, 0, 255),
|
0: (0, 255, 0, 255),
|
||||||
1: (255, 0, 0, 255),
|
1: (255, 0, 0, 255),
|
||||||
2: (255, 255, 0, 255),
|
2: (255, 255, 0, 255),
|
||||||
|
|
|
@ -7,7 +7,7 @@ import glob
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
@ -67,7 +67,7 @@ class NLCD(RasterDataset):
|
||||||
* 2019: https://doi.org/10.5066/P9KZCM54
|
* 2019: https://doi.org/10.5066/P9KZCM54
|
||||||
|
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
filename_glob = 'nlcd_*_land_cover_l48_*.img'
|
filename_glob = 'nlcd_*_land_cover_l48_*.img'
|
||||||
filename_regex = (
|
filename_regex = (
|
||||||
|
@ -79,7 +79,7 @@ class NLCD(RasterDataset):
|
||||||
|
|
||||||
url = 'https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip'
|
url = 'https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip'
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2001: '538166a4d783204764e3df3b221fc4cd',
|
2001: '538166a4d783204764e3df3b221fc4cd',
|
||||||
2006: '67454e7874a00294adb9442374d0c309',
|
2006: '67454e7874a00294adb9442374d0c309',
|
||||||
2011: 'ea524c835d173658eeb6fa3c8e6b917b',
|
2011: 'ea524c835d173658eeb6fa3c8e6b917b',
|
||||||
|
@ -87,7 +87,7 @@ class NLCD(RasterDataset):
|
||||||
2019: '82851c3f8105763b01c83b4a9e6f3961',
|
2019: '82851c3f8105763b01c83b4a9e6f3961',
|
||||||
}
|
}
|
||||||
|
|
||||||
cmap = {
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
0: (0, 0, 0, 0),
|
0: (0, 0, 0, 0),
|
||||||
11: (70, 107, 159, 255),
|
11: (70, 107, 159, 255),
|
||||||
12: (209, 222, 248, 255),
|
12: (209, 222, 248, 255),
|
||||||
|
|
|
@ -9,7 +9,7 @@ import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import fiona
|
import fiona
|
||||||
import fiona.transform
|
import fiona.transform
|
||||||
|
@ -61,7 +61,7 @@ class OpenBuildings(VectorDataset):
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'025_buildings.csv.gz': '41db2572bfd08628d01475a2ee1a2f17',
|
'025_buildings.csv.gz': '41db2572bfd08628d01475a2ee1a2f17',
|
||||||
'04f_buildings.csv.gz': '3232c1c6d45c1543260b77e5689fc8b1',
|
'04f_buildings.csv.gz': '3232c1c6d45c1543260b77e5689fc8b1',
|
||||||
'05b_buildings.csv.gz': '4fc57c63bbbf9a21a3902da7adc3a670',
|
'05b_buildings.csv.gz': '4fc57c63bbbf9a21a3902da7adc3a670',
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -50,7 +51,7 @@ class OSCD(NonGeoDataset):
|
||||||
.. versionadded:: 0.2
|
.. versionadded:: 0.2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
urls = {
|
urls: ClassVar[dict[str, str]] = {
|
||||||
'Onera Satellite Change Detection dataset - Images.zip': (
|
'Onera Satellite Change Detection dataset - Images.zip': (
|
||||||
'https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download'
|
'https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download'
|
||||||
),
|
),
|
||||||
|
@ -61,7 +62,7 @@ class OSCD(NonGeoDataset):
|
||||||
'https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download'
|
'https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download'
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, str]] = {
|
||||||
'Onera Satellite Change Detection dataset - Images.zip': (
|
'Onera Satellite Change Detection dataset - Images.zip': (
|
||||||
'c50d4a2941da64e03a47ac4dec63d915'
|
'c50d4a2941da64e03a47ac4dec63d915'
|
||||||
),
|
),
|
||||||
|
@ -75,9 +76,9 @@ class OSCD(NonGeoDataset):
|
||||||
|
|
||||||
zipfile_glob = '*Onera*.zip'
|
zipfile_glob = '*Onera*.zip'
|
||||||
filename_glob = '*Onera*'
|
filename_glob = '*Onera*'
|
||||||
splits = ['train', 'test']
|
splits = ('train', 'test')
|
||||||
|
|
||||||
colormap = ['blue']
|
colormap = ('blue',)
|
||||||
|
|
||||||
all_bands = (
|
all_bands = (
|
||||||
'B01',
|
'B01',
|
||||||
|
@ -319,7 +320,7 @@ class OSCD(NonGeoDataset):
|
||||||
torch.from_numpy(rgb_img),
|
torch.from_numpy(rgb_img),
|
||||||
sample['mask'],
|
sample['mask'],
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
colors=self.colormap,
|
colors=list(self.colormap),
|
||||||
)
|
)
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import fiona
|
import fiona
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -70,7 +71,7 @@ class PASTIS(NonGeoDataset):
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
classes = [
|
classes = (
|
||||||
'background', # all non-agricultural land
|
'background', # all non-agricultural land
|
||||||
'meadow',
|
'meadow',
|
||||||
'soft_winter_wheat',
|
'soft_winter_wheat',
|
||||||
|
@ -91,8 +92,8 @@ class PASTIS(NonGeoDataset):
|
||||||
'mixed_cereal',
|
'mixed_cereal',
|
||||||
'sorghum',
|
'sorghum',
|
||||||
'void_label', # for parcels mostly outside their patch
|
'void_label', # for parcels mostly outside their patch
|
||||||
]
|
)
|
||||||
cmap = {
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
0: (0, 0, 0, 255),
|
0: (0, 0, 0, 255),
|
||||||
1: (174, 199, 232, 255),
|
1: (174, 199, 232, 255),
|
||||||
2: (255, 127, 14, 255),
|
2: (255, 127, 14, 255),
|
||||||
|
@ -118,7 +119,7 @@ class PASTIS(NonGeoDataset):
|
||||||
filename = 'PASTIS-R.zip'
|
filename = 'PASTIS-R.zip'
|
||||||
url = 'https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1'
|
url = 'https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1'
|
||||||
md5 = '4887513d6c2d2b07fa935d325bd53e09'
|
md5 = '4887513d6c2d2b07fa935d325bd53e09'
|
||||||
prefix = {
|
prefix: ClassVar[dict[str, str]] = {
|
||||||
's2': os.path.join('DATA_S2', 'S2_'),
|
's2': os.path.join('DATA_S2', 'S2_'),
|
||||||
's1a': os.path.join('DATA_S1A', 'S1A_'),
|
's1a': os.path.join('DATA_S1A', 'S1A_'),
|
||||||
's1d': os.path.join('DATA_S1D', 'S1D_'),
|
's1d': os.path.join('DATA_S1D', 'S1D_'),
|
||||||
|
@ -232,7 +233,7 @@ class PASTIS(NonGeoDataset):
|
||||||
Returns:
|
Returns:
|
||||||
the target mask
|
the target mask
|
||||||
"""
|
"""
|
||||||
# See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501
|
# See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201
|
||||||
# even though the mask file is 3 bands, we just select the first band
|
# even though the mask file is 3 bands, we just select the first band
|
||||||
array = np.load(self.files[index]['semantic'])[0].astype(np.uint8)
|
array = np.load(self.files[index]['semantic'])[0].astype(np.uint8)
|
||||||
tensor = torch.from_numpy(array).long()
|
tensor = torch.from_numpy(array).long()
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -54,12 +55,12 @@ class Potsdam2D(NonGeoDataset):
|
||||||
* https://doi.org/10.5194/isprsannals-I-3-293-2012
|
* https://doi.org/10.5194/isprsannals-I-3-293-2012
|
||||||
|
|
||||||
.. versionadded:: 0.2
|
.. versionadded:: 0.2
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
filenames = ['4_Ortho_RGBIR.zip', '5_Labels_all.zip']
|
filenames = ('4_Ortho_RGBIR.zip', '5_Labels_all.zip')
|
||||||
md5s = ['c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db']
|
md5s = ('c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db')
|
||||||
image_root = '4_Ortho_RGBIR'
|
image_root = '4_Ortho_RGBIR'
|
||||||
splits = {
|
splits: ClassVar[dict[str, list[str]]] = {
|
||||||
'train': [
|
'train': [
|
||||||
'top_potsdam_2_10',
|
'top_potsdam_2_10',
|
||||||
'top_potsdam_2_11',
|
'top_potsdam_2_11',
|
||||||
|
@ -103,22 +104,22 @@ class Potsdam2D(NonGeoDataset):
|
||||||
'top_potsdam_7_13',
|
'top_potsdam_7_13',
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
classes = [
|
classes = (
|
||||||
'Clutter/background',
|
'Clutter/background',
|
||||||
'Impervious surfaces',
|
'Impervious surfaces',
|
||||||
'Building',
|
'Building',
|
||||||
'Low Vegetation',
|
'Low Vegetation',
|
||||||
'Tree',
|
'Tree',
|
||||||
'Car',
|
'Car',
|
||||||
]
|
)
|
||||||
colormap = [
|
colormap = (
|
||||||
(255, 0, 0),
|
(255, 0, 0),
|
||||||
(255, 255, 255),
|
(255, 255, 255),
|
||||||
(0, 0, 255),
|
(0, 0, 255),
|
||||||
(0, 255, 255),
|
(0, 255, 255),
|
||||||
(0, 255, 0),
|
(0, 255, 0),
|
||||||
(255, 255, 0),
|
(255, 255, 0),
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -257,7 +258,7 @@ class Potsdam2D(NonGeoDataset):
|
||||||
"""
|
"""
|
||||||
ncols = 1
|
ncols = 1
|
||||||
image1 = draw_semantic_segmentation_masks(
|
image1 = draw_semantic_segmentation_masks(
|
||||||
sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap
|
sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap)
|
||||||
)
|
)
|
||||||
if 'prediction' in sample:
|
if 'prediction' in sample:
|
||||||
ncols += 1
|
ncols += 1
|
||||||
|
@ -265,7 +266,7 @@ class Potsdam2D(NonGeoDataset):
|
||||||
sample['image'][:3],
|
sample['image'][:3],
|
||||||
sample['prediction'],
|
sample['prediction'],
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
colors=self.colormap,
|
colors=list(self.colormap),
|
||||||
)
|
)
|
||||||
|
|
||||||
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -61,8 +61,12 @@ class QuakeSet(NonGeoDataset):
|
||||||
filename = 'earthquakes.h5'
|
filename = 'earthquakes.h5'
|
||||||
url = 'https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5'
|
url = 'https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5'
|
||||||
md5 = '76fc7c76b7ca56f4844d852e175e1560'
|
md5 = '76fc7c76b7ca56f4844d852e175e1560'
|
||||||
splits = {'train': 'train', 'val': 'validation', 'test': 'test'}
|
splits: ClassVar[dict[str, str]] = {
|
||||||
classes = ['unaffected_area', 'earthquake_affected_area']
|
'train': 'train',
|
||||||
|
'val': 'validation',
|
||||||
|
'test': 'test',
|
||||||
|
}
|
||||||
|
classes = ('unaffected_area', 'earthquake_affected_area')
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -56,7 +56,7 @@ class ReforesTree(NonGeoDataset):
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
classes = ['other', 'banana', 'cacao', 'citrus', 'fruit', 'timber']
|
classes = ('other', 'banana', 'cacao', 'citrus', 'fruit', 'timber')
|
||||||
url = 'https://zenodo.org/record/6813783/files/reforesTree.zip?download=1'
|
url = 'https://zenodo.org/record/6813783/files/reforesTree.zip?download=1'
|
||||||
|
|
||||||
md5 = 'f6a4a1d8207aeaa5fbab7b21b683a302'
|
md5 = 'f6a4a1d8207aeaa5fbab7b21b683a302'
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import cast
|
from typing import ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -98,13 +98,13 @@ class RESISC45(NonGeoClassificationDataset):
|
||||||
filename = 'NWPU-RESISC45.zip'
|
filename = 'NWPU-RESISC45.zip'
|
||||||
directory = 'NWPU-RESISC45'
|
directory = 'NWPU-RESISC45'
|
||||||
|
|
||||||
splits = ['train', 'val', 'test']
|
splits = ('train', 'val', 'test')
|
||||||
split_urls = {
|
split_urls: ClassVar[dict[str, str]] = {
|
||||||
'train': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-train.txt',
|
'train': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-train.txt',
|
||||||
'val': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-val.txt',
|
'val': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-val.txt',
|
||||||
'test': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-test.txt',
|
'test': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-test.txt',
|
||||||
}
|
}
|
||||||
split_md5s = {
|
split_md5s: ClassVar[dict[str, str]] = {
|
||||||
'train': 'b5a4c05a37de15e4ca886696a85c403e',
|
'train': 'b5a4c05a37de15e4ca886696a85c403e',
|
||||||
'val': 'a0770cee4c5ca20b8c32bbd61e114805',
|
'val': 'a0770cee4c5ca20b8c32bbd61e114805',
|
||||||
'test': '3dda9e4988b47eb1de9f07993653eb08',
|
'test': '3dda9e4988b47eb1de9f07993653eb08',
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -55,11 +56,11 @@ class RwandaFieldBoundary(NonGeoDataset):
|
||||||
|
|
||||||
url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition'
|
url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition'
|
||||||
|
|
||||||
splits = {'train': 57, 'test': 13}
|
splits: ClassVar[dict[str, int]] = {'train': 57, 'test': 13}
|
||||||
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
|
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
|
||||||
all_bands = ('B01', 'B02', 'B03', 'B04')
|
all_bands = ('B01', 'B02', 'B03', 'B04')
|
||||||
rgb_bands = ('B03', 'B02', 'B01')
|
rgb_bands = ('B03', 'B02', 'B01')
|
||||||
classes = ['No field-boundary', 'Field-boundary']
|
classes = ('No field-boundary', 'Field-boundary')
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from collections.abc import Callable, Collection, Iterable
|
from collections.abc import Callable, Collection, Iterable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.patches as mpatches
|
import matplotlib.patches as mpatches
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -85,51 +86,51 @@ class SeasoNet(NonGeoDataset):
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
metadata = [
|
metadata = (
|
||||||
{
|
{
|
||||||
'name': 'spring',
|
'name': 'spring',
|
||||||
'ext': '.zip',
|
'ext': '.zip',
|
||||||
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip', # noqa: E501
|
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip',
|
||||||
'md5': 'de4cdba7b6196aff624073991b187561',
|
'md5': 'de4cdba7b6196aff624073991b187561',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'summer',
|
'name': 'summer',
|
||||||
'ext': '.zip',
|
'ext': '.zip',
|
||||||
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip', # noqa: E501
|
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip',
|
||||||
'md5': '6a54d4e134d27ae4eb03f180ee100550',
|
'md5': '6a54d4e134d27ae4eb03f180ee100550',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'fall',
|
'name': 'fall',
|
||||||
'ext': '.zip',
|
'ext': '.zip',
|
||||||
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip', # noqa: E501
|
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip',
|
||||||
'md5': '5f94920fe41a63c6bfbab7295f7d6b95',
|
'md5': '5f94920fe41a63c6bfbab7295f7d6b95',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'winter',
|
'name': 'winter',
|
||||||
'ext': '.zip',
|
'ext': '.zip',
|
||||||
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip', # noqa: E501
|
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip',
|
||||||
'md5': 'dc5e3e09e52ab5c72421b1e3186c9a48',
|
'md5': 'dc5e3e09e52ab5c72421b1e3186c9a48',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'snow',
|
'name': 'snow',
|
||||||
'ext': '.zip',
|
'ext': '.zip',
|
||||||
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip', # noqa: E501
|
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip',
|
||||||
'md5': 'e1b300994143f99ebb03f51d6ab1cbe6',
|
'md5': 'e1b300994143f99ebb03f51d6ab1cbe6',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'splits',
|
'name': 'splits',
|
||||||
'ext': '.zip',
|
'ext': '.zip',
|
||||||
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip', # noqa: E501
|
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip',
|
||||||
'md5': 'e4ec4a18bc4efc828f0944a7cf4d5fed',
|
'md5': 'e4ec4a18bc4efc828f0944a7cf4d5fed',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'meta.csv',
|
'name': 'meta.csv',
|
||||||
'ext': '',
|
'ext': '',
|
||||||
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv', # noqa: E501
|
'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv',
|
||||||
'md5': '43ea07974936a6bf47d989c32e16afe7',
|
'md5': '43ea07974936a6bf47d989c32e16afe7',
|
||||||
},
|
},
|
||||||
]
|
)
|
||||||
classes = [
|
classes = (
|
||||||
'Continuous urban fabric',
|
'Continuous urban fabric',
|
||||||
'Discontinuous urban fabric',
|
'Discontinuous urban fabric',
|
||||||
'Industrial or commercial units',
|
'Industrial or commercial units',
|
||||||
|
@ -163,12 +164,17 @@ class SeasoNet(NonGeoDataset):
|
||||||
'Coastal lagoons',
|
'Coastal lagoons',
|
||||||
'Estuaries',
|
'Estuaries',
|
||||||
'Sea and ocean',
|
'Sea and ocean',
|
||||||
]
|
)
|
||||||
all_seasons = {'Spring', 'Summer', 'Fall', 'Winter', 'Snow'}
|
all_seasons = frozenset({'Spring', 'Summer', 'Fall', 'Winter', 'Snow'})
|
||||||
all_bands = ('10m_RGB', '10m_IR', '20m', '60m')
|
all_bands = ('10m_RGB', '10m_IR', '20m', '60m')
|
||||||
band_nums = {'10m_RGB': 3, '10m_IR': 1, '20m': 6, '60m': 2}
|
band_nums: ClassVar[dict[str, int]] = {
|
||||||
splits = ['train', 'val', 'test']
|
'10m_RGB': 3,
|
||||||
cmap = {
|
'10m_IR': 1,
|
||||||
|
'20m': 6,
|
||||||
|
'60m': 2,
|
||||||
|
}
|
||||||
|
splits = ('train', 'val', 'test')
|
||||||
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
0: (230, 000, 77, 255),
|
0: (230, 000, 77, 255),
|
||||||
1: (255, 000, 000, 255),
|
1: (255, 000, 000, 255),
|
||||||
2: (204, 77, 242, 255),
|
2: (204, 77, 242, 255),
|
||||||
|
@ -331,7 +337,7 @@ class SeasoNet(NonGeoDataset):
|
||||||
for band in self.bands:
|
for band in self.bands:
|
||||||
with rasterio.open(f'{path}_{band}.tif') as f:
|
with rasterio.open(f'{path}_{band}.tif') as f:
|
||||||
array = f.read(
|
array = f.read(
|
||||||
out_shape=[f.count] + list(self.image_size),
|
out_shape=[f.count, *list(self.image_size)],
|
||||||
out_dtype='int32',
|
out_dtype='int32',
|
||||||
resampling=Resampling.bilinear,
|
resampling=Resampling.bilinear,
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -37,7 +38,7 @@ class SeasonalContrastS2(NonGeoDataset):
|
||||||
* https://arxiv.org/pdf/2103.16607.pdf
|
* https://arxiv.org/pdf/2103.16607.pdf
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_bands = [
|
all_bands = (
|
||||||
'B1',
|
'B1',
|
||||||
'B2',
|
'B2',
|
||||||
'B3',
|
'B3',
|
||||||
|
@ -50,10 +51,10 @@ class SeasonalContrastS2(NonGeoDataset):
|
||||||
'B9',
|
'B9',
|
||||||
'B11',
|
'B11',
|
||||||
'B12',
|
'B12',
|
||||||
]
|
)
|
||||||
rgb_bands = ['B4', 'B3', 'B2']
|
rgb_bands = ('B4', 'B3', 'B2')
|
||||||
|
|
||||||
metadata = {
|
metadata: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'100k': {
|
'100k': {
|
||||||
'url': 'https://zenodo.org/record/4728033/files/seco_100k.zip?download=1',
|
'url': 'https://zenodo.org/record/4728033/files/seco_100k.zip?download=1',
|
||||||
'md5': 'ebf2d5e03adc6e657f9a69a20ad863e0',
|
'md5': 'ebf2d5e03adc6e657f9a69a20ad863e0',
|
||||||
|
@ -73,7 +74,7 @@ class SeasonalContrastS2(NonGeoDataset):
|
||||||
root: Path = 'data',
|
root: Path = 'data',
|
||||||
version: str = '100k',
|
version: str = '100k',
|
||||||
seasons: int = 1,
|
seasons: int = 1,
|
||||||
bands: list[str] = rgb_bands,
|
bands: Sequence[str] = rgb_bands,
|
||||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
checksum: bool = False,
|
checksum: bool = False,
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -63,9 +64,9 @@ class SEN12MS(NonGeoDataset):
|
||||||
or manually downloaded from https://dataserv.ub.tum.de/s/m1474000
|
or manually downloaded from https://dataserv.ub.tum.de/s/m1474000
|
||||||
and https://github.com/schmitt-muc/SEN12MS/tree/master/splits.
|
and https://github.com/schmitt-muc/SEN12MS/tree/master/splits.
|
||||||
This download will likely take several hours.
|
This download will likely take several hours.
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
BAND_SETS: dict[str, tuple[str, ...]] = {
|
BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
|
||||||
'all': (
|
'all': (
|
||||||
'VV',
|
'VV',
|
||||||
'VH',
|
'VH',
|
||||||
|
@ -120,9 +121,9 @@ class SEN12MS(NonGeoDataset):
|
||||||
'B12',
|
'B12',
|
||||||
)
|
)
|
||||||
|
|
||||||
rgb_bands = ['B04', 'B03', 'B02']
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
|
|
||||||
filenames = [
|
filenames = (
|
||||||
'ROIs1158_spring_lc.tar.gz',
|
'ROIs1158_spring_lc.tar.gz',
|
||||||
'ROIs1158_spring_s1.tar.gz',
|
'ROIs1158_spring_s1.tar.gz',
|
||||||
'ROIs1158_spring_s2.tar.gz',
|
'ROIs1158_spring_s2.tar.gz',
|
||||||
|
@ -137,16 +138,16 @@ class SEN12MS(NonGeoDataset):
|
||||||
'ROIs2017_winter_s2.tar.gz',
|
'ROIs2017_winter_s2.tar.gz',
|
||||||
'train_list.txt',
|
'train_list.txt',
|
||||||
'test_list.txt',
|
'test_list.txt',
|
||||||
]
|
)
|
||||||
light_filenames = [
|
light_filenames = (
|
||||||
'ROIs1158_spring',
|
'ROIs1158_spring',
|
||||||
'ROIs1868_summer',
|
'ROIs1868_summer',
|
||||||
'ROIs1970_fall',
|
'ROIs1970_fall',
|
||||||
'ROIs2017_winter',
|
'ROIs2017_winter',
|
||||||
'train_list.txt',
|
'train_list.txt',
|
||||||
'test_list.txt',
|
'test_list.txt',
|
||||||
]
|
)
|
||||||
md5s = [
|
md5s = (
|
||||||
'6e2e8fa8b8cba77ddab49fd20ff5c37b',
|
'6e2e8fa8b8cba77ddab49fd20ff5c37b',
|
||||||
'fba019bb27a08c1db96b31f718c34d79',
|
'fba019bb27a08c1db96b31f718c34d79',
|
||||||
'd58af2c15a16f376eb3308dc9b685af2',
|
'd58af2c15a16f376eb3308dc9b685af2',
|
||||||
|
@ -161,7 +162,7 @@ class SEN12MS(NonGeoDataset):
|
||||||
'3807545661288dcca312c9c538537b63',
|
'3807545661288dcca312c9c538537b63',
|
||||||
'0a68d4e1eb24f128fccdb930000b2546',
|
'0a68d4e1eb24f128fccdb930000b2546',
|
||||||
'c7faad064001e646445c4c634169484d',
|
'c7faad064001e646445c4c634169484d',
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -137,7 +137,7 @@ class Sentinel1(Sentinel):
|
||||||
\.
|
\.
|
||||||
"""
|
"""
|
||||||
date_format = '%Y%m%dT%H%M%S'
|
date_format = '%Y%m%dT%H%M%S'
|
||||||
all_bands = ['HH', 'HV', 'VV', 'VH']
|
all_bands = ('HH', 'HV', 'VV', 'VH')
|
||||||
separate_files = True
|
separate_files = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -277,7 +277,7 @@ class Sentinel2(Sentinel):
|
||||||
date_format = '%Y%m%dT%H%M%S'
|
date_format = '%Y%m%dT%H%M%S'
|
||||||
|
|
||||||
# https://gisgeography.com/sentinel-2-bands-combinations/
|
# https://gisgeography.com/sentinel-2-bands-combinations/
|
||||||
all_bands = [
|
all_bands: tuple[str, ...] = (
|
||||||
'B01',
|
'B01',
|
||||||
'B02',
|
'B02',
|
||||||
'B03',
|
'B03',
|
||||||
|
@ -291,8 +291,8 @@ class Sentinel2(Sentinel):
|
||||||
'B10',
|
'B10',
|
||||||
'B11',
|
'B11',
|
||||||
'B12',
|
'B12',
|
||||||
]
|
)
|
||||||
rgb_bands = ['B04', 'B03', 'B02']
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
|
|
||||||
separate_files = True
|
separate_files = True
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -62,8 +62,8 @@ class SKIPPD(NonGeoDataset):
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}' # noqa: E501
|
url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}'
|
||||||
md5 = {
|
md5: ClassVar[dict[str, str]] = {
|
||||||
'forecast': 'f4f3509ddcc83a55c433be9db2e51077',
|
'forecast': 'f4f3509ddcc83a55c433be9db2e51077',
|
||||||
'nowcast': '0000761d403e45bb5f86c21d3c69aa80',
|
'nowcast': '0000761d403e45bb5f86c21d3c69aa80',
|
||||||
}
|
}
|
||||||
|
@ -71,9 +71,9 @@ class SKIPPD(NonGeoDataset):
|
||||||
data_file_name = '2017_2019_images_pv_processed_{}.hdf5'
|
data_file_name = '2017_2019_images_pv_processed_{}.hdf5'
|
||||||
zipfile_name = '2017_2019_images_pv_processed_{}.zip'
|
zipfile_name = '2017_2019_images_pv_processed_{}.zip'
|
||||||
|
|
||||||
valid_splits = ['trainval', 'test']
|
valid_splits = ('trainval', 'test')
|
||||||
|
|
||||||
valid_tasks = ['nowcast', 'forecast']
|
valid_tasks = ('nowcast', 'forecast')
|
||||||
|
|
||||||
dateformat = '%m/%d/%Y, %H:%M:%S'
|
dateformat = '%m/%d/%Y, %H:%M:%S'
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import cast
|
from typing import ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -103,10 +103,10 @@ class So2Sat(NonGeoDataset):
|
||||||
This dataset requires the following additional library to be installed:
|
This dataset requires the following additional library to be installed:
|
||||||
|
|
||||||
* `<https://pypi.org/project/h5py/>`_ to load the dataset
|
* `<https://pypi.org/project/h5py/>`_ to load the dataset
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
versions = ['2', '3_random', '3_block', '3_culture_10']
|
versions = ('2', '3_random', '3_block', '3_culture_10')
|
||||||
filenames_by_version = {
|
filenames_by_version: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'2': {
|
'2': {
|
||||||
'train': 'training.h5',
|
'train': 'training.h5',
|
||||||
'validation': 'validation.h5',
|
'validation': 'validation.h5',
|
||||||
|
@ -119,7 +119,7 @@ class So2Sat(NonGeoDataset):
|
||||||
'test': 'culture_10/testing.h5',
|
'test': 'culture_10/testing.h5',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
md5s_by_version = {
|
md5s_by_version: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'2': {
|
'2': {
|
||||||
'train': '702bc6a9368ebff4542d791e53469244',
|
'train': '702bc6a9368ebff4542d791e53469244',
|
||||||
'validation': '71cfa6795de3e22207229d06d6f8775d',
|
'validation': '71cfa6795de3e22207229d06d6f8775d',
|
||||||
|
@ -139,7 +139,7 @@ class So2Sat(NonGeoDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
classes = [
|
classes = (
|
||||||
'Compact high rise',
|
'Compact high rise',
|
||||||
'Compact mid rise',
|
'Compact mid rise',
|
||||||
'Compact low rise',
|
'Compact low rise',
|
||||||
|
@ -157,7 +157,7 @@ class So2Sat(NonGeoDataset):
|
||||||
'Bare rock or paved',
|
'Bare rock or paved',
|
||||||
'Bare soil or sand',
|
'Bare soil or sand',
|
||||||
'Water',
|
'Water',
|
||||||
]
|
)
|
||||||
|
|
||||||
all_s1_band_names = (
|
all_s1_band_names = (
|
||||||
'S1_B1',
|
'S1_B1',
|
||||||
|
@ -183,9 +183,9 @@ class So2Sat(NonGeoDataset):
|
||||||
)
|
)
|
||||||
all_band_names = all_s1_band_names + all_s2_band_names
|
all_band_names = all_s1_band_names + all_s2_band_names
|
||||||
|
|
||||||
rgb_bands = ['S2_B04', 'S2_B03', 'S2_B02']
|
rgb_bands = ('S2_B04', 'S2_B03', 'S2_B02')
|
||||||
|
|
||||||
BAND_SETS = {
|
BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = {
|
||||||
'all': all_band_names,
|
'all': all_band_names,
|
||||||
's1': all_s1_band_names,
|
's1': all_s1_band_names,
|
||||||
's2': all_s2_band_names,
|
's2': all_s2_band_names,
|
||||||
|
|
|
@ -6,8 +6,8 @@
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
@ -79,9 +79,9 @@ class SouthAfricaCropType(RasterDataset):
|
||||||
_10m
|
_10m
|
||||||
"""
|
"""
|
||||||
date_format = '%Y_%m_%d'
|
date_format = '%Y_%m_%d'
|
||||||
rgb_bands = ['B04', 'B03', 'B02']
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
s1_bands = ['VH', 'VV']
|
s1_bands = ('VH', 'VV')
|
||||||
s2_bands = [
|
s2_bands = (
|
||||||
'B01',
|
'B01',
|
||||||
'B02',
|
'B02',
|
||||||
'B03',
|
'B03',
|
||||||
|
@ -94,9 +94,9 @@ class SouthAfricaCropType(RasterDataset):
|
||||||
'B09',
|
'B09',
|
||||||
'B11',
|
'B11',
|
||||||
'B12',
|
'B12',
|
||||||
]
|
)
|
||||||
all_bands: list[str] = s1_bands + s2_bands
|
all_bands = s1_bands + s2_bands
|
||||||
cmap = {
|
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
|
||||||
0: (0, 0, 0, 255),
|
0: (0, 0, 0, 255),
|
||||||
1: (255, 211, 0, 255),
|
1: (255, 211, 0, 255),
|
||||||
2: (255, 37, 37, 255),
|
2: (255, 37, 37, 255),
|
||||||
|
@ -113,8 +113,8 @@ class SouthAfricaCropType(RasterDataset):
|
||||||
self,
|
self,
|
||||||
paths: Path | Iterable[Path] = 'data',
|
paths: Path | Iterable[Path] = 'data',
|
||||||
crs: CRS | None = None,
|
crs: CRS | None = None,
|
||||||
classes: list[int] = list(cmap.keys()),
|
classes: Sequence[int] = list(cmap.keys()),
|
||||||
bands: list[str] = s2_bands,
|
bands: Sequence[str] = s2_bands,
|
||||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
@ -47,7 +47,7 @@ class SouthAmericaSoybean(RasterDataset):
|
||||||
is_image = False
|
is_image = False
|
||||||
url = 'https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif'
|
url = 'https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif'
|
||||||
|
|
||||||
md5s = {
|
md5s: ClassVar[dict[int, str]] = {
|
||||||
2021: 'edff3ada13a1a9910d1fe844d28ae4f',
|
2021: 'edff3ada13a1a9910d1fe844d28ae4f',
|
||||||
2020: '0709dec807f576c9707c8c7e183db31',
|
2020: '0709dec807f576c9707c8c7e183db31',
|
||||||
2019: '441836493bbcd5e123cff579a58f5a4f',
|
2019: '441836493bbcd5e123cff579a58f5a4f',
|
||||||
|
|
|
@ -8,7 +8,7 @@ import os
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import fiona
|
import fiona
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -55,9 +55,9 @@ class SpaceNet(NonGeoDataset, ABC):
|
||||||
image_glob = '*.tif'
|
image_glob = '*.tif'
|
||||||
mask_glob = '*.geojson'
|
mask_glob = '*.geojson'
|
||||||
file_regex = r'_img(\d+)\.'
|
file_regex = r'_img(\d+)\.'
|
||||||
chip_size: dict[str, tuple[int, int]] = {}
|
chip_size: ClassVar[dict[str, tuple[int, int]]] = {}
|
||||||
|
|
||||||
cities = {
|
cities: ClassVar[dict[int, str]] = {
|
||||||
1: 'Rio',
|
1: 'Rio',
|
||||||
2: 'Vegas',
|
2: 'Vegas',
|
||||||
3: 'Paris',
|
3: 'Paris',
|
||||||
|
@ -98,7 +98,7 @@ class SpaceNet(NonGeoDataset, ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def valid_masks(self) -> list[str]:
|
def valid_masks(self) -> tuple[str, ...]:
|
||||||
"""List of valid masks."""
|
"""List of valid masks."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -426,7 +426,7 @@ class SpaceNet1(SpaceNet):
|
||||||
|
|
||||||
directory_glob = '{product}'
|
directory_glob = '{product}'
|
||||||
dataset_id = 'SN1_buildings'
|
dataset_id = 'SN1_buildings'
|
||||||
tarballs = {
|
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
1: [
|
1: [
|
||||||
'SN1_buildings_train_AOI_1_Rio_3band.tar.gz',
|
'SN1_buildings_train_AOI_1_Rio_3band.tar.gz',
|
||||||
|
@ -441,7 +441,7 @@ class SpaceNet1(SpaceNet):
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
1: [
|
1: [
|
||||||
'279e334a2120ecac70439ea246174516',
|
'279e334a2120ecac70439ea246174516',
|
||||||
|
@ -453,10 +453,16 @@ class SpaceNet1(SpaceNet):
|
||||||
1: ['18283d78b21c239bc1831f3bf1d2c996', '732b3a40603b76e80aac84e002e2b3e8']
|
1: ['18283d78b21c239bc1831f3bf1d2c996', '732b3a40603b76e80aac84e002e2b3e8']
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
valid_aois = {'train': [1], 'test': [1]}
|
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [1], 'test': [1]}
|
||||||
valid_images = {'train': ['3band', '8band'], 'test': ['3band', '8band']}
|
valid_images: ClassVar[dict[str, list[str]]] = {
|
||||||
valid_masks = ['geojson']
|
'train': ['3band', '8band'],
|
||||||
chip_size = {'3band': (406, 439), '8band': (102, 110)}
|
'test': ['3band', '8band'],
|
||||||
|
}
|
||||||
|
valid_masks = ('geojson',)
|
||||||
|
chip_size: ClassVar[dict[str, tuple[int, int]]] = {
|
||||||
|
'3band': (406, 439),
|
||||||
|
'8band': (102, 110),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SpaceNet2(SpaceNet):
|
class SpaceNet2(SpaceNet):
|
||||||
|
@ -522,7 +528,7 @@ class SpaceNet2(SpaceNet):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataset_id = 'SN2_buildings'
|
dataset_id = 'SN2_buildings'
|
||||||
tarballs = {
|
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
2: ['SN2_buildings_train_AOI_2_Vegas.tar.gz'],
|
2: ['SN2_buildings_train_AOI_2_Vegas.tar.gz'],
|
||||||
3: ['SN2_buildings_train_AOI_3_Paris.tar.gz'],
|
3: ['SN2_buildings_train_AOI_3_Paris.tar.gz'],
|
||||||
|
@ -536,7 +542,7 @@ class SpaceNet2(SpaceNet):
|
||||||
5: ['AOI_5_Khartoum_Test_public.tar.gz'],
|
5: ['AOI_5_Khartoum_Test_public.tar.gz'],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
2: ['307da318bc43aaf9481828f92eda9126'],
|
2: ['307da318bc43aaf9481828f92eda9126'],
|
||||||
3: ['4db469e3e4e7bf025368ad730aec0888'],
|
3: ['4db469e3e4e7bf025368ad730aec0888'],
|
||||||
|
@ -550,13 +556,16 @@ class SpaceNet2(SpaceNet):
|
||||||
5: ['037d7be10530f0dd1c43d4ef79f3236e'],
|
5: ['037d7be10530f0dd1c43d4ef79f3236e'],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]}
|
valid_aois: ClassVar[dict[str, list[int]]] = {
|
||||||
valid_images = {
|
'train': [2, 3, 4, 5],
|
||||||
|
'test': [2, 3, 4, 5],
|
||||||
|
}
|
||||||
|
valid_images: ClassVar[dict[str, list[str]]] = {
|
||||||
'train': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
|
'train': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
|
||||||
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
|
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
|
||||||
}
|
}
|
||||||
valid_masks = [os.path.join('geojson', 'buildings')]
|
valid_masks = (os.path.join('geojson', 'buildings'),)
|
||||||
chip_size = {'MUL': (163, 163)}
|
chip_size: ClassVar[dict[str, tuple[int, int]]] = {'MUL': (163, 163)}
|
||||||
|
|
||||||
|
|
||||||
class SpaceNet3(SpaceNet):
|
class SpaceNet3(SpaceNet):
|
||||||
|
@ -624,7 +633,7 @@ class SpaceNet3(SpaceNet):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataset_id = 'SN3_roads'
|
dataset_id = 'SN3_roads'
|
||||||
tarballs = {
|
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
2: [
|
2: [
|
||||||
'SN3_roads_train_AOI_2_Vegas.tar.gz',
|
'SN3_roads_train_AOI_2_Vegas.tar.gz',
|
||||||
|
@ -650,7 +659,7 @@ class SpaceNet3(SpaceNet):
|
||||||
5: ['SN3_roads_test_public_AOI_5_Khartoum.tar.gz'],
|
5: ['SN3_roads_test_public_AOI_5_Khartoum.tar.gz'],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
2: ['06317255b5e0c6df2643efd8a50f22ae', '4acf7846ed8121db1319345cfe9fdca9'],
|
2: ['06317255b5e0c6df2643efd8a50f22ae', '4acf7846ed8121db1319345cfe9fdca9'],
|
||||||
3: ['c13baf88ee10fe47870c303223cabf82', 'abc8199d4c522d3a14328f4f514702ad'],
|
3: ['c13baf88ee10fe47870c303223cabf82', 'abc8199d4c522d3a14328f4f514702ad'],
|
||||||
|
@ -664,12 +673,15 @@ class SpaceNet3(SpaceNet):
|
||||||
5: ['f367c79fa0fc1d38e63a0fdd065ed957'],
|
5: ['f367c79fa0fc1d38e63a0fdd065ed957'],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
valid_aois = {'train': [2, 3, 4, 5], 'test': [2, 3, 4, 5]}
|
valid_aois: ClassVar[dict[str, list[int]]] = {
|
||||||
valid_images = {
|
'train': [2, 3, 4, 5],
|
||||||
|
'test': [2, 3, 4, 5],
|
||||||
|
}
|
||||||
|
valid_images: ClassVar[dict[str, list[str]]] = {
|
||||||
'train': ['MS', 'PS-MS', 'PAN', 'PS-RGB'],
|
'train': ['MS', 'PS-MS', 'PAN', 'PS-RGB'],
|
||||||
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
|
'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'],
|
||||||
}
|
}
|
||||||
valid_masks = ['geojson_roads', 'geojson_roads_speed']
|
valid_masks: tuple[str, ...] = ('geojson_roads', 'geojson_roads_speed')
|
||||||
|
|
||||||
|
|
||||||
class SpaceNet4(SpaceNet):
|
class SpaceNet4(SpaceNet):
|
||||||
|
@ -708,7 +720,7 @@ class SpaceNet4(SpaceNet):
|
||||||
directory_glob = os.path.join('**', '{product}')
|
directory_glob = os.path.join('**', '{product}')
|
||||||
file_regex = r'_(\d+_\d+)\.'
|
file_regex = r'_(\d+_\d+)\.'
|
||||||
dataset_id = 'SN4_buildings'
|
dataset_id = 'SN4_buildings'
|
||||||
tarballs = {
|
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
6: [
|
6: [
|
||||||
'Atlanta_nadir7_catid_1030010003D22F00.tar.gz',
|
'Atlanta_nadir7_catid_1030010003D22F00.tar.gz',
|
||||||
|
@ -743,7 +755,7 @@ class SpaceNet4(SpaceNet):
|
||||||
},
|
},
|
||||||
'test': {6: ['SN4_buildings_AOI_6_Atlanta_test_public.tar.gz']},
|
'test': {6: ['SN4_buildings_AOI_6_Atlanta_test_public.tar.gz']},
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
6: [
|
6: [
|
||||||
'd41ab6ec087b07e1e046c55d1fa5754b',
|
'd41ab6ec087b07e1e046c55d1fa5754b',
|
||||||
|
@ -778,12 +790,12 @@ class SpaceNet4(SpaceNet):
|
||||||
},
|
},
|
||||||
'test': {6: ['0ec3874bfc19aed63b33ac47b039aace']},
|
'test': {6: ['0ec3874bfc19aed63b33ac47b039aace']},
|
||||||
}
|
}
|
||||||
valid_aois = {'train': [6], 'test': [6]}
|
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [6], 'test': [6]}
|
||||||
valid_images = {
|
valid_images: ClassVar[dict[str, list[str]]] = {
|
||||||
'train': ['MS', 'PAN', 'Pan-Sharpen'],
|
'train': ['MS', 'PAN', 'Pan-Sharpen'],
|
||||||
'test': ['MS', 'PAN', 'Pan-Sharpen'],
|
'test': ['MS', 'PAN', 'Pan-Sharpen'],
|
||||||
}
|
}
|
||||||
valid_masks = [os.path.join('geojson', 'spacenet-buildings')]
|
valid_masks = (os.path.join('geojson', 'spacenet-buildings'),)
|
||||||
|
|
||||||
|
|
||||||
class SpaceNet5(SpaceNet3):
|
class SpaceNet5(SpaceNet3):
|
||||||
|
@ -850,26 +862,26 @@ class SpaceNet5(SpaceNet3):
|
||||||
|
|
||||||
file_regex = r'_chip(\d+)\.'
|
file_regex = r'_chip(\d+)\.'
|
||||||
dataset_id = 'SN5_roads'
|
dataset_id = 'SN5_roads'
|
||||||
tarballs = {
|
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
7: ['SN5_roads_train_AOI_7_Moscow.tar.gz'],
|
7: ['SN5_roads_train_AOI_7_Moscow.tar.gz'],
|
||||||
8: ['SN5_roads_train_AOI_8_Mumbai.tar.gz'],
|
8: ['SN5_roads_train_AOI_8_Mumbai.tar.gz'],
|
||||||
},
|
},
|
||||||
'test': {9: ['SN5_roads_test_public_AOI_9_San_Juan.tar.gz']},
|
'test': {9: ['SN5_roads_test_public_AOI_9_San_Juan.tar.gz']},
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
7: ['03082d01081a6d8df2bc5a9645148d2a'],
|
7: ['03082d01081a6d8df2bc5a9645148d2a'],
|
||||||
8: ['1ee20ba781da6cb7696eef9a95a5bdcc'],
|
8: ['1ee20ba781da6cb7696eef9a95a5bdcc'],
|
||||||
},
|
},
|
||||||
'test': {9: ['fc45afef219dfd3a20f2d4fc597f6882']},
|
'test': {9: ['fc45afef219dfd3a20f2d4fc597f6882']},
|
||||||
}
|
}
|
||||||
valid_aois = {'train': [7, 8], 'test': [9]}
|
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [7, 8], 'test': [9]}
|
||||||
valid_images = {
|
valid_images: ClassVar[dict[str, list[str]]] = {
|
||||||
'train': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
|
'train': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
|
||||||
'test': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
|
'test': ['MS', 'PAN', 'PS-MS', 'PS-RGB'],
|
||||||
}
|
}
|
||||||
valid_masks = ['geojson_roads_speed']
|
valid_masks = ('geojson_roads_speed',)
|
||||||
|
|
||||||
|
|
||||||
class SpaceNet6(SpaceNet):
|
class SpaceNet6(SpaceNet):
|
||||||
|
@ -937,20 +949,20 @@ class SpaceNet6(SpaceNet):
|
||||||
|
|
||||||
file_regex = r'_tile_(\d+)\.'
|
file_regex = r'_tile_(\d+)\.'
|
||||||
dataset_id = 'SN6_buildings'
|
dataset_id = 'SN6_buildings'
|
||||||
tarballs = {
|
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {11: ['SN6_buildings_AOI_11_Rotterdam_train.tar.gz']},
|
'train': {11: ['SN6_buildings_AOI_11_Rotterdam_train.tar.gz']},
|
||||||
'test': {11: ['SN6_buildings_AOI_11_Rotterdam_test_public.tar.gz']},
|
'test': {11: ['SN6_buildings_AOI_11_Rotterdam_test_public.tar.gz']},
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {11: ['10ca26d2287716e3b6ef0cf0ad9f946e']},
|
'train': {11: ['10ca26d2287716e3b6ef0cf0ad9f946e']},
|
||||||
'test': {11: ['a07823a5e536feeb8bb6b6f0cb43cf05']},
|
'test': {11: ['a07823a5e536feeb8bb6b6f0cb43cf05']},
|
||||||
}
|
}
|
||||||
valid_aois = {'train': [11], 'test': [11]}
|
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [11], 'test': [11]}
|
||||||
valid_images = {
|
valid_images: ClassVar[dict[str, list[str]]] = {
|
||||||
'train': ['PAN', 'PS-RGB', 'PS-RGBNIR', 'RGBNIR', 'SAR-Intensity'],
|
'train': ['PAN', 'PS-RGB', 'PS-RGBNIR', 'RGBNIR', 'SAR-Intensity'],
|
||||||
'test': ['SAR-Intensity'],
|
'test': ['SAR-Intensity'],
|
||||||
}
|
}
|
||||||
valid_masks = ['geojson_buildings']
|
valid_masks = ('geojson_buildings',)
|
||||||
|
|
||||||
|
|
||||||
class SpaceNet7(SpaceNet):
|
class SpaceNet7(SpaceNet):
|
||||||
|
@ -958,7 +970,7 @@ class SpaceNet7(SpaceNet):
|
||||||
|
|
||||||
`SpaceNet 7 <https://spacenet.ai/sn7-challenge/>`_ is a dataset which
|
`SpaceNet 7 <https://spacenet.ai/sn7-challenge/>`_ is a dataset which
|
||||||
consist of medium resolution (4.0m) satellite imagery mosaics acquired from
|
consist of medium resolution (4.0m) satellite imagery mosaics acquired from
|
||||||
Planet Labs’ Dove constellation between 2017 and 2020. It includes ≈ 24
|
Planet Labs' Dove constellation between 2017 and 2020. It includes ≈ 24
|
||||||
images (one per month) covering > 100 unique geographies, and comprises >
|
images (one per month) covering > 100 unique geographies, and comprises >
|
||||||
40,000 km2 of imagery and exhaustive polygon labels of building footprints
|
40,000 km2 of imagery and exhaustive polygon labels of building footprints
|
||||||
therein, totaling over 11M individual annotations.
|
therein, totaling over 11M individual annotations.
|
||||||
|
@ -993,18 +1005,24 @@ class SpaceNet7(SpaceNet):
|
||||||
mask_glob = '*_Buildings.geojson'
|
mask_glob = '*_Buildings.geojson'
|
||||||
file_regex = r'global_monthly_(\d+.*\d+)'
|
file_regex = r'global_monthly_(\d+.*\d+)'
|
||||||
dataset_id = 'SN7_buildings'
|
dataset_id = 'SN7_buildings'
|
||||||
tarballs = {
|
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {0: ['SN7_buildings_train.tar.gz']},
|
'train': {0: ['SN7_buildings_train.tar.gz']},
|
||||||
'test': {0: ['SN7_buildings_test_public.tar.gz']},
|
'test': {0: ['SN7_buildings_test_public.tar.gz']},
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {0: ['6eda13b9c28f6f5cdf00a7e8e218c1b1']},
|
'train': {0: ['6eda13b9c28f6f5cdf00a7e8e218c1b1']},
|
||||||
'test': {0: ['b3bde95a0f8f32f3bfeba49464b9bc97']},
|
'test': {0: ['b3bde95a0f8f32f3bfeba49464b9bc97']},
|
||||||
}
|
}
|
||||||
valid_aois = {'train': [0], 'test': [0]}
|
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]}
|
||||||
valid_images = {'train': ['images', 'images_masked'], 'test': ['images_masked']}
|
valid_images: ClassVar[dict[str, list[str]]] = {
|
||||||
valid_masks = ['labels', 'labels_match', 'labels_match_pix']
|
'train': ['images', 'images_masked'],
|
||||||
chip_size = {'images': (1024, 1024), 'images_masked': (1024, 1024)}
|
'test': ['images_masked'],
|
||||||
|
}
|
||||||
|
valid_masks = ('labels', 'labels_match', 'labels_match_pix')
|
||||||
|
chip_size: ClassVar[dict[str, tuple[int, int]]] = {
|
||||||
|
'images': (1024, 1024),
|
||||||
|
'images_masked': (1024, 1024),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SpaceNet8(SpaceNet):
|
class SpaceNet8(SpaceNet):
|
||||||
|
@ -1024,7 +1042,7 @@ class SpaceNet8(SpaceNet):
|
||||||
directory_glob = '{product}'
|
directory_glob = '{product}'
|
||||||
file_regex = r'(\d+_\d+_\d+)\.'
|
file_regex = r'(\d+_\d+_\d+)\.'
|
||||||
dataset_id = 'SN8_floods'
|
dataset_id = 'SN8_floods'
|
||||||
tarballs = {
|
tarballs: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
0: [
|
0: [
|
||||||
'Germany_Training_Public.tar.gz',
|
'Germany_Training_Public.tar.gz',
|
||||||
|
@ -1033,16 +1051,19 @@ class SpaceNet8(SpaceNet):
|
||||||
},
|
},
|
||||||
'test': {0: ['Louisiana-West_Test_Public.tar.gz']},
|
'test': {0: ['Louisiana-West_Test_Public.tar.gz']},
|
||||||
}
|
}
|
||||||
md5s = {
|
md5s: ClassVar[dict[str, dict[int, list[str]]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
0: ['81383a9050b93e8f70c8557d4568e8a2', 'fa40ae3cf6ac212c90073bf93d70bd95']
|
0: ['81383a9050b93e8f70c8557d4568e8a2', 'fa40ae3cf6ac212c90073bf93d70bd95']
|
||||||
},
|
},
|
||||||
'test': {0: ['d41d8cd98f00b204e9800998ecf8427e']},
|
'test': {0: ['d41d8cd98f00b204e9800998ecf8427e']},
|
||||||
}
|
}
|
||||||
valid_aois = {'train': [0], 'test': [0]}
|
valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]}
|
||||||
valid_images = {
|
valid_images: ClassVar[dict[str, list[str]]] = {
|
||||||
'train': ['PRE-event', 'POST-event'],
|
'train': ['PRE-event', 'POST-event'],
|
||||||
'test': ['PRE-event', 'POST-event'],
|
'test': ['PRE-event', 'POST-event'],
|
||||||
}
|
}
|
||||||
valid_masks = ['annotations']
|
valid_masks = ('annotations',)
|
||||||
chip_size = {'PRE-event': (1300, 1300), 'POST-event': (1300, 1300)}
|
chip_size: ClassVar[dict[str, tuple[int, int]]] = {
|
||||||
|
'PRE-event': (1300, 1300),
|
||||||
|
'POST-event': (1300, 1300),
|
||||||
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import glob
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TypedDict
|
from typing import ClassVar, TypedDict
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -93,13 +93,13 @@ class SSL4EOL(NonGeoDataset):
|
||||||
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
|
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
|
||||||
|
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
class _Metadata(TypedDict):
|
class _Metadata(TypedDict):
|
||||||
num_bands: int
|
num_bands: int
|
||||||
rgb_bands: list[int]
|
rgb_bands: list[int]
|
||||||
|
|
||||||
metadata: dict[str, _Metadata] = {
|
metadata: ClassVar[dict[str, _Metadata]] = {
|
||||||
'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]},
|
'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]},
|
||||||
'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]},
|
'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]},
|
||||||
'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]},
|
'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]},
|
||||||
|
@ -107,8 +107,8 @@ class SSL4EOL(NonGeoDataset):
|
||||||
'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]},
|
'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]},
|
||||||
}
|
}
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}' # noqa: E501
|
url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}'
|
||||||
checksums = {
|
checksums: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'tm_toa': {
|
'tm_toa': {
|
||||||
'aa': '553795b8d73aa253445b1e67c5b81f11',
|
'aa': '553795b8d73aa253445b1e67c5b81f11',
|
||||||
'ab': 'e9e0739b5171b37d16086cb89ab370e8',
|
'ab': 'e9e0739b5171b37d16086cb89ab370e8',
|
||||||
|
@ -357,7 +357,7 @@ class SSL4EOS12(NonGeoDataset):
|
||||||
md5: str
|
md5: str
|
||||||
bands: list[str]
|
bands: list[str]
|
||||||
|
|
||||||
metadata: dict[str, _Metadata] = {
|
metadata: ClassVar[dict[str, _Metadata]] = {
|
||||||
's1': {
|
's1': {
|
||||||
'filename': 's1.tar.gz',
|
'filename': 's1.tar.gz',
|
||||||
'md5': '51ee23b33eb0a2f920bda25225072f3a',
|
'md5': '51ee23b33eb0a2f920bda25225072f3a',
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -46,16 +47,16 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
||||||
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
|
* https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html
|
||||||
|
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz' # noqa: E501
|
url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz'
|
||||||
|
|
||||||
valid_sensors = ['tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr']
|
valid_sensors = ('tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr')
|
||||||
valid_products = ['cdl', 'nlcd']
|
valid_products = ('cdl', 'nlcd')
|
||||||
valid_splits = ['train', 'val', 'test']
|
valid_splits = ('train', 'val', 'test')
|
||||||
|
|
||||||
image_root = 'ssl4eo_l_{}_benchmark'
|
image_root = 'ssl4eo_l_{}_benchmark'
|
||||||
img_md5s = {
|
img_md5s: ClassVar[dict[str, str]] = {
|
||||||
'tm_toa': '8e3c5bcd56d3780a442f1332013b8d15',
|
'tm_toa': '8e3c5bcd56d3780a442f1332013b8d15',
|
||||||
'etm_toa': '1b051c7fe4d61c581b341370c9e76f1f',
|
'etm_toa': '1b051c7fe4d61c581b341370c9e76f1f',
|
||||||
'etm_sr': '34a24fa89a801654f8d01e054662c8cd',
|
'etm_sr': '34a24fa89a801654f8d01e054662c8cd',
|
||||||
|
@ -63,14 +64,14 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
||||||
'oli_sr': '0700cd15cc2366fe68c2f8c02fa09a15',
|
'oli_sr': '0700cd15cc2366fe68c2f8c02fa09a15',
|
||||||
}
|
}
|
||||||
|
|
||||||
mask_dir_dict = {
|
mask_dir_dict: ClassVar[dict[str, str]] = {
|
||||||
'tm_toa': 'ssl4eo_l_tm_{}',
|
'tm_toa': 'ssl4eo_l_tm_{}',
|
||||||
'etm_toa': 'ssl4eo_l_etm_{}',
|
'etm_toa': 'ssl4eo_l_etm_{}',
|
||||||
'etm_sr': 'ssl4eo_l_etm_{}',
|
'etm_sr': 'ssl4eo_l_etm_{}',
|
||||||
'oli_tirs_toa': 'ssl4eo_l_oli_{}',
|
'oli_tirs_toa': 'ssl4eo_l_oli_{}',
|
||||||
'oli_sr': 'ssl4eo_l_oli_{}',
|
'oli_sr': 'ssl4eo_l_oli_{}',
|
||||||
}
|
}
|
||||||
mask_md5s = {
|
mask_md5s: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'tm': {
|
'tm': {
|
||||||
'cdl': '3d676770ffb56c7e222a7192a652a846',
|
'cdl': '3d676770ffb56c7e222a7192a652a846',
|
||||||
'nlcd': '261149d7614fcfdcb3be368eefa825c7',
|
'nlcd': '261149d7614fcfdcb3be368eefa825c7',
|
||||||
|
@ -85,7 +86,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
year_dict = {
|
year_dict: ClassVar[dict[str, int]] = {
|
||||||
'tm_toa': 2011,
|
'tm_toa': 2011,
|
||||||
'etm_toa': 2019,
|
'etm_toa': 2019,
|
||||||
'etm_sr': 2019,
|
'etm_sr': 2019,
|
||||||
|
@ -93,7 +94,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
||||||
'oli_sr': 2019,
|
'oli_sr': 2019,
|
||||||
}
|
}
|
||||||
|
|
||||||
rgb_indices = {
|
rgb_indices: ClassVar[dict[str, list[int]]] = {
|
||||||
'tm_toa': [2, 1, 0],
|
'tm_toa': [2, 1, 0],
|
||||||
'etm_toa': [2, 1, 0],
|
'etm_toa': [2, 1, 0],
|
||||||
'etm_sr': [2, 1, 0],
|
'etm_sr': [2, 1, 0],
|
||||||
|
@ -101,9 +102,12 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
||||||
'oli_sr': [3, 2, 1],
|
'oli_sr': [3, 2, 1],
|
||||||
}
|
}
|
||||||
|
|
||||||
split_percentages = [0.7, 0.15, 0.15]
|
split_percentages = (0.7, 0.15, 0.15)
|
||||||
|
|
||||||
cmaps = {'nlcd': NLCD.cmap, 'cdl': CDL.cmap}
|
cmaps: ClassVar[dict[str, dict[int, tuple[int, int, int, int]]]] = {
|
||||||
|
'nlcd': NLCD.cmap,
|
||||||
|
'cdl': CDL.cmap,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -45,17 +45,17 @@ class SustainBenchCropYield(NonGeoDataset):
|
||||||
* https://doi.org/10.1609/aaai.v31i1.11172
|
* https://doi.org/10.1609/aaai.v31i1.11172
|
||||||
|
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
valid_countries = ['usa', 'brazil', 'argentina']
|
valid_countries = ('usa', 'brazil', 'argentina')
|
||||||
|
|
||||||
md5 = '362bad07b51a1264172b8376b39d1fc9'
|
md5 = '362bad07b51a1264172b8376b39d1fc9'
|
||||||
|
|
||||||
url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link' # noqa: E501
|
url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link'
|
||||||
|
|
||||||
dir = 'soybeans'
|
dir = 'soybeans'
|
||||||
|
|
||||||
valid_splits = ['train', 'dev', 'test']
|
valid_splits = ('train', 'dev', 'test')
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import cast
|
from typing import ClassVar, cast
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -66,19 +66,19 @@ class UCMerced(NonGeoClassificationDataset):
|
||||||
* https://dl.acm.org/doi/10.1145/1869790.1869829
|
* https://dl.acm.org/doi/10.1145/1869790.1869829
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip' # noqa: E501
|
url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip'
|
||||||
filename = 'UCMerced_LandUse.zip'
|
filename = 'UCMerced_LandUse.zip'
|
||||||
md5 = '5b7ec56793786b6dc8a908e8854ac0e4'
|
md5 = '5b7ec56793786b6dc8a908e8854ac0e4'
|
||||||
|
|
||||||
base_dir = os.path.join('UCMerced_LandUse', 'Images')
|
base_dir = os.path.join('UCMerced_LandUse', 'Images')
|
||||||
|
|
||||||
splits = ['train', 'val', 'test']
|
splits = ('train', 'val', 'test')
|
||||||
split_urls = {
|
split_urls: ClassVar[dict[str, str]] = {
|
||||||
'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt', # noqa: E501
|
'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt',
|
||||||
'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt', # noqa: E501
|
'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt',
|
||||||
'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt', # noqa: E501
|
'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt',
|
||||||
}
|
}
|
||||||
split_md5s = {
|
split_md5s: ClassVar[dict[str, str]] = {
|
||||||
'train': 'f2fb12eb2210cfb53f93f063a35ff374',
|
'train': 'f2fb12eb2210cfb53f93f063a35ff374',
|
||||||
'val': '11ecabfc52782e5ea6a9c7c0d263aca0',
|
'val': '11ecabfc52782e5ea6a9c7c0d263aca0',
|
||||||
'test': '046aff88472d8fc07c4678d03749e28d',
|
'test': '046aff88472d8fc07c4678d03749e28d',
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -49,12 +50,12 @@ class USAVars(NonGeoDataset):
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}' # noqa: E501
|
data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}'
|
||||||
dirname = 'uar'
|
dirname = 'uar'
|
||||||
|
|
||||||
md5 = '677e89fd20e5dd0fe4d29b61827c2456'
|
md5 = '677e89fd20e5dd0fe4d29b61827c2456'
|
||||||
|
|
||||||
label_urls = {
|
label_urls: ClassVar[dict[str, str]] = {
|
||||||
'housing': data_url.format('housing.csv'),
|
'housing': data_url.format('housing.csv'),
|
||||||
'income': data_url.format('income.csv'),
|
'income': data_url.format('income.csv'),
|
||||||
'roads': data_url.format('roads.csv'),
|
'roads': data_url.format('roads.csv'),
|
||||||
|
@ -64,7 +65,7 @@ class USAVars(NonGeoDataset):
|
||||||
'treecover': data_url.format('treecover.csv'),
|
'treecover': data_url.format('treecover.csv'),
|
||||||
}
|
}
|
||||||
|
|
||||||
split_metadata = {
|
split_metadata: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
'url': data_url.format('train_split.txt'),
|
'url': data_url.format('train_split.txt'),
|
||||||
'filename': 'train_split.txt',
|
'filename': 'train_split.txt',
|
||||||
|
@ -82,7 +83,7 @@ class USAVars(NonGeoDataset):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ALL_LABELS = ['treecover', 'elevation', 'population']
|
ALL_LABELS = ('treecover', 'elevation', 'population')
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -86,11 +86,11 @@ class BoundingBox:
|
||||||
|
|
||||||
# https://github.com/PyCQA/pydocstyle/issues/525
|
# https://github.com/PyCQA/pydocstyle/issues/525
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: int) -> float: # noqa: D105
|
def __getitem__(self, key: int) -> float:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: slice) -> list[float]: # noqa: D105
|
def __getitem__(self, key: slice) -> list[float]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __getitem__(self, key: int | slice) -> float | list[float]:
|
def __getitem__(self, key: int | slice) -> float | list[float]:
|
||||||
|
@ -289,7 +289,7 @@ class Executable:
|
||||||
The completed process.
|
The completed process.
|
||||||
"""
|
"""
|
||||||
kwargs['check'] = True
|
kwargs['check'] = True
|
||||||
return subprocess.run((self.name,) + args, **kwargs)
|
return subprocess.run((self.name, *args), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]:
|
def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]:
|
||||||
|
@ -547,7 +547,7 @@ def draw_semantic_segmentation_masks(
|
||||||
|
|
||||||
|
|
||||||
def rgb_to_mask(
|
def rgb_to_mask(
|
||||||
rgb: np.typing.NDArray[np.uint8], colors: list[tuple[int, int, int]]
|
rgb: np.typing.NDArray[np.uint8], colors: Sequence[tuple[int, int, int]]
|
||||||
) -> np.typing.NDArray[np.uint8]:
|
) -> np.typing.NDArray[np.uint8]:
|
||||||
"""Converts an RGB colormap mask to a integer mask.
|
"""Converts an RGB colormap mask to a integer mask.
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -55,15 +56,15 @@ class Vaihingen2D(NonGeoDataset):
|
||||||
* https://doi.org/10.5194/isprsannals-I-3-293-2012
|
* https://doi.org/10.5194/isprsannals-I-3-293-2012
|
||||||
|
|
||||||
.. versionadded:: 0.2
|
.. versionadded:: 0.2
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
filenames = [
|
filenames = (
|
||||||
'ISPRS_semantic_labeling_Vaihingen.zip',
|
'ISPRS_semantic_labeling_Vaihingen.zip',
|
||||||
'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip',
|
'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip',
|
||||||
]
|
)
|
||||||
md5s = ['462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277']
|
md5s = ('462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277')
|
||||||
image_root = 'top'
|
image_root = 'top'
|
||||||
splits = {
|
splits: ClassVar[dict[str, list[str]]] = {
|
||||||
'train': [
|
'train': [
|
||||||
'top_mosaic_09cm_area1.tif',
|
'top_mosaic_09cm_area1.tif',
|
||||||
'top_mosaic_09cm_area11.tif',
|
'top_mosaic_09cm_area11.tif',
|
||||||
|
@ -102,22 +103,22 @@ class Vaihingen2D(NonGeoDataset):
|
||||||
'top_mosaic_09cm_area29.tif',
|
'top_mosaic_09cm_area29.tif',
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
classes = [
|
classes = (
|
||||||
'Clutter/background',
|
'Clutter/background',
|
||||||
'Impervious surfaces',
|
'Impervious surfaces',
|
||||||
'Building',
|
'Building',
|
||||||
'Low Vegetation',
|
'Low Vegetation',
|
||||||
'Tree',
|
'Tree',
|
||||||
'Car',
|
'Car',
|
||||||
]
|
)
|
||||||
colormap = [
|
colormap = (
|
||||||
(255, 0, 0),
|
(255, 0, 0),
|
||||||
(255, 255, 255),
|
(255, 255, 255),
|
||||||
(0, 0, 255),
|
(0, 0, 255),
|
||||||
(0, 255, 255),
|
(0, 255, 255),
|
||||||
(0, 255, 0),
|
(0, 255, 0),
|
||||||
(255, 255, 0),
|
(255, 255, 0),
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -258,7 +259,7 @@ class Vaihingen2D(NonGeoDataset):
|
||||||
"""
|
"""
|
||||||
ncols = 1
|
ncols = 1
|
||||||
image1 = draw_semantic_segmentation_masks(
|
image1 = draw_semantic_segmentation_masks(
|
||||||
sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap
|
sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap)
|
||||||
)
|
)
|
||||||
if 'prediction' in sample:
|
if 'prediction' in sample:
|
||||||
ncols += 1
|
ncols += 1
|
||||||
|
@ -266,7 +267,7 @@ class Vaihingen2D(NonGeoDataset):
|
||||||
sample['image'][:3],
|
sample['image'][:3],
|
||||||
sample['prediction'],
|
sample['prediction'],
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
colors=self.colormap,
|
colors=list(self.colormap),
|
||||||
)
|
)
|
||||||
|
|
||||||
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -158,18 +158,18 @@ class VHR10(NonGeoDataset):
|
||||||
``annotations.json`` file for the "positive" image set
|
``annotations.json`` file for the "positive" image set
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_meta = {
|
image_meta: ClassVar[dict[str, str]] = {
|
||||||
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/NWPU%20VHR-10%20dataset.zip',
|
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/NWPU%20VHR-10%20dataset.zip',
|
||||||
'filename': 'NWPU VHR-10 dataset.zip',
|
'filename': 'NWPU VHR-10 dataset.zip',
|
||||||
'md5': '6add6751469c12dd8c8d6223064c6c4d',
|
'md5': '6add6751469c12dd8c8d6223064c6c4d',
|
||||||
}
|
}
|
||||||
target_meta = {
|
target_meta: ClassVar[dict[str, str]] = {
|
||||||
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/annotations.json',
|
'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/annotations.json',
|
||||||
'filename': 'annotations.json',
|
'filename': 'annotations.json',
|
||||||
'md5': '7c76ec50c17a61bb0514050d20f22c08',
|
'md5': '7c76ec50c17a61bb0514050d20f22c08',
|
||||||
}
|
}
|
||||||
|
|
||||||
categories = [
|
categories = (
|
||||||
'background',
|
'background',
|
||||||
'airplane',
|
'airplane',
|
||||||
'ships',
|
'ships',
|
||||||
|
@ -181,7 +181,7 @@ class VHR10(NonGeoDataset):
|
||||||
'harbor',
|
'harbor',
|
||||||
'bridge',
|
'bridge',
|
||||||
'vehicle',
|
'vehicle',
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Iterable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -53,7 +53,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
|
||||||
|
|
||||||
label_name = 'percent(t)'
|
label_name = 'percent(t)'
|
||||||
|
|
||||||
all_variable_names = [
|
all_variable_names = (
|
||||||
# "date",
|
# "date",
|
||||||
'slope(t)',
|
'slope(t)',
|
||||||
'elevation(t)',
|
'elevation(t)',
|
||||||
|
@ -193,12 +193,12 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
|
||||||
'vh_vv(t-3)',
|
'vh_vv(t-3)',
|
||||||
'lat',
|
'lat',
|
||||||
'lon',
|
'lon',
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: Path = 'data',
|
root: Path = 'data',
|
||||||
input_features: list[str] = all_variable_names,
|
input_features: Iterable[str] = all_variable_names,
|
||||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -273,7 +273,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
|
||||||
data_rows.append(data_dict)
|
data_rows.append(data_dict)
|
||||||
|
|
||||||
df = pd.DataFrame(data_rows)
|
df = pd.DataFrame(data_rows)
|
||||||
df = df[self.input_features + [self.label_name]]
|
df = df[[*self.input_features, self.label_name]]
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def _verify(self) -> None:
|
def _verify(self) -> None:
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -54,7 +55,7 @@ class XView2(NonGeoDataset):
|
||||||
.. versionadded:: 0.2
|
.. versionadded:: 0.2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
metadata = {
|
metadata: ClassVar[dict[str, dict[str, str]]] = {
|
||||||
'train': {
|
'train': {
|
||||||
'filename': 'train_images_labels_targets.tar.gz',
|
'filename': 'train_images_labels_targets.tar.gz',
|
||||||
'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16',
|
'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16',
|
||||||
|
@ -66,8 +67,8 @@ class XView2(NonGeoDataset):
|
||||||
'directory': 'test',
|
'directory': 'test',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed']
|
classes = ('background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed')
|
||||||
colormap = ['green', 'blue', 'orange', 'red']
|
colormap = ('green', 'blue', 'orange', 'red')
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -242,10 +243,16 @@ class XView2(NonGeoDataset):
|
||||||
"""
|
"""
|
||||||
ncols = 2
|
ncols = 2
|
||||||
image1 = draw_semantic_segmentation_masks(
|
image1 = draw_semantic_segmentation_masks(
|
||||||
sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap
|
sample['image'][0],
|
||||||
|
sample['mask'][0],
|
||||||
|
alpha=alpha,
|
||||||
|
colors=list(self.colormap),
|
||||||
)
|
)
|
||||||
image2 = draw_semantic_segmentation_masks(
|
image2 = draw_semantic_segmentation_masks(
|
||||||
sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap
|
sample['image'][1],
|
||||||
|
sample['mask'][1],
|
||||||
|
alpha=alpha,
|
||||||
|
colors=list(self.colormap),
|
||||||
)
|
)
|
||||||
if 'prediction' in sample: # NOTE: this assumes predictions are made for post
|
if 'prediction' in sample: # NOTE: this assumes predictions are made for post
|
||||||
ncols += 1
|
ncols += 1
|
||||||
|
@ -253,7 +260,7 @@ class XView2(NonGeoDataset):
|
||||||
sample['image'][1],
|
sample['image'][1],
|
||||||
sample['prediction'],
|
sample['prediction'],
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
colors=self.colormap,
|
colors=list(self.colormap),
|
||||||
)
|
)
|
||||||
|
|
||||||
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
||||||
|
|
|
@ -52,15 +52,15 @@ class ZueriCrop(NonGeoDataset):
|
||||||
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
|
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
urls = [
|
urls = (
|
||||||
'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download',
|
'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download',
|
||||||
'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv', # noqa: E501
|
'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv',
|
||||||
]
|
)
|
||||||
md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b']
|
md5s = ('1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b')
|
||||||
filenames = ['ZueriCrop.hdf5', 'labels.csv']
|
filenames = ('ZueriCrop.hdf5', 'labels.csv')
|
||||||
|
|
||||||
band_names = ('NIR', 'B03', 'B02', 'B04', 'B05', 'B06', 'B07', 'B11', 'B12')
|
band_names = ('NIR', 'B03', 'B02', 'B04', 'B05', 'B06', 'B07', 'B11', 'B12')
|
||||||
rgb_bands = ['B04', 'B03', 'B02']
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -8,7 +8,7 @@ import os
|
||||||
from lightning.pytorch.cli import ArgsType, LightningCLI
|
from lightning.pytorch.cli import ArgsType, LightningCLI
|
||||||
|
|
||||||
# Allows classes to be referenced using only the class name
|
# Allows classes to be referenced using only the class name
|
||||||
import torchgeo.datamodules # noqa: F401
|
import torchgeo.datamodules
|
||||||
import torchgeo.trainers # noqa: F401
|
import torchgeo.trainers # noqa: F401
|
||||||
from torchgeo.datamodules import BaseDataModule
|
from torchgeo.datamodules import BaseDataModule
|
||||||
from torchgeo.trainers import BaseTask
|
from torchgeo.trainers import BaseTask
|
||||||
|
|
|
@ -8,7 +8,7 @@ See the following references for design details:
|
||||||
* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
|
* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
|
||||||
* https://pytorch.org/vision/stable/models.html
|
* https://pytorch.org/vision/stable/models.html
|
||||||
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
|
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
|
@ -384,7 +384,7 @@ class DOFABase16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DOFA_MAE = Weights(
|
DOFA_MAE = Weights(
|
||||||
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth', # noqa: E501
|
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth',
|
||||||
transforms=_dofa_transforms,
|
transforms=_dofa_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',
|
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',
|
||||||
|
@ -403,7 +403,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DOFA_MAE = Weights(
|
DOFA_MAE = Weights(
|
||||||
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth', # noqa: E501
|
url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth',
|
||||||
transforms=_dofa_transforms,
|
transforms=_dofa_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',
|
'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',
|
||||||
|
|
|
@ -140,7 +140,7 @@ class RCF(Module):
|
||||||
a numpy array of size (N, C, H, W) containing the normalized patches
|
a numpy array of size (N, C, H, W) containing the normalized patches
|
||||||
|
|
||||||
.. versionadded:: 0.5
|
.. versionadded:: 0.5
|
||||||
""" # noqa: E501
|
"""
|
||||||
n_patches = patches.shape[0]
|
n_patches = patches.shape[0]
|
||||||
orig_shape = patches.shape
|
orig_shape = patches.shape
|
||||||
patches = patches.reshape(patches.shape[0], -1)
|
patches = patches.reshape(patches.shape[0], -1)
|
||||||
|
|
|
@ -11,8 +11,8 @@ import torch
|
||||||
from timm.models import ResNet
|
from timm.models import ResNet
|
||||||
from torchvision.models._api import Weights, WeightsEnum
|
from torchvision.models._api import Weights, WeightsEnum
|
||||||
|
|
||||||
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
|
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
|
||||||
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
|
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
|
||||||
# Normalization either by 10K or channel-wise with band statistics
|
# Normalization either by 10K or channel-wise with band statistics
|
||||||
_zhu_xlab_transforms = K.AugmentationSequential(
|
_zhu_xlab_transforms = K.AugmentationSequential(
|
||||||
K.Resize(256),
|
K.Resize(256),
|
||||||
|
@ -22,7 +22,7 @@ _zhu_xlab_transforms = K.AugmentationSequential(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalization only available for RGB dataset, defined here:
|
# Normalization only available for RGB dataset, defined here:
|
||||||
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
|
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
|
||||||
_min = torch.tensor([3, 2, 0])
|
_min = torch.tensor([3, 2, 0])
|
||||||
_max = torch.tensor([88, 103, 129])
|
_max = torch.tensor([88, 103, 129])
|
||||||
_mean = torch.tensor([0.485, 0.456, 0.406])
|
_mean = torch.tensor([0.485, 0.456, 0.406])
|
||||||
|
@ -37,7 +37,7 @@ _seco_transforms = K.AugmentationSequential(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalization only available for RGB dataset, defined here:
|
# Normalization only available for RGB dataset, defined here:
|
||||||
# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501
|
# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287
|
||||||
_mean = torch.tensor([0.485, 0.456, 0.406])
|
_mean = torch.tensor([0.485, 0.456, 0.406])
|
||||||
_std = torch.tensor([0.229, 0.224, 0.225])
|
_std = torch.tensor([0.229, 0.224, 0.225])
|
||||||
_gassl_transforms = K.AugmentationSequential(
|
_gassl_transforms = K.AugmentationSequential(
|
||||||
|
@ -47,7 +47,7 @@ _gassl_transforms = K.AugmentationSequential(
|
||||||
data_keys=None,
|
data_keys=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
|
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43
|
||||||
_ssl4eo_l_transforms = K.AugmentationSequential(
|
_ssl4eo_l_transforms = K.AugmentationSequential(
|
||||||
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
|
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
|
||||||
K.CenterCrop((224, 224)),
|
K.CenterCrop((224, 224)),
|
||||||
|
@ -70,7 +70,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LANDSAT_TM_TOA_MOCO = Weights(
|
LANDSAT_TM_TOA_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -83,7 +83,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_TM_TOA_SIMCLR = Weights(
|
LANDSAT_TM_TOA_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -96,7 +96,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_TOA_MOCO = Weights(
|
LANDSAT_ETM_TOA_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -109,7 +109,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_TOA_SIMCLR = Weights(
|
LANDSAT_ETM_TOA_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -122,7 +122,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_SR_MOCO = Weights(
|
LANDSAT_ETM_SR_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -135,7 +135,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_SR_SIMCLR = Weights(
|
LANDSAT_ETM_SR_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -148,7 +148,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
|
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -161,7 +161,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
|
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -174,7 +174,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_SR_MOCO = Weights(
|
LANDSAT_OLI_SR_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -187,7 +187,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_SR_SIMCLR = Weights(
|
LANDSAT_OLI_SR_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -200,7 +200,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_ALL_MOCO = Weights(
|
SENTINEL2_ALL_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth', # noqa: E501
|
url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth',
|
||||||
transforms=_zhu_xlab_transforms,
|
transforms=_zhu_xlab_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-S12',
|
'dataset': 'SSL4EO-S12',
|
||||||
|
@ -213,7 +213,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_RGB_MOCO = Weights(
|
SENTINEL2_RGB_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth', # noqa: E501
|
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth',
|
||||||
transforms=_zhu_xlab_transforms,
|
transforms=_zhu_xlab_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-S12',
|
'dataset': 'SSL4EO-S12',
|
||||||
|
@ -226,7 +226,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_RGB_SECO = Weights(
|
SENTINEL2_RGB_SECO = Weights(
|
||||||
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth', # noqa: E501
|
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth',
|
||||||
transforms=_seco_transforms,
|
transforms=_seco_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SeCo Dataset',
|
'dataset': 'SeCo Dataset',
|
||||||
|
@ -249,7 +249,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FMOW_RGB_GASSL = Weights(
|
FMOW_RGB_GASSL = Weights(
|
||||||
url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth', # noqa: E501
|
url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth',
|
||||||
transforms=_gassl_transforms,
|
transforms=_gassl_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'fMoW Dataset',
|
'dataset': 'fMoW Dataset',
|
||||||
|
@ -262,7 +262,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_TM_TOA_MOCO = Weights(
|
LANDSAT_TM_TOA_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -275,7 +275,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_TM_TOA_SIMCLR = Weights(
|
LANDSAT_TM_TOA_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -288,7 +288,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_TOA_MOCO = Weights(
|
LANDSAT_ETM_TOA_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -301,7 +301,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_TOA_SIMCLR = Weights(
|
LANDSAT_ETM_TOA_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -314,7 +314,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_SR_MOCO = Weights(
|
LANDSAT_ETM_SR_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -327,7 +327,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_SR_SIMCLR = Weights(
|
LANDSAT_ETM_SR_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -340,7 +340,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
|
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -353,7 +353,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
|
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -366,7 +366,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_SR_MOCO = Weights(
|
LANDSAT_OLI_SR_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -379,7 +379,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_SR_SIMCLR = Weights(
|
LANDSAT_OLI_SR_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -392,7 +392,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL1_ALL_MOCO = Weights(
|
SENTINEL1_ALL_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth', # noqa: E501
|
url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth',
|
||||||
transforms=_zhu_xlab_transforms,
|
transforms=_zhu_xlab_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-S12',
|
'dataset': 'SSL4EO-S12',
|
||||||
|
@ -405,7 +405,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_ALL_DINO = Weights(
|
SENTINEL2_ALL_DINO = Weights(
|
||||||
url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth', # noqa: E501
|
url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth',
|
||||||
transforms=_zhu_xlab_transforms,
|
transforms=_zhu_xlab_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-S12',
|
'dataset': 'SSL4EO-S12',
|
||||||
|
@ -418,7 +418,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_ALL_MOCO = Weights(
|
SENTINEL2_ALL_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth', # noqa: E501
|
url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth',
|
||||||
transforms=_zhu_xlab_transforms,
|
transforms=_zhu_xlab_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-S12',
|
'dataset': 'SSL4EO-S12',
|
||||||
|
@ -431,7 +431,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_RGB_MOCO = Weights(
|
SENTINEL2_RGB_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', # noqa: E501
|
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth',
|
||||||
transforms=_zhu_xlab_transforms,
|
transforms=_zhu_xlab_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-S12',
|
'dataset': 'SSL4EO-S12',
|
||||||
|
@ -444,7 +444,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_RGB_SECO = Weights(
|
SENTINEL2_RGB_SECO = Weights(
|
||||||
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth', # noqa: E501
|
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth',
|
||||||
transforms=_seco_transforms,
|
transforms=_seco_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SeCo Dataset',
|
'dataset': 'SeCo Dataset',
|
||||||
|
|
|
@ -12,20 +12,20 @@ from kornia.contrib import Lambda
|
||||||
from torchvision.models import SwinTransformer
|
from torchvision.models import SwinTransformer
|
||||||
from torchvision.models._api import Weights, WeightsEnum
|
from torchvision.models._api import Weights, WeightsEnum
|
||||||
|
|
||||||
# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501
|
# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42
|
||||||
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
|
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
|
||||||
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501
|
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
|
||||||
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501
|
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255.
|
||||||
_satlas_transforms = K.AugmentationSequential(
|
_satlas_transforms = K.AugmentationSequential(
|
||||||
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None
|
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
|
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
|
||||||
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501
|
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
|
||||||
# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). # noqa: E501
|
# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1).
|
||||||
_std = torch.tensor(
|
_std = torch.tensor(
|
||||||
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
|
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
|
||||||
) # noqa: E501
|
)
|
||||||
_mean = torch.zeros_like(_std)
|
_mean = torch.zeros_like(_std)
|
||||||
_sentinel2_ms_satlas_transforms = K.AugmentationSequential(
|
_sentinel2_ms_satlas_transforms = K.AugmentationSequential(
|
||||||
K.Normalize(mean=_mean, std=_std),
|
K.Normalize(mean=_mean, std=_std),
|
||||||
|
@ -33,7 +33,7 @@ _sentinel2_ms_satlas_transforms = K.AugmentationSequential(
|
||||||
data_keys=None,
|
data_keys=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501
|
# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1).
|
||||||
_landsat_satlas_transforms = K.AugmentationSequential(
|
_landsat_satlas_transforms = K.AugmentationSequential(
|
||||||
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
|
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
|
||||||
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
|
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
|
||||||
|
@ -56,7 +56,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NAIP_RGB_SI_SATLAS = Weights(
|
NAIP_RGB_SI_SATLAS = Weights(
|
||||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth', # noqa: E501
|
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth',
|
||||||
transforms=_satlas_transforms,
|
transforms=_satlas_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'Satlas',
|
'dataset': 'Satlas',
|
||||||
|
@ -68,7 +68,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_RGB_SI_SATLAS = Weights(
|
SENTINEL2_RGB_SI_SATLAS = Weights(
|
||||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth', # noqa: E501
|
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth',
|
||||||
transforms=_satlas_transforms,
|
transforms=_satlas_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'Satlas',
|
'dataset': 'Satlas',
|
||||||
|
@ -80,7 +80,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_MS_SI_SATLAS = Weights(
|
SENTINEL2_MS_SI_SATLAS = Weights(
|
||||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth', # noqa: E501
|
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth',
|
||||||
transforms=_sentinel2_ms_satlas_transforms,
|
transforms=_sentinel2_ms_satlas_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'Satlas',
|
'dataset': 'Satlas',
|
||||||
|
@ -93,7 +93,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL1_SI_SATLAS = Weights(
|
SENTINEL1_SI_SATLAS = Weights(
|
||||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth', # noqa: E501
|
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth',
|
||||||
transforms=_satlas_transforms,
|
transforms=_satlas_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'Satlas',
|
'dataset': 'Satlas',
|
||||||
|
@ -106,7 +106,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_SI_SATLAS = Weights(
|
LANDSAT_SI_SATLAS = Weights(
|
||||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth', # noqa: E501
|
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth',
|
||||||
transforms=_landsat_satlas_transforms,
|
transforms=_landsat_satlas_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'Satlas',
|
'dataset': 'Satlas',
|
||||||
|
@ -126,7 +126,7 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
'B09',
|
'B09',
|
||||||
'B10',
|
'B10',
|
||||||
'B11',
|
'B11',
|
||||||
], # noqa: E501
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -11,8 +11,8 @@ import torch
|
||||||
from timm.models.vision_transformer import VisionTransformer
|
from timm.models.vision_transformer import VisionTransformer
|
||||||
from torchvision.models._api import Weights, WeightsEnum
|
from torchvision.models._api import Weights, WeightsEnum
|
||||||
|
|
||||||
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
|
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
|
||||||
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
|
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
|
||||||
# Normalization either by 10K or channel-wise with band statistics
|
# Normalization either by 10K or channel-wise with band statistics
|
||||||
_zhu_xlab_transforms = K.AugmentationSequential(
|
_zhu_xlab_transforms = K.AugmentationSequential(
|
||||||
K.Resize(256),
|
K.Resize(256),
|
||||||
|
@ -21,7 +21,7 @@ _zhu_xlab_transforms = K.AugmentationSequential(
|
||||||
data_keys=None,
|
data_keys=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
|
# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43
|
||||||
_ssl4eo_l_transforms = K.AugmentationSequential(
|
_ssl4eo_l_transforms = K.AugmentationSequential(
|
||||||
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
|
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
|
||||||
K.CenterCrop((224, 224)),
|
K.CenterCrop((224, 224)),
|
||||||
|
@ -44,7 +44,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LANDSAT_TM_TOA_MOCO = Weights(
|
LANDSAT_TM_TOA_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -57,7 +57,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_TM_TOA_SIMCLR = Weights(
|
LANDSAT_TM_TOA_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -70,7 +70,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_TOA_MOCO = Weights(
|
LANDSAT_ETM_TOA_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -83,7 +83,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_TOA_SIMCLR = Weights(
|
LANDSAT_ETM_TOA_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -96,7 +96,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_SR_MOCO = Weights(
|
LANDSAT_ETM_SR_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -109,7 +109,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_ETM_SR_SIMCLR = Weights(
|
LANDSAT_ETM_SR_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -122,7 +122,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
|
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -135,7 +135,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
|
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -148,7 +148,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_SR_MOCO = Weights(
|
LANDSAT_OLI_SR_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -161,7 +161,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
LANDSAT_OLI_SR_SIMCLR = Weights(
|
LANDSAT_OLI_SR_SIMCLR = Weights(
|
||||||
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth', # noqa: E501
|
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth',
|
||||||
transforms=_ssl4eo_l_transforms,
|
transforms=_ssl4eo_l_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-L',
|
'dataset': 'SSL4EO-L',
|
||||||
|
@ -174,7 +174,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_ALL_DINO = Weights(
|
SENTINEL2_ALL_DINO = Weights(
|
||||||
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth', # noqa: E501
|
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth',
|
||||||
transforms=_zhu_xlab_transforms,
|
transforms=_zhu_xlab_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-S12',
|
'dataset': 'SSL4EO-S12',
|
||||||
|
@ -187,7 +187,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
)
|
)
|
||||||
|
|
||||||
SENTINEL2_ALL_MOCO = Weights(
|
SENTINEL2_ALL_MOCO = Weights(
|
||||||
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth', # noqa: E501
|
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth',
|
||||||
transforms=_zhu_xlab_transforms,
|
transforms=_zhu_xlab_transforms,
|
||||||
meta={
|
meta={
|
||||||
'dataset': 'SSL4EO-S12',
|
'dataset': 'SSL4EO-S12',
|
||||||
|
|
|
@ -53,7 +53,7 @@ class AugmentationSequential(Module):
|
||||||
else:
|
else:
|
||||||
keys.append(key)
|
keys.append(key)
|
||||||
|
|
||||||
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] # noqa: E501
|
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Perform augmentations and update data dict.
|
"""Perform augmentations and update data dict.
|
||||||
|
|
Загрузка…
Ссылка в новой задаче