Adding RESISC45 trainer with augmentations (#225)

* Removing some keys from the defaults that were removed in Lightning 1.5

* Adding RESISC45

* Experimenting with different LR decay

* Formatting

* Updating RESISC45 test data to be the same size as the original data

* RESISC45 trainer and tests

* Fix import, add deprecation

* mypy fixes

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Caleb Robinson 2021-11-06 22:17:57 -07:00 коммит произвёл GitHub
Родитель b87d2707ae
Коммит e5bbc738a3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
15 изменённых файлов: 119 добавлений и 5 удалений

Просмотреть файл

@ -59,7 +59,6 @@ trainer: # These are the parameters passed to the pytorch lightning Trainer obje
weights_summary: 'top'
weights_save_path: null
num_sanity_val_steps: 2
truncated_bptt_steps: null
resume_from_checkpoint: null
profiler: null
benchmark: False
@ -72,8 +71,6 @@ trainer: # These are the parameters passed to the pytorch lightning Trainer obje
prepare_data_per_node: True
plugins: null
amp_backend: 'native'
amp_level: 'O2'
distributed_backend: null
move_metrics_to_cpu: False
multiple_trainloader_mode: 'max_size_cycle'
stochastic_weight_avg: False

Двоичные данные
tests/data/resisc45/NWPU-RESISC45.rar

Двоичный файл не отображается.

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Просмотреть файл

@ -35,7 +35,7 @@ class TestRESISC45:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.resisc45, "download_url", download_url
)
md5 = "5d898bd91e3ebc64314893ff191b2f9d"
md5 = "5895dea3757ba88707d52f5521c444d3"
monkeypatch.setattr(RESISC45, "md5", md5) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "resisc45", "NWPU-RESISC45.rar")
monkeypatch.setattr(RESISC45, "url", url) # type: ignore[attr-defined]

Просмотреть файл

@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
from typing import Any, Dict
import pytest
from torchgeo.datasets import RESISC45DataModule
from torchgeo.trainers.resisc45 import RESISC45ClassificationTask
class TestRESISC45ClassificationTask:
@pytest.fixture(scope="class")
def datamodule(self) -> RESISC45DataModule:
root = os.path.join("tests", "data", "resisc45")
batch_size = 2
num_workers = 0
dm = RESISC45DataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm
@pytest.fixture()
def config(self) -> Dict[str, Any]:
task_args: Dict[str, Any] = {}
task_args["classification_model"] = "resnet18"
task_args["learning_rate"] = 3e-4
task_args["learning_rate_schedule_patience"] = 6
task_args["in_channels"] = 3
task_args["loss"] = "ce"
task_args["num_classes"] = 45
task_args["weights"] = "random"
return task_args
@pytest.fixture
def task(self, config: Dict[str, Any]) -> RESISC45ClassificationTask:
task = RESISC45ClassificationTask(**config)
return task
def test_training(
self, datamodule: RESISC45DataModule, task: RESISC45ClassificationTask
) -> None:
batch = next(iter(datamodule.train_dataloader()))
task.training_step(batch, 0)
task.training_epoch_end(0)

Просмотреть файл

@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Custom trainer for the RESISC45 dataset."""
from typing import Any, Dict, cast
import kornia.augmentation as K
import torch
from torch import Tensor
from .classification import ClassificationTask
# TODO: move this functionality into ClassificationTask and remove this class
class RESISC45ClassificationTask(ClassificationTask):
"""LightningModule for training on RESISC45 with data augmentation.
.. deprecated:: 0.1
Use :class:`ClassificationTask` instead.
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.
Keyword Args:
classification_model: Name of the classification model use
loss: Name of the loss function
weights: Either "random", "imagenet_only", "imagenet_and_random", or
"random_rgb"
"""
super().__init__(**kwargs)
self.train_augmentations = K.AugmentationSequential(
K.RandomRotation(p=0.5, degrees=90),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
K.RandomErasing(p=0.1),
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
data_keys=["input"],
)
def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step - reports average accuracy and average IoU.
Args:
batch: Current batch
batch_idx: Index of current batch
Returns:
training loss
"""
x = batch["image"]
y = batch["label"]
with torch.no_grad():
x = self.train_augmentations(x)
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
# by default, the train step logs every `log_every_n_steps` steps where
# `log_every_n_steps` is a parameter to the `Trainer` object
self.log("train_loss", loss, on_step=True, on_epoch=False)
self.train_metrics(y_hat_hard, y)
return cast(Tensor, loss)

Просмотреть файл

@ -35,6 +35,7 @@ from torchgeo.trainers import (
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask
from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask
from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask
from torchgeo.trainers.resisc45 import RESISC45ClassificationTask
from torchgeo.trainers.so2sat import So2SatClassificationTask
TASK_TO_MODULES_MAPPING: Dict[
@ -47,7 +48,7 @@ TASK_TO_MODULES_MAPPING: Dict[
"cyclone": (RegressionTask, CycloneDataModule),
"landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule),
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
"resisc45": (ClassificationTask, RESISC45DataModule),
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
"so2sat": (So2SatClassificationTask, So2SatDataModule),
"ucmerced": (ClassificationTask, UCMercedDataModule),