* add additional bigearthnet test data for train/val/test split

* update bigearthnet dataset length test

* add MultiLabelClassificationTask

* add BigEarthNet trainer and datamodule

* add bigearthnet and multilabelclassificationtask tests

* mypy and format

* add estimated band min/max values for normalization

* softmax outputs to correctly compute metrics

* update min/max stats for 100k samples

* organize imports in torchgeo.trainers.__init__.py

* clean up fixtures in test_tasks.py

* added bigearthnet to train.py

* format

* move fixtures into class methods

* consolidate bigearthnet fixtures

* refactor tasks tests

* add scope=class

* style/mypy fixes

* mypy fixes
This commit is contained in:
isaac 2021-11-02 10:45:38 -05:00 коммит произвёл GitHub
Родитель b8f5a7ce64
Коммит 3cc63def02
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 554 добавлений и 36 удалений

18
conf/bigearthnet.yaml Normal file
Просмотреть файл

@ -0,0 +1,18 @@
trainer:
gpus: 1 # single GPU training
min_epochs: 10
max_epochs: 40
benchmark: True
experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 14
datamodule:
batch_size: 128
num_workers: 6
bands: "all"

Просмотреть файл

@ -0,0 +1,13 @@
experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
in_channels: 14
datamodule:
batch_size: 128
num_workers: 6
bands: "all"

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -71,7 +71,7 @@ class TestBigEarthNet:
assert x["image"].shape == (12, 120, 120) assert x["image"].shape == (12, 120, 120)
def test_len(self, dataset: BigEarthNet) -> None: def test_len(self, dataset: BigEarthNet) -> None:
assert len(dataset) == 2 assert len(dataset) == 4
def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None: def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True) BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True)

Просмотреть файл

