* add RESISC45 trainer

* update working locally

* Adding ability to choose the random split sizes via config

* If you don't have a val or test split, then return the train split so the Trainer doesn't break by default. If you actually want to train without val/test though, then you should set the appropriate Trainer args.

* RESISC experiments

* Reverting accidental changes

* mypy fix

* add dataset_split unit tests

* Document dataset_split

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
isaac 2021-10-11 17:35:38 -05:00 коммит произвёл GitHub
Родитель 84696d2b2c
Коммит 142835cede
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 555 добавлений и 15 удалений

18
conf/resisc45.yaml Normal file
Просмотреть файл

@ -0,0 +1,18 @@
trainer:
gpus: 1 # single GPU training
min_epochs: 10
max_epochs: 40
benchmark: True
experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
batch_size: 128
num_workers: 6
val_split_pct: 0.2
test_split_pct: 0.2

Просмотреть файл

@ -0,0 +1,14 @@
experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
datamodule:
batch_size: 128
num_workers: 6
weights: ${experiment.module.weights}
val_split_pct: 0.2
test_split_pct: 0.2

Просмотреть файл

@ -15,11 +15,10 @@ DATA_DIR = "" # path to the LandcoverAI data directory
# Hyperparameter options
model_options = ["unet"]
encoder_options = ["resnet50"]
lr_options = [1e-4]
loss_options = ["ce"]
weight_init_options = ["imagenet"]
seeds = list(range(15))
encoder_options = ["resnet18", "resnet50"]
lr_options = [1e-2, 1e-3, 1e-4]
loss_options = ["ce", "jaccard"]
weight_init_options = ["null", "imagenet"]
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
@ -36,18 +35,13 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
if __name__ == "__main__":
work: "Queue[str]" = Queue()
for (model, encoder, lr, loss, weight_init, seed) in itertools.product(
model_options,
encoder_options,
lr_options,
loss_options,
weight_init_options,
seeds,
for (model, encoder, lr, loss, weight_init) in itertools.product(
model_options, encoder_options, lr_options, loss_options, weight_init_options
):
experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}_{seed}"
experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}"
output_dir = os.path.join("output", "landcoverai_seed_experiments")
output_dir = os.path.join("output", "landcoverai_experiments")
log_dir = os.path.join(output_dir, "logs")
config_file = os.path.join("conf", "landcoverai.yaml")
@ -63,7 +57,6 @@ if __name__ == "__main__":
+ f" experiment.module.encoder_name={encoder}"
+ f" experiment.module.encoder_weights={weight_init}"
+ f" program.output_dir={output_dir}"
+ f" program.seed={seed}"
+ f" program.log_dir={log_dir}"
+ f" program.data_dir={DATA_DIR}"
+ " trainer.gpus=[GPU]"

Просмотреть файл

