зеркало из https://github.com/microsoft/torchgeo.git
BigEarthNet Trainers (#211)
* add additional bigearthnet test data for train/val/test split * update bigearthnet dataset length test * add MultiLabelClassificationTask * add BigEarthNet trainer and datamodule * add bigearthnet and multilabelclassificationtask tests * mypy and format * add estimated band min/max values for normalization * softmax outputs to correctly compute metrics * update min/max stats for 100k samples * organize imports in torchgeo.trainers.__init__.py * clean up fixtures in test_tasks.py * added bigearthnet to train.py * format * move fixtures into class methods * consolidate bigearthnet fixtures * refactor tasks tests * add scope=class * style/mypy fixes * mypy fixes
This commit is contained in:
Родитель
b8f5a7ce64
Коммит
3cc63def02
|
@ -0,0 +1,18 @@
|
|||
trainer:
|
||||
gpus: 1 # single GPU training
|
||||
min_epochs: 10
|
||||
max_epochs: 40
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "bigearthnet"
|
||||
module:
|
||||
loss: "bce"
|
||||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
in_channels: 14
|
||||
datamodule:
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
bands: "all"
|
|
@ -0,0 +1,13 @@
|
|||
experiment:
|
||||
task: "bigearthnet"
|
||||
module:
|
||||
loss: "bce"
|
||||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
in_channels: 14
|
||||
datamodule:
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
bands: "all"
|
Двоичные данные
tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz
Двоичные данные
tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz
Двоичные данные
tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz
Двоичный файл не отображается.
|
@ -71,7 +71,7 @@ class TestBigEarthNet:
|
|||
assert x["image"].shape == (12, 120, 120)
|
||||
|
||||
def test_len(self, dataset: BigEarthNet) -> None:
|
||||
assert len(dataset) == 2
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
|
||||
BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True)
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.trainers import BigEarthNetDataModule
|
||||
|
||||
|
||||
class TestBigEarthNetDataModule:
|
||||
@pytest.fixture(scope="class", params=zip(["s1", "s2", "all"], [True, True, False]))
|
||||
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
|
||||
bands, unsupervised_mode = request.param
|
||||
root = os.path.join("tests", "data", "bigearthnet")
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = BigEarthNetDataModule(
|
||||
root,
|
||||
bands,
|
||||
batch_size,
|
||||
num_workers,
|
||||
unsupervised_mode,
|
||||
val_split_pct=0.3,
|
||||
test_split_pct=0.3,
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
||||
def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
||||
next(iter(datamodule.test_dataloader()))
|
|
@ -2,50 +2,114 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Generator, Tuple, cast
|
||||
from typing import Any, Dict, Generator, Optional, cast
|
||||
|
||||
import pytest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset, TensorDataset
|
||||
|
||||
from torchgeo.trainers import (
|
||||
ClassificationTask,
|
||||
CycloneDataModule,
|
||||
MultiLabelClassificationTask,
|
||||
RegressionTask,
|
||||
So2SatDataModule,
|
||||
)
|
||||
|
||||
from .test_utils import mocked_log
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[("rgb", 3), ("s2", 10)])
|
||||
def bands(request: SubRequest) -> Tuple[str, int]:
|
||||
return cast(Tuple[str, int], request.param)
|
||||
class DummyDataset(Dataset): # type: ignore[type-arg]
|
||||
def __init__(self, num_channels: int, num_classes: int, multilabel: bool) -> None:
|
||||
x = torch.randn(10, num_channels, 128, 128) # (b, c, h, w)
|
||||
y = torch.randint( # type: ignore[attr-defined]
|
||||
0, num_classes, size=(10,)
|
||||
) # (b,)
|
||||
|
||||
if multilabel:
|
||||
y = F.one_hot(y, num_classes=num_classes) # (b, classes)
|
||||
|
||||
self.dataset = TensorDataset(x, y)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
|
||||
x, y = self.dataset[idx]
|
||||
sample = {"image": x, "label": y}
|
||||
return sample
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[True, False])
|
||||
def datamodule(bands: Tuple[str, int], request: SubRequest) -> So2SatDataModule:
|
||||
band_set = bands[0]
|
||||
unsupervised_mode = request.param
|
||||
root = os.path.join("tests", "data", "so2sat")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = So2SatDataModule(root, batch_size, num_workers, band_set, unsupervised_mode)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
class DummyDataModule(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
num_classes: int,
|
||||
multilabel: bool,
|
||||
batch_size: int = 1,
|
||||
num_workers: int = 0,
|
||||
) -> None:
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.num_channels = num_channels
|
||||
self.num_classes = num_classes
|
||||
self.multilabel = multilabel
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
self.dataset = DummyDataset(
|
||||
num_channels=self.num_channels,
|
||||
num_classes=self.num_classes,
|
||||
multilabel=self.multilabel,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader: # type: ignore[type-arg]
|
||||
return DataLoader(
|
||||
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader: # type: ignore[type-arg]
|
||||
return DataLoader(
|
||||
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader: # type: ignore[type-arg]
|
||||
return DataLoader(
|
||||
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
|
||||
)
|
||||
|
||||
|
||||
class TestClassificationTask:
|
||||
@pytest.fixture(scope="class", params=[2, 3, 5])
|
||||
def datamodule(self, request: SubRequest) -> DummyDataModule:
|
||||
dm = DummyDataModule(
|
||||
num_channels=request.param,
|
||||
num_classes=45,
|
||||
multilabel=False,
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
@pytest.fixture(
|
||||
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"])
|
||||
scope="class",
|
||||
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]),
|
||||
)
|
||||
def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]:
|
||||
task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml"))
|
||||
task_args = OmegaConf.to_object(task_conf.experiment.module)
|
||||
task_args = cast(Dict[str, Any], task_args)
|
||||
task_args["in_channels"] = bands[1]
|
||||
def config(
|
||||
self, request: SubRequest, datamodule: DummyDataModule
|
||||
) -> Dict[str, Any]:
|
||||
task_args = {}
|
||||
task_args["classification_model"] = "resnet18"
|
||||
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
|
||||
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
|
||||
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
|
||||
loss, weights = request.param
|
||||
task_args["loss"] = loss
|
||||
task_args["weights"] = weights
|
||||
|
@ -65,20 +129,20 @@ class TestClassificationTask:
|
|||
assert "lr_scheduler" in out
|
||||
|
||||
def test_training(
|
||||
self, datamodule: So2SatDataModule, task: ClassificationTask
|
||||
self, datamodule: DummyDataModule, task: ClassificationTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.train_dataloader()))
|
||||
task.training_step(batch, 0)
|
||||
task.training_epoch_end(0)
|
||||
|
||||
def test_validation(
|
||||
self, datamodule: So2SatDataModule, task: ClassificationTask
|
||||
self, datamodule: DummyDataModule, task: ClassificationTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.val_dataloader()))
|
||||
task.validation_step(batch, 0)
|
||||
task.validation_epoch_end(0)
|
||||
|
||||
def test_test(self, datamodule: So2SatDataModule, task: ClassificationTask) -> None:
|
||||
def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None:
|
||||
batch = next(iter(datamodule.test_dataloader()))
|
||||
task.test_step(batch, 0)
|
||||
task.test_epoch_end(0)
|
||||
|
@ -99,6 +163,7 @@ class TestClassificationTask:
|
|||
|
||||
def test_invalid_loss(self, config: Dict[str, Any]) -> None:
|
||||
config["loss"] = "invalid_loss"
|
||||
config["classification_model"] = "resnet18"
|
||||
error_message = "Loss type 'invalid_loss' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
ClassificationTask(**config)
|
||||
|
@ -117,6 +182,68 @@ class TestClassificationTask:
|
|||
ClassificationTask(**config)
|
||||
|
||||
|
||||
class TestMultiLabelClassificationTask:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self, request: SubRequest) -> DummyDataModule:
|
||||
dm = DummyDataModule(
|
||||
num_channels=3,
|
||||
num_classes=43,
|
||||
multilabel=True,
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
@pytest.fixture(scope="class", params=zip(["bce", "bce"], ["imagenet", "random"]))
|
||||
def config(
|
||||
self, datamodule: DummyDataModule, request: SubRequest
|
||||
) -> Dict[str, Any]:
|
||||
task_args = {}
|
||||
task_args["classification_model"] = "resnet18"
|
||||
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
|
||||
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
|
||||
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
|
||||
loss, weights = request.param
|
||||
task_args["loss"] = loss
|
||||
task_args["weights"] = weights
|
||||
return task_args
|
||||
|
||||
@pytest.fixture
|
||||
def task(
|
||||
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
|
||||
) -> MultiLabelClassificationTask:
|
||||
task = MultiLabelClassificationTask(**config)
|
||||
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
|
||||
return task
|
||||
|
||||
def test_training(
|
||||
self, datamodule: DummyDataModule, task: ClassificationTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.train_dataloader()))
|
||||
task.training_step(batch, 0)
|
||||
task.training_epoch_end(0)
|
||||
|
||||
def test_validation(
|
||||
self, datamodule: DummyDataModule, task: ClassificationTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.val_dataloader()))
|
||||
task.validation_step(batch, 0)
|
||||
task.validation_epoch_end(0)
|
||||
|
||||
def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None:
|
||||
batch = next(iter(datamodule.test_dataloader()))
|
||||
task.test_step(batch, 0)
|
||||
task.test_epoch_end(0)
|
||||
|
||||
def test_invalid_loss(self, config: Dict[str, Any]) -> None:
|
||||
config["loss"] = "invalid_loss"
|
||||
error_message = "Loss type 'invalid_loss' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
MultiLabelClassificationTask(**config)
|
||||
|
||||
|
||||
class TestRegressionTask:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> CycloneDataModule:
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
"""TorchGeo trainers."""
|
||||
|
||||
from .bigearthnet import BigEarthNetClassificationTask, BigEarthNetDataModule
|
||||
from .byol import BYOLTask
|
||||
from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
|
||||
from .cyclone import CycloneDataModule
|
||||
|
@ -11,29 +12,35 @@ from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentation
|
|||
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
|
||||
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
|
||||
from .so2sat import So2SatClassificationTask, So2SatDataModule
|
||||
from .tasks import ClassificationTask, RegressionTask
|
||||
from .tasks import ClassificationTask, MultiLabelClassificationTask, RegressionTask
|
||||
from .ucmerced import UCMercedClassificationTask, UCMercedDataModule
|
||||
|
||||
__all__ = (
|
||||
# Tasks
|
||||
"ClassificationTask",
|
||||
"RegressionTask",
|
||||
# Trainers
|
||||
"BigEarthNetClassificationTask",
|
||||
"BYOLTask",
|
||||
"ChesapeakeCVPRSegmentationTask",
|
||||
"ChesapeakeCVPRDataModule",
|
||||
"ClassificationTask",
|
||||
"CycloneDataModule",
|
||||
"LandcoverAIDataModule",
|
||||
"LandcoverAISegmentationTask",
|
||||
"NAIPChesapeakeDataModule",
|
||||
"MultiLabelClassificationTask",
|
||||
"NAIPChesapeakeSegmentationTask",
|
||||
"RESISC45ClassificationTask",
|
||||
"RESISC45DataModule",
|
||||
"SEN12MSDataModule",
|
||||
"RegressionTask",
|
||||
"SEN12MSSegmentationTask",
|
||||
"So2SatDataModule",
|
||||
"So2SatClassificationTask",
|
||||
"UCMercedClassificationTask",
|
||||
# DataModules
|
||||
"BigEarthNetDataModule",
|
||||
"ChesapeakeCVPRDataModule",
|
||||
"CycloneDataModule",
|
||||
"LandcoverAIDataModule",
|
||||
"NAIPChesapeakeDataModule",
|
||||
"RESISC45DataModule",
|
||||
"SEN12MSDataModule",
|
||||
"So2SatDataModule",
|
||||
"UCMercedDataModule",
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""BigEarthNet trainer."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ..datasets import BigEarthNet
|
||||
from ..datasets.utils import dataset_split
|
||||
from .tasks import MultiLabelClassificationTask
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class BigEarthNetClassificationTask(MultiLabelClassificationTask):
|
||||
"""LightningModule for training models on the BigEarthNet Dataset."""
|
||||
|
||||
num_classes = 43
|
||||
|
||||
|
||||
class BigEarthNetDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the BigEarthNet dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
# (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
|
||||
# min/max band statistics computed on 100k random samples
|
||||
band_mins_raw = torch.tensor( # type: ignore[attr-defined]
|
||||
[-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
|
||||
)
|
||||
band_maxs_raw = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
31.0,
|
||||
35.0,
|
||||
18556.0,
|
||||
20528.0,
|
||||
18976.0,
|
||||
17874.0,
|
||||
16611.0,
|
||||
16512.0,
|
||||
16394.0,
|
||||
16672.0,
|
||||
16141.0,
|
||||
16097.0,
|
||||
15336.0,
|
||||
15203.0,
|
||||
]
|
||||
)
|
||||
|
||||
# min/max band statistics computed by percentile clipping the
|
||||
# above to samples to [2, 98]
|
||||
band_mins = torch.tensor( # type: ignore[attr-defined]
|
||||
[-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
)
|
||||
band_maxs = torch.tensor( # type: ignore[attr-defined]
|
||||
[
|
||||
6.0,
|
||||
16.0,
|
||||
9859.0,
|
||||
12872.0,
|
||||
13163.0,
|
||||
14445.0,
|
||||
12477.0,
|
||||
12563.0,
|
||||
12289.0,
|
||||
15596.0,
|
||||
12183.0,
|
||||
9458.0,
|
||||
5897.0,
|
||||
5544.0,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
bands: str = "all",
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
unsupervised_mode: bool = False,
|
||||
val_split_pct: float = 0.2,
|
||||
test_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes
|
||||
bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all}
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||
val, and test sets
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
test_split_pct: What percentage of the dataset to use as a test set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.bands = bands
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.unsupervised_mode = unsupervised_mode
|
||||
|
||||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
|
||||
if bands == "all":
|
||||
self.mins = self.band_mins[:, None, None]
|
||||
self.maxs = self.band_maxs[:, None, None]
|
||||
elif bands == "s1":
|
||||
self.mins = self.band_mins[:2, None, None]
|
||||
self.maxs = self.band_maxs[:2, None, None]
|
||||
else:
|
||||
self.mins = self.band_mins[2:, None, None]
|
||||
self.maxs = self.band_maxs[2:, None, None]
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins)
|
||||
sample["image"] = torch.clip( # type: ignore[attr-defined]
|
||||
sample["image"], min=0.0, max=1.0
|
||||
)
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
BigEarthNet(self.root_dir, bands=self.bands, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
if not self.unsupervised_mode:
|
||||
|
||||
dataset = BigEarthNet(
|
||||
self.root_dir, bands=self.bands, transforms=transforms
|
||||
)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||
)
|
||||
else:
|
||||
self.train_dataset = BigEarthNet( # type: ignore[assignment]
|
||||
self.root_dir, bands=self.bands, transforms=transforms
|
||||
)
|
||||
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation."""
|
||||
if self.unsupervised_mode or self.val_split_pct == 0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing."""
|
||||
if self.unsupervised_mode or self.test_split_pct == 0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
|
@ -57,7 +57,7 @@ class ClassificationTask(pl.LightningModule):
|
|||
|
||||
# Update first layer
|
||||
if in_channels != 3:
|
||||
w_old = None
|
||||
w_old = torch.empty(0) # type: ignore[attr-defined]
|
||||
if pretrained:
|
||||
w_old = torch.clone( # type: ignore[attr-defined]
|
||||
self.model.conv1.weight
|
||||
|
@ -75,7 +75,11 @@ class ClassificationTask(pl.LightningModule):
|
|||
w_new = torch.clone( # type: ignore[attr-defined]
|
||||
self.model.conv1.weight
|
||||
).detach()
|
||||
w_new[:, :3, :, :] = w_old
|
||||
if in_channels > 3:
|
||||
w_new[:, :3, :, :] = w_old
|
||||
else:
|
||||
w_old = w_old[:, :in_channels, :, :]
|
||||
w_new[:, :in_channels, :, :] = w_old
|
||||
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501
|
||||
w_new
|
||||
)
|
||||
|
@ -266,6 +270,120 @@ class ClassificationTask(pl.LightningModule):
|
|||
}
|
||||
|
||||
|
||||
class MultiLabelClassificationTask(ClassificationTask):
|
||||
"""Abstract base class for multi label image classification LightningModules."""
|
||||
|
||||
#: number of classes in dataset
|
||||
num_classes: int = 43
|
||||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters passed to the constructor."""
|
||||
self.config_model()
|
||||
|
||||
if self.hparams["loss"] == "bce":
|
||||
self.loss = nn.BCEWithLogitsLoss() # type: ignore[attr-defined]
|
||||
else:
|
||||
raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.")
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
Keyword Args:
|
||||
classification_model: Name of the classification model use
|
||||
loss: Name of the loss function
|
||||
weights: Either "random", "imagenet_only", "imagenet_and_random", or
|
||||
"random_rgb"
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
|
||||
self.config_task()
|
||||
|
||||
self.train_metrics = MetricCollection(
|
||||
{
|
||||
"OverallAccuracy": Accuracy(
|
||||
num_classes=self.num_classes, average="micro", multiclass=False
|
||||
),
|
||||
"AverageAccuracy": Accuracy(
|
||||
num_classes=self.num_classes, average="macro", multiclass=False
|
||||
),
|
||||
"F1Score": FBeta(
|
||||
num_classes=self.num_classes,
|
||||
beta=1.0,
|
||||
average="micro",
|
||||
multiclass=False,
|
||||
),
|
||||
},
|
||||
prefix="train_",
|
||||
)
|
||||
self.val_metrics = self.train_metrics.clone(prefix="val_")
|
||||
self.test_metrics = self.train_metrics.clone(prefix="test_")
|
||||
|
||||
def training_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> Tensor:
|
||||
"""Training step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
Returns:
|
||||
training loss
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined]
|
||||
|
||||
loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined]
|
||||
|
||||
# by default, the train step logs every `log_every_n_steps` steps where
|
||||
# `log_every_n_steps` is a parameter to the `Trainer` object
|
||||
self.log("train_loss", loss, on_step=True, on_epoch=False)
|
||||
self.train_metrics(y_hat_hard, y)
|
||||
|
||||
return cast(Tensor, loss)
|
||||
|
||||
def validation_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Validation step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined]
|
||||
|
||||
loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined]
|
||||
|
||||
self.log("val_loss", loss, on_step=False, on_epoch=True)
|
||||
self.val_metrics(y_hat_hard, y)
|
||||
|
||||
def test_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Test step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined]
|
||||
|
||||
loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined]
|
||||
|
||||
# by default, the test and validation steps only log per *epoch*
|
||||
self.log("test_loss", loss, on_step=False, on_epoch=True)
|
||||
self.test_metrics(y_hat_hard, y)
|
||||
|
||||
|
||||
class RegressionTask(pl.LightningModule):
|
||||
"""LightningModule for training models on regression datasets."""
|
||||
|
||||
|
|
3
train.py
3
train.py
|
@ -14,6 +14,8 @@ from pytorch_lightning import loggers as pl_loggers
|
|||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
from torchgeo.trainers import (
|
||||
BigEarthNetClassificationTask,
|
||||
BigEarthNetDataModule,
|
||||
BYOLTask,
|
||||
ChesapeakeCVPRDataModule,
|
||||
ChesapeakeCVPRSegmentationTask,
|
||||
|
@ -36,6 +38,7 @@ from torchgeo.trainers import (
|
|||
TASK_TO_MODULES_MAPPING: Dict[
|
||||
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
|
||||
] = {
|
||||
"bigearthnet": (BigEarthNetClassificationTask, BigEarthNetDataModule),
|
||||
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
|
||||
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
|
||||
"cyclone": (RegressionTask, CycloneDataModule),
|
||||
|
|
Загрузка…
Ссылка в новой задаче