Adding RESISC45 trainer with augmentations (#225)
* Removing some keys from the defaults that were removed in Lightning 1.5 * Adding RESISC45 * Experimenting with different LR decay * Formatting * Updating RESISC45 test data to be the same size as the original data * RESISC45 trainer and tests * Fix import, add deprecation * mypy fixes Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
|
@ -59,7 +59,6 @@ trainer: # These are the parameters passed to the pytorch lightning Trainer obje
|
|||
weights_summary: 'top'
|
||||
weights_save_path: null
|
||||
num_sanity_val_steps: 2
|
||||
truncated_bptt_steps: null
|
||||
resume_from_checkpoint: null
|
||||
profiler: null
|
||||
benchmark: False
|
||||
|
@ -72,8 +71,6 @@ trainer: # These are the parameters passed to the pytorch lightning Trainer obje
|
|||
prepare_data_per_node: True
|
||||
plugins: null
|
||||
amp_backend: 'native'
|
||||
amp_level: 'O2'
|
||||
distributed_backend: null
|
||||
move_metrics_to_cpu: False
|
||||
multiple_trainloader_mode: 'max_size_cycle'
|
||||
stochastic_weight_avg: False
|
||||
|
|
Двоичные данные
tests/data/resisc45/NWPU-RESISC45.rar
Двоичные данные
tests/data/resisc45/NWPU-RESISC45/airplane/airplane_001.jpg
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 39 KiB |
Двоичные данные
tests/data/resisc45/NWPU-RESISC45/airplane/airplane_002.jpg
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 39 KiB |
Двоичные данные
tests/data/resisc45/NWPU-RESISC45/airplane/airplane_003.jpg
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 39 KiB |
Двоичные данные
tests/data/resisc45/NWPU-RESISC45/airport/airport_001.jpg
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 39 KiB |
Двоичные данные
tests/data/resisc45/NWPU-RESISC45/airport/airport_002.jpg
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 39 KiB |
Двоичные данные
tests/data/resisc45/NWPU-RESISC45/airport/airport_003.jpg
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 39 KiB |
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 39 KiB |
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 39 KiB |
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 39 KiB |
|
@ -35,7 +35,7 @@ class TestRESISC45:
|
|||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.resisc45, "download_url", download_url
|
||||
)
|
||||
md5 = "5d898bd91e3ebc64314893ff191b2f9d"
|
||||
md5 = "5895dea3757ba88707d52f5521c444d3"
|
||||
monkeypatch.setattr(RESISC45, "md5", md5) # type: ignore[attr-defined]
|
||||
url = os.path.join("tests", "data", "resisc45", "NWPU-RESISC45.rar")
|
||||
monkeypatch.setattr(RESISC45, "url", url) # type: ignore[attr-defined]
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datasets import RESISC45DataModule
|
||||
from torchgeo.trainers.resisc45 import RESISC45ClassificationTask
|
||||
|
||||
|
||||
class TestRESISC45ClassificationTask:
|
||||
@pytest.fixture(scope="class")
|
||||
def datamodule(self) -> RESISC45DataModule:
|
||||
root = os.path.join("tests", "data", "resisc45")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = RESISC45DataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
||||
@pytest.fixture()
|
||||
def config(self) -> Dict[str, Any]:
|
||||
task_args: Dict[str, Any] = {}
|
||||
task_args["classification_model"] = "resnet18"
|
||||
task_args["learning_rate"] = 3e-4
|
||||
task_args["learning_rate_schedule_patience"] = 6
|
||||
task_args["in_channels"] = 3
|
||||
task_args["loss"] = "ce"
|
||||
task_args["num_classes"] = 45
|
||||
task_args["weights"] = "random"
|
||||
return task_args
|
||||
|
||||
@pytest.fixture
|
||||
def task(self, config: Dict[str, Any]) -> RESISC45ClassificationTask:
|
||||
task = RESISC45ClassificationTask(**config)
|
||||
return task
|
||||
|
||||
def test_training(
|
||||
self, datamodule: RESISC45DataModule, task: RESISC45ClassificationTask
|
||||
) -> None:
|
||||
batch = next(iter(datamodule.train_dataloader()))
|
||||
task.training_step(batch, 0)
|
||||
task.training_epoch_end(0)
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Custom trainer for the RESISC45 dataset."""
|
||||
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .classification import ClassificationTask
|
||||
|
||||
|
||||
# TODO: move this functionality into ClassificationTask and remove this class
|
||||
class RESISC45ClassificationTask(ClassificationTask):
|
||||
"""LightningModule for training on RESISC45 with data augmentation.
|
||||
|
||||
.. deprecated:: 0.1
|
||||
Use :class:`ClassificationTask` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
Keyword Args:
|
||||
classification_model: Name of the classification model use
|
||||
loss: Name of the loss function
|
||||
weights: Either "random", "imagenet_only", "imagenet_and_random", or
|
||||
"random_rgb"
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.train_augmentations = K.AugmentationSequential(
|
||||
K.RandomRotation(p=0.5, degrees=90),
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
K.RandomVerticalFlip(p=0.5),
|
||||
K.RandomSharpness(p=0.5),
|
||||
K.RandomErasing(p=0.1),
|
||||
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
|
||||
data_keys=["input"],
|
||||
)
|
||||
|
||||
def training_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> Tensor:
|
||||
"""Training step - reports average accuracy and average IoU.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
|
||||
Returns:
|
||||
training loss
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
with torch.no_grad():
|
||||
x = self.train_augmentations(x)
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
# by default, the train step logs every `log_every_n_steps` steps where
|
||||
# `log_every_n_steps` is a parameter to the `Trainer` object
|
||||
self.log("train_loss", loss, on_step=True, on_epoch=False)
|
||||
self.train_metrics(y_hat_hard, y)
|
||||
|
||||
return cast(Tensor, loss)
|
3
train.py
|
@ -35,6 +35,7 @@ from torchgeo.trainers import (
|
|||
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask
|
||||
from torchgeo.trainers.landcoverai import LandCoverAISegmentationTask
|
||||
from torchgeo.trainers.naipchesapeake import NAIPChesapeakeSegmentationTask
|
||||
from torchgeo.trainers.resisc45 import RESISC45ClassificationTask
|
||||
from torchgeo.trainers.so2sat import So2SatClassificationTask
|
||||
|
||||
TASK_TO_MODULES_MAPPING: Dict[
|
||||
|
@ -47,7 +48,7 @@ TASK_TO_MODULES_MAPPING: Dict[
|
|||
"cyclone": (RegressionTask, CycloneDataModule),
|
||||
"landcoverai": (LandCoverAISegmentationTask, LandCoverAIDataModule),
|
||||
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
|
||||
"resisc45": (ClassificationTask, RESISC45DataModule),
|
||||
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
|
||||
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
|
||||
"so2sat": (So2SatClassificationTask, So2SatDataModule),
|
||||
"ucmerced": (ClassificationTask, UCMercedDataModule),
|
||||
|
|