зеркало из https://github.com/microsoft/torchgeo.git
Add RESISC45 Trainer (#179)
* add RESISC45 trainer * update working locally * Adding ability to choose the random split sizes via config * If you don't have a val or test split, then return the train split so the Trainer doesn't break by default. If you actually want to train without val/test though, then you should set the appropriate Trainer args. * RESISC experiments * Reverting accidental changes * mypy fix * add dataset_split unit tests * Document dataset_split Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Родитель
84696d2b2c
Коммит
142835cede
|
@ -0,0 +1,18 @@
|
|||
trainer:
|
||||
gpus: 1 # single GPU training
|
||||
min_epochs: 10
|
||||
max_epochs: 40
|
||||
benchmark: True
|
||||
|
||||
experiment:
|
||||
task: "resisc45"
|
||||
module:
|
||||
loss: "ce"
|
||||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
datamodule:
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
val_split_pct: 0.2
|
||||
test_split_pct: 0.2
|
|
@ -0,0 +1,14 @@
|
|||
experiment:
|
||||
task: "resisc45"
|
||||
module:
|
||||
loss: "ce"
|
||||
classification_model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
datamodule:
|
||||
batch_size: 128
|
||||
num_workers: 6
|
||||
weights: ${experiment.module.weights}
|
||||
val_split_pct: 0.2
|
||||
test_split_pct: 0.2
|
|
@ -15,11 +15,10 @@ DATA_DIR = "" # path to the LandcoverAI data directory
|
|||
|
||||
# Hyperparameter options
|
||||
model_options = ["unet"]
|
||||
encoder_options = ["resnet50"]
|
||||
lr_options = [1e-4]
|
||||
loss_options = ["ce"]
|
||||
weight_init_options = ["imagenet"]
|
||||
seeds = list(range(15))
|
||||
encoder_options = ["resnet18", "resnet50"]
|
||||
lr_options = [1e-2, 1e-3, 1e-4]
|
||||
loss_options = ["ce", "jaccard"]
|
||||
weight_init_options = ["null", "imagenet"]
|
||||
|
||||
|
||||
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
||||
|
@ -36,18 +35,13 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
|||
if __name__ == "__main__":
|
||||
work: "Queue[str]" = Queue()
|
||||
|
||||
for (model, encoder, lr, loss, weight_init, seed) in itertools.product(
|
||||
model_options,
|
||||
encoder_options,
|
||||
lr_options,
|
||||
loss_options,
|
||||
weight_init_options,
|
||||
seeds,
|
||||
for (model, encoder, lr, loss, weight_init) in itertools.product(
|
||||
model_options, encoder_options, lr_options, loss_options, weight_init_options
|
||||
):
|
||||
|
||||
experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}_{seed}"
|
||||
experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}"
|
||||
|
||||
output_dir = os.path.join("output", "landcoverai_seed_experiments")
|
||||
output_dir = os.path.join("output", "landcoverai_experiments")
|
||||
log_dir = os.path.join(output_dir, "logs")
|
||||
config_file = os.path.join("conf", "landcoverai.yaml")
|
||||
|
||||
|
@ -63,7 +57,6 @@ if __name__ == "__main__":
|
|||
+ f" experiment.module.encoder_name={encoder}"
|
||||
+ f" experiment.module.encoder_weights={weight_init}"
|
||||
+ f" program.output_dir={output_dir}"
|
||||
+ f" program.seed={seed}"
|
||||
+ f" program.log_dir={log_dir}"
|
||||
+ f" program.data_dir={DATA_DIR}"
|
||||
+ " trainer.gpus=[GPU]"
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Runs the train script with a grid of hyperparameters."""
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
# list of GPU IDs that we want to use, one job will be started for every ID in the list
|
||||
GPUS = [0]
|
||||
DRY_RUN = False # if False then print out the commands to be run, if True then run
|
||||
DATA_DIR = "" # path to the RESISC45 data directory
|
||||
|
||||
# Hyperparameter options
|
||||
model_options = ["resnet18", "resnet50"]
|
||||
lr_options = [1e-2, 1e-3, 1e-4]
|
||||
loss_options = ["ce"]
|
||||
weight_options = ["imagenet_only", "random"]
|
||||
|
||||
|
||||
def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
|
||||
"""Process for each ID in GPUS."""
|
||||
while not work.empty():
|
||||
experiment = work.get()
|
||||
experiment = experiment.replace("GPU", str(gpu_idx))
|
||||
print(experiment)
|
||||
if not DRY_RUN:
|
||||
subprocess.call(experiment.split(" "))
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
work: "Queue[str]" = Queue()
|
||||
|
||||
for (model, lr, loss, weights) in itertools.product(
|
||||
model_options,
|
||||
lr_options,
|
||||
loss_options,
|
||||
weight_options,
|
||||
):
|
||||
|
||||
experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_','-')}"
|
||||
|
||||
output_dir = os.path.join("output", "resisc45_experiments")
|
||||
log_dir = os.path.join(output_dir, "logs")
|
||||
config_file = os.path.join("conf", "resisc45.yaml")
|
||||
|
||||
if not os.path.exists(os.path.join(output_dir, experiment_name)):
|
||||
|
||||
command = (
|
||||
"python train.py"
|
||||
+ f" config_file={config_file}"
|
||||
+ f" experiment.name={experiment_name}"
|
||||
+ f" experiment.module.classification_model={model}"
|
||||
+ f" experiment.module.learning_rate={lr}"
|
||||
+ f" experiment.module.loss={loss}"
|
||||
+ f" experiment.module.weights={weights}"
|
||||
+ f" experiment.datamodule.weights={weights}"
|
||||
+ f" program.output_dir={output_dir}"
|
||||
+ f" program.log_dir={log_dir}"
|
||||
+ f" program.data_dir={DATA_DIR}"
|
||||
+ " trainer.gpus=[GPU]"
|
||||
)
|
||||
command = command.strip()
|
||||
|
||||
work.put(command)
|
||||
|
||||
processes = []
|
||||
for gpu_idx in GPUS:
|
||||
p = Process(
|
||||
target=do_work,
|
||||
args=(
|
||||
work,
|
||||
gpu_idx,
|
||||
),
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
for p in processes:
|
||||
p.join()
|
|
@ -16,11 +16,13 @@ import pytest
|
|||
import torch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets.utils import (
|
||||
BoundingBox,
|
||||
collate_dict,
|
||||
dataset_split,
|
||||
disambiguate_timestamp,
|
||||
download_and_extract_archive,
|
||||
download_radiant_mlhub_collection,
|
||||
|
@ -335,3 +337,21 @@ def test_nonexisting_directory(tmp_path: Path) -> None:
|
|||
|
||||
with working_dir(str(subdir), create=True):
|
||||
assert subdir.cwd() == subdir
|
||||
|
||||
|
||||
def test_dataset_split() -> None:
|
||||
num_samples = 24
|
||||
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
|
||||
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
|
||||
ds = TensorDataset(x, y)
|
||||
|
||||
# Test only train/val set split
|
||||
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
|
||||
assert len(train_ds) == num_samples // 2
|
||||
assert len(val_ds) == num_samples // 2
|
||||
|
||||
# Test train/val/test set split
|
||||
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
|
||||
assert len(train_ds) == num_samples // 3
|
||||
assert len(val_ds) == num_samples // 3
|
||||
assert len(test_ds) == num_samples // 3
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from torchgeo.trainers import RESISC45ClassificationTask
|
||||
|
||||
|
||||
class TestRESISC45Trainer:
|
||||
@pytest.fixture
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
task_conf = OmegaConf.load("conf/task_defaults/resisc45.yaml")
|
||||
task_args = OmegaConf.to_object(task_conf.experiment.module)
|
||||
task_args = cast(Dict[str, Any], task_args)
|
||||
return task_args
|
||||
|
||||
def test_resnet_ce(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["classification_model"] = "resnet18"
|
||||
default_config["loss"] = "ce"
|
||||
task = RESISC45ClassificationTask(**default_config)
|
||||
assert isinstance(task.model, torchvision.models.ResNet)
|
||||
assert isinstance(task.loss, nn.CrossEntropyLoss) # type: ignore[attr-defined]
|
||||
|
||||
def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["classification_model"] = "invalid_model"
|
||||
error_message = "Model type 'invalid_model' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
RESISC45ClassificationTask(**default_config)
|
||||
|
||||
def test_invalid_loss(self, default_config: Dict[str, Any]) -> None:
|
||||
default_config["loss"] = "invalid_loss"
|
||||
error_message = "Loss type 'invalid_loss' is not valid."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
RESISC45ClassificationTask(**default_config)
|
|
@ -18,6 +18,7 @@ import numpy as np
|
|||
import rasterio
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset, Subset, random_split
|
||||
from torchvision.datasets.utils import check_integrity, download_url
|
||||
|
||||
__all__ = (
|
||||
|
@ -30,6 +31,7 @@ __all__ = (
|
|||
"working_dir",
|
||||
"collate_dict",
|
||||
"rasterio_loader",
|
||||
"dataset_split",
|
||||
)
|
||||
|
||||
|
||||
|
@ -394,3 +396,28 @@ def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg]
|
|||
# VisionClassificationDataset expects images returned with channels last (HWC)
|
||||
array = array.transpose(1, 2, 0)
|
||||
return array
|
||||
|
||||
|
||||
def dataset_split(
|
||||
dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None
|
||||
) -> List[Subset[Any]]:
|
||||
"""Split a torch Dataset into train/val/test sets.
|
||||
|
||||
If ``test_pct`` is not set then only train and validation splits are returned.
|
||||
|
||||
Args:
|
||||
dataset: dataset to be split into train/val or train/val/test subsets
|
||||
val_pct: percentage of samples to be in validation set
|
||||
test_pct: (Optional) percentage of samples to be in test set
|
||||
Returns:
|
||||
a list of the subset datasets. Either [train, val] or [train, val, test]
|
||||
"""
|
||||
if test_pct is None:
|
||||
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
|
||||
train_length = len(dataset) - val_length # type: ignore[arg-type]
|
||||
return random_split(dataset, [train_length, val_length])
|
||||
else:
|
||||
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
|
||||
test_length = int(len(dataset) * test_pct) # type: ignore[arg-type]
|
||||
train_length = len(dataset) - (val_length + test_length) # type: ignore[arg-type] # noqa: E501
|
||||
return random_split(dataset, [train_length, val_length, test_length])
|
||||
|
|
|
@ -8,6 +8,7 @@ from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
|
|||
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
|
||||
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
|
||||
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
|
||||
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
|
||||
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
|
||||
from .so2sat import So2SatClassificationTask, So2SatDataModule
|
||||
|
||||
|
@ -21,6 +22,8 @@ __all__ = (
|
|||
"LandcoverAISegmentationTask",
|
||||
"NAIPChesapeakeDataModule",
|
||||
"NAIPChesapeakeSegmentationTask",
|
||||
"RESISC45ClassificationTask",
|
||||
"RESISC45DataModule",
|
||||
"SEN12MSDataModule",
|
||||
"SEN12MSSegmentationTask",
|
||||
"So2SatDataModule",
|
||||
|
|
|
@ -0,0 +1,341 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""RESISC45 trainer."""
|
||||
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Conv2d, Linear, Module
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics import Accuracy, IoU, MetricCollection
|
||||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from ..datasets import RESISC45
|
||||
from ..datasets.utils import dataset_split
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
Module.__module__ = "torch.nn"
|
||||
Conv2d.__module__ = "nn.Conv2d"
|
||||
Linear.__module__ = "nn.Linear"
|
||||
|
||||
IN_CHANNELS = 3
|
||||
NUM_CLASSES = 45
|
||||
|
||||
|
||||
class RESISC45ClassificationTask(pl.LightningModule):
|
||||
"""LightningModule for training models on the RESISC45 Dataset."""
|
||||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters passed to the constructor."""
|
||||
pretrained = "imagenet" in self.hparams["weights"]
|
||||
|
||||
if "resnet" in self.hparams["classification_model"]:
|
||||
self.model = getattr(
|
||||
torchvision.models.resnet, self.hparams["classification_model"]
|
||||
)(pretrained=pretrained)
|
||||
in_features = self.model.fc.in_features
|
||||
self.model.fc = Linear(in_features, out_features=NUM_CLASSES)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type '{self.hparams['classification_model']}' is not valid."
|
||||
)
|
||||
|
||||
if "resnet" in self.hparams["classification_model"]:
|
||||
|
||||
if self.hparams["weights"] in ["imagenet_only", "random"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Weight type '{self.hparams['weights']}' is not valid."
|
||||
)
|
||||
else:
|
||||
pass # stub for initializing the weights of other models
|
||||
|
||||
if self.hparams["loss"] == "ce":
|
||||
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
|
||||
else:
|
||||
raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
Keyword Args:
|
||||
classification_model: Name of the classification model use
|
||||
loss: Name of the loss function
|
||||
weights: Either "random", "imagenet_only", "imagenet_and_random", or
|
||||
"random_rgb"
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
|
||||
self.config_task()
|
||||
|
||||
self.train_metrics = MetricCollection(
|
||||
{
|
||||
"OverallAccuracy": Accuracy(num_classes=NUM_CLASSES, average="micro"),
|
||||
"AverageAccuracy": Accuracy(num_classes=NUM_CLASSES, average="macro"),
|
||||
"IoU": IoU(num_classes=NUM_CLASSES),
|
||||
},
|
||||
prefix="train_",
|
||||
)
|
||||
self.val_metrics = self.train_metrics.clone(prefix="val_")
|
||||
self.test_metrics = self.train_metrics.clone(prefix="test_")
|
||||
|
||||
def forward(self, x: Tensor) -> Any: # type: ignore[override]
|
||||
"""Forward pass of the model."""
|
||||
return self.model(x)
|
||||
|
||||
def training_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> Tensor:
|
||||
"""Training step - reports average accuracy and average IoU.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
|
||||
Returns:
|
||||
training loss
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
# by default, the train step logs every `log_every_n_steps` steps where
|
||||
# `log_every_n_steps` is a parameter to the `Trainer` object
|
||||
self.log("train_loss", loss, on_step=True, on_epoch=False)
|
||||
self.train_metrics(y_hat_hard, y)
|
||||
|
||||
return cast(Tensor, loss)
|
||||
|
||||
def training_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch-level training metrics.
|
||||
|
||||
Args:
|
||||
outputs: list of items returned by training_step
|
||||
"""
|
||||
self.log_dict(self.train_metrics.compute())
|
||||
self.train_metrics.reset()
|
||||
|
||||
def validation_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Validation step - reports average accuracy and average IoU.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
self.log("val_loss", loss, on_step=False, on_epoch=True)
|
||||
self.val_metrics(y_hat_hard, y)
|
||||
|
||||
def validation_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level validation metrics.
|
||||
|
||||
Args:
|
||||
outputs: list of items returned by validation_step
|
||||
"""
|
||||
self.log_dict(self.val_metrics.compute())
|
||||
self.val_metrics.reset()
|
||||
|
||||
def test_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Test step identical to the validation step.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
"""
|
||||
x = batch["image"]
|
||||
y = batch["label"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
# by default, the test and validation steps only log per *epoch*
|
||||
self.log("test_loss", loss, on_step=False, on_epoch=True)
|
||||
self.test_metrics(y_hat_hard, y)
|
||||
|
||||
def test_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level test metrics.
|
||||
|
||||
Args:
|
||||
outputs: list of items returned by test_step
|
||||
"""
|
||||
self.log_dict(self.test_metrics.compute())
|
||||
self.test_metrics.reset()
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""Initialize the optimizer and learning rate scheduler.
|
||||
|
||||
Returns:
|
||||
a "lr dict" according to the pytorch lightning documentation --
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
|
||||
"""
|
||||
optimizer = torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.hparams["learning_rate"],
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": ReduceLROnPlateau(
|
||||
optimizer,
|
||||
patience=self.hparams["learning_rate_schedule_patience"],
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class RESISC45DataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the RESISC45 dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset.
|
||||
"""
|
||||
|
||||
band_means = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.36801773, 0.38097873, 0.343583]
|
||||
)
|
||||
|
||||
band_stds = torch.tensor( # type: ignore[attr-defined]
|
||||
[0.14540215, 0.13558227, 0.13203649]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
weights: str = "random",
|
||||
unsupervised_mode: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for RESISC45 based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the RESISC45 Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
weights: Either "random", "imagenet_only", "imagenet_and_random", or
|
||||
"random_rgb"
|
||||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||
val, and test sets
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.weights = weights
|
||||
self.unsupervised_mode = unsupervised_mode
|
||||
|
||||
self.val_split_pct = kwargs["val_split_pct"]
|
||||
self.test_split_pct = kwargs["test_split_pct"]
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
self.transforms = K.AugmentationSequential(
|
||||
K.RandomAffine(degrees=30),
|
||||
K.RandomHorizontalFlip(),
|
||||
K.RandomVerticalFlip(),
|
||||
data_keys=["input"],
|
||||
)
|
||||
|
||||
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"].float()
|
||||
sample["image"] /= 255.0
|
||||
sample["image"] = self.norm(sample["image"])
|
||||
return sample
|
||||
|
||||
def kornia_pipeline(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset with Kornia."""
|
||||
sample["image"] = self.transforms(sample["image"]).squeeze()
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Make sure that the dataset is downloaded.
|
||||
|
||||
This method is only called once per run.
|
||||
"""
|
||||
RESISC45(self.root_dir, checksum=False)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Initialize the main ``Dataset`` objects.
|
||||
|
||||
This method is called once per GPU per run.
|
||||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
if not self.unsupervised_mode:
|
||||
|
||||
dataset = RESISC45(
|
||||
self.root_dir,
|
||||
transforms=transforms,
|
||||
)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||
)
|
||||
else:
|
||||
|
||||
self.train_dataset = RESISC45(
|
||||
self.root_dir,
|
||||
transforms=transforms,
|
||||
)
|
||||
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation."""
|
||||
if self.val_dataset is None or len(self.val_dataset) == 0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing."""
|
||||
if self.test_dataset is None or len(self.test_dataset) == 0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
3
train.py
3
train.py
|
@ -23,6 +23,8 @@ from torchgeo.trainers import (
|
|||
LandcoverAISegmentationTask,
|
||||
NAIPChesapeakeDataModule,
|
||||
NAIPChesapeakeSegmentationTask,
|
||||
RESISC45ClassificationTask,
|
||||
RESISC45DataModule,
|
||||
SEN12MSDataModule,
|
||||
SEN12MSSegmentationTask,
|
||||
So2SatClassificationTask,
|
||||
|
@ -37,6 +39,7 @@ TASK_TO_MODULES_MAPPING: Dict[
|
|||
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
|
||||
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),
|
||||
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
|
||||
"resisc45": (RESISC45ClassificationTask, RESISC45DataModule),
|
||||
"sen12ms": (SEN12MSSegmentationTask, SEN12MSDataModule),
|
||||
"so2sat": (So2SatClassificationTask, So2SatDataModule),
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче