Pytorch lightning based training framework (#42)

* Initial commit of lightning based model training framework

* Made save directories work correctly

* Add pytorch-lightning dependency and some comments

* More documentation and cosmetic tweaks

* Typo fix

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Fix some style issues

* Fix pydocstyle

* Add missing sklearn dependency

* Try to get conda environment working

* Add documentation

* Ignore missing target reference

* Make train.py executable

* Ignore logs and output dirs

* Raise exceptions instead of returning

* Move all argparse stuff to set_up_parser

* Add tests for train.py

* Fix Python 3.6 compatibility

* Fix support for older versions of pytorch-lightning

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Caleb Robinson 2021-07-17 16:57:18 -07:00 коммит произвёл GitHub
Родитель b50c3582cb
Коммит e460d5af23
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
54 изменённых файлов: 683 добавлений и 26 удалений

Просмотреть файл

@ -4,8 +4,10 @@ extend-ignore =
# See https://github.com/PyCQA/pycodestyle/issues/373
E203,
exclude =
# Data
# TorchGeo
data/,
logs/,
output/,
# Spack
.spack-env/,

3
.gitignore поставляемый
Просмотреть файл

@ -1,4 +1,7 @@
# TorchGeo
/data/
/logs/
/output/
# Spack
.spack-env/

Просмотреть файл

@ -50,6 +50,8 @@ nitpicky = True
nitpick_ignore = [
# https://github.com/sphinx-doc/sphinx/issues/8127
("py:class", ".."),
# TODO: can't figure out why this isn't found
("py:class", "LightningDataModule"),
]
@ -86,6 +88,7 @@ autodoc_typehints = "description"
# sphinx.ext.intersphinx
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"pytorch-lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
"rasterio": ("https://rasterio.readthedocs.io/en/latest/", None),
"torch": ("https://pytorch.org/docs/stable", None),
}

Просмотреть файл

@ -15,6 +15,7 @@ architectures, and common image transformations for geospatial data.
datasets
samplers
trainers
transforms
.. toctree::

4
docs/trainers.rst Normal file
Просмотреть файл

@ -0,0 +1,4 @@
torchgeo.trainers
=================
.. automodule:: torchgeo.trainers

Просмотреть файл

@ -1,29 +1,33 @@
name: torchgeo
channels:
- anaconda
- conda-forge
- pytorch
dependencies:
- gdal
- affine
- h5py
- numpy
- opencv
- pillow
- pip
- python
- pytorch::pytorch>=1.7
- pytorch::torchvision!=0.10.0
- setuptools>=42
- sphinx
- pycocotools
- pytorch>=1.7
- pytorch-lightning
- radiant-mlhub>=0.2.1
- rarfile
- rasterio>=0.3
- rtree>=0.5.0
- scikit-learn
- torchvision>=0.3,!=0.10.0
- pip:
- affine
- black[colorama]>=21
- flake8
- isort[colors]>=4.3.5
- mypy>=0.900
- opencv-python
- pycocotools
- pydocstyle[toml]>=6.1
- pytest>=6.0
- pytest-cov
- pytorch-sphinx-theme
- radiant-mlhub>=0.2.1
- rarfile
- rtree>=0.5.0
- setuptools>=42
- sphinx
- git+https://github.com/pytorch/pytorch_sphinx_theme

Просмотреть файл

@ -10,8 +10,10 @@ target-version = ["py36", "py37", "py38", "py39"]
color = true
exclude = '''
/(
# Data
# TorchGeo
| data
| logs
| output
# Spack
| \.spack-env
@ -34,14 +36,14 @@ exclude = '''
[tool.isort]
profile = "black"
known_first_party = ["docs", "tests", "torchgeo"]
extend_skip = [".spack-env/"]
extend_skip = [".spack-env/", "data", "logs", "output"]
skip_gitignore = true
color_output = true
[tool.mypy]
ignore_missing_imports = true
show_error_codes = true
exclude = "(data|build|dist)/"
exclude = "(build|data|dist|logs|output)/"
# Strict
warn_unused_configs = true
@ -72,5 +74,7 @@ norecursedirs = [
"data",
"dist",
"docs",
"logs",
"output",
"__pycache__",
]

Просмотреть файл

@ -11,12 +11,14 @@ pycocotools
pydocstyle[toml]>=6.1
pytest>=6.0
pytest-cov
pytorch-lightning
pytorch-sphinx-theme
radiant-mlhub>=0.2.1
rarfile
rasterio>=0.3
rtree>=0.5.0
scikit-learn
setuptools>=42
sphinx
torch>=1.7
torchvision!=0.10.0
torchvision>=0.3,!=0.10.0

Просмотреть файл

