зеркало из https://github.com/microsoft/torchgeo.git
Refactor trainer logic (#54)
* Moving task specific configuration logic from train.py into respective classes * Small fixes * Adding basic FCN model for benchmarking * Adding simple FCN model * Removing OrderedDict from model definitions * Adding torchgeo.models to docs * Adding model tests * Making all the formatters happy * Adding optimizer options to landcoverai * Fixing conda environment I think * How do you feel about a Makefile, Adam? * Formatting * Adding some documentation to the readme * Sanity check command in README * Fixes in the landcoverai datamodule to make multi-GPU training possible * Closing figures that we send to Tensorboard * Fix sphinx missing target warning * Fix pytest coverage * Fix flake8 * Update torchgeo/models/__init__.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
14054a15c0
Коммит
a54359dcde
|
@ -0,0 +1,12 @@
|
|||
.PHONY: tests docs
|
||||
|
||||
tests:
|
||||
black --check .
|
||||
isort . --check --diff
|
||||
flake8 .
|
||||
pydocstyle .
|
||||
mypy .
|
||||
pytest --cov=. --cov-report=term-missing
|
||||
|
||||
docs:
|
||||
$(MAKE) -C docs html
|
17
README.md
17
README.md
|
@ -11,8 +11,25 @@ Datasets, transforms, and models for geospatial data.
|
|||
### Conda
|
||||
|
||||
```bash
|
||||
conda config --set channel_priority false
|
||||
conda env create --file environment.yml
|
||||
conda activate torchgeo
|
||||
|
||||
# verify that the PyTorch can use the GPU
|
||||
python -c "import torch; print(torch.cuda.is_available())"
|
||||
```
|
||||
|
||||
## Example training run
|
||||
|
||||
```bash
|
||||
# run the training script with a config file
|
||||
python train.py config_file=conf/landcoverai.yaml
|
||||
```
|
||||
|
||||
## Developing
|
||||
|
||||
```
|
||||
make tests
|
||||
```
|
||||
|
||||
## Datasets
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
task:
|
||||
name: "sen12ms"
|
||||
name: "cyclone"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
model: "resnet18"
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
task:
|
||||
name: "landcoverai"
|
||||
optimizer: "adamw"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
loss: "ce"
|
||||
segmentation_model: "deeplabv3+"
|
||||
encoder_name: "resnet34"
|
||||
encoder_weights: "imagenet"
|
||||
encoder_output_stride: 16
|
||||
|
|
|
@ -2,3 +2,7 @@ task:
|
|||
name: "sen12ms"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
loss: "ce"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: "imagenet"
|
||||
|
|
|
@ -14,6 +14,7 @@ architectures, and common image transformations for geospatial data.
|
|||
:caption: Package Reference
|
||||
|
||||
datasets
|
||||
models
|
||||
samplers
|
||||
trainers
|
||||
transforms
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
torchgeo.models
|
||||
=================
|
||||
|
||||
.. automodule:: torchgeo.models
|
|
@ -1,12 +1,37 @@
|
|||
name: torchgeo
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- anaconda
|
||||
dependencies:
|
||||
- cudatoolkit=11.1
|
||||
- cudatoolkit
|
||||
- h5py
|
||||
- matplotlib
|
||||
- numpy
|
||||
- pip
|
||||
- conda-forge:pycocotools
|
||||
- python>=3.7
|
||||
- pytorch::pytorch>=1.7
|
||||
- conda-forge::rarfile
|
||||
- pytorch::torchvision>=0.3
|
||||
- pycocotools
|
||||
- python
|
||||
- pytorch>=1.8.1
|
||||
- rarfile
|
||||
- rasterio>=1.0
|
||||
- torchvision>=0.9.1
|
||||
- pip:
|
||||
- affine
|
||||
- black[colorama]>=21b
|
||||
- flake8
|
||||
- isort[colors]>=4.3.5
|
||||
- mypy>=0.900
|
||||
- omegaconf
|
||||
- opencv-python
|
||||
- pillow
|
||||
- pydocstyle[toml]>=6.1
|
||||
- pytest>=6.0
|
||||
- pytest-cov
|
||||
- pytorch-lightning
|
||||
- git+https://github.com/pytorch/pytorch_sphinx_theme
|
||||
- radiant-mlhub>=0.2.1
|
||||
- rtree>=0.5.0
|
||||
- scikit-learn
|
||||
- segmentation-models-pytorch
|
||||
- setuptools>=42
|
||||
- sphinx
|
||||
- torchmetrics
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from torchgeo.models import FCN
|
||||
|
||||
|
||||
class TestFCN:
|
||||
def test_in_channels(self) -> None:
|
||||
model = FCN(in_channels=5, classes=4, num_filters=10)
|
||||
x = torch.randn(2, 5, 64, 64)
|
||||
model(x)
|
||||
|
||||
model = FCN(in_channels=3, classes=4, num_filters=10)
|
||||
match = "to have 3 channels, but got 5 channels instead"
|
||||
with pytest.raises(RuntimeError, match=match):
|
||||
model(x)
|
||||
|
||||
def test_classes(self) -> None:
|
||||
model = FCN(in_channels=5, classes=4, num_filters=10)
|
||||
x = torch.randn(2, 5, 64, 64)
|
||||
y = model(x)
|
||||
|
||||
assert y.shape[1] == 4
|
||||
assert model.last.out_channels == 4
|
||||
|
||||
def test_model_size(self) -> None:
|
||||
model = FCN(in_channels=5, classes=4, num_filters=10)
|
||||
|
||||
assert len(model.backbone) == 10
|
||||
|
||||
def test_model_filters(self) -> None:
|
||||
model = FCN(in_channels=5, classes=4, num_filters=10)
|
||||
|
||||
conv_layers = [
|
||||
model.backbone[0],
|
||||
model.backbone[2],
|
||||
model.backbone[4],
|
||||
model.backbone[6],
|
||||
model.backbone[8],
|
||||
]
|
||||
for conv_layer in conv_layers:
|
||||
assert conv_layer.out_channels == 10
|
|
@ -0,0 +1,27 @@
|
|||
from typing import Any, Dict, cast
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision import models
|
||||
|
||||
from torchgeo.trainers import CycloneSimpleRegressionTask
|
||||
|
||||
|
||||
class TestCycloneTrainer:
|
||||
@pytest.fixture
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
task_conf = OmegaConf.load("conf/task_defaults/cyclone.yaml")
|
||||
task_args = OmegaConf.to_object(task_conf.task)
|
||||
task_args = cast(Dict[str, Any], task_args)
|
||||
return task_args
|
||||
|
||||
def test_resnet18(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["model"] = "resnet18"
|
||||
task = CycloneSimpleRegressionTask(**default_config)
|
||||
assert isinstance(task.model, models.resnet.ResNet)
|
||||
|
||||
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["model"] = "invalid_model"
|
||||
error_message = "Model type 'invalid_model' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
CycloneSimpleRegressionTask(**default_config)
|
|
@ -0,0 +1,58 @@
|
|||
from typing import Any, Dict, cast
|
||||
|
||||
import pytest
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch.nn as nn
|
||||
import torch.optim
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import torchgeo.models
|
||||
from torchgeo.trainers import LandcoverAISegmentationTask
|
||||
|
||||
|
||||
class TestLandCoverAITrainer:
|
||||
@pytest.fixture
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
task_conf = OmegaConf.load("conf/task_defaults/landcoverai.yaml")
|
||||
task_args = OmegaConf.to_object(task_conf.task)
|
||||
task_args = cast(Dict[str, Any], task_args)
|
||||
return task_args
|
||||
|
||||
def test_unet_ce_adamw(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["segmentation_model"] = "unet"
|
||||
default_config["loss"] = "ce"
|
||||
default_config["optimizer"] = "adamw"
|
||||
task = LandcoverAISegmentationTask(**default_config)
|
||||
optimizer_dict = task.configure_optimizers()
|
||||
assert isinstance(task.model, smp.Unet)
|
||||
assert isinstance(task.loss, nn.CrossEntropyLoss) # type: ignore[attr-defined]
|
||||
assert isinstance(optimizer_dict["optimizer"], torch.optim.AdamW)
|
||||
|
||||
def test_fcn_jaccard_sgd(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["segmentation_model"] = "fcn"
|
||||
default_config["loss"] = "jaccard"
|
||||
default_config["optimizer"] = "sgd"
|
||||
task = LandcoverAISegmentationTask(**default_config)
|
||||
optimizer_dict = task.configure_optimizers()
|
||||
assert isinstance(task.model, torchgeo.models.FCN)
|
||||
assert isinstance(task.loss, smp.losses.JaccardLoss)
|
||||
assert isinstance(optimizer_dict["optimizer"], torch.optim.SGD)
|
||||
|
||||
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["segmentation_model"] = "invalid_model"
|
||||
error_message = "Model type 'invalid_model' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
LandcoverAISegmentationTask(**default_config)
|
||||
|
||||
def test_invalid_loss(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["loss"] = "invalid_loss"
|
||||
error_message = "Loss type 'invalid_loss' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
LandcoverAISegmentationTask(**default_config)
|
||||
|
||||
def test_invalid_optimizer(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["optimizer"] = "invalid_optimizer"
|
||||
error_message = "Optimizer choice 'invalid_optimizer' is not valid."
|
||||
task = LandcoverAISegmentationTask(**default_config)
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
task.configure_optimizers()
|
|
@ -1,8 +1,48 @@
|
|||
import os
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import pytest
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.trainers import SEN12MSDataModule
|
||||
from torchgeo.trainers import SEN12MSDataModule, SEN12MSSegmentationTask
|
||||
|
||||
|
||||
class TestSEN12MSTrainer:
|
||||
@pytest.fixture
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
task_conf = OmegaConf.load("conf/task_defaults/sen12ms.yaml")
|
||||
task_args = OmegaConf.to_object(task_conf.task)
|
||||
task_args = cast(Dict[str, Any], task_args)
|
||||
return task_args
|
||||
|
||||
def test_unet_ce(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["segmentation_model"] = "unet"
|
||||
default_config["loss"] = "ce"
|
||||
default_config["optimizer"] = "adamw"
|
||||
task = SEN12MSSegmentationTask(**default_config)
|
||||
assert isinstance(task.model, smp.Unet)
|
||||
assert isinstance(task.loss, nn.CrossEntropyLoss) # type: ignore[attr-defined]
|
||||
|
||||
def test_unet_jaccard(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["segmentation_model"] = "unet"
|
||||
default_config["loss"] = "jaccard"
|
||||
task = SEN12MSSegmentationTask(**default_config)
|
||||
assert isinstance(task.model, smp.Unet)
|
||||
assert isinstance(task.loss, smp.losses.JaccardLoss)
|
||||
|
||||
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["segmentation_model"] = "invalid_model"
|
||||
error_message = "Model type 'invalid_model' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
SEN12MSSegmentationTask(**default_config)
|
||||
|
||||
def test_invalid_loss(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["loss"] = "invalid_loss"
|
||||
error_message = "Loss type 'invalid_loss' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
SEN12MSSegmentationTask(**default_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("band_set", ["all", "s1", "s2-all", "s2-reduced"])
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
"""TorchGeo models."""
|
||||
|
||||
from .fcn import FCN
|
||||
|
||||
__all__ = ("FCN",)
|
||||
|
||||
# https://stackoverflow.com/questions/40018681
|
||||
for module in __all__:
|
||||
globals()[module].__module__ = "torchgeo.models"
|
|
@ -0,0 +1,62 @@
|
|||
"""Simple fully convolutional neural network (FCN) implementations."""
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Module
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
Module.__module__ = "torch.nn"
|
||||
|
||||
|
||||
class FCN(Module):
|
||||
"""A simple 5 layer FCN with leaky relus and 'same' padding."""
|
||||
|
||||
def __init__(self, in_channels: int, classes: int, num_filters: int = 64) -> None:
|
||||
"""Initializes the 5 layer FCN model.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels that the model will expect
|
||||
classes: Number of filters in the final layer
|
||||
num_filters: Number of filters in each convolutional layer
|
||||
"""
|
||||
super(FCN, self).__init__() # type: ignore[no-untyped-call]
|
||||
|
||||
conv1 = nn.modules.Conv2d(
|
||||
in_channels, num_filters, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
conv2 = nn.modules.Conv2d(
|
||||
num_filters, num_filters, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
conv3 = nn.modules.Conv2d(
|
||||
num_filters, num_filters, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
conv4 = nn.modules.Conv2d(
|
||||
num_filters, num_filters, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
conv5 = nn.modules.Conv2d(
|
||||
num_filters, num_filters, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.backbone = nn.modules.Sequential(
|
||||
conv1,
|
||||
nn.modules.LeakyReLU(inplace=True),
|
||||
conv2,
|
||||
nn.modules.LeakyReLU(inplace=True),
|
||||
conv3,
|
||||
nn.modules.LeakyReLU(inplace=True),
|
||||
conv4,
|
||||
nn.modules.LeakyReLU(inplace=True),
|
||||
conv5,
|
||||
nn.modules.LeakyReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.last = nn.modules.Conv2d(
|
||||
num_filters, classes, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward pass of the model."""
|
||||
x = self.backbone(x)
|
||||
x = self.last(x)
|
||||
return x
|
|
@ -10,6 +10,7 @@ from torch import Tensor
|
|||
from torch.nn.modules import Module
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torchvision import models
|
||||
|
||||
from ..datasets import TropicalCycloneWindEstimation
|
||||
|
||||
|
@ -25,15 +26,23 @@ class CycloneSimpleRegressionTask(pl.LightningModule):
|
|||
This does not take into account other per-sample features available in this dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Module, **kwargs: Dict[str, Any]) -> None:
|
||||
def config_task(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Configures the task based on kwargs parameters."""
|
||||
if kwargs["model"] == "resnet18":
|
||||
self.model = models.resnet18(pretrained=False, num_classes=1)
|
||||
else:
|
||||
raise ValueError(f"Model type '{kwargs['model']}' is not valid.")
|
||||
|
||||
def __init__(self, **kwargs: Dict[str, Any]) -> None:
|
||||
"""Initialize a new LightningModule for training simple regression models.
|
||||
|
||||
Args:
|
||||
model: A model (specifically, a ``nn.Module``) instance to be trained.
|
||||
Keyword Args:
|
||||
model: Name of the model to use.
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
self.model = model
|
||||
|
||||
self.config_task(kwargs)
|
||||
|
||||
def forward(self, x: Tensor) -> Any: # type: ignore[override]
|
||||
"""Forward pass of the model."""
|
||||
|
@ -100,6 +109,7 @@ class CycloneSimpleRegressionTask(pl.LightningModule):
|
|||
patience=self.hparams["learning_rate_schedule_patience"],
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
"verbose": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
"""Landcover.ai trainer."""
|
||||
|
||||
from typing import Any, Dict, cast
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
@ -15,6 +16,7 @@ from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
|
|||
from torchmetrics import Accuracy, IoU # type: ignore[attr-defined]
|
||||
|
||||
from ..datasets import LandCoverAI
|
||||
from ..models import FCN
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
|
@ -29,23 +31,55 @@ class LandcoverAISegmentationTask(pl.LightningModule):
|
|||
``pytorch_segmentation_models`` package.
|
||||
"""
|
||||
|
||||
def config_task(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Configures the task based on kwargs parameters."""
|
||||
if kwargs["segmentation_model"] == "unet":
|
||||
self.model = smp.Unet(
|
||||
encoder_name=kwargs["encoder_name"],
|
||||
encoder_weights=kwargs["encoder_weights"],
|
||||
in_channels=3,
|
||||
classes=5,
|
||||
)
|
||||
elif kwargs["segmentation_model"] == "deeplabv3+":
|
||||
self.model = smp.DeepLabV3Plus(
|
||||
encoder_name=kwargs["encoder_name"],
|
||||
encoder_weights=kwargs["encoder_weights"],
|
||||
encoder_output_stride=kwargs["encoder_output_stride"],
|
||||
in_channels=3,
|
||||
classes=5,
|
||||
)
|
||||
elif kwargs["segmentation_model"] == "fcn":
|
||||
self.model = FCN(3, 5, 64)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type '{kwargs['segmentation_model']}' is not valid."
|
||||
)
|
||||
|
||||
if kwargs["loss"] == "ce":
|
||||
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
|
||||
elif kwargs["loss"] == "jaccard":
|
||||
self.loss = smp.losses.JaccardLoss(mode="multiclass")
|
||||
else:
|
||||
raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Module,
|
||||
loss: Module = nn.CrossEntropyLoss(), # type: ignore[attr-defined]
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
Args:
|
||||
model: A model (specifically, a ``nn.Module``) instance to be trained.
|
||||
loss: A semantic segmentation loss function to use (e.g. pixel-wise
|
||||
crossentropy)
|
||||
Keyword Args:
|
||||
segmentation_model: Name of the segmentation model type to use
|
||||
encoder_name: Name of the encoder model backbone to use
|
||||
encoder_weights: None or "imagenet" to use imagenet pretrained weights in
|
||||
the encoder model
|
||||
encoder_output_stride: The output stride parameter in DeepLabV3+ models
|
||||
loss: Name of the loss function
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
|
||||
self.config_task(kwargs)
|
||||
|
||||
self.train_accuracy = Accuracy()
|
||||
self.val_accuracy = Accuracy()
|
||||
|
@ -124,6 +158,8 @@ class LandcoverAISegmentationTask(pl.LightningModule):
|
|||
f"image/{batch_idx}", fig, global_step=self.global_step
|
||||
)
|
||||
|
||||
plt.close()
|
||||
|
||||
def validation_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level validation metrics."""
|
||||
self.log("val_acc", self.val_accuracy.compute())
|
||||
|
@ -156,10 +192,24 @@ class LandcoverAISegmentationTask(pl.LightningModule):
|
|||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""Initialize the optimizer and learning rate scheduler."""
|
||||
optimizer = torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.hparams["learning_rate"],
|
||||
)
|
||||
optimizer: torch.optim.optimizer.Optimizer
|
||||
if self.hparams["optimizer"] == "adamw":
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.hparams["learning_rate"],
|
||||
)
|
||||
elif self.hparams["optimizer"] == "sgd":
|
||||
optimizer = torch.optim.SGD(
|
||||
self.model.parameters(),
|
||||
lr=self.hparams["learning_rate"],
|
||||
momentum=0.9,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Optimizer choice '{self.hparams['optimizer']}' is not valid."
|
||||
)
|
||||
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
|
@ -168,6 +218,7 @@ class LandcoverAISegmentationTask(pl.LightningModule):
|
|||
patience=self.hparams["learning_rate_schedule_patience"],
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
"verbose": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -206,7 +257,21 @@ class LandcoverAIDataModule(pl.LightningDataModule):
|
|||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Initialize the main ``Dataset`` objects."""
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
_ = LandCoverAI(
|
||||
self.root_dir,
|
||||
download=True,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
"""
|
||||
self.train_dataset = LandCoverAI(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sklearn.model_selection import GroupShuffleSplit
|
||||
|
@ -27,23 +28,44 @@ class SEN12MSSegmentationTask(pl.LightningModule):
|
|||
``pytorch_segmentation_models`` package.
|
||||
"""
|
||||
|
||||
def config_task(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Configures the task based on kwargs parameters."""
|
||||
if kwargs["segmentation_model"] == "unet":
|
||||
self.model = smp.Unet(
|
||||
encoder_name=kwargs["encoder_name"],
|
||||
encoder_weights=kwargs["encoder_weights"],
|
||||
in_channels=15, # TODO: set number of input channels based on task
|
||||
classes=11,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type '{kwargs['segmentation_model']}' is not valid."
|
||||
)
|
||||
|
||||
if kwargs["loss"] == "ce":
|
||||
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
|
||||
elif kwargs["loss"] == "jaccard":
|
||||
self.loss = smp.losses.JaccardLoss(mode="multiclass")
|
||||
else:
|
||||
raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Module,
|
||||
loss: Module = nn.CrossEntropyLoss(), # type: ignore[attr-defined]
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
Args:
|
||||
model: A model (specifically, a ``nn.Module``) instance to be trained.
|
||||
loss: A semantic segmentation loss function to use (e.g. pixel-wise
|
||||
crossentropy)
|
||||
Keyword Args:
|
||||
segmentation_model: Name of the segmentation model type to use
|
||||
encoder_name: Name of the encoder model backbone to use
|
||||
encoder_weights: None or "imagenet" to use imagenet pretrained weights in
|
||||
the encoder model
|
||||
loss: Name of the loss function
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
|
||||
self.config_task(kwargs)
|
||||
|
||||
self.train_accuracy = Accuracy()
|
||||
self.val_accuracy = Accuracy()
|
||||
|
@ -122,6 +144,7 @@ class SEN12MSSegmentationTask(pl.LightningModule):
|
|||
patience=self.hparams["learning_rate_schedule_patience"],
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
"verbose": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
27
train.py
27
train.py
|
@ -6,13 +6,11 @@ import os
|
|||
from typing import Any, Dict, cast
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn as nn
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning import loggers as pl_loggers
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from torchvision import models
|
||||
|
||||
from torchgeo.trainers import (
|
||||
CycloneDataModule,
|
||||
|
@ -124,41 +122,22 @@ def main(conf: DictConfig) -> None:
|
|||
batch_size=conf.program.batch_size,
|
||||
num_workers=conf.program.num_workers,
|
||||
)
|
||||
model = models.resnet18(pretrained=False, num_classes=1)
|
||||
task = CycloneSimpleRegressionTask(model, **task_args)
|
||||
task = CycloneSimpleRegressionTask(**task_args)
|
||||
elif conf.task.name == "landcoverai":
|
||||
import segmentation_models_pytorch as smp
|
||||
|
||||
datamodule = LandcoverAIDataModule(
|
||||
conf.program.data_dir,
|
||||
batch_size=conf.program.batch_size,
|
||||
num_workers=conf.program.num_workers,
|
||||
)
|
||||
model = smp.Unet(
|
||||
encoder_name="resnet18",
|
||||
encoder_weights=None,
|
||||
in_channels=3,
|
||||
classes=5,
|
||||
)
|
||||
loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
|
||||
task = LandcoverAISegmentationTask(model, loss, **task_args)
|
||||
task = LandcoverAISegmentationTask(**task_args)
|
||||
elif conf.task.name == "sen12ms":
|
||||
import segmentation_models_pytorch as smp
|
||||
|
||||
datamodule = SEN12MSDataModule(
|
||||
conf.program.data_dir,
|
||||
seed=conf.program.seed,
|
||||
batch_size=conf.program.batch_size,
|
||||
num_workers=conf.program.num_workers,
|
||||
)
|
||||
model = smp.Unet(
|
||||
encoder_name="resnet18",
|
||||
encoder_weights=None,
|
||||
in_channels=15,
|
||||
classes=11,
|
||||
)
|
||||
loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
|
||||
task = SEN12MSSegmentationTask(model, loss, **task_args)
|
||||
task = SEN12MSSegmentationTask(**task_args)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"task.name={conf.task.name} is not recognized as a valid task"
|
||||
|
|
Загрузка…
Ссылка в новой задаче