@ -0,0 +1,82 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Runs the train script with a grid of hyperparameters."""
import itertools
import os
import subprocess
from multiprocessing import Process, Queue
# list of GPU IDs that we want to use, one job will be started for every ID in the list
GPUS = [0]
DRY_RUN = False # if False then print out the commands to be run, if True then run
DATA_DIR = "" # path to the RESISC45 data directory
# Hyperparameter options
model_options = ["resnet18", "resnet50"]
lr_options = [1e-2, 1e-3, 1e-4]
loss_options = ["ce"]
weight_options = ["imagenet_only", "random"]
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"""Process for each ID in GPUS."""
while not work.empty():
experiment = work.get()
experiment = experiment.replace("GPU", str(gpu_idx))
print(experiment)
if not DRY_RUN:
subprocess.call(experiment.split(" "))
return True
if __name__ == "__main__":
work: "Queue[str]" = Queue()
for (model, lr, loss, weights) in itertools.product(
model_options,
lr_options,
loss_options,
weight_options,
):
experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_','-')}"
output_dir = os.path.join("output", "resisc45_experiments")
log_dir = os.path.join(output_dir, "logs")
config_file = os.path.join("conf", "resisc45.yaml")
if not os.path.exists(os.path.join(output_dir, experiment_name)):
command = (
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
+ f" experiment.datamodule.weights={weights}"
+ f" program.output_dir={output_dir}"
+ f" program.log_dir={log_dir}"
+ f" program.data_dir={DATA_DIR}"
+ " trainer.gpus=[GPU]"
)
command = command.strip()
work.put(command)
processes = []
for gpu_idx in GPUS:
p = Process(
target=do_work,
args=(
work,
gpu_idx,
),
)
processes.append(p)
p.start()
for p in processes:
p.join()

Просмотреть файл

@ -16,11 +16,13 @@ import pytest
import torch
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torch.utils.data import TensorDataset
import torchgeo.datasets.utils
from torchgeo.datasets.utils import (
BoundingBox,
collate_dict,
dataset_split,
disambiguate_timestamp,
download_and_extract_archive,
download_radiant_mlhub_collection,
@ -335,3 +337,21 @@ def test_nonexisting_directory(tmp_path: Path) -> None:
with working_dir(str(subdir), create=True):
assert subdir.cwd() == subdir
def test_dataset_split() -> None:
num_samples = 24
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
ds = TensorDataset(x, y)
# Test only train/val set split
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
assert len(train_ds) == num_samples // 2
assert len(val_ds) == num_samples // 2
# Test train/val/test set split
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
assert len(train_ds) == num_samples // 3
assert len(val_ds) == num_samples // 3
assert len(test_ds) == num_samples // 3

Просмотреть файл

@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from typing import Any, Dict, cast
import pytest
import torch.nn as nn
import torchvision
from omegaconf import OmegaConf
from torchgeo.trainers import RESISC45ClassificationTask
class TestRESISC45Trainer:
@pytest.fixture
def default_config(self) -> Dict[str, Any]:
task_conf = OmegaConf.load("conf/task_defaults/resisc45.yaml")
task_args = OmegaConf.to_object(task_conf.experiment.module)
task_args = cast(Dict[str, Any], task_args)
return task_args
def test_resnet_ce(self, default_config: Dict[str, Any]) -> None:
default_config["classification_model"] = "resnet18"
default_config["loss"] = "ce"
task = RESISC45ClassificationTask(**default_config)
assert isinstance(task.model, torchvision.models.ResNet)
assert isinstance(task.loss, nn.CrossEntropyLoss) # type: ignore[attr-defined]
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
default_config["classification_model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
RESISC45ClassificationTask(**default_config)
def test_invalid_loss(self, default_config: Dict[str, Any]) -> None:
default_config["loss"] = "invalid_loss"
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
RESISC45ClassificationTask(**default_config)

Просмотреть файл

@ -18,6 +18,7 @@ import numpy as np
import rasterio
import torch
from torch import Tensor
from torch.utils.data import Dataset, Subset, random_split
from torchvision.datasets.utils import check_integrity, download_url
__all__ = (
@ -30,6 +31,7 @@ __all__ = (
"working_dir",
"collate_dict",
"rasterio_loader",
"dataset_split",
)
@ -394,3 +396,28 @@ def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg]
# VisionClassificationDataset expects images returned with channels last (HWC)
array = array.transpose(1, 2, 0)
return array
def dataset_split(
dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None
) -> List[Subset[Any]]:
"""Split a torch Dataset into train/val/test sets.
If ``test_pct`` is not set then only train and validation splits are returned.
Args:
dataset: dataset to be split into train/val or train/val/test subsets
val_pct: percentage of samples to be in validation set
test_pct: (Optional) percentage of samples to be in test set
Returns:
a list of the subset datasets. Either [train, val] or [train, val, test]
"""
if test_pct is None:
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
train_length = len(dataset) - val_length # type: ignore[arg-type]
return random_split(dataset, [train_length, val_length])
else:
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
test_length = int(len(dataset) * test_pct) # type: ignore[arg-type]
train_length = len(dataset) - (val_length + test_length) # type: ignore[arg-type] # noqa: E501
return random_split(dataset, [train_length, val_length, test_length])

Просмотреть файл

@ -8,6 +8,7 @@ from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
from .so2sat import So2SatClassificationTask, So2SatDataModule
@ -21,6 +22,8 @@ __all__ = (
"LandcoverAISegmentationTask",
"NAIPChesapeakeDataModule",
"NAIPChesapeakeSegmentationTask",
"RESISC45ClassificationTask",
"RESISC45DataModule",
"SEN12MSDataModule",
"SEN12MSSegmentationTask",
"So2SatDataModule",

Просмотреть файл

@ -0,0 +1,341 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""RESISC45 trainer."""
from typing import Any, Dict, Optional, cast
import kornia.augmentation as K
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision.models
from torch import Tensor
from torch.nn.modules import Conv2d, Linear, Module
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, IoU, MetricCollection
from torchvision.transforms import Compose, Normalize
from ..datasets import RESISC45
from ..datasets.utils import dataset_split
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
Module.__module__ = "torch.nn"
Conv2d.__module__ = "nn.Conv2d"
Linear.__module__ = "nn.Linear"
IN_CHANNELS = 3
NUM_CLASSES = 45
class RESISC45ClassificationTask(pl.LightningModule):
"""LightningModule for training models on the RESISC45 Dataset."""
def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
pretrained = "imagenet" in self.hparams["weights"]
if "resnet" in self.hparams["classification_model"]:
self.model = getattr(
torchvision.models.resnet, self.hparams["classification_model"]
)(pretrained=pretrained)
in_features = self.model.fc.in_features
self.model.fc = Linear(in_features, out_features=NUM_CLASSES)
else:
raise ValueError(
f"Model type '{self.hparams['classification_model']}' is not valid."
)
if "resnet" in self.hparams["classification_model"]:
if self.hparams["weights"] in ["imagenet_only", "random"]:
pass
else:
raise ValueError(
f"Weight type '{self.hparams['weights']}' is not valid."
)
else:
pass # stub for initializing the weights of other models
if self.hparams["loss"] == "ce":
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
else:
raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.")
def __init__(
self,
**kwargs: Any,
) -> None:
"""Initialize the LightningModule with a model and loss function.
Keyword Args:
classification_model: Name of the classification model use
loss: Name of the loss function
weights: Either "random", "imagenet_only", "imagenet_and_random", or
"random_rgb"
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.config_task()
self.train_metrics = MetricCollection(
{
"OverallAccuracy": Accuracy(num_classes=NUM_CLASSES, average="micro"),
"AverageAccuracy": Accuracy(num_classes=NUM_CLASSES, average="macro"),
"IoU": IoU(num_classes=NUM_CLASSES),
},
prefix="train_",
)
self.val_metrics = self.train_metrics.clone(prefix="val_")
self.test_metrics = self.train_metrics.clone(prefix="test_")
def forward(self, x: Tensor) -> Any: # type: ignore[override]
"""Forward pass of the model."""
return self.model(x)
def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step - reports average accuracy and average IoU.
Args:
batch: Current batch
batch_idx: Index of current batch
Returns:
training loss
"""
x = batch["image"]
y = batch["label"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
# by default, the train step logs every `log_every_n_steps` steps where
# `log_every_n_steps` is a parameter to the `Trainer` object
self.log("train_loss", loss, on_step=True, on_epoch=False)
self.train_metrics(y_hat_hard, y)
return cast(Tensor, loss)
def training_epoch_end(self, outputs: Any) -> None:
"""Logs epoch-level training metrics.
Args:
outputs: list of items returned by training_step
"""
self.log_dict(self.train_metrics.compute())
self.train_metrics.reset()
def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Validation step - reports average accuracy and average IoU.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["label"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
self.log("val_loss", loss, on_step=False, on_epoch=True)
self.val_metrics(y_hat_hard, y)
def validation_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level validation metrics.
Args:
outputs: list of items returned by validation_step
"""
self.log_dict(self.val_metrics.compute())
self.val_metrics.reset()
def test_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Test step identical to the validation step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["label"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
# by default, the test and validation steps only log per *epoch*
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.test_metrics(y_hat_hard, y)
def test_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level test metrics.
Args:
outputs: list of items returned by test_step
"""
self.log_dict(self.test_metrics.compute())
self.test_metrics.reset()
def configure_optimizers(self) -> Dict[str, Any]:
"""Initialize the optimizer and learning rate scheduler.
Returns:
a "lr dict" according to the pytorch lightning documentation --
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
"""
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.hparams["learning_rate"],
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(
optimizer,
patience=self.hparams["learning_rate_schedule_patience"],
),
"monitor": "val_loss",
},
}
class RESISC45DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the RESISC45 dataset.
Uses the train/val/test splits from the dataset.
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[0.36801773, 0.38097873, 0.343583]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[0.14540215, 0.13558227, 0.13203649]
)
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
weights: str = "random",
unsupervised_mode: bool = False,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
weights: Either "random", "imagenet_only", "imagenet_and_random", or
"random_rgb"
unsupervised_mode: Makes the train dataloader return imagery from the train,
val, and test sets
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.weights = weights
self.unsupervised_mode = unsupervised_mode
self.val_split_pct = kwargs["val_split_pct"]
self.test_split_pct = kwargs["test_split_pct"]
self.norm = Normalize(self.band_means, self.band_stds)
self.transforms = K.AugmentationSequential(
K.RandomAffine(degrees=30),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(),
data_keys=["input"],
)
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = self.norm(sample["image"])
return sample
def kornia_pipeline(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset with Kornia."""
sample["image"] = self.transforms(sample["image"]).squeeze()
return sample
def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
RESISC45(self.root_dir, checksum=False)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
"""
transforms = Compose([self.preprocess])
if not self.unsupervised_mode:
dataset = RESISC45(
self.root_dir,
transforms=transforms,
)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
else:
self.train_dataset = RESISC45(
self.root_dir,
transforms=transforms,
)
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
if self.val_dataset is None or len(self.val_dataset) == 0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
if self.test_dataset is None or len(self.test_dataset) == 0:
return self.train_dataloader()
else:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)

Просмотреть файл

@ -23,6 +23,8 @@ from torchgeo.trainers import (
LandcoverAISegmentationTask,
NAIPChesapeakeDataModule,
NAIPChesapeakeSegmentationTask,
RESISC45ClassificationTask,
RESISC45DataModule,
SEN12MSDataModule,
SEN12MSSegmentationTask,
So2SatClassificationTask,
@ -37,6 +39,7 @@ TASK_TO_MODULES_MAPPING: Dict[
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
"sen12ms": (SEN12MSSegmentationTask, SEN12MSDataModule),
"so2sat": (So2SatClassificationTask, So2SatDataModule),
}