зеркало из https://github.com/microsoft/torchgeo.git
BigEarthNet Trainers (#211)
* add additional bigearthnet test data for train/val/test split * update bigearthnet dataset length test * add MultiLabelClassificationTask * add BigEarthNet trainer and datamodule * add bigearthnet and multilabelclassificationtask tests * mypy and format * add estimated band min/max values for normalization * softmax outputs to correctly compute metrics * update min/max stats for 100k samples * organize imports in torchgeo.trainers.__init__.py * clean up fixtures in test_tasks.py * added bigearthnet to train.py * format * move fixtures into class methods * consolidate bigearthnet fixtures * refactor tasks tests * add scope=class * style/mypy fixes * mypy fixes
This commit is contained in:
Родитель
b8f5a7ce64
Коммит
3cc63def02
|
@ -0,0 +1,18 @@
|
||||||
|
trainer:
|
||||||
|
gpus: 1 # single GPU training
|
||||||
|
min_epochs: 10
|
||||||
|
max_epochs: 40
|
||||||
|
benchmark: True
|
||||||
|
|
||||||
|
experiment:
|
||||||
|
task: "bigearthnet"
|
||||||
|
module:
|
||||||
|
loss: "bce"
|
||||||
|
classification_model: "resnet18"
|
||||||
|
learning_rate: 1e-3
|
||||||
|
learning_rate_schedule_patience: 6
|
||||||
|
in_channels: 14
|
||||||
|
datamodule:
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 6
|
||||||
|
bands: "all"
|
|
@ -0,0 +1,13 @@
|
||||||
|
experiment:
|
||||||
|
task: "bigearthnet"
|
||||||
|
module:
|
||||||
|
loss: "bce"
|
||||||
|
classification_model: "resnet18"
|
||||||
|
learning_rate: 1e-3
|
||||||
|
learning_rate_schedule_patience: 6
|
||||||
|
weights: "random"
|
||||||
|
in_channels: 14
|
||||||
|
datamodule:
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 6
|
||||||
|
bands: "all"
|
Двоичные данные
tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz
Двоичные данные
tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz
Двоичные данные
tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz
Двоичный файл не отображается.
|
@ -71,7 +71,7 @@ class TestBigEarthNet:
|
||||||
assert x["image"].shape == (12, 120, 120)
|
assert x["image"].shape == (12, 120, 120)
|
||||||
|
|
||||||
def test_len(self, dataset: BigEarthNet) -> None:
|
def test_len(self, dataset: BigEarthNet) -> None:
|
||||||
assert len(dataset) == 2
|
assert len(dataset) == 4
|
||||||
|
|
||||||
def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
|
def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
|
||||||
BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True)
|
BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True)
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from _pytest.fixtures import SubRequest
|
||||||
|
|
||||||
|
from torchgeo.trainers import BigEarthNetDataModule
|
||||||
|
|
||||||
|
|
||||||
|
class TestBigEarthNetDataModule:
|
||||||
|
@pytest.fixture(scope="class", params=zip(["s1", "s2", "all"], [True, True, False]))
|
||||||
|
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
|
||||||
|
bands, unsupervised_mode = request.param
|
||||||
|
root = os.path.join("tests", "data", "bigearthnet")
|
||||||
|
batch_size = 1
|
||||||
|
num_workers = 0
|
||||||
|
dm = BigEarthNetDataModule(
|
||||||
|
root,
|
||||||
|
bands,
|
||||||
|
batch_size,
|
||||||
|
num_workers,
|
||||||
|
unsupervised_mode,
|
||||||
|
val_split_pct=0.3,
|
||||||
|
test_split_pct=0.3,
|
||||||
|
)
|
||||||
|
dm.prepare_data()
|
||||||
|
dm.setup()
|
||||||
|
return dm
|
||||||
|
|
||||||
|
def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||||
|
next(iter(datamodule.train_dataloader()))
|
||||||
|
|
||||||
|
def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||||
|
next(iter(datamodule.val_dataloader()))
|
||||||
|
|
||||||
|
def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||||
|
next(iter(datamodule.test_dataloader()))
|
|
@ -2,50 +2,114 @@
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Generator, Tuple, cast
|
from typing import Any, Dict, Generator, Optional, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from _pytest.fixtures import SubRequest
|
from _pytest.fixtures import SubRequest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import DataLoader, Dataset, TensorDataset
|
||||||
|
|
||||||
from torchgeo.trainers import (
|
from torchgeo.trainers import (
|
||||||
ClassificationTask,
|
ClassificationTask,
|
||||||
CycloneDataModule,
|
CycloneDataModule,
|
||||||
|
MultiLabelClassificationTask,
|
||||||
RegressionTask,
|
RegressionTask,
|
||||||
So2SatDataModule,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .test_utils import mocked_log
|
from .test_utils import mocked_log
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=[("rgb", 3), ("s2", 10)])
|
class DummyDataset(Dataset): # type: ignore[type-arg]
|
||||||
def bands(request: SubRequest) -> Tuple[str, int]:
|
def __init__(self, num_channels: int, num_classes: int, multilabel: bool) -> None:
|
||||||
return cast(Tuple[str, int], request.param)
|
x = torch.randn(10, num_channels, 128, 128) # (b, c, h, w)
|
||||||
|
y = torch.randint( # type: ignore[attr-defined]
|
||||||
|
0, num_classes, size=(10,)
|
||||||
|
) # (b,)
|
||||||
|
|
||||||
|
if multilabel:
|
||||||
|
y = F.one_hot(y, num_classes=num_classes) # (b, classes)
|
||||||
|
|
||||||
|
self.dataset = TensorDataset(x, y)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
|
||||||
|
x, y = self.dataset[idx]
|
||||||
|
sample = {"image": x, "label": y}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=[True, False])
|
class DummyDataModule(pl.LightningDataModule):
|
||||||
def datamodule(bands: Tuple[str, int], request: SubRequest) -> So2SatDataModule:
|
def __init__(
|
||||||
band_set = bands[0]
|
self,
|
||||||
unsupervised_mode = request.param
|
num_channels: int,
|
||||||
root = os.path.join("tests", "data", "so2sat")
|
num_classes: int,
|
||||||
batch_size = 2
|
multilabel: bool,
|
||||||
num_workers = 0
|
batch_size: int = 1,
|
||||||
dm = So2SatDataModule(root, batch_size, num_workers, band_set, unsupervised_mode)
|
num_workers: int = 0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__() # type: ignore[no-untyped-call]
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.multilabel = multilabel
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_workers = num_workers
|
||||||
|
|
||||||
|
def setup(self, stage: Optional[str] = None) -> None:
|
||||||
|
self.dataset = DummyDataset(
|
||||||
|
num_channels=self.num_channels,
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
multilabel=self.multilabel,
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloader(self) -> DataLoader: # type: ignore[type-arg]
|
||||||
|
return DataLoader(
|
||||||
|
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self) -> DataLoader: # type: ignore[type-arg]
|
||||||
|
return DataLoader(
|
||||||
|
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dataloader(self) -> DataLoader: # type: ignore[type-arg]
|
||||||
|
return DataLoader(
|
||||||
|
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassificationTask:
|
||||||
|
@pytest.fixture(scope="class", params=[2, 3, 5])
|
||||||
|
def datamodule(self, request: SubRequest) -> DummyDataModule:
|
||||||
|
dm = DummyDataModule(
|
||||||
|
num_channels=request.param,
|
||||||
|
num_classes=45,
|
||||||
|
multilabel=False,
|
||||||
|
batch_size=2,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
dm.prepare_data()
|
dm.prepare_data()
|
||||||
dm.setup()
|
dm.setup()
|
||||||
return dm
|
return dm
|
||||||
|
|
||||||
|
|
||||||
class TestClassificationTask:
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"])
|
scope="class",
|
||||||
|
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]),
|
||||||
)
|
)
|
||||||
def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]:
|
def config(
|
||||||
task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml"))
|
self, request: SubRequest, datamodule: DummyDataModule
|
||||||
task_args = OmegaConf.to_object(task_conf.experiment.module)
|
) -> Dict[str, Any]:
|
||||||
task_args = cast(Dict[str, Any], task_args)
|
task_args = {}
|
||||||
task_args["in_channels"] = bands[1]
|
task_args["classification_model"] = "resnet18"
|
||||||
|
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
|
||||||
|
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
|
||||||
|
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
|
||||||
loss, weights = request.param
|
loss, weights = request.param
|
||||||
task_args["loss"] = loss
|
task_args["loss"] = loss
|
||||||
task_args["weights"] = weights
|
task_args["weights"] = weights
|
||||||
|
@ -65,20 +129,20 @@ class TestClassificationTask:
|
||||||
assert "lr_scheduler" in out
|
assert "lr_scheduler" in out
|
||||||
|
|
||||||
def test_training(
|
def test_training(
|
||||||
self, datamodule: So2SatDataModule, task: ClassificationTask
|
self, datamodule: DummyDataModule, task: ClassificationTask
|
||||||
) -> None:
|
) -> None:
|
||||||
batch = next(iter(datamodule.train_dataloader()))
|
batch = next(iter(datamodule.train_dataloader()))
|
||||||
task.training_step(batch, 0)
|
task.training_step(batch, 0)
|
||||||
task.training_epoch_end(0)
|
task.training_epoch_end(0)
|
||||||
|
|
||||||
def test_validation(
|
def test_validation(
|
||||||
self, datamodule: So2SatDataModule, task: ClassificationTask
|
self, datamodule: DummyDataModule, task: ClassificationTask
|
||||||
) -> None:
|
) -> None:
|
||||||
batch = next(iter(datamodule.val_dataloader()))
|
batch = next(iter(datamodule.val_dataloader()))
|
||||||
task.validation_step(batch, 0)
|
task.validation_step(batch, 0)
|
||||||
task.validation_epoch_end(0)
|
task.validation_epoch_end(0)
|
||||||
|
|
||||||
def test_test(self, datamodule: So2SatDataModule, task: ClassificationTask) -> None:
|
def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None:
|
||||||
batch = next(iter(datamodule.test_dataloader()))
|
batch = next(iter(datamodule.test_dataloader()))
|
||||||
task.test_step(batch, 0)
|
task.test_step(batch, 0)
|
||||||
task.test_epoch_end(0)
|
task.test_epoch_end(0)
|
||||||
|
@ -99,6 +163,7 @@ class TestClassificationTask:
|
||||||
|
|
||||||
def test_invalid_loss(self, config: Dict[str, Any]) -> None:
|
def test_invalid_loss(self, config: Dict[str, Any]) -> None:
|
||||||
config["loss"] = "invalid_loss"
|
config["loss"] = "invalid_loss"
|
||||||
|
config["classification_model"] = "resnet18"
|
||||||
error_message = "Loss type 'invalid_loss' is not valid."
|
error_message = "Loss type 'invalid_loss' is not valid."
|
||||||
with pytest.raises(ValueError, match=error_message):
|
with pytest.raises(ValueError, match=error_message):
|
||||||
ClassificationTask(**config)
|
ClassificationTask(**config)
|
||||||
|
@ -117,6 +182,68 @@ class TestClassificationTask:
|
||||||
ClassificationTask(**config)
|
ClassificationTask(**config)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiLabelClassificationTask:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def datamodule(self, request: SubRequest) -> DummyDataModule:
|
||||||
|
dm = DummyDataModule(
|
||||||
|
num_channels=3,
|
||||||
|
num_classes=43,
|
||||||
|
multilabel=True,
|
||||||
|
batch_size=2,
|
||||||
|
num_workers=0,
|
||||||
|
)
|
||||||
|
dm.prepare_data()
|
||||||
|
dm.setup()
|
||||||
|
return dm
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class", params=zip(["bce", "bce"], ["imagenet", "random"]))
|
||||||
|
def config(
|
||||||
|
self, datamodule: DummyDataModule, request: SubRequest
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
task_args = {}
|
||||||
|
task_args["classification_model"] = "resnet18"
|
||||||
|
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
|
||||||
|
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
|
||||||
|
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
|
||||||
|
loss, weights = request.param
|
||||||
|
task_args["loss"] = loss
|
||||||
|
task_args["weights"] = weights
|
||||||
|
return task_args
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def task(
|
||||||
|
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
|
||||||
|
) -> MultiLabelClassificationTask:
|
||||||
|
task = MultiLabelClassificationTask(**config)
|
||||||
|
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
|
||||||
|
return task
|
||||||
|
|
||||||
|
def test_training(
|
||||||
|
self, datamodule: DummyDataModule, task: ClassificationTask
|
||||||
|
) -> None:
|
||||||
|
batch = next(iter(datamodule.train_dataloader()))
|
||||||
|
task.training_step(batch, 0)
|
||||||
|
task.training_epoch_end(0)
|
||||||
|
|
||||||
|
def test_validation(
|
||||||
|
self, datamodule: DummyDataModule, task: ClassificationTask
|
||||||
|
) -> None:
|
||||||
|
batch = next(iter(datamodule.val_dataloader()))
|
||||||
|
task.validation_step(batch, 0)
|
||||||
|
task.validation_epoch_end(0)
|
||||||
|
|
||||||
|
def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None:
|
||||||
|
batch = next(iter(datamodule.test_dataloader()))
|
||||||
|
task.test_step(batch, 0)
|
||||||
|
task.test_epoch_end(0)
|
||||||
|
|
||||||
|
def test_invalid_loss(self, config: Dict[str, Any]) -> None:
|
||||||
|
config["loss"] = "invalid_loss"
|
||||||
|
error_message = "Loss type 'invalid_loss' is not valid."
|
||||||
|
with pytest.raises(ValueError, match=error_message):
|
||||||
|
MultiLabelClassificationTask(**config)
|
||||||
|
|
||||||
|
|
||||||
class TestRegressionTask:
|
class TestRegressionTask:
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
def datamodule(self) -> CycloneDataModule:
|
def datamodule(self) -> CycloneDataModule:
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
"""TorchGeo trainers."""
|
"""TorchGeo trainers."""
|
||||||
|
|
||||||
|
from .bigearthnet import BigEarthNetClassificationTask, BigEarthNetDataModule
|
||||||
from .byol import BYOLTask
|
from .byol import BYOLTask
|
||||||
from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
|
from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
|
||||||
from .cyclone import CycloneDataModule
|
from .cyclone import CycloneDataModule
|
||||||
|
@ -11,29 +12,35 @@ from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentation
|
||||||
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
|
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
|
||||||
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
|
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
|
||||||
from .so2sat import So2SatClassificationTask, So2SatDataModule
|
from .so2sat import So2SatClassificationTask, So2SatDataModule
|
||||||
from .tasks import ClassificationTask, RegressionTask
|
from .tasks import ClassificationTask, MultiLabelClassificationTask, RegressionTask
|
||||||
from .ucmerced import UCMercedClassificationTask, UCMercedDataModule
|
from .ucmerced import UCMercedClassificationTask, UCMercedDataModule
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
# Tasks
|
# Tasks
|
||||||
"ClassificationTask",
|
"BigEarthNetClassificationTask",
|
||||||
"RegressionTask",
|
|
||||||
# Trainers
|
|
||||||
"BYOLTask",
|
"BYOLTask",
|
||||||
"ChesapeakeCVPRSegmentationTask",
|
"ChesapeakeCVPRSegmentationTask",
|
||||||
"ChesapeakeCVPRDataModule",
|
"ChesapeakeCVPRDataModule",
|
||||||
|
"ClassificationTask",
|
||||||
"CycloneDataModule",
|
"CycloneDataModule",
|
||||||
"LandcoverAIDataModule",
|
"LandcoverAIDataModule",
|
||||||
"LandcoverAISegmentationTask",
|
"LandcoverAISegmentationTask",
|
||||||
"NAIPChesapeakeDataModule",
|
"MultiLabelClassificationTask",
|
||||||
"NAIPChesapeakeSegmentationTask",
|
"NAIPChesapeakeSegmentationTask",
|
||||||
"RESISC45ClassificationTask",
|
"RESISC45ClassificationTask",
|
||||||
"RESISC45DataModule",
|
"RegressionTask",
|
||||||
"SEN12MSDataModule",
|
|
||||||
"SEN12MSSegmentationTask",
|
"SEN12MSSegmentationTask",
|
||||||
"So2SatDataModule",
|
|
||||||
"So2SatClassificationTask",
|
"So2SatClassificationTask",
|
||||||
"UCMercedClassificationTask",
|
"UCMercedClassificationTask",
|
||||||
|
# DataModules
|
||||||
|
"BigEarthNetDataModule",
|
||||||
|
"ChesapeakeCVPRDataModule",
|
||||||
|
"CycloneDataModule",
|
||||||
|
"LandcoverAIDataModule",
|
||||||
|
"NAIPChesapeakeDataModule",
|
||||||
|
"RESISC45DataModule",
|
||||||
|
"SEN12MSDataModule",
|
||||||
|
"So2SatDataModule",
|
||||||
"UCMercedDataModule",
|
"UCMercedDataModule",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,193 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""BigEarthNet trainer."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
|
from ..datasets import BigEarthNet
|
||||||
|
from ..datasets.utils import dataset_split
|
||||||
|
from .tasks import MultiLabelClassificationTask
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/issues/60979
|
||||||
|
# https://github.com/pytorch/pytorch/pull/61045
|
||||||
|
DataLoader.__module__ = "torch.utils.data"
|
||||||
|
|
||||||
|
|
||||||
|
class BigEarthNetClassificationTask(MultiLabelClassificationTask):
|
||||||
|
"""LightningModule for training models on the BigEarthNet Dataset."""
|
||||||
|
|
||||||
|
num_classes = 43
|
||||||
|
|
||||||
|
|
||||||
|
class BigEarthNetDataModule(pl.LightningDataModule):
|
||||||
|
"""LightningDataModule implementation for the BigEarthNet dataset.
|
||||||
|
|
||||||
|
Uses the train/val/test splits from the dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
|
||||||
|
# min/max band statistics computed on 100k random samples
|
||||||
|
band_mins_raw = torch.tensor( # type: ignore[attr-defined]
|
||||||
|
[-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
|
||||||
|
)
|
||||||
|
band_maxs_raw = torch.tensor( # type: ignore[attr-defined]
|
||||||
|
[
|
||||||
|
31.0,
|
||||||
|
35.0,
|
||||||
|
18556.0,
|
||||||
|
20528.0,
|
||||||
|
18976.0,
|
||||||
|
17874.0,
|
||||||
|
16611.0,
|
||||||
|
16512.0,
|
||||||
|
16394.0,
|
||||||
|
16672.0,
|
||||||
|
16141.0,
|
||||||
|
16097.0,
|
||||||
|
15336.0,
|
||||||
|
15203.0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# min/max band statistics computed by percentile clipping the
|
||||||
|
# above to samples to [2, 98]
|
||||||
|
band_mins = torch.tensor( # type: ignore[attr-defined]
|
||||||
|
[-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||||
|
)
|
||||||
|
band_maxs = torch.tensor( # type: ignore[attr-defined]
|
||||||
|
[
|
||||||
|
6.0,
|
||||||
|
16.0,
|
||||||
|
9859.0,
|
||||||
|
12872.0,
|
||||||
|
13163.0,
|
||||||
|
14445.0,
|
||||||
|
12477.0,
|
||||||
|
12563.0,
|
||||||
|
12289.0,
|
||||||
|
15596.0,
|
||||||
|
12183.0,
|
||||||
|
9458.0,
|
||||||
|
5897.0,
|
||||||
|
5544.0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root_dir: str,
|
||||||
|
bands: str = "all",
|
||||||
|
batch_size: int = 64,
|
||||||
|
num_workers: int = 4,
|
||||||
|
unsupervised_mode: bool = False,
|
||||||
|
val_split_pct: float = 0.2,
|
||||||
|
test_split_pct: float = 0.2,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes
|
||||||
|
bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all}
|
||||||
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
|
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||||
|
val, and test sets
|
||||||
|
val_split_pct: What percentage of the dataset to use as a validation set
|
||||||
|
test_split_pct: What percentage of the dataset to use as a test set
|
||||||
|
"""
|
||||||
|
super().__init__() # type: ignore[no-untyped-call]
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.bands = bands
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.unsupervised_mode = unsupervised_mode
|
||||||
|
|
||||||
|
self.val_split_pct = val_split_pct
|
||||||
|
self.test_split_pct = test_split_pct
|
||||||
|
|
||||||
|
if bands == "all":
|
||||||
|
self.mins = self.band_mins[:, None, None]
|
||||||
|
self.maxs = self.band_maxs[:, None, None]
|
||||||
|
elif bands == "s1":
|
||||||
|
self.mins = self.band_mins[:2, None, None]
|
||||||
|
self.maxs = self.band_maxs[:2, None, None]
|
||||||
|
else:
|
||||||
|
self.mins = self.band_mins[2:, None, None]
|
||||||
|
self.maxs = self.band_maxs[2:, None, None]
|
||||||
|
|
||||||
|
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Transform a single sample from the Dataset."""
|
||||||
|
sample["image"] = sample["image"].float()
|
||||||
|
sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins)
|
||||||
|
sample["image"] = torch.clip( # type: ignore[attr-defined]
|
||||||
|
sample["image"], min=0.0, max=1.0
|
||||||
|
)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def prepare_data(self) -> None:
|
||||||
|
"""Make sure that the dataset is downloaded.
|
||||||
|
|
||||||
|
This method is only called once per run.
|
||||||
|
"""
|
||||||
|
BigEarthNet(self.root_dir, bands=self.bands, checksum=False)
|
||||||
|
|
||||||
|
def setup(self, stage: Optional[str] = None) -> None:
|
||||||
|
"""Initialize the main ``Dataset`` objects.
|
||||||
|
|
||||||
|
This method is called once per GPU per run.
|
||||||
|
"""
|
||||||
|
transforms = Compose([self.preprocess])
|
||||||
|
|
||||||
|
if not self.unsupervised_mode:
|
||||||
|
|
||||||
|
dataset = BigEarthNet(
|
||||||
|
self.root_dir, bands=self.bands, transforms=transforms
|
||||||
|
)
|
||||||
|
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||||
|
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.train_dataset = BigEarthNet( # type: ignore[assignment]
|
||||||
|
self.root_dir, bands=self.bands, transforms=transforms
|
||||||
|
)
|
||||||
|
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]
|
||||||
|
|
||||||
|
def train_dataloader(self) -> DataLoader[Any]:
|
||||||
|
"""Return a DataLoader for training."""
|
||||||
|
return DataLoader(
|
||||||
|
self.train_dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self) -> DataLoader[Any]:
|
||||||
|
"""Return a DataLoader for validation."""
|
||||||
|
if self.unsupervised_mode or self.val_split_pct == 0:
|
||||||
|
return self.train_dataloader()
|
||||||
|
else:
|
||||||
|
return DataLoader(
|
||||||
|
self.val_dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dataloader(self) -> DataLoader[Any]:
|
||||||
|
"""Return a DataLoader for testing."""
|
||||||
|
if self.unsupervised_mode or self.test_split_pct == 0:
|
||||||
|
return self.train_dataloader()
|
||||||
|
else:
|
||||||
|
return DataLoader(
|
||||||
|
self.test_dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
|
@ -57,7 +57,7 @@ class ClassificationTask(pl.LightningModule):
|
||||||
|
|
||||||
# Update first layer
|
# Update first layer
|
||||||
if in_channels != 3:
|
if in_channels != 3:
|
||||||
w_old = None
|
w_old = torch.empty(0) # type: ignore[attr-defined]
|
||||||
if pretrained:
|
if pretrained:
|
||||||
w_old = torch.clone( # type: ignore[attr-defined]
|
w_old = torch.clone( # type: ignore[attr-defined]
|
||||||
self.model.conv1.weight
|
self.model.conv1.weight
|
||||||
|
@ -75,7 +75,11 @@ class ClassificationTask(pl.LightningModule):
|
||||||
w_new = torch.clone( # type: ignore[attr-defined]
|
w_new = torch.clone( # type: ignore[attr-defined]
|
||||||
self.model.conv1.weight
|
self.model.conv1.weight
|
||||||
).detach()
|
).detach()
|
||||||
|
if in_channels > 3:
|
||||||
w_new[:, :3, :, :] = w_old
|
w_new[:, :3, :, :] = w_old
|
||||||
|
else:
|
||||||
|
w_old = w_old[:, :in_channels, :, :]
|
||||||
|
w_new[:, :in_channels, :, :] = w_old
|
||||||
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501
|
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501
|
||||||
w_new
|
w_new
|
||||||
)
|
)
|
||||||
|
@ -266,6 +270,120 @@ class ClassificationTask(pl.LightningModule):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MultiLabelClassificationTask(ClassificationTask):
|
||||||
|
"""Abstract base class for multi label image classification LightningModules."""
|
||||||
|
|
||||||
|
#: number of classes in dataset
|
||||||
|
num_classes: int = 43
|
||||||
|
|
||||||
|
def config_task(self) -> None:
|
||||||
|
"""Configures the task based on kwargs parameters passed to the constructor."""
|
||||||
|
self.config_model()
|
||||||
|
|
||||||
|
if self.hparams["loss"] == "bce":
|
||||||
|
self.loss = nn.BCEWithLogitsLoss() # type: ignore[attr-defined]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.")
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the LightningModule with a model and loss function.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
classification_model: Name of the classification model use
|
||||||
|
loss: Name of the loss function
|
||||||
|
weights: Either "random", "imagenet_only", "imagenet_and_random", or
|
||||||
|
"random_rgb"
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||||
|
|
||||||
|
self.config_task()
|
||||||
|
|
||||||
|
self.train_metrics = MetricCollection(
|
||||||
|
{
|
||||||
|
"OverallAccuracy": Accuracy(
|
||||||
|
num_classes=self.num_classes, average="micro", multiclass=False
|
||||||
|
),
|
||||||
|
"AverageAccuracy": Accuracy(
|
||||||
|
num_classes=self.num_classes, average="macro", multiclass=False
|
||||||
|
),
|
||||||
|
"F1Score": FBeta(
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
beta=1.0,
|
||||||
|
average="micro",
|
||||||
|
multiclass=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
prefix="train_",
|
||||||
|
)
|
||||||
|
self.val_metrics = self.train_metrics.clone(prefix="val_")
|
||||||
|
self.test_metrics = self.train_metrics.clone(prefix="test_")
|
||||||
|
|
||||||
|
def training_step( # type: ignore[override]
|
||||||
|
self, batch: Dict[str, Any], batch_idx: int
|
||||||
|
) -> Tensor:
|
||||||
|
"""Training step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Current batch
|
||||||
|
batch_idx: Index of current batch
|
||||||
|
Returns:
|
||||||
|
training loss
|
||||||
|
"""
|
||||||
|
x = batch["image"]
|
||||||
|
y = batch["label"]
|
||||||
|
y_hat = self.forward(x)
|
||||||
|
y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# by default, the train step logs every `log_every_n_steps` steps where
|
||||||
|
# `log_every_n_steps` is a parameter to the `Trainer` object
|
||||||
|
self.log("train_loss", loss, on_step=True, on_epoch=False)
|
||||||
|
self.train_metrics(y_hat_hard, y)
|
||||||
|
|
||||||
|
return cast(Tensor, loss)
|
||||||
|
|
||||||
|
def validation_step( # type: ignore[override]
|
||||||
|
self, batch: Dict[str, Any], batch_idx: int
|
||||||
|
) -> None:
|
||||||
|
"""Validation step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Current batch
|
||||||
|
batch_idx: Index of current batch
|
||||||
|
"""
|
||||||
|
x = batch["image"]
|
||||||
|
y = batch["label"]
|
||||||
|
y_hat = self.forward(x)
|
||||||
|
y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
self.log("val_loss", loss, on_step=False, on_epoch=True)
|
||||||
|
self.val_metrics(y_hat_hard, y)
|
||||||
|
|
||||||
|
def test_step( # type: ignore[override]
|
||||||
|
self, batch: Dict[str, Any], batch_idx: int
|
||||||
|
) -> None:
|
||||||
|
"""Test step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Current batch
|
||||||
|
batch_idx: Index of current batch
|
||||||
|
"""
|
||||||
|
x = batch["image"]
|
||||||
|
y = batch["label"]
|
||||||
|
y_hat = self.forward(x)
|
||||||
|
y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# by default, the test and validation steps only log per *epoch*
|
||||||
|
self.log("test_loss", loss, on_step=False, on_epoch=True)
|
||||||
|
self.test_metrics(y_hat_hard, y)
|
||||||
|
|
||||||
|
|
||||||
class RegressionTask(pl.LightningModule):
|
class RegressionTask(pl.LightningModule):
|
||||||
"""LightningModule for training models on regression datasets."""
|
"""LightningModule for training models on regression datasets."""
|
||||||
|
|
||||||
|
|
3
train.py
3
train.py
|
@ -14,6 +14,8 @@ from pytorch_lightning import loggers as pl_loggers
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
|
||||||
from torchgeo.trainers import (
|
from torchgeo.trainers import (
|
||||||
|
BigEarthNetClassificationTask,
|
||||||
|
BigEarthNetDataModule,
|
||||||
BYOLTask,
|
BYOLTask,
|
||||||
ChesapeakeCVPRDataModule,
|
ChesapeakeCVPRDataModule,
|
||||||
ChesapeakeCVPRSegmentationTask,
|
ChesapeakeCVPRSegmentationTask,
|
||||||
|
@ -36,6 +38,7 @@ from torchgeo.trainers import (
|
||||||
TASK_TO_MODULES_MAPPING: Dict[
|
TASK_TO_MODULES_MAPPING: Dict[
|
||||||
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
|
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
|
||||||
] = {
|
] = {
|
||||||
|
"bigearthnet": (BigEarthNetClassificationTask, BigEarthNetDataModule),
|
||||||
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
|
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
|
||||||
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
|
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
|
||||||
"cyclone": (RegressionTask, CycloneDataModule),
|
"cyclone": (RegressionTask, CycloneDataModule),
|
||||||
|
|
Загрузка…
Ссылка в новой задаче