зеркало из https://github.com/microsoft/torchgeo.git
EuroSATDataModule: set mean/std based on bands (#1681)
* Use dicts to generate mean and std * correctly pass bands * black format * Remove unused import * Simplify * Import all bands
This commit is contained in:
Родитель
1b48cb0ef8
Коммит
a6a11fd5b8
|
@ -10,41 +10,37 @@ import torch
|
|||
from ..datasets import EuroSAT, EuroSAT100
|
||||
from .geo import NonGeoDataModule
|
||||
|
||||
MEAN = torch.tensor(
|
||||
[
|
||||
1354.40546513,
|
||||
1118.24399958,
|
||||
1042.92983953,
|
||||
947.62620298,
|
||||
1199.47283961,
|
||||
1999.79090914,
|
||||
2369.22292565,
|
||||
2296.82608323,
|
||||
732.08340178,
|
||||
12.11327804,
|
||||
1819.01027855,
|
||||
1118.92391149,
|
||||
2594.14080798,
|
||||
]
|
||||
)
|
||||
MEAN = {
|
||||
"B01": 1354.40546513,
|
||||
"B02": 1118.24399958,
|
||||
"B03": 1042.92983953,
|
||||
"B04": 947.62620298,
|
||||
"B05": 1199.47283961,
|
||||
"B06": 1999.79090914,
|
||||
"B07": 2369.22292565,
|
||||
"B08": 2296.82608323,
|
||||
"B8A": 732.08340178,
|
||||
"B09": 12.11327804,
|
||||
"B10": 1819.01027855,
|
||||
"B11": 1118.92391149,
|
||||
"B12": 2594.14080798,
|
||||
}
|
||||
|
||||
STD = torch.tensor(
|
||||
[
|
||||
245.71762908,
|
||||
333.00778264,
|
||||
395.09249139,
|
||||
593.75055589,
|
||||
566.4170017,
|
||||
861.18399006,
|
||||
1086.63139075,
|
||||
1117.98170791,
|
||||
404.91978886,
|
||||
4.77584468,
|
||||
1002.58768311,
|
||||
761.30323499,
|
||||
1231.58581042,
|
||||
]
|
||||
)
|
||||
STD = {
|
||||
"B01": 245.71762908,
|
||||
"B02": 333.00778264,
|
||||
"B03": 395.09249139,
|
||||
"B04": 593.75055589,
|
||||
"B05": 566.4170017,
|
||||
"B06": 861.18399006,
|
||||
"B07": 1086.63139075,
|
||||
"B08": 1117.98170791,
|
||||
"B8A": 404.91978886,
|
||||
"B09": 4.77584468,
|
||||
"B10": 1002.58768311,
|
||||
"B11": 761.30323499,
|
||||
"B12": 1231.58581042,
|
||||
}
|
||||
|
||||
|
||||
class EuroSATDataModule(NonGeoDataModule):
|
||||
|
@ -55,9 +51,6 @@ class EuroSATDataModule(NonGeoDataModule):
|
|||
.. versionadded:: 0.2
|
||||
"""
|
||||
|
||||
mean = MEAN
|
||||
std = STD
|
||||
|
||||
def __init__(
|
||||
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
|
@ -71,6 +64,10 @@ class EuroSATDataModule(NonGeoDataModule):
|
|||
"""
|
||||
super().__init__(EuroSAT, batch_size, num_workers, **kwargs)
|
||||
|
||||
bands = kwargs.get("bands", EuroSAT.all_band_names)
|
||||
self.mean = torch.tensor([MEAN[b] for b in bands])
|
||||
self.std = torch.tensor([STD[b] for b in bands])
|
||||
|
||||
|
||||
class EuroSAT100DataModule(NonGeoDataModule):
|
||||
"""LightningDataModule implementation for the EuroSAT100 dataset.
|
||||
|
@ -80,9 +77,6 @@ class EuroSAT100DataModule(NonGeoDataModule):
|
|||
.. versionadded:: 0.5
|
||||
"""
|
||||
|
||||
mean = MEAN
|
||||
std = STD
|
||||
|
||||
def __init__(
|
||||
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
|
@ -95,3 +89,7 @@ class EuroSAT100DataModule(NonGeoDataModule):
|
|||
:class:`~torchgeo.datasets.EuroSAT100`.
|
||||
"""
|
||||
super().__init__(EuroSAT100, batch_size, num_workers, **kwargs)
|
||||
|
||||
bands = kwargs.get("bands", EuroSAT.all_band_names)
|
||||
self.mean = torch.tensor([MEAN[b] for b in bands])
|
||||
self.std = torch.tensor([STD[b] for b in bands])
|
||||
|
|
Загрузка…
Ссылка в новой задаче