@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.trainers import BigEarthNetDataModule
class TestBigEarthNetDataModule:
@pytest.fixture(scope="class", params=zip(["s1", "s2", "all"], [True, True, False]))
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
bands, unsupervised_mode = request.param
root = os.path.join("tests", "data", "bigearthnet")
batch_size = 1
num_workers = 0
dm = BigEarthNetDataModule(
root,
bands,
batch_size,
num_workers,
unsupervised_mode,
val_split_pct=0.3,
test_split_pct=0.3,
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -2,50 +2,114 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import os import os
from typing import Any, Dict, Generator, Tuple, cast from typing import Any, Dict, Generator, Optional, cast
import pytest import pytest
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from _pytest.fixtures import SubRequest from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchgeo.trainers import ( from torchgeo.trainers import (
ClassificationTask, ClassificationTask,
CycloneDataModule, CycloneDataModule,
MultiLabelClassificationTask,
RegressionTask, RegressionTask,
So2SatDataModule,
) )
from .test_utils import mocked_log from .test_utils import mocked_log
@pytest.fixture(scope="module", params=[("rgb", 3), ("s2", 10)]) class DummyDataset(Dataset): # type: ignore[type-arg]
def bands(request: SubRequest) -> Tuple[str, int]: def __init__(self, num_channels: int, num_classes: int, multilabel: bool) -> None:
return cast(Tuple[str, int], request.param) x = torch.randn(10, num_channels, 128, 128) # (b, c, h, w)
y = torch.randint( # type: ignore[attr-defined]
0, num_classes, size=(10,)
) # (b,)
if multilabel:
y = F.one_hot(y, num_classes=num_classes) # (b, classes)
self.dataset = TensorDataset(x, y)
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
x, y = self.dataset[idx]
sample = {"image": x, "label": y}
return sample
@pytest.fixture(scope="module", params=[True, False]) class DummyDataModule(pl.LightningDataModule):
def datamodule(bands: Tuple[str, int], request: SubRequest) -> So2SatDataModule: def __init__(
band_set = bands[0] self,
unsupervised_mode = request.param num_channels: int,
root = os.path.join("tests", "data", "so2sat") num_classes: int,
batch_size = 2 multilabel: bool,
num_workers = 0 batch_size: int = 1,
dm = So2SatDataModule(root, batch_size, num_workers, band_set, unsupervised_mode) num_workers: int = 0,
dm.prepare_data() ) -> None:
dm.setup() super().__init__() # type: ignore[no-untyped-call]
return dm self.num_channels = num_channels
self.num_classes = num_classes
self.multilabel = multilabel
self.batch_size = batch_size
self.num_workers = num_workers
def setup(self, stage: Optional[str] = None) -> None:
self.dataset = DummyDataset(
num_channels=self.num_channels,
num_classes=self.num_classes,
multilabel=self.multilabel,
)
def train_dataloader(self) -> DataLoader: # type: ignore[type-arg]
return DataLoader(
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
)
def val_dataloader(self) -> DataLoader: # type: ignore[type-arg]
return DataLoader(
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
)
def test_dataloader(self) -> DataLoader: # type: ignore[type-arg]
return DataLoader(
self.dataset, batch_size=self.batch_size, num_workers=self.num_workers
)
class TestClassificationTask: class TestClassificationTask:
@pytest.fixture(scope="class", params=[2, 3, 5])
def datamodule(self, request: SubRequest) -> DummyDataModule:
dm = DummyDataModule(
num_channels=request.param,
num_classes=45,
multilabel=False,
batch_size=2,
num_workers=0,
)
dm.prepare_data()
dm.setup()
return dm
@pytest.fixture( @pytest.fixture(
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) scope="class",
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]),
) )
def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: def config(
task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) self, request: SubRequest, datamodule: DummyDataModule
task_args = OmegaConf.to_object(task_conf.experiment.module) ) -> Dict[str, Any]:
task_args = cast(Dict[str, Any], task_args) task_args = {}
task_args["in_channels"] = bands[1] task_args["classification_model"] = "resnet18"
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
loss, weights = request.param loss, weights = request.param
task_args["loss"] = loss task_args["loss"] = loss
task_args["weights"] = weights task_args["weights"] = weights
@ -65,20 +129,20 @@ class TestClassificationTask:
assert "lr_scheduler" in out assert "lr_scheduler" in out
def test_training( def test_training(
self, datamodule: So2SatDataModule, task: ClassificationTask self, datamodule: DummyDataModule, task: ClassificationTask
) -> None: ) -> None:
batch = next(iter(datamodule.train_dataloader())) batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0) task.training_step(batch, 0)
task.training_epoch_end(0) task.training_epoch_end(0)
def test_validation( def test_validation(
self, datamodule: So2SatDataModule, task: ClassificationTask self, datamodule: DummyDataModule, task: ClassificationTask
) -> None: ) -> None:
batch = next(iter(datamodule.val_dataloader())) batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0) task.validation_step(batch, 0)
task.validation_epoch_end(0) task.validation_epoch_end(0)
def test_test(self, datamodule: So2SatDataModule, task: ClassificationTask) -> None: def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None:
batch = next(iter(datamodule.test_dataloader())) batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0) task.test_step(batch, 0)
task.test_epoch_end(0) task.test_epoch_end(0)
@ -99,6 +163,7 @@ class TestClassificationTask:
def test_invalid_loss(self, config: Dict[str, Any]) -> None: def test_invalid_loss(self, config: Dict[str, Any]) -> None:
config["loss"] = "invalid_loss" config["loss"] = "invalid_loss"
config["classification_model"] = "resnet18"
error_message = "Loss type 'invalid_loss' is not valid." error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message): with pytest.raises(ValueError, match=error_message):
ClassificationTask(**config) ClassificationTask(**config)
@ -117,6 +182,68 @@ class TestClassificationTask:
ClassificationTask(**config) ClassificationTask(**config)
class TestMultiLabelClassificationTask:
@pytest.fixture(scope="class")
def datamodule(self, request: SubRequest) -> DummyDataModule:
dm = DummyDataModule(
num_channels=3,
num_classes=43,
multilabel=True,
batch_size=2,
num_workers=0,
)
dm.prepare_data()
dm.setup()
return dm
@pytest.fixture(scope="class", params=zip(["bce", "bce"], ["imagenet", "random"]))
def config(
self, datamodule: DummyDataModule, request: SubRequest
) -> Dict[str, Any]:
task_args = {}
task_args["classification_model"] = "resnet18"
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
loss, weights = request.param
task_args["loss"] = loss
task_args["weights"] = weights
return task_args
@pytest.fixture
def task(
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
) -> MultiLabelClassificationTask:
task = MultiLabelClassificationTask(**config)
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
return task
def test_training(
self, datamodule: DummyDataModule, task: ClassificationTask
) -> None:
batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0)
task.training_epoch_end(0)
def test_validation(
self, datamodule: DummyDataModule, task: ClassificationTask
) -> None:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
task.validation_epoch_end(0)
def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None:
batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0)
task.test_epoch_end(0)
def test_invalid_loss(self, config: Dict[str, Any]) -> None:
config["loss"] = "invalid_loss"
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
MultiLabelClassificationTask(**config)
class TestRegressionTask: class TestRegressionTask:
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def datamodule(self) -> CycloneDataModule: def datamodule(self) -> CycloneDataModule:

Просмотреть файл

