зеркало из https://github.com/microsoft/torchgeo.git
Refactoring how the trainer modules are called from train.py and how the configuration files are structured
This commit is contained in:
Родитель
8e19755d24
Коммит
ccb7f1912d
|
@ -1,21 +1,27 @@
|
|||
config_file: null # The user can pass a filename here on the command line
|
||||
config_file: null # This lets the user pass a config filename to load other arguments from
|
||||
|
||||
program: # These are the default arguments
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
program: # These are the arguments that define how the train.py script works
|
||||
seed: 1337
|
||||
experiment_name: ??? # This is OmegaConf syntax that makes this a required field
|
||||
output_dir: output
|
||||
data_dir: data
|
||||
log_dir: logs
|
||||
overwrite: False
|
||||
|
||||
task:
|
||||
name: ??? # this must be defined so we can get the task specific arguments
|
||||
experiment: # These are arugments specific to the experiment we are running
|
||||
name: ??? # this is the name given to this experiment run
|
||||
task: ??? # this is the type of task to use for this experiement (e.g. "landcoverai")
|
||||
module: # these will be passed as kwargs to the LightningModule assosciated with the task
|
||||
learning_rate: 1e-3
|
||||
datamodule: # these will be passed as kwargs to the LightningDataModule assosciated with the task
|
||||
root_dir: ${program.data_dir}
|
||||
seed: ${program.seed}
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
|
||||
# Taken from https://pytorch-lightning.readthedocs.io/en/1.3.8/common/trainer.html#init
|
||||
|
||||
# The values here are taken from the defaults here https://pytorch-lightning.readthedocs.io/en/1.3.8/common/trainer.html#init
|
||||
# this probably should be made into a schema, e.g. as shown https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs
|
||||
trainer:
|
||||
trainer: # These are the parameters passed to the pytorch lightning Trainer object
|
||||
logger: True
|
||||
checkpoint_callback: True
|
||||
callbacks: null
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
experiment:
|
||||
task: "chesapeake_cvpr"
|
||||
module:
|
||||
loss: "ce"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: "imagenet"
|
||||
encoder_output_stride: 16
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
datamodule:
|
||||
batch_size: 32
|
||||
num_workers: 4
|
|
@ -1,5 +1,9 @@
|
|||
task:
|
||||
name: "cyclone"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
model: "resnet18"
|
||||
experiment:
|
||||
task: "cyclone"
|
||||
module:
|
||||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
datamodule:
|
||||
batch_size: 32
|
||||
num_workers: 4
|
|
@ -1,10 +1,14 @@
|
|||
task:
|
||||
name: "landcoverai"
|
||||
optimizer: "adamw"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
loss: "ce"
|
||||
segmentation_model: "deeplabv3+"
|
||||
encoder_name: "resnet34"
|
||||
encoder_weights: "imagenet"
|
||||
encoder_output_stride: 16
|
||||
experiment:
|
||||
task: "landcoverai"
|
||||
module:
|
||||
loss: "ce"
|
||||
segmentation_model: "deeplabv3+"
|
||||
encoder_name: "resnet34"
|
||||
encoder_weights: "imagenet"
|
||||
encoder_output_stride: 16
|
||||
optimizer: "adamw"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
datamodule:
|
||||
batch_size: 32
|
||||
num_workers: 4
|
|
@ -1,10 +1,14 @@
|
|||
task:
|
||||
name: "naipchesapeake"
|
||||
optimizer: "adamw"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
loss: "ce"
|
||||
segmentation_model: "deeplabv3+"
|
||||
encoder_name: "resnet34"
|
||||
encoder_weights: "imagenet"
|
||||
encoder_output_stride: 16
|
||||
experiment:
|
||||
task: "naipchesapeake"
|
||||
module:
|
||||
loss: "ce"
|
||||
segmentation_model: "deeplabv3+"
|
||||
encoder_name: "resnet34"
|
||||
encoder_weights: "imagenet"
|
||||
encoder_output_stride: 16
|
||||
optimizer: "adamw"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
datamodule:
|
||||
batch_size: 32
|
||||
num_workers: 4
|
|
@ -1,8 +1,13 @@
|
|||
task:
|
||||
name: "sen12ms"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
loss: "ce"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: "imagenet"
|
||||
experiment:
|
||||
task: "sen12ms"
|
||||
module:
|
||||
loss: "ce"
|
||||
segmentation_model: "unet"
|
||||
encoder_name: "resnet18"
|
||||
encoder_weights: "imagenet"
|
||||
encoder_output_stride: 16
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
datamodule:
|
||||
batch_size: 32
|
||||
num_workers: 4
|
|
@ -25,9 +25,9 @@ def test_output_file(tmp_path: Path) -> None:
|
|||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
"experiment.name=test",
|
||||
"program.output_dir=" + str(output_file),
|
||||
"task.name=test",
|
||||
"experiment.task=test",
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
|
@ -43,9 +43,9 @@ def test_experiment_dir_not_empty(tmp_path: Path) -> None:
|
|||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
"experiment.name=test",
|
||||
"program.output_dir=" + str(output_dir),
|
||||
"task.name=test",
|
||||
"experiment.task=test",
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
|
@ -64,11 +64,11 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
|
|||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
"experiment.name=test",
|
||||
"program.output_dir=" + str(output_dir),
|
||||
"program.data_dir=" + data_dir,
|
||||
"program.log_dir=" + str(log_dir),
|
||||
"task.name=cyclone",
|
||||
"experiment.task=cyclone",
|
||||
"program.overwrite=True",
|
||||
"trainer.fast_dev_run=1",
|
||||
]
|
||||
|
@ -87,9 +87,9 @@ def test_invalid_task(task: str, tmp_path: Path) -> None:
|
|||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
"experiment.name=test",
|
||||
"program.output_dir=" + str(output_dir),
|
||||
"task.name=" + task,
|
||||
"experiment.task=" + task,
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
|
@ -102,9 +102,9 @@ def test_missing_config_file(tmp_path: Path) -> None:
|
|||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
"experiment.name=test",
|
||||
"program.output_dir=" + str(output_dir),
|
||||
"task.name=test",
|
||||
"experiment.task=test",
|
||||
"config_file=" + str(config_file),
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
@ -120,12 +120,12 @@ def test_config_file(tmp_path: Path) -> None:
|
|||
config_file.write_text(
|
||||
f"""
|
||||
program:
|
||||
experiment_name: test
|
||||
output_dir: {output_dir}
|
||||
data_dir: {data_dir}
|
||||
log_dir: {log_dir}
|
||||
task:
|
||||
name: cyclone
|
||||
experiment:
|
||||
name: test
|
||||
task: cyclone
|
||||
trainer:
|
||||
fast_dev_run: true
|
||||
"""
|
||||
|
@ -146,12 +146,12 @@ def test_tasks(task: str, tmp_path: Path) -> None:
|
|||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"program.experiment_name=test",
|
||||
"experiment.name=test",
|
||||
"program.output_dir=" + str(output_dir),
|
||||
"program.data_dir=" + data_dir,
|
||||
"program.log_dir=" + str(log_dir),
|
||||
"trainer.fast_dev_run=1",
|
||||
"task.name=" + task,
|
||||
"experiment.task=" + task,
|
||||
"program.overwrite=True",
|
||||
]
|
||||
subprocess.run(args, check=True)
|
||||
|
|
|
@ -3,12 +3,15 @@
|
|||
|
||||
"""TorchGeo trainers."""
|
||||
|
||||
from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
|
||||
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
|
||||
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
|
||||
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
|
||||
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
|
||||
|
||||
__all__ = (
|
||||
"ChesapeakeCVPRSegmentationTask",
|
||||
"ChesapeakeCVPRDataModule",
|
||||
"CycloneDataModule",
|
||||
"CycloneSimpleRegressionTask",
|
||||
"LandcoverAIDataModule",
|
||||
|
|
|
@ -0,0 +1,292 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Trainers for the Chesapeake datasets."""
|
||||
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Module
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
from ..datasets import ChesapeakeCVPR
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
Module.__module__ = "torch.nn"
|
||||
|
||||
|
||||
class ChesapeakeCVPRSegmentationTask(LightningModule):
|
||||
"""LightningModule for training models on the Chesapeake CVPR Land Cover Dataset.
|
||||
|
||||
This allows using arbitrary models and losses from the
|
||||
``pytorch_segmentation_models`` package.
|
||||
"""
|
||||
|
||||
def config_task(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Configures the task based on kwargs parameters."""
|
||||
if kwargs["segmentation_model"] == "unet":
|
||||
self.model = smp.Unet(
|
||||
encoder_name=kwargs["encoder_name"],
|
||||
encoder_weights=kwargs["encoder_weights"],
|
||||
in_channels=4,
|
||||
classes=6,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type '{kwargs['segmentation_model']}' is not valid."
|
||||
)
|
||||
|
||||
if kwargs["loss"] == "ce":
|
||||
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
|
||||
elif kwargs["loss"] == "jaccard":
|
||||
self.loss = smp.losses.JaccardLoss(mode="multiclass")
|
||||
else:
|
||||
raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
Keyword Args:
|
||||
segmentation_model: Name of the segmentation model type to use
|
||||
encoder_name: Name of the encoder model backbone to use
|
||||
encoder_weights: None or "imagenet" to use imagenet pretrained weights in
|
||||
the encoder model
|
||||
loss: Name of the loss function
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
|
||||
self.config_task(kwargs)
|
||||
|
||||
self.train_accuracy = Accuracy()
|
||||
self.val_accuracy = Accuracy()
|
||||
self.test_accuracy = Accuracy()
|
||||
|
||||
def forward(self, x: Tensor) -> Any: # type: ignore[override]
|
||||
"""Forward pass of the model."""
|
||||
return self.model(x)
|
||||
|
||||
def training_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> Tensor:
|
||||
"""Training step - reports average accuracy and average IoU."""
|
||||
x = batch["image"]
|
||||
y = batch["mask"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
self.log("train_loss", loss) # logging to TensorBoard
|
||||
self.log("train_acc_step", self.train_accuracy(y_hat_hard, y))
|
||||
|
||||
return cast(Tensor, loss)
|
||||
|
||||
def training_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level training metrics."""
|
||||
self.log("train_acc_epoch", self.train_accuracy.compute())
|
||||
|
||||
def validation_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Validation step - reports average accuracy and average IoU."""
|
||||
x = batch["image"]
|
||||
y = batch["mask"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
|
||||
self.log("val_loss", loss)
|
||||
self.log("val_acc_step", self.val_accuracy(y_hat_hard, y))
|
||||
|
||||
def validation_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level validation metrics."""
|
||||
self.log("val_acc_epoch", self.val_accuracy.compute())
|
||||
|
||||
def test_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Test step identical to the validation step."""
|
||||
x = batch["image"]
|
||||
y = batch["mask"]
|
||||
y_hat = self.forward(x)
|
||||
y_hat_hard = y_hat.argmax(dim=1)
|
||||
|
||||
loss = self.loss(y_hat, y)
|
||||
self.log("test_loss", loss)
|
||||
self.log("test_acc_step", self.test_accuracy(y_hat_hard, y))
|
||||
|
||||
def test_epoch_end(self, outputs: Any) -> None:
|
||||
"""Logs epoch level test metrics."""
|
||||
self.log("test_acc_epoch", self.test_accuracy.compute())
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""Initialize the optimizer and learning rate scheduler."""
|
||||
optimizer = torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.hparams["learning_rate"],
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": ReduceLROnPlateau(
|
||||
optimizer,
|
||||
patience=self.hparams["learning_rate_schedule_patience"],
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
"verbose": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ChesapeakeCVPRDataModule(LightningDataModule):
|
||||
"""LightningDataModule implementation for the CVPR Chesapeake Land Cover dataset.
|
||||
|
||||
Uses the random spatial split defined per state to partition tiles into train, val,
|
||||
and test sets.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int,
|
||||
band_set: str = "all",
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes
|
||||
seed: The seed value to use when doing the sklearn based ShuffleSplit
|
||||
band_set: The subset of S1/S2 bands to use. Options are: "all",
|
||||
"s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes:
|
||||
B2, B3, B4, B8, B11, and B12.
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.band_set = band_set
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"].float()
|
||||
|
||||
if self.band_set == "all":
|
||||
sample["image"][:2] = sample["image"][:2].clip(-25, 0) / -25
|
||||
sample["image"][2:] = sample["image"][2:].clip(0, 10000) / 10000
|
||||
elif self.band_set == "s1":
|
||||
sample["image"][:2] = sample["image"][:2].clip(-25, 0) / -25
|
||||
else:
|
||||
sample["image"][:] = sample["image"][:].clip(0, 10000) / 10000
|
||||
|
||||
sample["mask"] = sample["mask"][0, :, :].long()
|
||||
sample["mask"] = torch.take( # type: ignore[attr-defined]
|
||||
self.DFC2020_CLASS_MAPPING, sample["mask"]
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
|
||||
|
||||
This includes optionally downloading the dataset. This is done once per node,
|
||||
while :func:`setup` is done once per GPU.
|
||||
"""
|
||||
ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.band_indices,
|
||||
transforms=self.custom_transform,
|
||||
download=True,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
bands=self.band_indices,
|
||||
transforms=self.custom_transform,
|
||||
download=True,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
"""
|
||||
self.train_dataset = ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.band_indices,
|
||||
transforms=self.custom_transform,
|
||||
download=True,
|
||||
checksum=False,
|
||||
)
|
||||
self.val_dataset = ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.band_indices,
|
||||
transforms=self.custom_transform,
|
||||
download=True,
|
||||
checksum=False,
|
||||
)
|
||||
self.test_dataset = ChesapeakeCVPR(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
bands=self.band_indices,
|
||||
transforms=self.custom_transform,
|
||||
download=True,
|
||||
checksum=False,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation."""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing."""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
)
|
|
@ -36,7 +36,7 @@ class CycloneSimpleRegressionTask(pl.LightningModule):
|
|||
else:
|
||||
raise ValueError(f"Model type '{kwargs['model']}' is not valid.")
|
||||
|
||||
def __init__(self, **kwargs: Dict[str, Any]) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize a new LightningModule for training simple regression models.
|
||||
|
||||
Keyword Args:
|
||||
|
@ -131,6 +131,7 @@ class CycloneDataModule(pl.LightningDataModule):
|
|||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ class LandcoverAISegmentationTask(pl.LightningModule):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
|
@ -237,6 +237,7 @@ class LandcoverAIDataModule(pl.LightningDataModule):
|
|||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for Landcover.AI based DataLoaders.
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ class NAIPChesapeakeSegmentationTask(pl.LightningModule):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
|
@ -249,6 +249,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
chesapeake_root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders.
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ class SEN12MSSegmentationTask(pl.LightningModule):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LightningModule with a model and loss function.
|
||||
|
||||
|
@ -194,6 +194,7 @@ class SEN12MSDataModule(pl.LightningDataModule):
|
|||
band_set: str = "all",
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.
|
||||
|
||||
|
|
94
train.py
94
train.py
|
@ -6,16 +6,16 @@
|
|||
"""torchgeo model training script."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, cast
|
||||
from typing import Any, Dict, Tuple, Type, cast
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning import loggers as pl_loggers
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
from torchgeo.trainers import (
|
||||
ChesapeakeCVPRDataModule,
|
||||
ChesapeakeCVPRSegmentationTask,
|
||||
CycloneDataModule,
|
||||
CycloneSimpleRegressionTask,
|
||||
LandcoverAIDataModule,
|
||||
|
@ -26,6 +26,16 @@ from torchgeo.trainers import (
|
|||
SEN12MSSegmentationTask,
|
||||
)
|
||||
|
||||
TASK_TO_MODULES_MAPPING: Dict[
|
||||
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
|
||||
] = {
|
||||
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
|
||||
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
|
||||
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),
|
||||
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
|
||||
"sen12ms": (SEN12MSSegmentationTask, SEN12MSDataModule),
|
||||
}
|
||||
|
||||
|
||||
def set_up_omegaconf() -> DictConfig:
|
||||
"""Loads program arguments from either YAML config files or command line arguments.
|
||||
|
@ -65,17 +75,15 @@ def set_up_omegaconf() -> DictConfig:
|
|||
|
||||
# These OmegaConf structured configs enforce a schema at runtime, see:
|
||||
# https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs
|
||||
if conf.task.name == "cyclone":
|
||||
task_conf = OmegaConf.load("conf/task_defaults/cyclone.yaml")
|
||||
elif conf.task.name == "landcoverai":
|
||||
task_conf = OmegaConf.load("conf/task_defaults/landcoverai.yaml")
|
||||
elif conf.task.name == "sen12ms":
|
||||
task_conf = OmegaConf.load("conf/task_defaults/sen12ms.yaml")
|
||||
elif conf.task.name == "test":
|
||||
task_name = conf.experiment.task
|
||||
task_config_fn = os.path.join("conf", "task_defaults", f"{task_name}.yaml")
|
||||
if task_name == "test":
|
||||
task_conf = OmegaConf.create()
|
||||
elif os.path.exists(task_config_fn):
|
||||
task_conf = cast(DictConfig, OmegaConf.load(task_config_fn))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"task.name={conf.task.name} is not recognized as a valid task"
|
||||
f"experiment.task={task_name} is not recognized as a valid task"
|
||||
)
|
||||
|
||||
conf = OmegaConf.merge(task_conf, conf)
|
||||
|
@ -90,11 +98,13 @@ def main(conf: DictConfig) -> None:
|
|||
# Setup output directory
|
||||
######################################
|
||||
|
||||
experiment_name = conf.experiment.name
|
||||
task_name = conf.experiment.task
|
||||
if os.path.isfile(conf.program.output_dir):
|
||||
raise NotADirectoryError("`program.output_dir` must be a directory")
|
||||
os.makedirs(conf.program.output_dir, exist_ok=True)
|
||||
|
||||
experiment_dir = os.path.join(conf.program.output_dir, conf.program.experiment_name)
|
||||
experiment_dir = os.path.join(conf.program.output_dir, experiment_name)
|
||||
os.makedirs(experiment_dir, exist_ok=True)
|
||||
|
||||
if len(os.listdir(experiment_dir)) > 0:
|
||||
|
@ -109,58 +119,33 @@ def main(conf: DictConfig) -> None:
|
|||
+ "empty. We don't want to overwrite any existing results, exiting..."
|
||||
)
|
||||
|
||||
with open(os.path.join(experiment_dir, "experiment_config.yaml"), "w") as f:
|
||||
OmegaConf.save(config=conf, f=f)
|
||||
|
||||
######################################
|
||||
# Choose task to run based on arguments or configuration
|
||||
######################################
|
||||
# Convert the DictConfig into a dictionary so that we can pass as kwargs. We use
|
||||
# var() to convert the @dataclass from to_object() to a dictionary and to help mypy
|
||||
task_args = OmegaConf.to_object(conf.task)
|
||||
task_args = cast(Dict[str, Any], task_args)
|
||||
# Convert the DictConfig into a dictionary so that we can pass as kwargs.
|
||||
task_args = cast(Dict[str, Any], OmegaConf.to_object(conf.experiment.module))
|
||||
datamodule_args = cast(
|
||||
Dict[str, Any], OmegaConf.to_object(conf.experiment.datamodule)
|
||||
)
|
||||
|
||||
datamodule: LightningDataModule
|
||||
task: LightningModule
|
||||
if conf.task.name == "cyclone":
|
||||
datamodule = CycloneDataModule(
|
||||
conf.program.data_dir,
|
||||
seed=conf.program.seed,
|
||||
batch_size=conf.program.batch_size,
|
||||
num_workers=conf.program.num_workers,
|
||||
)
|
||||
task = CycloneSimpleRegressionTask(**task_args)
|
||||
elif conf.task.name == "landcoverai":
|
||||
datamodule = LandcoverAIDataModule(
|
||||
conf.program.data_dir,
|
||||
batch_size=conf.program.batch_size,
|
||||
num_workers=conf.program.num_workers,
|
||||
)
|
||||
task = LandcoverAISegmentationTask(**task_args)
|
||||
elif conf.task.name == "naipchesapeake":
|
||||
datamodule = NAIPChesapeakeDataModule(
|
||||
conf.program.naip_data_dir,
|
||||
conf.program.chesapeake_data_dir,
|
||||
batch_size=conf.program.batch_size,
|
||||
num_workers=conf.program.num_workers,
|
||||
)
|
||||
task = NAIPChesapeakeSegmentationTask(**task_args)
|
||||
elif conf.task.name == "sen12ms":
|
||||
datamodule = SEN12MSDataModule(
|
||||
conf.program.data_dir,
|
||||
seed=conf.program.seed,
|
||||
batch_size=conf.program.batch_size,
|
||||
num_workers=conf.program.num_workers,
|
||||
)
|
||||
task = SEN12MSSegmentationTask(**task_args)
|
||||
datamodule: pl.LightningDataModule
|
||||
task: pl.LightningModule
|
||||
if task_name in TASK_TO_MODULES_MAPPING:
|
||||
task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name]
|
||||
task = task_class(**task_args)
|
||||
datamodule = datamodule_class(**datamodule_args)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"task.name={conf.task.name} is not recognized as a valid task"
|
||||
f"experiment.task={task_name} is not recognized as a valid task"
|
||||
)
|
||||
|
||||
######################################
|
||||
# Setup trainer
|
||||
######################################
|
||||
tb_logger = pl_loggers.TensorBoardLogger(
|
||||
conf.program.log_dir, name=conf.program.experiment_name
|
||||
)
|
||||
tb_logger = pl_loggers.TensorBoardLogger(conf.program.log_dir, name=experiment_name)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor="val_loss",
|
||||
|
@ -174,8 +159,7 @@ def main(conf: DictConfig) -> None:
|
|||
patience=10,
|
||||
)
|
||||
|
||||
trainer_args = OmegaConf.to_object(conf.trainer)
|
||||
trainer_args = cast(Dict[str, Any], trainer_args)
|
||||
trainer_args = cast(Dict[str, Any], OmegaConf.to_object(conf.trainer))
|
||||
|
||||
trainer_args["callbacks"] = [checkpoint_callback, early_stopping_callback]
|
||||
trainer_args["logger"] = tb_logger
|
||||
|
|
Загрузка…
Ссылка в новой задаче