зеркало из https://github.com/microsoft/torchgeo.git
Refactoring: RegressionTask trainer (#215)
* Refactoring: RegressionTask trainer * Fix import sorting * Update trainer tutorial * Use torchmetrics for metric logging
This commit is contained in:
Родитель
69598528e6
Коммит
3446ea5f47
|
@ -78,7 +78,7 @@
|
|||
"import pytorch_lightning as pl\n",
|
||||
"from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
|
||||
"from pytorch_lightning.loggers import CSVLogger\n",
|
||||
"from torchgeo.trainers import CycloneSimpleRegressionTask, CycloneDataModule"
|
||||
"from torchgeo.trainers import CycloneDataModule, RegressionTask"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -134,7 +134,7 @@
|
|||
"id": "HQVji2B22Qfu"
|
||||
},
|
||||
"source": [
|
||||
"Next, we create a `CycloneSimpleRegressionTask` object that holds the model object, optimizer object, and training logic."
|
||||
"Next, we create a `RegressionTask` object that holds the model object, optimizer object, and training logic."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -150,7 +150,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"task = CycloneSimpleRegressionTask(\n",
|
||||
"task = RegressionTask(\n",
|
||||
" model=\"resnet18\",\n",
|
||||
" learning_rate=0.1,\n",
|
||||
" learning_rate_schedule_patience=5\n",
|
||||
|
|
|
@ -2,15 +2,10 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Generator, cast
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.trainers import CycloneDataModule, CycloneSimpleRegressionTask
|
||||
|
||||
from .test_utils import mocked_log
|
||||
from torchgeo.trainers import CycloneDataModule
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -25,54 +20,6 @@ def datamodule() -> CycloneDataModule:
|
|||
return dm
|
||||
|
||||
|
||||
class TestCycloneSimpleRegressionTask:
|
||||
@pytest.fixture
|
||||
def config(self) -> Dict[str, Any]:
|
||||
task_conf = OmegaConf.load(
|
||||
os.path.join("conf", "task_defaults", "cyclone.yaml")
|
||||
)
|
||||
task_args = OmegaConf.to_object(task_conf.experiment.module)
|
||||
task_args = cast(Dict[str, Any], task_args)
|
||||
return task_args
|
||||
|
||||
@pytest.fixture
|
||||
def task(
|
||||
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
|
||||
) -> CycloneSimpleRegressionTask:
|
||||
task = CycloneSimpleRegressionTask(**config)
|
||||
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
|
||||
return task
|
||||
|
||||
def test_configure_optimizers(self, task: CycloneSimpleRegressionTask) -> None:
|
||||
out = task.configure_optimizers()
|
||||
assert "optimizer" in out
|
||||
assert "lr_scheduler" in out
|
||||
|
||||
def test_training(
|
||||
self, datamodule: CycloneDataModule, task: CycloneSimpleRegressionTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.train_dataloader()))
|
||||
task.training_step(batch, 0)
|
||||
|
||||
def test_validation(
|
||||
self, datamodule: CycloneDataModule, task: CycloneSimpleRegressionTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.val_dataloader()))
|
||||
task.validation_step(batch, 0)
|
||||
|
||||
def test_test(
|
||||
self, datamodule: CycloneDataModule, task: CycloneSimpleRegressionTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.test_dataloader()))
|
||||
task.test_step(batch, 0)
|
||||
|
||||
def test_invalid_model(self, config: Dict[str, Any]) -> None:
|
||||
config["model"] = "invalid_model"
|
||||
error_message = "Model type 'invalid_model' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
CycloneSimpleRegressionTask(**config)
|
||||
|
||||
|
||||
class TestCycloneDataModule:
|
||||
def test_train_dataloader(self, datamodule: CycloneDataModule) -> None:
|
||||
next(iter(datamodule.train_dataloader()))
|
||||
|
|
|
@ -9,7 +9,12 @@ from _pytest.fixtures import SubRequest
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.trainers import ClassificationTask, So2SatDataModule
|
||||
from torchgeo.trainers import (
|
||||
ClassificationTask,
|
||||
CycloneDataModule,
|
||||
RegressionTask,
|
||||
So2SatDataModule,
|
||||
)
|
||||
|
||||
from .test_utils import mocked_log
|
||||
|
||||
|
@ -110,3 +115,63 @@ class TestClassificationTask:
|
|||
error_message = "Trying to load resnet18 weights into a resnet50"
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
ClassificationTask(**config)
|
||||
|
||||
|
||||
class TestRegressionTask:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> CycloneDataModule:
|
||||
root = os.path.join("tests", "data", "cyclone")
|
||||
seed = 0
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
dm = CycloneDataModule(root, seed, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> Dict[str, Any]:
|
||||
task_conf = OmegaConf.load(
|
||||
os.path.join("conf", "task_defaults", "cyclone.yaml")
|
||||
)
|
||||
task_args = OmegaConf.to_object(task_conf.experiment.module)
|
||||
task_args = cast(Dict[str, Any], task_args)
|
||||
return task_args
|
||||
|
||||
@pytest.fixture
|
||||
def task(
|
||||
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
|
||||
) -> RegressionTask:
|
||||
task = RegressionTask(**config)
|
||||
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
|
||||
return task
|
||||
|
||||
def test_configure_optimizers(self, task: RegressionTask) -> None:
|
||||
out = task.configure_optimizers()
|
||||
assert "optimizer" in out
|
||||
assert "lr_scheduler" in out
|
||||
|
||||
def test_training(
|
||||
self, datamodule: CycloneDataModule, task: RegressionTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.train_dataloader()))
|
||||
task.training_step(batch, 0)
|
||||
task.training_epoch_end(0)
|
||||
|
||||
def test_validation(
|
||||
self, datamodule: CycloneDataModule, task: RegressionTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.val_dataloader()))
|
||||
task.validation_step(batch, 0)
|
||||
task.validation_epoch_end(0)
|
||||
|
||||
def test_test(self, datamodule: CycloneDataModule, task: RegressionTask) -> None:
|
||||
batch = next(iter(datamodule.test_dataloader()))
|
||||
task.test_step(batch, 0)
|
||||
task.test_epoch_end(0)
|
||||
|
||||
def test_invalid_model(self, config: Dict[str, Any]) -> None:
|
||||
config["model"] = "invalid_model"
|
||||
error_message = "Model type 'invalid_model' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
RegressionTask(**config)
|
||||
|
|
|
@ -5,24 +5,24 @@
|
|||
|
||||
from .byol import BYOLTask
|
||||
from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
|
||||
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
|
||||
from .cyclone import CycloneDataModule
|
||||
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
|
||||
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
|
||||
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
|
||||
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
|
||||
from .so2sat import So2SatClassificationTask, So2SatDataModule
|
||||
from .tasks import ClassificationTask
|
||||
from .tasks import ClassificationTask, RegressionTask
|
||||
from .ucmerced import UCMercedClassificationTask, UCMercedDataModule
|
||||
|
||||
__all__ = (
|
||||
# Tasks
|
||||
"ClassificationTask",
|
||||
"RegressionTask",
|
||||
# Trainers
|
||||
"BYOLTask",
|
||||
"ChesapeakeCVPRSegmentationTask",
|
||||
"ChesapeakeCVPRDataModule",
|
||||
"CycloneDataModule",
|
||||
"CycloneSimpleRegressionTask",
|
||||
"LandcoverAIDataModule",
|
||||
"LandcoverAISegmentationTask",
|
||||
"NAIPChesapeakeDataModule",
|
||||
|
|
|
@ -7,13 +7,9 @@ from typing import Any, Dict, Optional
|
|||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sklearn.model_selection import GroupShuffleSplit
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Module
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torchvision import models
|
||||
|
||||
from ..datasets import TropicalCycloneWindEstimation
|
||||
|
||||
|
@ -23,120 +19,6 @@ DataLoader.__module__ = "torch.utils.data"
|
|||
Module.__module__ = "torch.nn"
|
||||
|
||||
|
||||
class CycloneSimpleRegressionTask(pl.LightningModule):
|
||||
"""LightningModule for training models on the NASA Cyclone Dataset using MSE loss.
|
||||
|
||||
This does not take into account other per-sample features available in this dataset.
|
||||
"""
|
||||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters."""
|
||||
if self.hparams["model"] == "resnet18":
|
||||
self.model = models.resnet18(pretrained=False, num_classes=1)
|
||||
else:
|
||||
raise ValueError(f"Model type '{self.hparams['model']}' is not valid.")
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize a new LightningModule for training simple regression models.
|
||||
|
||||
Keyword Args:
|
||||
model: Name of the model to use
|
||||
learning_rate: Initial learning rate to use in the optimizer
|
||||
learning_rate_schedule_patience: Patience parameter for the LR scheduler
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
|
||||
self.config_task()
|
||||
|
||||
def forward(self, x: Tensor) -> Any: # type: ignore[override]
|
||||
"""Forward pass of the model."""
|
||||
return self.model(x)
|
||||
|
||||
def training_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> Tensor:
|
||||
"""Training step with an MSE loss.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
|
||||
Returns:
|
||||
training loss
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["target"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
|
||||
self.log("train_loss", loss) # logging to TensorBoard
|
||||
|
||||
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
|
||||
self.log("train_rmse", rmse)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Validation step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["target"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
self.log("val_loss", loss)
|
||||
|
||||
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
|
||||
self.log("val_rmse", rmse)
|
||||
|
||||
def test_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Test step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["target"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
self.log("test_loss", loss)
|
||||
|
||||
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
|
||||
self.log("test_rmse", rmse)
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""Initialize the optimizer and learning rate scheduler.
|
||||
|
||||
Returns:
|
||||
a "lr dict" according to the pytorch lightning documentation --
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
|
||||
"""
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(), lr=self.hparams["learning_rate"]
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": ReduceLROnPlateau(
|
||||
optimizer, patience=self.hparams["learning_rate_schedule_patience"]
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CycloneDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the NASA Cyclone dataset.
|
||||
|
||||
|
|
|
@ -9,12 +9,14 @@ from typing import Any, Dict, cast
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models
|
||||
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Conv2d, Linear
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchmetrics import Accuracy, FBeta, IoU, MetricCollection
|
||||
from torchmetrics import Accuracy, FBeta, IoU, MeanSquaredError, MetricCollection
|
||||
from torchvision import models
|
||||
|
||||
from . import utils
|
||||
|
||||
|
@ -132,9 +134,7 @@ class ClassificationTask(pl.LightningModule):
|
|||
),
|
||||
"IoU": IoU(num_classes=self.num_classes),
|
||||
"F1Score": FBeta(
|
||||
num_classes=self.num_classes,
|
||||
beta=1.0,
|
||||
average="micro",
|
||||
num_classes=self.num_classes, beta=1.0, average="micro"
|
||||
),
|
||||
},
|
||||
prefix="train_",
|
||||
|
@ -264,3 +264,141 @@ class ClassificationTask(pl.LightningModule):
|
|||
"monitor": "val_loss",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class RegressionTask(pl.LightningModule):
|
||||
"""LightningModule for training models on regression datasets."""
|
||||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters."""
|
||||
if self.hparams["model"] == "resnet18":
|
||||
self.model = models.resnet18(pretrained=False, num_classes=1)
|
||||
else:
|
||||
raise ValueError(f"Model type '{self.hparams['model']}' is not valid.")
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize a new LightningModule for training simple regression models.
|
||||
|
||||
Keyword Args:
|
||||
model: Name of the model to use
|
||||
learning_rate: Initial learning rate to use in the optimizer
|
||||
learning_rate_schedule_patience: Patience parameter for the LR scheduler
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
self.config_task()
|
||||
|
||||
self.train_metrics = MetricCollection(
|
||||
{"RMSE": MeanSquaredError(squared=False)},
|
||||
prefix="train_",
|
||||
)
|
||||
self.val_metrics = self.train_metrics.clone(prefix="val_")
|
||||
self.test_metrics = self.train_metrics.clone(prefix="test_")
|
||||
|
||||
def forward(self, x: Tensor) -> Any: # type: ignore[override]
|
||||
"""Forward pass of the model."""
|
||||
return self.model(x)
|
||||
|
||||
def training_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> Tensor:
|
||||
"""Training step with an MSE loss.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
|
||||
Returns:
|
||||
training loss
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["target"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
|
||||
self.log("train_loss", loss) # logging to TensorBoard
|
||||
self.train_metrics(y_hat, y)
|
||||
|
||||
return loss
|
||||
|
||||
def training_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch-level training metrics.
|
||||
|
||||
Args:
|
||||
outputs: list of items returned by training_step
|
||||
"""
|
||||
self.log_dict(self.train_metrics.compute())
|
||||
self.train_metrics.reset()
|
||||
|
||||
def validation_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Validation step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["target"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
self.log("val_loss", loss)
|
||||
self.val_metrics(y_hat, y)
|
||||
|
||||
def validation_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level validation metrics.
|
||||
|
||||
Args:
|
||||
outputs: list of items returned by validation_step
|
||||
"""
|
||||
self.log_dict(self.val_metrics.compute())
|
||||
self.val_metrics.reset()
|
||||
|
||||
def test_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Test step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["target"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
self.log("test_loss", loss)
|
||||
self.test_metrics(y_hat, y)
|
||||
|
||||
def test_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level test metrics.
|
||||
|
||||
Args:
|
||||
outputs: list of items returned by test_step
|
||||
"""
|
||||
self.log_dict(self.test_metrics.compute())
|
||||
self.test_metrics.reset()
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""Initialize the optimizer and learning rate scheduler.
|
||||
|
||||
Returns:
|
||||
a "lr dict" according to the pytorch lightning documentation --
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
|
||||
"""
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(), lr=self.hparams["learning_rate"]
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": ReduceLROnPlateau(
|
||||
optimizer, patience=self.hparams["learning_rate_schedule_patience"]
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
},
|
||||
}
|
||||
|
|
4
train.py
4
train.py
|
@ -18,11 +18,11 @@ from torchgeo.trainers import (
|
|||
ChesapeakeCVPRDataModule,
|
||||
ChesapeakeCVPRSegmentationTask,
|
||||
CycloneDataModule,
|
||||
CycloneSimpleRegressionTask,
|
||||
LandcoverAIDataModule,
|
||||
LandcoverAISegmentationTask,
|
||||
NAIPChesapeakeDataModule,
|
||||
NAIPChesapeakeSegmentationTask,
|
||||
RegressionTask,
|
||||
RESISC45ClassificationTask,
|
||||
RESISC45DataModule,
|
||||
SEN12MSDataModule,
|
||||
|
@ -38,7 +38,7 @@ TASK_TO_MODULES_MAPPING: Dict[
|
|||
] = {
|
||||
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
|
||||
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
|
||||
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
|
||||
"cyclone": (RegressionTask, CycloneDataModule),
|
||||
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),
|
||||
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
|
||||
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
|
||||
|
|
Загрузка…
Ссылка в новой задаче