@ -3,6 +3,7 @@
"""TorchGeo trainers.""" """TorchGeo trainers."""
from .bigearthnet import BigEarthNetClassificationTask, BigEarthNetDataModule
from .byol import BYOLTask from .byol import BYOLTask
from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
from .cyclone import CycloneDataModule from .cyclone import CycloneDataModule
@ -11,29 +12,35 @@ from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentation
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
from .so2sat import So2SatClassificationTask, So2SatDataModule from .so2sat import So2SatClassificationTask, So2SatDataModule
from .tasks import ClassificationTask, RegressionTask from .tasks import ClassificationTask, MultiLabelClassificationTask, RegressionTask
from .ucmerced import UCMercedClassificationTask, UCMercedDataModule from .ucmerced import UCMercedClassificationTask, UCMercedDataModule
__all__ = ( __all__ = (
# Tasks # Tasks
"ClassificationTask", "BigEarthNetClassificationTask",
"RegressionTask",
# Trainers
"BYOLTask", "BYOLTask",
"ChesapeakeCVPRSegmentationTask", "ChesapeakeCVPRSegmentationTask",
"ChesapeakeCVPRDataModule", "ChesapeakeCVPRDataModule",
"ClassificationTask",
"CycloneDataModule", "CycloneDataModule",
"LandcoverAIDataModule", "LandcoverAIDataModule",
"LandcoverAISegmentationTask", "LandcoverAISegmentationTask",
"NAIPChesapeakeDataModule", "MultiLabelClassificationTask",
"NAIPChesapeakeSegmentationTask", "NAIPChesapeakeSegmentationTask",
"RESISC45ClassificationTask", "RESISC45ClassificationTask",
"RESISC45DataModule", "RegressionTask",
"SEN12MSDataModule",
"SEN12MSSegmentationTask", "SEN12MSSegmentationTask",
"So2SatDataModule",
"So2SatClassificationTask", "So2SatClassificationTask",
"UCMercedClassificationTask", "UCMercedClassificationTask",
# DataModules
"BigEarthNetDataModule",
"ChesapeakeCVPRDataModule",
"CycloneDataModule",
"LandcoverAIDataModule",
"NAIPChesapeakeDataModule",
"RESISC45DataModule",
"SEN12MSDataModule",
"So2SatDataModule",
"UCMercedDataModule", "UCMercedDataModule",
) )

Просмотреть файл

