Refactoring: RegressionTask trainer (#215)

* Refactoring: RegressionTask trainer

* Fix import sorting

* Update trainer tutorial

* Use torchmetrics for metric logging
This commit is contained in:
Adam J. Stewart 2021-11-01 12:53:09 -05:00 коммит произвёл GitHub
Родитель 69598528e6
Коммит 3446ea5f47
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 217 добавлений и 185 удалений

Просмотреть файл

@ -78,7 +78,7 @@
"import pytorch_lightning as pl\n",
"from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
"from pytorch_lightning.loggers import CSVLogger\n",
"from torchgeo.trainers import CycloneSimpleRegressionTask, CycloneDataModule"
"from torchgeo.trainers import CycloneDataModule, RegressionTask"
]
},
{
@ -134,7 +134,7 @@
"id": "HQVji2B22Qfu"
},
"source": [
"Next, we create a `CycloneSimpleRegressionTask` object that holds the model object, optimizer object, and training logic."
"Next, we create a `RegressionTask` object that holds the model object, optimizer object, and training logic."
]
},
{
@ -150,7 +150,7 @@
},
"outputs": [],
"source": [
"task = CycloneSimpleRegressionTask(\n",
"task = RegressionTask(\n",
" model=\"resnet18\",\n",
" learning_rate=0.1,\n",
" learning_rate_schedule_patience=5\n",

Просмотреть файл

@ -2,15 +2,10 @@
# Licensed under the MIT License.
import os
from typing import Any, Dict, Generator, cast
import pytest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from torchgeo.trainers import CycloneDataModule, CycloneSimpleRegressionTask
from .test_utils import mocked_log
from torchgeo.trainers import CycloneDataModule
@pytest.fixture(scope="module")
@ -25,54 +20,6 @@ def datamodule() -> CycloneDataModule:
return dm
class TestCycloneSimpleRegressionTask:
@pytest.fixture
def config(self) -> Dict[str, Any]:
task_conf = OmegaConf.load(
os.path.join("conf", "task_defaults", "cyclone.yaml")
)
task_args = OmegaConf.to_object(task_conf.experiment.module)
task_args = cast(Dict[str, Any], task_args)
return task_args
@pytest.fixture
def task(
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
) -> CycloneSimpleRegressionTask:
task = CycloneSimpleRegressionTask(**config)
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
return task
def test_configure_optimizers(self, task: CycloneSimpleRegressionTask) -> None:
out = task.configure_optimizers()
assert "optimizer" in out
assert "lr_scheduler" in out
def test_training(
self, datamodule: CycloneDataModule, task: CycloneSimpleRegressionTask
) -> None:
batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0)
def test_validation(
self, datamodule: CycloneDataModule, task: CycloneSimpleRegressionTask
) -> None:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
def test_test(
self, datamodule: CycloneDataModule, task: CycloneSimpleRegressionTask
) -> None:
batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0)
def test_invalid_model(self, config: Dict[str, Any]) -> None:
config["model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
CycloneSimpleRegressionTask(**config)
class TestCycloneDataModule:
def test_train_dataloader(self, datamodule: CycloneDataModule) -> None:
next(iter(datamodule.train_dataloader()))

Просмотреть файл

@ -9,7 +9,12 @@ from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from torchgeo.trainers import ClassificationTask, So2SatDataModule
from torchgeo.trainers import (
ClassificationTask,
CycloneDataModule,
RegressionTask,
So2SatDataModule,
)
from .test_utils import mocked_log
@ -110,3 +115,63 @@ class TestClassificationTask:
error_message = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=error_message):
ClassificationTask(**config)
class TestRegressionTask:
@pytest.fixture(scope="class")
def datamodule(self) -> CycloneDataModule:
root = os.path.join("tests", "data", "cyclone")
seed = 0
batch_size = 1
num_workers = 0
dm = CycloneDataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
@pytest.fixture
def config(self) -> Dict[str, Any]:
task_conf = OmegaConf.load(
os.path.join("conf", "task_defaults", "cyclone.yaml")
)
task_args = OmegaConf.to_object(task_conf.experiment.module)
task_args = cast(Dict[str, Any], task_args)
return task_args
@pytest.fixture
def task(
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
) -> RegressionTask:
task = RegressionTask(**config)
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
return task
def test_configure_optimizers(self, task: RegressionTask) -> None:
out = task.configure_optimizers()
assert "optimizer" in out
assert "lr_scheduler" in out
def test_training(
self, datamodule: CycloneDataModule, task: RegressionTask
) -> None:
batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0)
task.training_epoch_end(0)
def test_validation(
self, datamodule: CycloneDataModule, task: RegressionTask
) -> None:
batch = next(iter(datamodule.val_dataloader()))
task.validation_step(batch, 0)
task.validation_epoch_end(0)
def test_test(self, datamodule: CycloneDataModule, task: RegressionTask) -> None:
batch = next(iter(datamodule.test_dataloader()))
task.test_step(batch, 0)
task.test_epoch_end(0)
def test_invalid_model(self, config: Dict[str, Any]) -> None:
config["model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
RegressionTask(**config)

Просмотреть файл

@ -5,24 +5,24 @@
from .byol import BYOLTask
from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
from .cyclone import CycloneDataModule
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
from .so2sat import So2SatClassificationTask, So2SatDataModule
from .tasks import ClassificationTask
from .tasks import ClassificationTask, RegressionTask
from .ucmerced import UCMercedClassificationTask, UCMercedDataModule
__all__ = (
# Tasks
"ClassificationTask",
"RegressionTask",
# Trainers
"BYOLTask",
"ChesapeakeCVPRSegmentationTask",
"ChesapeakeCVPRDataModule",
"CycloneDataModule",
"CycloneSimpleRegressionTask",
"LandcoverAIDataModule",
"LandcoverAISegmentationTask",
"NAIPChesapeakeDataModule",

Просмотреть файл

@ -7,13 +7,9 @@ from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.nn.modules import Module
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Subset
from torchvision import models
from ..datasets import TropicalCycloneWindEstimation
@ -23,120 +19,6 @@ DataLoader.__module__ = "torch.utils.data"
Module.__module__ = "torch.nn"
class CycloneSimpleRegressionTask(pl.LightningModule):
"""LightningModule for training models on the NASA Cyclone Dataset using MSE loss.
This does not take into account other per-sample features available in this dataset.
"""
def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
if self.hparams["model"] == "resnet18":
self.model = models.resnet18(pretrained=False, num_classes=1)
else:
raise ValueError(f"Model type '{self.hparams['model']}' is not valid.")
def __init__(self, **kwargs: Any) -> None:
"""Initialize a new LightningModule for training simple regression models.
Keyword Args:
model: Name of the model to use
learning_rate: Initial learning rate to use in the optimizer
learning_rate_schedule_patience: Patience parameter for the LR scheduler
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.config_task()
def forward(self, x: Tensor) -> Any: # type: ignore[override]
"""Forward pass of the model."""
return self.model(x)
def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step with an MSE loss.
Args:
batch: Current batch
batch_idx: Index of current batch
Returns:
training loss
"""
x = batch["image"]
y = batch["target"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
self.log("train_loss", loss) # logging to TensorBoard
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
self.log("train_rmse", rmse)
return loss
def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Validation step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["target"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
self.log("val_loss", loss)
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
self.log("val_rmse", rmse)
def test_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Test step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["target"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
self.log("test_loss", loss)
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
self.log("test_rmse", rmse)
def configure_optimizers(self) -> Dict[str, Any]:
"""Initialize the optimizer and learning rate scheduler.
Returns:
a "lr dict" according to the pytorch lightning documentation --
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
"""
optimizer = torch.optim.AdamW(
self.model.parameters(), lr=self.hparams["learning_rate"]
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(
optimizer, patience=self.hparams["learning_rate_schedule_patience"]
),
"monitor": "val_loss",
},
}
class CycloneDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NASA Cyclone dataset.

Просмотреть файл

@ -9,12 +9,14 @@ from typing import Any, Dict, cast
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import Accuracy, FBeta, IoU, MetricCollection
from torchmetrics import Accuracy, FBeta, IoU, MeanSquaredError, MetricCollection
from torchvision import models
from . import utils
@ -132,9 +134,7 @@ class ClassificationTask(pl.LightningModule):
),
"IoU": IoU(num_classes=self.num_classes),
"F1Score": FBeta(
num_classes=self.num_classes,
beta=1.0,
average="micro",
num_classes=self.num_classes, beta=1.0, average="micro"
),
},
prefix="train_",
@ -264,3 +264,141 @@ class ClassificationTask(pl.LightningModule):
"monitor": "val_loss",
},
}
class RegressionTask(pl.LightningModule):
"""LightningModule for training models on regression datasets."""
def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
if self.hparams["model"] == "resnet18":
self.model = models.resnet18(pretrained=False, num_classes=1)
else:
raise ValueError(f"Model type '{self.hparams['model']}' is not valid.")
def __init__(self, **kwargs: Any) -> None:
"""Initialize a new LightningModule for training simple regression models.
Keyword Args:
model: Name of the model to use
learning_rate: Initial learning rate to use in the optimizer
learning_rate_schedule_patience: Patience parameter for the LR scheduler
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.config_task()
self.train_metrics = MetricCollection(
{"RMSE": MeanSquaredError(squared=False)},
prefix="train_",
)
self.val_metrics = self.train_metrics.clone(prefix="val_")
self.test_metrics = self.train_metrics.clone(prefix="test_")
def forward(self, x: Tensor) -> Any: # type: ignore[override]
"""Forward pass of the model."""
return self.model(x)
def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step with an MSE loss.
Args:
batch: Current batch
batch_idx: Index of current batch
Returns:
training loss
"""
x = batch["image"]
y = batch["target"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
self.log("train_loss", loss) # logging to TensorBoard
self.train_metrics(y_hat, y)
return loss
def training_epoch_end(self, outputs: Any) -> None:
"""Logs epoch-level training metrics.
Args:
outputs: list of items returned by training_step
"""
self.log_dict(self.train_metrics.compute())
self.train_metrics.reset()
def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Validation step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["target"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
self.log("val_loss", loss)
self.val_metrics(y_hat, y)
def validation_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level validation metrics.
Args:
outputs: list of items returned by validation_step
"""
self.log_dict(self.val_metrics.compute())
self.val_metrics.reset()
def test_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Test step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["target"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
self.log("test_loss", loss)
self.test_metrics(y_hat, y)
def test_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level test metrics.
Args:
outputs: list of items returned by test_step
"""
self.log_dict(self.test_metrics.compute())
self.test_metrics.reset()
def configure_optimizers(self) -> Dict[str, Any]:
"""Initialize the optimizer and learning rate scheduler.
Returns:
a "lr dict" according to the pytorch lightning documentation --
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
"""
optimizer = torch.optim.AdamW(
self.model.parameters(), lr=self.hparams["learning_rate"]
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(
optimizer, patience=self.hparams["learning_rate_schedule_patience"]
),
"monitor": "val_loss",
},
}

Просмотреть файл

@ -18,11 +18,11 @@ from torchgeo.trainers import (
ChesapeakeCVPRDataModule,
ChesapeakeCVPRSegmentationTask,
CycloneDataModule,
CycloneSimpleRegressionTask,
LandcoverAIDataModule,
LandcoverAISegmentationTask,
NAIPChesapeakeDataModule,
NAIPChesapeakeSegmentationTask,
RegressionTask,
RESISC45ClassificationTask,
RESISC45DataModule,
SEN12MSDataModule,
@ -38,7 +38,7 @@ TASK_TO_MODULES_MAPPING: Dict[
] = {
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
"cyclone": (RegressionTask, CycloneDataModule),
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),