* Moving task specific configuration logic from train.py into respective classes

* Small fixes

* Adding basic FCN model for benchmarking

* Adding simple FCN model

* Removing OrderedDict from model definitions

* Adding torchgeo.models to docs

* Adding model tests

* Making all the formatters happy

* Adding optimizer options to landcoverai

* Fixing conda environment I think

* How do you feel about a Makefile, Adam?

* Formatting

* Adding some documentation to the readme

* Sanity check command in README

* Fixes in the landcoverai datamodule to make multi-GPU training possible

* Closing figures that we send to Tensorboard

* Fix sphinx missing target warning

* Fix pytest coverage

* Fix flake8

* Update torchgeo/models/__init__.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Caleb Robinson 2021-07-30 18:58:49 -07:00 коммит произвёл GitHub
Родитель 14054a15c0
Коммит a54359dcde
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
19 изменённых файлов: 444 добавлений и 59 удалений

12
Makefile Normal file
Просмотреть файл

@ -0,0 +1,12 @@
.PHONY: tests docs
tests:
black --check .
isort . --check --diff
flake8 .
pydocstyle .
mypy .
pytest --cov=. --cov-report=term-missing
docs:
$(MAKE) -C docs html

Просмотреть файл

@ -11,8 +11,25 @@ Datasets, transforms, and models for geospatial data.
### Conda
```bash
conda config --set channel_priority false
conda env create --file environment.yml
conda activate torchgeo
# verify that the PyTorch can use the GPU
python -c "import torch; print(torch.cuda.is_available())"
```
## Example training run
```bash
# run the training script with a config file
python train.py config_file=conf/landcoverai.yaml
```
## Developing
```
make tests
```
## Datasets

Просмотреть файл

@ -1,4 +1,5 @@
task:
name: "sen12ms"
name: "cyclone"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
model: "resnet18"

Просмотреть файл

@ -1,4 +1,10 @@
task:
name: "landcoverai"
optimizer: "adamw"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "ce"
segmentation_model: "deeplabv3+"
encoder_name: "resnet34"
encoder_weights: "imagenet"
encoder_output_stride: 16

Просмотреть файл

@ -2,3 +2,7 @@ task:
name: "sen12ms"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"

Просмотреть файл

@ -14,6 +14,7 @@ architectures, and common image transformations for geospatial data.
:caption: Package Reference
datasets
models
samplers
trainers
transforms

4
docs/models.rst Normal file
Просмотреть файл

@ -0,0 +1,4 @@
torchgeo.models
=================
.. automodule:: torchgeo.models

Просмотреть файл

@ -1,12 +1,37 @@
name: torchgeo
channels:
- pytorch
- conda-forge
- anaconda
dependencies:
- cudatoolkit=11.1
- cudatoolkit
- h5py
- matplotlib
- numpy
- pip
- conda-forge:pycocotools
- python>=3.7
- pytorch::pytorch>=1.7
- conda-forge::rarfile
- pytorch::torchvision>=0.3
- pycocotools
- python
- pytorch>=1.8.1
- rarfile
- rasterio>=1.0
- torchvision>=0.9.1
- pip:
- affine
- black[colorama]>=21b
- flake8
- isort[colors]>=4.3.5
- mypy>=0.900
- omegaconf
- opencv-python
- pillow
- pydocstyle[toml]>=6.1
- pytest>=6.0
- pytest-cov
- pytorch-lightning
- git+https://github.com/pytorch/pytorch_sphinx_theme
- radiant-mlhub>=0.2.1
- rtree>=0.5.0
- scikit-learn
- segmentation-models-pytorch
- setuptools>=42
- sphinx
- torchmetrics

0
tests/models/__init__.py Normal file
Просмотреть файл

42
tests/models/test_fcn.py Normal file
Просмотреть файл