@ -32,7 +32,7 @@ install_requires =
rasterio>=0.3
rtree>=0.5.0
torch>=1.7
torchvision!=0.10.0
torchvision>=0.3,!=0.10.0
python_requires = >= 3.6
packages = find:
@ -61,3 +61,6 @@ tests =
mypy>=0.900
pytest>=6.0
pytest-cov
train =
pytorch-lightning
scikit-learn

Просмотреть файл

@ -14,13 +14,15 @@ spack:
- "py-pydocstyle@6.1:+toml"
- "py-pytest@6.0:"
- py-pytest-cov
- py-pytorch-lightning
- py-pytorch-sphinx-theme
- "py-radiant-mlhub@0.2.1:"
- py-rarfile
- "py-rasterio@0.3:"
- "py-rtree@0.5.0:"
- py-scikit-learn
- "py-setuptools@42:"
- py-sphinx
- "py-torch@1.7:"
- "py-torchvision@:0.9,0.10.1:"
- "py-torchvision@0.3:0.9,0.10.1:"
concretization: together

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,24 @@
{
"links": [
{
"href": "nasa_tropical_storm_competition_test_labels_a_000/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_labels_b_001/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_labels_c_002/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_labels_d_003/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_labels_e_004/stac.json",
"rel": "item"
}
]
}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,24 @@
{
"links": [
{
"href": "nasa_tropical_storm_competition_test_source_a_000/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_source_b_001/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_source_c_002/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_source_d_003/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_source_e_004/stac.json",
"rel": "item"
}
]
}

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "a", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "b", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 631 B

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "c", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "d", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "e", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 333 B

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,24 @@
{
"links": [
{
"href": "nasa_tropical_storm_competition_train_labels_a_000/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_labels_b_001/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_labels_c_002/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_labels_d_003/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_labels_e_004/stac.json",
"rel": "item"
}
]
}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Просмотреть файл

