Pytorch lightning based training framework (#42)
* Initial commit of lightning based model training framework * Made save directories work correctly * Add pytorch-lightning dependency and some comments * More documentation and cosmetic tweaks * Typo fix Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Fix some style issues * Fix pydocstyle * Add missing sklearn dependency * Try to get conda environment working * Add documentation * Ignore missing target reference * Make train.py executable * Ignore logs and output dirs * Raise exceptions instead of returning * Move all argparse stuff to set_up_parser * Add tests for train.py * Fix Python 3.6 compatibility * Fix support for older versions of pytorch-lightning Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
4
.flake8
|
@ -4,8 +4,10 @@ extend-ignore =
|
|||
# See https://github.com/PyCQA/pycodestyle/issues/373
|
||||
E203,
|
||||
exclude =
|
||||
# Data
|
||||
# TorchGeo
|
||||
data/,
|
||||
logs/,
|
||||
output/,
|
||||
|
||||
# Spack
|
||||
.spack-env/,
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
# TorchGeo
|
||||
/data/
|
||||
/logs/
|
||||
/output/
|
||||
|
||||
# Spack
|
||||
.spack-env/
|
||||
|
|
|
@ -50,6 +50,8 @@ nitpicky = True
|
|||
nitpick_ignore = [
|
||||
# https://github.com/sphinx-doc/sphinx/issues/8127
|
||||
("py:class", ".."),
|
||||
# TODO: can't figure out why this isn't found
|
||||
("py:class", "LightningDataModule"),
|
||||
]
|
||||
|
||||
|
||||
|
@ -86,6 +88,7 @@ autodoc_typehints = "description"
|
|||
# sphinx.ext.intersphinx
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"pytorch-lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
|
||||
"rasterio": ("https://rasterio.readthedocs.io/en/latest/", None),
|
||||
"torch": ("https://pytorch.org/docs/stable", None),
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ architectures, and common image transformations for geospatial data.
|
|||
|
||||
datasets
|
||||
samplers
|
||||
trainers
|
||||
transforms
|
||||
|
||||
.. toctree::
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
torchgeo.trainers
|
||||
=================
|
||||
|
||||
.. automodule:: torchgeo.trainers
|
|
@ -1,29 +1,33 @@
|
|||
name: torchgeo
|
||||
|
||||
channels:
|
||||
- anaconda
|
||||
- conda-forge
|
||||
- pytorch
|
||||
dependencies:
|
||||
- gdal
|
||||
- affine
|
||||
- h5py
|
||||
- numpy
|
||||
- opencv
|
||||
- pillow
|
||||
- pip
|
||||
- python
|
||||
- pytorch::pytorch>=1.7
|
||||
- pytorch::torchvision!=0.10.0
|
||||
- setuptools>=42
|
||||
- sphinx
|
||||
- pycocotools
|
||||
- pytorch>=1.7
|
||||
- pytorch-lightning
|
||||
- radiant-mlhub>=0.2.1
|
||||
- rarfile
|
||||
- rasterio>=0.3
|
||||
- rtree>=0.5.0
|
||||
- scikit-learn
|
||||
- torchvision>=0.3,!=0.10.0
|
||||
- pip:
|
||||
- affine
|
||||
- black[colorama]>=21
|
||||
- flake8
|
||||
- isort[colors]>=4.3.5
|
||||
- mypy>=0.900
|
||||
- opencv-python
|
||||
- pycocotools
|
||||
- pydocstyle[toml]>=6.1
|
||||
- pytest>=6.0
|
||||
- pytest-cov
|
||||
- pytorch-sphinx-theme
|
||||
- radiant-mlhub>=0.2.1
|
||||
- rarfile
|
||||
- rtree>=0.5.0
|
||||
- setuptools>=42
|
||||
- sphinx
|
||||
- git+https://github.com/pytorch/pytorch_sphinx_theme
|
||||
|
|
|
@ -10,8 +10,10 @@ target-version = ["py36", "py37", "py38", "py39"]
|
|||
color = true
|
||||
exclude = '''
|
||||
/(
|
||||
# Data
|
||||
# TorchGeo
|
||||
| data
|
||||
| logs
|
||||
| output
|
||||
|
||||
# Spack
|
||||
| \.spack-env
|
||||
|
@ -34,14 +36,14 @@ exclude = '''
|
|||
[tool.isort]
|
||||
profile = "black"
|
||||
known_first_party = ["docs", "tests", "torchgeo"]
|
||||
extend_skip = [".spack-env/"]
|
||||
extend_skip = [".spack-env/", "data", "logs", "output"]
|
||||
skip_gitignore = true
|
||||
color_output = true
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
show_error_codes = true
|
||||
exclude = "(data|build|dist)/"
|
||||
exclude = "(build|data|dist|logs|output)/"
|
||||
|
||||
# Strict
|
||||
warn_unused_configs = true
|
||||
|
@ -72,5 +74,7 @@ norecursedirs = [
|
|||
"data",
|
||||
"dist",
|
||||
"docs",
|
||||
"logs",
|
||||
"output",
|
||||
"__pycache__",
|
||||
]
|
||||
|
|
|
@ -11,12 +11,14 @@ pycocotools
|
|||
pydocstyle[toml]>=6.1
|
||||
pytest>=6.0
|
||||
pytest-cov
|
||||
pytorch-lightning
|
||||
pytorch-sphinx-theme
|
||||
radiant-mlhub>=0.2.1
|
||||
rarfile
|
||||
rasterio>=0.3
|
||||
rtree>=0.5.0
|
||||
scikit-learn
|
||||
setuptools>=42
|
||||
sphinx
|
||||
torch>=1.7
|
||||
torchvision!=0.10.0
|
||||
torchvision>=0.3,!=0.10.0
|
||||
|
|
|
@ -32,7 +32,7 @@ install_requires =
|
|||
rasterio>=0.3
|
||||
rtree>=0.5.0
|
||||
torch>=1.7
|
||||
torchvision!=0.10.0
|
||||
torchvision>=0.3,!=0.10.0
|
||||
python_requires = >= 3.6
|
||||
packages = find:
|
||||
|
||||
|
@ -61,3 +61,6 @@ tests =
|
|||
mypy>=0.900
|
||||
pytest>=6.0
|
||||
pytest-cov
|
||||
train =
|
||||
pytorch-lightning
|
||||
scikit-learn
|
||||
|
|
|
@ -14,13 +14,15 @@ spack:
|
|||
- "py-pydocstyle@6.1:+toml"
|
||||
- "py-pytest@6.0:"
|
||||
- py-pytest-cov
|
||||
- py-pytorch-lightning
|
||||
- py-pytorch-sphinx-theme
|
||||
- "py-radiant-mlhub@0.2.1:"
|
||||
- py-rarfile
|
||||
- "py-rasterio@0.3:"
|
||||
- "py-rtree@0.5.0:"
|
||||
- py-scikit-learn
|
||||
- "py-setuptools@42:"
|
||||
- py-sphinx
|
||||
- "py-torch@1.7:"
|
||||
- "py-torchvision@:0.9,0.10.1:"
|
||||
- "py-torchvision@0.3:0.9,0.10.1:"
|
||||
concretization: together
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"links": [
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_a_000/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_b_001/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_c_002/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_d_003/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_e_004/stac.json",
|
||||
"rel": "item"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"links": [
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_a_000/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_b_001/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_c_002/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_d_003/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_e_004/stac.json",
|
||||
"rel": "item"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
{"storm_id": "a", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 333 B |
|
@ -0,0 +1 @@
|
|||
{"storm_id": "b", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 631 B |
|
@ -0,0 +1 @@
|
|||
{"storm_id": "c", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 333 B |
|
@ -0,0 +1 @@
|
|||
{"storm_id": "d", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 333 B |
|
@ -0,0 +1 @@
|
|||
{"storm_id": "e", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 333 B |
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"links": [
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_a_000/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_b_001/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_c_002/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_d_003/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_e_004/stac.json",
|
||||
"rel": "item"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1 @@
|
|||
{"wind_speed": "34"}
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"links": [
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_a_000/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_b_001/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_c_002/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_d_003/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_e_004/stac.json",
|
||||
"rel": "item"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
{"storm_id": "a", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 333 B |
|
@ -0,0 +1 @@
|
|||
{"storm_id": "b", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 631 B |
|
@ -0,0 +1 @@
|
|||
{"storm_id": "c", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 333 B |
|
@ -0,0 +1 @@
|
|||
{"storm_id": "d", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 333 B |
|
@ -0,0 +1 @@
|
|||
{"storm_id": "e", "relative_time": "0", "ocean": "2"}
|
После Ширина: | Высота: | Размер: 333 B |
|
@ -38,12 +38,12 @@ class TestTropicalCycloneWindEstimation:
|
|||
)
|
||||
md5s = {
|
||||
"train": {
|
||||
"source": "3c9041d3a6a8178e5ed37fff3ec131b0",
|
||||
"labels": "d8cebe3d51ef7a5d4e992b75559a0348",
|
||||
"source": "2b818e0a0873728dabf52c7054a0ce4c",
|
||||
"labels": "c3c2b6d02c469c5519f4add4f9132712",
|
||||
},
|
||||
"test": {
|
||||
"source": "072c0e6e662f1f9658a47a3eee9218a1",
|
||||
"labels": "b168c6cea0857ea41e65ebceadf7d85b",
|
||||
"source": "bc07c519ddf3ce88857435ddddf98a16",
|
||||
"labels": "3ca4243eff39b87c73e05ec8db1824bf",
|
||||
},
|
||||
}
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
|
@ -72,12 +72,12 @@ class TestTropicalCycloneWindEstimation:
|
|||
assert x["image"].shape == (dataset.size, dataset.size)
|
||||
|
||||
def test_len(self, dataset: TropicalCycloneWindEstimation) -> None:
|
||||
assert len(dataset) == 2
|
||||
assert len(dataset) == 5
|
||||
|
||||
def test_add(self, dataset: TropicalCycloneWindEstimation) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ConcatDataset)
|
||||
assert len(ds) == 4
|
||||
assert len(ds) == 10
|
||||
|
||||
def test_already_downloaded(self, dataset: TropicalCycloneWindEstimation) -> None:
|
||||
TropicalCycloneWindEstimation(root=dataset.root, download=True, api_key="")
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_help() -> None:
|
||||
args = [sys.executable, "train.py", "--help"]
|
||||
subprocess.run(args, check=True)
|
||||
|
||||
|
||||
def test_required_args() -> None:
|
||||
args = [sys.executable, "train.py"]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
assert b"error: the following arguments are required:" in ps.stderr
|
||||
|
||||
|
||||
def test_output_file(tmp_path: Path) -> None:
|
||||
output_file = tmp_path / "output"
|
||||
output_file.touch()
|
||||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"--experiment_name",
|
||||
"test",
|
||||
"--output_dir",
|
||||
str(output_file),
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
assert b"NotADirectoryError" in ps.stderr
|
||||
|
||||
|
||||
def test_experiment_dir_not_empty(tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
experiment_dir = output_dir / "test"
|
||||
experiment_dir.mkdir(parents=True)
|
||||
experiment_file = experiment_dir / "foo"
|
||||
experiment_file.touch()
|
||||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"--experiment_name",
|
||||
"test",
|
||||
"--output_dir",
|
||||
str(output_dir),
|
||||
]
|
||||
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
assert ps.returncode != 0
|
||||
assert b"FileExistsError" in ps.stderr
|
||||
|
||||
|
||||
def test_overwrite_experiment_dir(tmp_path: Path) -> None:
|
||||
experiment_name = "test"
|
||||
output_dir = tmp_path / "output"
|
||||
data_dir = os.path.join("tests", "data")
|
||||
log_dir = tmp_path / "logs"
|
||||
experiment_dir = output_dir / experiment_name
|
||||
experiment_dir.mkdir(parents=True)
|
||||
experiment_file = experiment_dir / "foo"
|
||||
experiment_file.touch()
|
||||
args = [
|
||||
sys.executable,
|
||||
"train.py",
|
||||
"--experiment_name",
|
||||
experiment_name,
|
||||
"--output_dir",
|
||||
str(output_dir),
|
||||
"--data_dir",
|
||||
data_dir,
|
||||
"--log_dir",
|
||||
str(log_dir),
|
||||
"--overwrite",
|
||||
"--fast_dev_run",
|
||||
"1",
|
||||
]
|
||||
ps = subprocess.run(
|
||||
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||||
)
|
||||
assert re.search(
|
||||
b"The experiment directory, .*, already exists, we might overwrite data in it!",
|
||||
ps.stdout,
|
||||
)
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -136,6 +137,7 @@ class TropicalCycloneWindEstimation(VisionDataset):
|
|||
"""
|
||||
return len(self.collection)
|
||||
|
||||
@lru_cache()
|
||||
def _load_image(self, directory: str) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
"""TorchGeo trainers."""
|
||||
|
||||
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
|
||||
|
||||
__all__ = ("CycloneDataModule", "CycloneSimpleRegressionTask")
|
||||
|
||||
# https://stackoverflow.com/questions/40018681
|
||||
for module in __all__:
|
||||
globals()[module].__module__ = "torchgeo.trainers"
|
|
@ -0,0 +1,236 @@
|
|||
"""NASA Cyclone dataset trainer."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sklearn.model_selection import GroupShuffleSplit
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Module
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from ..datasets import TropicalCycloneWindEstimation
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
DataLoader.__module__ = "torch.utils.data"
|
||||
Module.__module__ = "torch.nn"
|
||||
|
||||
|
||||
class CycloneSimpleRegressionTask(pl.LightningModule):
|
||||
"""LightningModule for training models on the NASA Cyclone Dataset using MSE loss.
|
||||
|
||||
This does not take into account other per-sample features available in this dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Module, **kwargs: Dict[str, Any]) -> None:
|
||||
"""Initialize a new LightningModule for training simple regression models.
|
||||
|
||||
Args:
|
||||
model: A model (specifically, a ``nn.Module``) instance to be trained.
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters() # creates `self.hparams` from kwargs
|
||||
self.model = model
|
||||
|
||||
def forward(self, x: Tensor) -> Any: # type: ignore[override]
|
||||
"""Forward pass of the model."""
|
||||
return self.model(x)
|
||||
|
||||
# NOTE: See https://github.com/PyTorchLightning/pytorch-lightning/issues/5023 for
|
||||
# why we need to tell mypy to ignore a bunch of things
|
||||
def training_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> Tensor:
|
||||
"""Training step with an MSE loss. Reports MSE and RMSE."""
|
||||
x = batch["image"]
|
||||
y = batch["wind_speed"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
|
||||
self.log("train_loss", loss) # logging to TensorBoard
|
||||
|
||||
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
|
||||
self.log("train_rmse", rmse)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Validation step - reports MSE and RMSE."""
|
||||
x = batch["image"]
|
||||
y = batch["wind_speed"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
self.log("val_loss", loss)
|
||||
|
||||
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
|
||||
self.log("val_rmse", rmse)
|
||||
|
||||
def test_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Test step identical to the validation step. Reports MSE and RMSE."""
|
||||
x = batch["image"]
|
||||
y = batch["wind_speed"].view(-1, 1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
self.log("test_loss", loss)
|
||||
|
||||
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
|
||||
self.log("test_rmse", rmse)
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""Initialize the optimizer and learning rate scheduler."""
|
||||
optimizer = torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.hparams["learning_rate"], # type: ignore[index]
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": ReduceLROnPlateau(
|
||||
optimizer,
|
||||
patience=self.hparams[
|
||||
"learning_rate_schedule_patience"
|
||||
], # type: ignore[index]
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CycloneDataModule(pl.LightningDataModule):
|
||||
"""LightningDataModule implementation for the NASA Cyclone dataset.
|
||||
|
||||
Implements 80/20 train/val splits based on hurricane storm ids.
|
||||
See :func:`setup` for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str,
|
||||
seed: int,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
api_key: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
|
||||
|
||||
Args:
|
||||
root_dir: The ``root`` arugment to pass to the
|
||||
TropicalCycloneWindEstimation Datasets classes
|
||||
seed: The seed value to use when doing the sklearn based GroupShuffleSplit
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
api_key: The RadiantEarth MLHub API key to use if the dataset needs to be
|
||||
downloaded
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.api_key = api_key
|
||||
|
||||
# TODO: This needs to be converted to actual transforms instead of hacked
|
||||
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform a single sample from the Dataset."""
|
||||
sample["image"] = sample["image"] / 255.0 # scale to [0,1]
|
||||
sample["image"] = (
|
||||
sample["image"].unsqueeze(0).repeat(3, 1, 1)
|
||||
) # convert to 3 channel
|
||||
sample["wind_speed"] = torch.as_tensor( # type: ignore[attr-defined]
|
||||
sample["wind_speed"]
|
||||
).float()
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
|
||||
|
||||
This includes optionally downloading the dataset. This is done once per node,
|
||||
while :func:`setup` is done once per GPU.
|
||||
"""
|
||||
do_download = self.api_key is not None
|
||||
self.all_train_dataset = TropicalCycloneWindEstimation(
|
||||
self.root_dir,
|
||||
split="train",
|
||||
transforms=self.custom_transform,
|
||||
download=do_download,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
self.all_test_dataset = TropicalCycloneWindEstimation(
|
||||
self.root_dir,
|
||||
split="test",
|
||||
transforms=self.custom_transform,
|
||||
download=do_download,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""Create the train/val/test splits based on the original Dataset objects.
|
||||
|
||||
The splits should be done here vs. in :func:`__init__` per the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
|
||||
|
||||
We split samples between train/val by the ``storm_id`` property. I.e. all
|
||||
samples with the same ``storm_id`` value will be either in the train or the val
|
||||
split. This is important to test one type of generalizability -- given a new
|
||||
storm, can we predict its windspeed. The test set, however, contains *some*
|
||||
storms from the training set (specifically, the latter parts of the storms) as
|
||||
well as some novel storms.
|
||||
"""
|
||||
storm_ids = []
|
||||
for item in self.all_train_dataset.collection:
|
||||
storm_id = item["href"].split("/")[0].split("_")[-2]
|
||||
storm_ids.append(storm_id)
|
||||
|
||||
train_indices, val_indices = next(
|
||||
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
|
||||
storm_ids, groups=storm_ids
|
||||
)
|
||||
)
|
||||
|
||||
self.train_dataset = Subset(self.all_train_dataset, train_indices)
|
||||
self.val_dataset = Subset(self.all_train_dataset, val_indices)
|
||||
self.test_dataset = Subset(
|
||||
self.all_test_dataset, range(len(self.all_test_dataset))
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training."""
|
||||
return DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for validation."""
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing."""
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
)
|
|
@ -0,0 +1,181 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""torchgeo model training script."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import loggers as pl_loggers
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from torchvision import models
|
||||
|
||||
from torchgeo.trainers import CycloneDataModule, CycloneSimpleRegressionTask
|
||||
|
||||
|
||||
def set_up_parser() -> argparse.ArgumentParser:
|
||||
"""Set up the argument parser with program level arguments.
|
||||
|
||||
Returns:
|
||||
the argument parser
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
# Add _program_ level arguments to the parser
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size to use in training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers to use in the Dataloaders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=1337,
|
||||
help="Random number generator seed for numpy and torch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experiment_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of this experiment (used in TensorBoard and as the subdirectory "
|
||||
+ "name to save results)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Directory to store experiment results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default="data",
|
||||
help="Directory where datasets are/will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help="Directory where logs will be stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Flag to enable overwriting existing output",
|
||||
)
|
||||
|
||||
# TODO: may want to eventually switch to an OmegaConf based configuration system
|
||||
# See https://pytorch-lightning.readthedocs.io/en/latest/common/hyperparameters.html
|
||||
# for best practices here
|
||||
|
||||
# Add _trainer_ level arguments to the parser
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
|
||||
# TODO: Add _task_ level arguments to the parser for each _task_ we have implemented
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-3,
|
||||
help="Learning rate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate_schedule_patience",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Patience factor for the ReduceLROnPlateau schedule",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
"""Main training loop."""
|
||||
######################################
|
||||
# Setup output directory
|
||||
######################################
|
||||
|
||||
if os.path.isfile(args.output_dir):
|
||||
raise NotADirectoryError("`--output_dir` must be a directory")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
experiment_dir = os.path.join(args.output_dir, args.experiment_name)
|
||||
os.makedirs(experiment_dir, exist_ok=True)
|
||||
|
||||
if len(os.listdir(experiment_dir)) > 0:
|
||||
if args.overwrite:
|
||||
# TODO: convert this to logging.WARNING
|
||||
print(
|
||||
f"WARNING! The experiment directory, {experiment_dir}, already exists, "
|
||||
+ "we might overwrite data in it!"
|
||||
)
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"The experiment directory, {experiment_dir}, already exists and isn't "
|
||||
+ "empty. We don't want to overwrite any existing results, exiting..."
|
||||
)
|
||||
|
||||
######################################
|
||||
# Choose task to run based on arguments or configuration
|
||||
######################################
|
||||
# TODO: Logic to switch between tasks
|
||||
|
||||
model = models.resnet18(pretrained=False, num_classes=1)
|
||||
datamodule = CycloneDataModule(
|
||||
args.data_dir,
|
||||
seed=args.seed,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
|
||||
# Convert the argparse Namespace into a dictionary so that we can pass as kwargs
|
||||
dict_args = vars(args)
|
||||
task = CycloneSimpleRegressionTask(model, **dict_args)
|
||||
|
||||
######################################
|
||||
# Setup trainer
|
||||
######################################
|
||||
tb_logger = pl_loggers.TensorBoardLogger(args.log_dir, name=args.experiment_name)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor="val_loss",
|
||||
dirpath=experiment_dir,
|
||||
save_top_k=3,
|
||||
save_last=True,
|
||||
)
|
||||
early_stopping_callback = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
min_delta=0.00,
|
||||
patience=10,
|
||||
)
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args, logger=tb_logger, callbacks=[checkpoint_callback, early_stopping_callback]
|
||||
)
|
||||
|
||||
######################################
|
||||
# Run experiment
|
||||
######################################
|
||||
trainer.fit(model=task, datamodule=datamodule)
|
||||
trainer.test(model=task, datamodule=datamodule)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = set_up_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed for reproducibility
|
||||
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.utilities.seed.html#pytorch_lightning.utilities.seed.seed_everything
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
# Main training procedure
|
||||
main(args)
|