@ -0,0 +1,193 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""BigEarthNet trainer."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets import BigEarthNet
from ..datasets.utils import dataset_split
from .tasks import MultiLabelClassificationTask
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class BigEarthNetClassificationTask(MultiLabelClassificationTask):
"""LightningModule for training models on the BigEarthNet Dataset."""
num_classes = 43
class BigEarthNetDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the BigEarthNet dataset.
Uses the train/val/test splits from the dataset.
"""
# (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
# min/max band statistics computed on 100k random samples
band_mins_raw = torch.tensor( # type: ignore[attr-defined]
[-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
)
band_maxs_raw = torch.tensor( # type: ignore[attr-defined]
[
31.0,
35.0,
18556.0,
20528.0,
18976.0,
17874.0,
16611.0,
16512.0,
16394.0,
16672.0,
16141.0,
16097.0,
15336.0,
15203.0,
]
)
# min/max band statistics computed by percentile clipping the
# above to samples to [2, 98]
band_mins = torch.tensor( # type: ignore[attr-defined]
[-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
)
band_maxs = torch.tensor( # type: ignore[attr-defined]
[
6.0,
16.0,
9859.0,
12872.0,
13163.0,
14445.0,
12477.0,
12563.0,
12289.0,
15596.0,
12183.0,
9458.0,
5897.0,
5544.0,
]
)
def __init__(
self,
root_dir: str,
bands: str = "all",
batch_size: int = 64,
num_workers: int = 4,
unsupervised_mode: bool = False,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for BigEarthNet based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes
bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all}
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
unsupervised_mode: Makes the train dataloader return imagery from the train,
val, and test sets
val_split_pct: What percentage of the dataset to use as a validation set
test_split_pct: What percentage of the dataset to use as a test set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.bands = bands
self.batch_size = batch_size
self.num_workers = num_workers
self.unsupervised_mode = unsupervised_mode
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
if bands == "all":
self.mins = self.band_mins[:, None, None]
self.maxs = self.band_maxs[:, None, None]
elif bands == "s1":
self.mins = self.band_mins[:2, None, None]
self.maxs = self.band_maxs[:2, None, None]
else:
self.mins = self.band_mins[2:, None, None]
self.maxs = self.band_maxs[2:, None, None]
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins)
sample["image"] = torch.clip( # type: ignore[attr-defined]
sample["image"], min=0.0, max=1.0
)
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
BigEarthNet(self.root_dir, bands=self.bands, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
"""
transforms = Compose([self.preprocess])
if not self.unsupervised_mode:
dataset = BigEarthNet(
self.root_dir, bands=self.bands, transforms=transforms
)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
else:
self.train_dataset = BigEarthNet( # type: ignore[assignment]
self.root_dir, bands=self.bands, transforms=transforms
)
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
if self.unsupervised_mode or self.val_split_pct == 0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
if self.unsupervised_mode or self.test_split_pct == 0:
return self.train_dataloader()
else:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -57,7 +57,7 @@ class ClassificationTask(pl.LightningModule):
# Update first layer # Update first layer
if in_channels != 3: if in_channels != 3:
w_old = None w_old = torch.empty(0) # type: ignore[attr-defined]
if pretrained: if pretrained:
w_old = torch.clone( # type: ignore[attr-defined] w_old = torch.clone( # type: ignore[attr-defined]
self.model.conv1.weight self.model.conv1.weight
@ -75,7 +75,11 @@ class ClassificationTask(pl.LightningModule):
w_new = torch.clone( # type: ignore[attr-defined] w_new = torch.clone( # type: ignore[attr-defined]
self.model.conv1.weight self.model.conv1.weight
).detach() ).detach()
w_new[:, :3, :, :] = w_old if in_channels > 3:
w_new[:, :3, :, :] = w_old
else:
w_old = w_old[:, :in_channels, :, :]
w_new[:, :in_channels, :, :] = w_old
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501 self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501
w_new w_new
) )
@ -266,6 +270,120 @@ class ClassificationTask(pl.LightningModule):
} }
class MultiLabelClassificationTask(ClassificationTask):
"""Abstract base class for multi label image classification LightningModules."""
#: number of classes in dataset
num_classes: int = 43
def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
self.config_model()
if self.hparams["loss"] == "bce":
self.loss = nn.BCEWithLogitsLoss() # type: ignore[attr-defined]
else:
raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.")
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.
Keyword Args:
classification_model: Name of the classification model use
loss: Name of the loss function
weights: Either "random", "imagenet_only", "imagenet_and_random", or
"random_rgb"
"""
super().__init__(**kwargs)
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.config_task()
self.train_metrics = MetricCollection(
{
"OverallAccuracy": Accuracy(
num_classes=self.num_classes, average="micro", multiclass=False
),
"AverageAccuracy": Accuracy(
num_classes=self.num_classes, average="macro", multiclass=False
),
"F1Score": FBeta(
num_classes=self.num_classes,
beta=1.0,
average="micro",
multiclass=False,
),
},
prefix="train_",
)
self.val_metrics = self.train_metrics.clone(prefix="val_")
self.test_metrics = self.train_metrics.clone(prefix="test_")
def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step.
Args:
batch: Current batch
batch_idx: Index of current batch
Returns:
training loss
"""
x = batch["image"]
y = batch["label"]
y_hat = self.forward(x)
y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined]
loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined]
# by default, the train step logs every `log_every_n_steps` steps where
# `log_every_n_steps` is a parameter to the `Trainer` object
self.log("train_loss", loss, on_step=True, on_epoch=False)
self.train_metrics(y_hat_hard, y)
return cast(Tensor, loss)
def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Validation step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["label"]
y_hat = self.forward(x)
y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined]
loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined]
self.log("val_loss", loss, on_step=False, on_epoch=True)
self.val_metrics(y_hat_hard, y)
def test_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Test step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["label"]
y_hat = self.forward(x)
y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined]
loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined]
# by default, the test and validation steps only log per *epoch*
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.test_metrics(y_hat_hard, y)
class RegressionTask(pl.LightningModule): class RegressionTask(pl.LightningModule):
"""LightningModule for training models on regression datasets.""" """LightningModule for training models on regression datasets."""

Просмотреть файл

@ -14,6 +14,8 @@ from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torchgeo.trainers import ( from torchgeo.trainers import (
BigEarthNetClassificationTask,
BigEarthNetDataModule,
BYOLTask, BYOLTask,
ChesapeakeCVPRDataModule, ChesapeakeCVPRDataModule,
ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRSegmentationTask,
@ -36,6 +38,7 @@ from torchgeo.trainers import (
TASK_TO_MODULES_MAPPING: Dict[ TASK_TO_MODULES_MAPPING: Dict[
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
] = { ] = {
"bigearthnet": (BigEarthNetClassificationTask, BigEarthNetDataModule),
"byol": (BYOLTask, ChesapeakeCVPRDataModule), "byol": (BYOLTask, ChesapeakeCVPRDataModule),
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), "chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
"cyclone": (RegressionTask, CycloneDataModule), "cyclone": (RegressionTask, CycloneDataModule),