@ -0,0 +1,42 @@
import pytest
import torch
from torchgeo.models import FCN
class TestFCN:
def test_in_channels(self) -> None:
model = FCN(in_channels=5, classes=4, num_filters=10)
x = torch.randn(2, 5, 64, 64)
model(x)
model = FCN(in_channels=3, classes=4, num_filters=10)
match = "to have 3 channels, but got 5 channels instead"
with pytest.raises(RuntimeError, match=match):
model(x)
def test_classes(self) -> None:
model = FCN(in_channels=5, classes=4, num_filters=10)
x = torch.randn(2, 5, 64, 64)
y = model(x)
assert y.shape[1] == 4
assert model.last.out_channels == 4
def test_model_size(self) -> None:
model = FCN(in_channels=5, classes=4, num_filters=10)
assert len(model.backbone) == 10
def test_model_filters(self) -> None:
model = FCN(in_channels=5, classes=4, num_filters=10)
conv_layers = [
model.backbone[0],
model.backbone[2],
model.backbone[4],
model.backbone[6],
model.backbone[8],
]
for conv_layer in conv_layers:
assert conv_layer.out_channels == 10

Просмотреть файл

@ -0,0 +1,27 @@
from typing import Any, Dict, cast
import pytest
from omegaconf import OmegaConf
from torchvision import models
from torchgeo.trainers import CycloneSimpleRegressionTask
class TestCycloneTrainer:
@pytest.fixture
def default_config(self) -> Dict[str, Any]:
task_conf = OmegaConf.load("conf/task_defaults/cyclone.yaml")
task_args = OmegaConf.to_object(task_conf.task)
task_args = cast(Dict[str, Any], task_args)
return task_args
def test_resnet18(self, default_config: Dict[str, Any]) -> None:
default_config["model"] = "resnet18"
task = CycloneSimpleRegressionTask(**default_config)
assert isinstance(task.model, models.resnet.ResNet)
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
default_config["model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
CycloneSimpleRegressionTask(**default_config)

Просмотреть файл

@ -0,0 +1,58 @@
from typing import Any, Dict, cast
import pytest
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.optim
from omegaconf import OmegaConf
import torchgeo.models
from torchgeo.trainers import LandcoverAISegmentationTask
class TestLandCoverAITrainer:
@pytest.fixture
def default_config(self) -> Dict[str, Any]:
task_conf = OmegaConf.load("conf/task_defaults/landcoverai.yaml")
task_args = OmegaConf.to_object(task_conf.task)
task_args = cast(Dict[str, Any], task_args)
return task_args
def test_unet_ce_adamw(self, default_config: Dict[str, Any]) -> None:
default_config["segmentation_model"] = "unet"
default_config["loss"] = "ce"
default_config["optimizer"] = "adamw"
task = LandcoverAISegmentationTask(**default_config)
optimizer_dict = task.configure_optimizers()
assert isinstance(task.model, smp.Unet)
assert isinstance(task.loss, nn.CrossEntropyLoss) # type: ignore[attr-defined]
assert isinstance(optimizer_dict["optimizer"], torch.optim.AdamW)
def test_fcn_jaccard_sgd(self, default_config: Dict[str, Any]) -> None:
default_config["segmentation_model"] = "fcn"
default_config["loss"] = "jaccard"
default_config["optimizer"] = "sgd"
task = LandcoverAISegmentationTask(**default_config)
optimizer_dict = task.configure_optimizers()
assert isinstance(task.model, torchgeo.models.FCN)
assert isinstance(task.loss, smp.losses.JaccardLoss)
assert isinstance(optimizer_dict["optimizer"], torch.optim.SGD)
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
default_config["segmentation_model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
LandcoverAISegmentationTask(**default_config)
def test_invalid_loss(self, default_config: Dict[str, Any]) -> None:
default_config["loss"] = "invalid_loss"
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
LandcoverAISegmentationTask(**default_config)
def test_invalid_optimizer(self, default_config: Dict[str, Any]) -> None:
default_config["optimizer"] = "invalid_optimizer"
error_message = "Optimizer choice 'invalid_optimizer' is not valid."
task = LandcoverAISegmentationTask(**default_config)
with pytest.raises(ValueError, match=error_message):
task.configure_optimizers()

Просмотреть файл

@ -1,8 +1,48 @@
import os
from typing import Any, Dict, cast
import pytest
import segmentation_models_pytorch as smp
import torch.nn as nn
from omegaconf import OmegaConf
from torchgeo.trainers import SEN12MSDataModule
from torchgeo.trainers import SEN12MSDataModule, SEN12MSSegmentationTask
class TestSEN12MSTrainer:
@pytest.fixture
def default_config(self) -> Dict[str, Any]:
task_conf = OmegaConf.load("conf/task_defaults/sen12ms.yaml")
task_args = OmegaConf.to_object(task_conf.task)
task_args = cast(Dict[str, Any], task_args)
return task_args
def test_unet_ce(self, default_config: Dict[str, Any]) -> None:
default_config["segmentation_model"] = "unet"
default_config["loss"] = "ce"
default_config["optimizer"] = "adamw"
task = SEN12MSSegmentationTask(**default_config)
assert isinstance(task.model, smp.Unet)
assert isinstance(task.loss, nn.CrossEntropyLoss) # type: ignore[attr-defined]
def test_unet_jaccard(self, default_config: Dict[str, Any]) -> None:
default_config["segmentation_model"] = "unet"
default_config["loss"] = "jaccard"
task = SEN12MSSegmentationTask(**default_config)
assert isinstance(task.model, smp.Unet)
assert isinstance(task.loss, smp.losses.JaccardLoss)
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
default_config["segmentation_model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
SEN12MSSegmentationTask(**default_config)
def test_invalid_loss(self, default_config: Dict[str, Any]) -> None:
default_config["loss"] = "invalid_loss"
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
SEN12MSSegmentationTask(**default_config)
@pytest.mark.parametrize("band_set", ["all", "s1", "s2-all", "s2-reduced"])

Просмотреть файл

@ -0,0 +1,9 @@
"""TorchGeo models."""
from .fcn import FCN
__all__ = ("FCN",)
# https://stackoverflow.com/questions/40018681
for module in __all__:
globals()[module].__module__ = "torchgeo.models"

62
torchgeo/models/fcn.py Normal file
Просмотреть файл

@ -0,0 +1,62 @@
"""Simple fully convolutional neural network (FCN) implementations."""
import torch.nn as nn
from torch import Tensor
from torch.nn.modules import Module
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "torch.nn"
class FCN(Module):
"""A simple 5 layer FCN with leaky relus and 'same' padding."""
def __init__(self, in_channels: int, classes: int, num_filters: int = 64) -> None:
"""Initializes the 5 layer FCN model.
Args:
in_channels: Number of input channels that the model will expect
classes: Number of filters in the final layer
num_filters: Number of filters in each convolutional layer
"""
super(FCN, self).__init__() # type: ignore[no-untyped-call]
conv1 = nn.modules.Conv2d(
in_channels, num_filters, kernel_size=3, stride=1, padding=1
)
conv2 = nn.modules.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
conv3 = nn.modules.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
conv4 = nn.modules.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
conv5 = nn.modules.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.backbone = nn.modules.Sequential(
conv1,
nn.modules.LeakyReLU(inplace=True),
conv2,
nn.modules.LeakyReLU(inplace=True),
conv3,
nn.modules.LeakyReLU(inplace=True),
conv4,
nn.modules.LeakyReLU(inplace=True),
conv5,
nn.modules.LeakyReLU(inplace=True),
)
self.last = nn.modules.Conv2d(
num_filters, classes, kernel_size=1, stride=1, padding=0
)
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the model."""
x = self.backbone(x)
x = self.last(x)
return x

Просмотреть файл

@ -10,6 +10,7 @@ from torch import Tensor
from torch.nn.modules import Module
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Subset
from torchvision import models
from ..datasets import TropicalCycloneWindEstimation
@ -25,15 +26,23 @@ class CycloneSimpleRegressionTask(pl.LightningModule):
This does not take into account other per-sample features available in this dataset.
"""
def __init__(self, model: Module, **kwargs: Dict[str, Any]) -> None:
def config_task(self, kwargs: Dict[str, Any]) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["model"] == "resnet18":
self.model = models.resnet18(pretrained=False, num_classes=1)
else:
raise ValueError(f"Model type '{kwargs['model']}' is not valid.")
def __init__(self, **kwargs: Dict[str, Any]) -> None:
"""Initialize a new LightningModule for training simple regression models.
Args:
model: A model (specifically, a ``nn.Module``) instance to be trained.
Keyword Args:
model: Name of the model to use.
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.model = model
self.config_task(kwargs)
def forward(self, x: Tensor) -> Any: # type: ignore[override]
"""Forward pass of the model."""
@ -100,6 +109,7 @@ class CycloneSimpleRegressionTask(pl.LightningModule):
patience=self.hparams["learning_rate_schedule_patience"],
),
"monitor": "val_loss",
"verbose": True,
},
}

Просмотреть файл

@ -1,10 +1,11 @@
"""Landcover.ai trainer."""
from typing import Any, Dict, cast
from typing import Any, Dict, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from torch import Tensor
@ -15,6 +16,7 @@ from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
from torchmetrics import Accuracy, IoU # type: ignore[attr-defined]
from ..datasets import LandCoverAI
from ..models import FCN
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
@ -29,23 +31,55 @@ class LandcoverAISegmentationTask(pl.LightningModule):
``pytorch_segmentation_models`` package.
"""
def config_task(self, kwargs: Dict[str, Any]) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["segmentation_model"] == "unet":
self.model = smp.Unet(
encoder_name=kwargs["encoder_name"],
encoder_weights=kwargs["encoder_weights"],
in_channels=3,
classes=5,
)
elif kwargs["segmentation_model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=kwargs["encoder_name"],
encoder_weights=kwargs["encoder_weights"],
encoder_output_stride=kwargs["encoder_output_stride"],
in_channels=3,
classes=5,
)
elif kwargs["segmentation_model"] == "fcn":
self.model = FCN(3, 5, 64)
else:
raise ValueError(
f"Model type '{kwargs['segmentation_model']}' is not valid."
)
if kwargs["loss"] == "ce":
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
elif kwargs["loss"] == "jaccard":
self.loss = smp.losses.JaccardLoss(mode="multiclass")
else:
raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.")
def __init__(
self,
model: Module,
loss: Module = nn.CrossEntropyLoss(), # type: ignore[attr-defined]
**kwargs: Dict[str, Any],
) -> None:
"""Initialize the LightningModule with a model and loss function.
Args:
model: A model (specifically, a ``nn.Module``) instance to be trained.
loss: A semantic segmentation loss function to use (e.g. pixel-wise
crossentropy)
Keyword Args:
segmentation_model: Name of the segmentation model type to use
encoder_name: Name of the encoder model backbone to use
encoder_weights: None or "imagenet" to use imagenet pretrained weights in
the encoder model
encoder_output_stride: The output stride parameter in DeepLabV3+ models
loss: Name of the loss function
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.model = model
self.loss = loss
self.config_task(kwargs)
self.train_accuracy = Accuracy()
self.val_accuracy = Accuracy()
@ -124,6 +158,8 @@ class LandcoverAISegmentationTask(pl.LightningModule):
f"image/{batch_idx}", fig, global_step=self.global_step
)
plt.close()
def validation_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level validation metrics."""
self.log("val_acc", self.val_accuracy.compute())
@ -156,10 +192,24 @@ class LandcoverAISegmentationTask(pl.LightningModule):
def configure_optimizers(self) -> Dict[str, Any]:
"""Initialize the optimizer and learning rate scheduler."""
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.hparams["learning_rate"],
)
optimizer: torch.optim.optimizer.Optimizer
if self.hparams["optimizer"] == "adamw":
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.hparams["learning_rate"],
)
elif self.hparams["optimizer"] == "sgd":
optimizer = torch.optim.SGD(
self.model.parameters(),
lr=self.hparams["learning_rate"],
momentum=0.9,
weight_decay=1e-2,
)
else:
raise ValueError(
f"Optimizer choice '{self.hparams['optimizer']}' is not valid."
)
return {
"optimizer": optimizer,
"lr_scheduler": {
@ -168,6 +218,7 @@ class LandcoverAISegmentationTask(pl.LightningModule):
patience=self.hparams["learning_rate_schedule_patience"],
),
"monitor": "val_loss",
"verbose": True,
},
}
@ -206,7 +257,21 @@ class LandcoverAIDataModule(pl.LightningDataModule):
return sample
def prepare_data(self) -> None:
"""Initialize the main ``Dataset`` objects."""
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
_ = LandCoverAI(
self.root_dir,
download=True,
checksum=False,
)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
"""
self.train_dataset = LandCoverAI(
self.root_dir,
split="train",

Просмотреть файл

@ -3,6 +3,7 @@
from typing import Any, Dict, Optional, cast
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from sklearn.model_selection import GroupShuffleSplit
@ -27,23 +28,44 @@ class SEN12MSSegmentationTask(pl.LightningModule):
``pytorch_segmentation_models`` package.
"""
def config_task(self, kwargs: Dict[str, Any]) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["segmentation_model"] == "unet":
self.model = smp.Unet(
encoder_name=kwargs["encoder_name"],
encoder_weights=kwargs["encoder_weights"],
in_channels=15, # TODO: set number of input channels based on task
classes=11,
)
else:
raise ValueError(
f"Model type '{kwargs['segmentation_model']}' is not valid."
)
if kwargs["loss"] == "ce":
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
elif kwargs["loss"] == "jaccard":
self.loss = smp.losses.JaccardLoss(mode="multiclass")
else:
raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.")
def __init__(
self,
model: Module,
loss: Module = nn.CrossEntropyLoss(), # type: ignore[attr-defined]
**kwargs: Dict[str, Any],
) -> None:
"""Initialize the LightningModule with a model and loss function.
Args:
model: A model (specifically, a ``nn.Module``) instance to be trained.
loss: A semantic segmentation loss function to use (e.g. pixel-wise
crossentropy)
Keyword Args:
segmentation_model: Name of the segmentation model type to use
encoder_name: Name of the encoder model backbone to use
encoder_weights: None or "imagenet" to use imagenet pretrained weights in
the encoder model
loss: Name of the loss function
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.model = model
self.loss = loss
self.config_task(kwargs)
self.train_accuracy = Accuracy()
self.val_accuracy = Accuracy()
@ -122,6 +144,7 @@ class SEN12MSSegmentationTask(pl.LightningModule):
patience=self.hparams["learning_rate_schedule_patience"],
),
"monitor": "val_loss",
"verbose": True,
},
}

Просмотреть файл

@ -6,13 +6,11 @@ import os
from typing import Any, Dict, cast
import pytorch_lightning as pl
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from torchvision import models
from torchgeo.trainers import (
CycloneDataModule,
@ -124,41 +122,22 @@ def main(conf: DictConfig) -> None:
batch_size=conf.program.batch_size,
num_workers=conf.program.num_workers,
)
model = models.resnet18(pretrained=False, num_classes=1)
task = CycloneSimpleRegressionTask(model, **task_args)
task = CycloneSimpleRegressionTask(**task_args)
elif conf.task.name == "landcoverai":
import segmentation_models_pytorch as smp
datamodule = LandcoverAIDataModule(
conf.program.data_dir,
batch_size=conf.program.batch_size,
num_workers=conf.program.num_workers,
)
model = smp.Unet(
encoder_name="resnet18",
encoder_weights=None,
in_channels=3,
classes=5,
)
loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
task = LandcoverAISegmentationTask(model, loss, **task_args)
task = LandcoverAISegmentationTask(**task_args)
elif conf.task.name == "sen12ms":
import segmentation_models_pytorch as smp
datamodule = SEN12MSDataModule(
conf.program.data_dir,
seed=conf.program.seed,
batch_size=conf.program.batch_size,
num_workers=conf.program.num_workers,
)
model = smp.Unet(
encoder_name="resnet18",
encoder_weights=None,
in_channels=15,
classes=11,
)
loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
task = SEN12MSSegmentationTask(model, loss, **task_args)
task = SEN12MSSegmentationTask(**task_args)
else:
raise ValueError(
f"task.name={conf.task.name} is not recognized as a valid task"