EuroSATDataModule: set mean/std based on bands (#1681)

* Use dicts to generate mean and std

* correctly pass bands

* black format

* Remove unused import

* Simplify

* Import all bands
This commit is contained in:
Robin Cole 2023-10-20 10:40:40 +01:00 коммит произвёл Nils Lehmann
Родитель 1b48cb0ef8
Коммит a6a11fd5b8
1 изменённых файлов: 38 добавлений и 40 удалений

Просмотреть файл

@ -10,41 +10,37 @@ import torch
from ..datasets import EuroSAT, EuroSAT100
from .geo import NonGeoDataModule
MEAN = torch.tensor(
[
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798,
]
)
MEAN = {
"B01": 1354.40546513,
"B02": 1118.24399958,
"B03": 1042.92983953,
"B04": 947.62620298,
"B05": 1199.47283961,
"B06": 1999.79090914,
"B07": 2369.22292565,
"B08": 2296.82608323,
"B8A": 732.08340178,
"B09": 12.11327804,
"B10": 1819.01027855,
"B11": 1118.92391149,
"B12": 2594.14080798,
}
STD = torch.tensor(
[
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042,
]
)
STD = {
"B01": 245.71762908,
"B02": 333.00778264,
"B03": 395.09249139,
"B04": 593.75055589,
"B05": 566.4170017,
"B06": 861.18399006,
"B07": 1086.63139075,
"B08": 1117.98170791,
"B8A": 404.91978886,
"B09": 4.77584468,
"B10": 1002.58768311,
"B11": 761.30323499,
"B12": 1231.58581042,
}
class EuroSATDataModule(NonGeoDataModule):
@ -55,9 +51,6 @@ class EuroSATDataModule(NonGeoDataModule):
.. versionadded:: 0.2
"""
mean = MEAN
std = STD
def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
@ -71,6 +64,10 @@ class EuroSATDataModule(NonGeoDataModule):
"""
super().__init__(EuroSAT, batch_size, num_workers, **kwargs)
bands = kwargs.get("bands", EuroSAT.all_band_names)
self.mean = torch.tensor([MEAN[b] for b in bands])
self.std = torch.tensor([STD[b] for b in bands])
class EuroSAT100DataModule(NonGeoDataModule):
"""LightningDataModule implementation for the EuroSAT100 dataset.
@ -80,9 +77,6 @@ class EuroSAT100DataModule(NonGeoDataModule):
.. versionadded:: 0.5
"""
mean = MEAN
std = STD
def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
@ -95,3 +89,7 @@ class EuroSAT100DataModule(NonGeoDataModule):
:class:`~torchgeo.datasets.EuroSAT100`.
"""
super().__init__(EuroSAT100, batch_size, num_workers, **kwargs)
bands = kwargs.get("bands", EuroSAT.all_band_names)
self.mean = torch.tensor([MEAN[b] for b in bands])
self.std = torch.tensor([STD[b] for b in bands])