@ -0,0 +1 @@
{"wind_speed": "34"}

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,24 @@
{
"links": [
{
"href": "nasa_tropical_storm_competition_train_source_a_000/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_source_b_001/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_source_c_002/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_source_d_003/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_source_e_004/stac.json",
"rel": "item"
}
]
}

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "a", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "b", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 631 B

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "c", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "d", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -0,0 +1 @@
{"storm_id": "e", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -38,12 +38,12 @@ class TestTropicalCycloneWindEstimation:
)
md5s = {
"train": {
"source": "3c9041d3a6a8178e5ed37fff3ec131b0",
"labels": "d8cebe3d51ef7a5d4e992b75559a0348",
"source": "2b818e0a0873728dabf52c7054a0ce4c",
"labels": "c3c2b6d02c469c5519f4add4f9132712",
},
"test": {
"source": "072c0e6e662f1f9658a47a3eee9218a1",
"labels": "b168c6cea0857ea41e65ebceadf7d85b",
"source": "bc07c519ddf3ce88857435ddddf98a16",
"labels": "3ca4243eff39b87c73e05ec8db1824bf",
},
}
monkeypatch.setattr( # type: ignore[attr-defined]
@ -72,12 +72,12 @@ class TestTropicalCycloneWindEstimation:
assert x["image"].shape == (dataset.size, dataset.size)
def test_len(self, dataset: TropicalCycloneWindEstimation) -> None:
assert len(dataset) == 2
assert len(dataset) == 5
def test_add(self, dataset: TropicalCycloneWindEstimation) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 4
assert len(ds) == 10
def test_already_downloaded(self, dataset: TropicalCycloneWindEstimation) -> None:
TropicalCycloneWindEstimation(root=dataset.root, download=True, api_key="")

85
tests/test_train.py Normal file
Просмотреть файл

@ -0,0 +1,85 @@
import os
import re
import subprocess
import sys
from pathlib import Path
def test_help() -> None:
args = [sys.executable, "train.py", "--help"]
subprocess.run(args, check=True)
def test_required_args() -> None:
args = [sys.executable, "train.py"]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
assert b"error: the following arguments are required:" in ps.stderr
def test_output_file(tmp_path: Path) -> None:
output_file = tmp_path / "output"
output_file.touch()
args = [
sys.executable,
"train.py",
"--experiment_name",
"test",
"--output_dir",
str(output_file),
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
assert b"NotADirectoryError" in ps.stderr
def test_experiment_dir_not_empty(tmp_path: Path) -> None:
output_dir = tmp_path / "output"
experiment_dir = output_dir / "test"
experiment_dir.mkdir(parents=True)
experiment_file = experiment_dir / "foo"
experiment_file.touch()
args = [
sys.executable,
"train.py",
"--experiment_name",
"test",
"--output_dir",
str(output_dir),
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
assert b"FileExistsError" in ps.stderr
def test_overwrite_experiment_dir(tmp_path: Path) -> None:
experiment_name = "test"
output_dir = tmp_path / "output"
data_dir = os.path.join("tests", "data")
log_dir = tmp_path / "logs"
experiment_dir = output_dir / experiment_name
experiment_dir.mkdir(parents=True)
experiment_file = experiment_dir / "foo"
experiment_file.touch()
args = [
sys.executable,
"train.py",
"--experiment_name",
experiment_name,
"--output_dir",
str(output_dir),
"--data_dir",
data_dir,
"--log_dir",
str(log_dir),
"--overwrite",
"--fast_dev_run",
"1",
]
ps = subprocess.run(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
)
assert re.search(
b"The experiment directory, .*, already exists, we might overwrite data in it!",
ps.stdout,
)

Просмотреть файл

@ -2,6 +2,7 @@
import json
import os
from functools import lru_cache
from typing import Any, Callable, Dict, Optional
import numpy as np
@ -136,6 +137,7 @@ class TropicalCycloneWindEstimation(VisionDataset):
"""
return len(self.collection)
@lru_cache()
def _load_image(self, directory: str) -> Tensor:
"""Load a single image.

Просмотреть файл

@ -0,0 +1,9 @@
"""TorchGeo trainers."""
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
__all__ = ("CycloneDataModule", "CycloneSimpleRegressionTask")
# https://stackoverflow.com/questions/40018681
for module in __all__:
globals()[module].__module__ = "torchgeo.trainers"

Просмотреть файл

@ -0,0 +1,236 @@
"""NASA Cyclone dataset trainer."""
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.nn.modules import Module
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Subset
from ..datasets import TropicalCycloneWindEstimation
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
Module.__module__ = "torch.nn"
class CycloneSimpleRegressionTask(pl.LightningModule):
"""LightningModule for training models on the NASA Cyclone Dataset using MSE loss.
This does not take into account other per-sample features available in this dataset.
"""
def __init__(self, model: Module, **kwargs: Dict[str, Any]) -> None:
"""Initialize a new LightningModule for training simple regression models.
Args:
model: A model (specifically, a ``nn.Module``) instance to be trained.
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.model = model
def forward(self, x: Tensor) -> Any: # type: ignore[override]
"""Forward pass of the model."""
return self.model(x)
# NOTE: See https://github.com/PyTorchLightning/pytorch-lightning/issues/5023 for
# why we need to tell mypy to ignore a bunch of things
def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step with an MSE loss. Reports MSE and RMSE."""
x = batch["image"]
y = batch["wind_speed"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
self.log("train_loss", loss) # logging to TensorBoard
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
self.log("train_rmse", rmse)
return loss
def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Validation step - reports MSE and RMSE."""
x = batch["image"]
y = batch["wind_speed"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
self.log("val_loss", loss)
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
self.log("val_rmse", rmse)
def test_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Test step identical to the validation step. Reports MSE and RMSE."""
x = batch["image"]
y = batch["wind_speed"].view(-1, 1)
y_hat = self.forward(x)
loss = F.mse_loss(y_hat, y)
self.log("test_loss", loss)
rmse = torch.sqrt(loss) # type: ignore[attr-defined]
self.log("test_rmse", rmse)
def configure_optimizers(self) -> Dict[str, Any]:
"""Initialize the optimizer and learning rate scheduler."""
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.hparams["learning_rate"], # type: ignore[index]
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(
optimizer,
patience=self.hparams[
"learning_rate_schedule_patience"
], # type: ignore[index]
),
"monitor": "val_loss",
},
}
class CycloneDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NASA Cyclone dataset.
Implements 80/20 train/val splits based on hurricane storm ids.
See :func:`setup` for more details.
"""
def __init__(
self,
root_dir: str,
seed: int,
batch_size: int = 64,
num_workers: int = 4,
api_key: Optional[str] = None,
) -> None:
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the
TropicalCycloneWindEstimation Datasets classes
seed: The seed value to use when doing the sklearn based GroupShuffleSplit
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
api_key: The RadiantEarth MLHub API key to use if the dataset needs to be
downloaded
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
self.api_key = api_key
# TODO: This needs to be converted to actual transforms instead of hacked
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"] / 255.0 # scale to [0,1]
sample["image"] = (
sample["image"].unsqueeze(0).repeat(3, 1, 1)
) # convert to 3 channel
sample["wind_speed"] = torch.as_tensor( # type: ignore[attr-defined]
sample["wind_speed"]
).float()
return sample
def prepare_data(self) -> None:
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
do_download = self.api_key is not None
self.all_train_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="train",
transforms=self.custom_transform,
download=do_download,
api_key=self.api_key,
)
self.all_test_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="test",
transforms=self.custom_transform,
download=do_download,
api_key=self.api_key,
)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
We split samples between train/val by the ``storm_id`` property. I.e. all
samples with the same ``storm_id`` value will be either in the train or the val
split. This is important to test one type of generalizability -- given a new
storm, can we predict its windspeed. The test set, however, contains *some*
storms from the training set (specifically, the latter parts of the storms) as
well as some novel storms.
"""
storm_ids = []
for item in self.all_train_dataset.collection:
storm_id = item["href"].split("/")[0].split("_")[-2]
storm_ids.append(storm_id)
train_indices, val_indices = next(
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split(
storm_ids, groups=storm_ids
)
)
self.train_dataset = Subset(self.all_train_dataset, train_indices)
self.val_dataset = Subset(self.all_train_dataset, val_indices)
self.test_dataset = Subset(
self.all_test_dataset, range(len(self.all_test_dataset))
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True,
)

181
train.py Executable file
Просмотреть файл

@ -0,0 +1,181 @@
#!/usr/bin/env python3
"""torchgeo model training script."""
import argparse
import os
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torchvision import models
from torchgeo.trainers import CycloneDataModule, CycloneSimpleRegressionTask
def set_up_parser() -> argparse.ArgumentParser:
"""Set up the argument parser with program level arguments.
Returns:
the argument parser
"""
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# Add _program_ level arguments to the parser
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size to use in training",
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="Number of workers to use in the Dataloaders",
)
parser.add_argument(
"--seed",
type=int,
default=1337,
help="Random number generator seed for numpy and torch",
)
parser.add_argument(
"--experiment_name",
type=str,
required=True,
help="Name of this experiment (used in TensorBoard and as the subdirectory "
+ "name to save results)",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory to store experiment results",
)
parser.add_argument(
"--data_dir",
type=str,
default="data",
help="Directory where datasets are/will be stored",
)
parser.add_argument(
"--log_dir",
type=str,
default="logs",
help="Directory where logs will be stored.",
)
parser.add_argument(
"--overwrite",
action="store_true",
default=False,
help="Flag to enable overwriting existing output",
)
# TODO: may want to eventually switch to an OmegaConf based configuration system
# See https://pytorch-lightning.readthedocs.io/en/latest/common/hyperparameters.html
# for best practices here
# Add _trainer_ level arguments to the parser
parser = pl.Trainer.add_argparse_args(parser)
# TODO: Add _task_ level arguments to the parser for each _task_ we have implemented
parser.add_argument(
"--learning_rate",
type=float,
default=1e-3,
help="Learning rate",
)
parser.add_argument(
"--learning_rate_schedule_patience",
type=int,
default=2,
help="Patience factor for the ReduceLROnPlateau schedule",
)
return parser
def main(args: argparse.Namespace) -> None:
"""Main training loop."""
######################################
# Setup output directory
######################################
if os.path.isfile(args.output_dir):
raise NotADirectoryError("`--output_dir` must be a directory")
os.makedirs(args.output_dir, exist_ok=True)
experiment_dir = os.path.join(args.output_dir, args.experiment_name)
os.makedirs(experiment_dir, exist_ok=True)
if len(os.listdir(experiment_dir)) > 0:
if args.overwrite:
# TODO: convert this to logging.WARNING
print(
f"WARNING! The experiment directory, {experiment_dir}, already exists, "
+ "we might overwrite data in it!"
)
else:
raise FileExistsError(
f"The experiment directory, {experiment_dir}, already exists and isn't "
+ "empty. We don't want to overwrite any existing results, exiting..."
)
######################################
# Choose task to run based on arguments or configuration
######################################
# TODO: Logic to switch between tasks
model = models.resnet18(pretrained=False, num_classes=1)
datamodule = CycloneDataModule(
args.data_dir,
seed=args.seed,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
# Convert the argparse Namespace into a dictionary so that we can pass as kwargs
dict_args = vars(args)
task = CycloneSimpleRegressionTask(model, **dict_args)
######################################
# Setup trainer
######################################
tb_logger = pl_loggers.TensorBoardLogger(args.log_dir, name=args.experiment_name)
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
dirpath=experiment_dir,
save_top_k=3,
save_last=True,
)
early_stopping_callback = EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=10,
)
trainer = pl.Trainer.from_argparse_args(
args, logger=tb_logger, callbacks=[checkpoint_callback, early_stopping_callback]
)
######################################
# Run experiment
######################################
trainer.fit(model=task, datamodule=datamodule)
trainer.test(model=task, datamodule=datamodule)
if __name__ == "__main__":
parser = set_up_parser()
args = parser.parse_args()
# Set random seed for reproducibility
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.utilities.seed.html#pytorch_lightning.utilities.seed.seed_everything
pl.seed_everything(args.seed)
# Main training procedure
main(args)