Refactoring how the trainer modules are called from train.py and how the configuration files are structured

This commit is contained in:
Caleb Robinson 2021-09-04 00:51:55 +00:00
Родитель 8e19755d24
Коммит ccb7f1912d
14 изменённых файлов: 435 добавлений и 116 удалений

Просмотреть файл

@ -1,21 +1,27 @@
config_file: null # The user can pass a filename here on the command line
config_file: null # This lets the user pass a config filename to load other arguments from
program: # These are the default arguments
batch_size: 32
num_workers: 4
program: # These are the arguments that define how the train.py script works
seed: 1337
experiment_name: ??? # This is OmegaConf syntax that makes this a required field
output_dir: output
data_dir: data
log_dir: logs
overwrite: False
task:
name: ??? # this must be defined so we can get the task specific arguments
experiment: # These are arugments specific to the experiment we are running
name: ??? # this is the name given to this experiment run
task: ??? # this is the type of task to use for this experiement (e.g. "landcoverai")
module: # these will be passed as kwargs to the LightningModule assosciated with the task
learning_rate: 1e-3
datamodule: # these will be passed as kwargs to the LightningDataModule assosciated with the task
root_dir: ${program.data_dir}
seed: ${program.seed}
batch_size: 32
num_workers: 4
# Taken from https://pytorch-lightning.readthedocs.io/en/1.3.8/common/trainer.html#init
# The values here are taken from the defaults here https://pytorch-lightning.readthedocs.io/en/1.3.8/common/trainer.html#init
# this probably should be made into a schema, e.g. as shown https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs
trainer:
trainer: # These are the parameters passed to the pytorch lightning Trainer object
logger: True
checkpoint_callback: True
callbacks: null

Просмотреть файл

@ -0,0 +1,13 @@
experiment:
task: "chesapeake_cvpr"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
encoder_output_stride: 16
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4

Просмотреть файл

@ -1,5 +1,9 @@
task:
name: "cyclone"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
model: "resnet18"
experiment:
task: "cyclone"
module:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4

Просмотреть файл

@ -1,10 +1,14 @@
task:
name: "landcoverai"
optimizer: "adamw"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "ce"
segmentation_model: "deeplabv3+"
encoder_name: "resnet34"
encoder_weights: "imagenet"
encoder_output_stride: 16
experiment:
task: "landcoverai"
module:
loss: "ce"
segmentation_model: "deeplabv3+"
encoder_name: "resnet34"
encoder_weights: "imagenet"
encoder_output_stride: 16
optimizer: "adamw"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4

Просмотреть файл

@ -1,10 +1,14 @@
task:
name: "naipchesapeake"
optimizer: "adamw"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "ce"
segmentation_model: "deeplabv3+"
encoder_name: "resnet34"
encoder_weights: "imagenet"
encoder_output_stride: 16
experiment:
task: "naipchesapeake"
module:
loss: "ce"
segmentation_model: "deeplabv3+"
encoder_name: "resnet34"
encoder_weights: "imagenet"
encoder_output_stride: 16
optimizer: "adamw"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4

Просмотреть файл

@ -1,8 +1,13 @@
task:
name: "sen12ms"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
experiment:
task: "sen12ms"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
encoder_output_stride: 16
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4

Просмотреть файл

@ -25,9 +25,9 @@ def test_output_file(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_file),
"task.name=test",
"experiment.task=test",
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
@ -43,9 +43,9 @@ def test_experiment_dir_not_empty(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"task.name=test",
"experiment.task=test",
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
@ -64,11 +64,11 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"program.data_dir=" + data_dir,
"program.log_dir=" + str(log_dir),
"task.name=cyclone",
"experiment.task=cyclone",
"program.overwrite=True",
"trainer.fast_dev_run=1",
]
@ -87,9 +87,9 @@ def test_invalid_task(task: str, tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"task.name=" + task,
"experiment.task=" + task,
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
@ -102,9 +102,9 @@ def test_missing_config_file(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"task.name=test",
"experiment.task=test",
"config_file=" + str(config_file),
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@ -120,12 +120,12 @@ def test_config_file(tmp_path: Path) -> None:
config_file.write_text(
f"""
program:
experiment_name: test
output_dir: {output_dir}
data_dir: {data_dir}
log_dir: {log_dir}
task:
name: cyclone
experiment:
name: test
task: cyclone
trainer:
fast_dev_run: true
"""
@ -146,12 +146,12 @@ def test_tasks(task: str, tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"program.data_dir=" + data_dir,
"program.log_dir=" + str(log_dir),
"trainer.fast_dev_run=1",
"task.name=" + task,
"experiment.task=" + task,
"program.overwrite=True",
]
subprocess.run(args, check=True)

Просмотреть файл

@ -3,12 +3,15 @@
"""TorchGeo trainers."""
from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
__all__ = (
"ChesapeakeCVPRSegmentationTask",
"ChesapeakeCVPRDataModule",
"CycloneDataModule",
"CycloneSimpleRegressionTask",
"LandcoverAIDataModule",

Просмотреть файл

@ -0,0 +1,292 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Trainers for the Chesapeake datasets."""
from typing import Any, Dict, Optional, cast
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from torch import Tensor
from torch.nn.modules import Module
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Subset
from torchmetrics import Accuracy
from ..datasets import ChesapeakeCVPR
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
Module.__module__ = "torch.nn"
class ChesapeakeCVPRSegmentationTask(LightningModule):
"""LightningModule for training models on the Chesapeake CVPR Land Cover Dataset.
This allows using arbitrary models and losses from the
``pytorch_segmentation_models`` package.
"""
def config_task(self, kwargs: Dict[str, Any]) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["segmentation_model"] == "unet":
self.model = smp.Unet(
encoder_name=kwargs["encoder_name"],
encoder_weights=kwargs["encoder_weights"],
in_channels=4,
classes=6,
)
else:
raise ValueError(
f"Model type '{kwargs['segmentation_model']}' is not valid."
)
if kwargs["loss"] == "ce":
self.loss = nn.CrossEntropyLoss() # type: ignore[attr-defined]
elif kwargs["loss"] == "jaccard":
self.loss = smp.losses.JaccardLoss(mode="multiclass")
else:
raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.")
def __init__(
self,
**kwargs: Any,
) -> None:
"""Initialize the LightningModule with a model and loss function.
Keyword Args:
segmentation_model: Name of the segmentation model type to use
encoder_name: Name of the encoder model backbone to use
encoder_weights: None or "imagenet" to use imagenet pretrained weights in
the encoder model
loss: Name of the loss function
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs
self.config_task(kwargs)
self.train_accuracy = Accuracy()
self.val_accuracy = Accuracy()
self.test_accuracy = Accuracy()
def forward(self, x: Tensor) -> Any: # type: ignore[override]
"""Forward pass of the model."""
return self.model(x)
def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step - reports average accuracy and average IoU."""
x = batch["image"]
y = batch["mask"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
self.log("train_loss", loss) # logging to TensorBoard
self.log("train_acc_step", self.train_accuracy(y_hat_hard, y))
return cast(Tensor, loss)
def training_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level training metrics."""
self.log("train_acc_epoch", self.train_accuracy.compute())
def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Validation step - reports average accuracy and average IoU."""
x = batch["image"]
y = batch["mask"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
self.log("val_loss", loss)
self.log("val_acc_step", self.val_accuracy(y_hat_hard, y))
def validation_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level validation metrics."""
self.log("val_acc_epoch", self.val_accuracy.compute())
def test_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Test step identical to the validation step."""
x = batch["image"]
y = batch["mask"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
self.log("test_loss", loss)
self.log("test_acc_step", self.test_accuracy(y_hat_hard, y))
def test_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level test metrics."""
self.log("test_acc_epoch", self.test_accuracy.compute())
def configure_optimizers(self) -> Dict[str, Any]:
"""Initialize the optimizer and learning rate scheduler."""
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.hparams["learning_rate"],
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(
optimizer,
patience=self.hparams["learning_rate_schedule_patience"],
),
"monitor": "val_loss",
"verbose": True,
},
}
class ChesapeakeCVPRDataModule(LightningDataModule):
"""LightningDataModule implementation for the CVPR Chesapeake Land Cover dataset.
Uses the random spatial split defined per state to partition tiles into train, val,
and test sets.
"""
def __init__(
self,
root_dir: str,
seed: int,
band_set: str = "all",
batch_size: int = 64,
num_workers: int = 4,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes
seed: The seed value to use when doing the sklearn based ShuffleSplit
band_set: The subset of S1/S2 bands to use. Options are: "all",
"s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes:
B2, B3, B4, B8, B11, and B12.
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.seed = seed
self.band_set = band_set
self.batch_size = batch_size
self.num_workers = num_workers
def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
sample["image"] = sample["image"].float()
if self.band_set == "all":
sample["image"][:2] = sample["image"][:2].clip(-25, 0) / -25
sample["image"][2:] = sample["image"][2:].clip(0, 10000) / 10000
elif self.band_set == "s1":
sample["image"][:2] = sample["image"][:2].clip(-25, 0) / -25
else:
sample["image"][:] = sample["image"][:].clip(0, 10000) / 10000
sample["mask"] = sample["mask"][0, :, :].long()
sample["mask"] = torch.take( # type: ignore[attr-defined]
self.DFC2020_CLASS_MAPPING, sample["mask"]
)
return sample
def prepare_data(self) -> None:
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
ChesapeakeCVPR(
self.root_dir,
split="train",
bands=self.band_indices,
transforms=self.custom_transform,
download=True,
checksum=False,
)
ChesapeakeCVPR(
self.root_dir,
split="test",
bands=self.band_indices,
transforms=self.custom_transform,
download=True,
checksum=False,
)
def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
"""
self.train_dataset = ChesapeakeCVPR(
self.root_dir,
split="train",
bands=self.band_indices,
transforms=self.custom_transform,
download=True,
checksum=False,
)
self.val_dataset = ChesapeakeCVPR(
self.root_dir,
split="train",
bands=self.band_indices,
transforms=self.custom_transform,
download=True,
checksum=False,
)
self.test_dataset = ChesapeakeCVPR(
self.root_dir,
split="train",
bands=self.band_indices,
transforms=self.custom_transform,
download=True,
checksum=False,
)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True,
)

Просмотреть файл

@ -36,7 +36,7 @@ class CycloneSimpleRegressionTask(pl.LightningModule):
else:
raise ValueError(f"Model type '{kwargs['model']}' is not valid.")
def __init__(self, **kwargs: Dict[str, Any]) -> None:
def __init__(self, **kwargs: Any) -> None:
"""Initialize a new LightningModule for training simple regression models.
Keyword Args:
@ -131,6 +131,7 @@ class CycloneDataModule(pl.LightningDataModule):
batch_size: int = 64,
num_workers: int = 4,
api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.

Просмотреть файл

@ -67,7 +67,7 @@ class LandcoverAISegmentationTask(pl.LightningModule):
def __init__(
self,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
"""Initialize the LightningModule with a model and loss function.
@ -237,6 +237,7 @@ class LandcoverAIDataModule(pl.LightningDataModule):
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Landcover.AI based DataLoaders.

Просмотреть файл

@ -73,7 +73,7 @@ class NAIPChesapeakeSegmentationTask(pl.LightningModule):
def __init__(
self,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
"""Initialize the LightningModule with a model and loss function.
@ -249,6 +249,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
chesapeake_root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders.

Просмотреть файл

@ -54,7 +54,7 @@ class SEN12MSSegmentationTask(pl.LightningModule):
def __init__(
self,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
"""Initialize the LightningModule with a model and loss function.
@ -194,6 +194,7 @@ class SEN12MSDataModule(pl.LightningDataModule):
band_set: str = "all",
batch_size: int = 64,
num_workers: int = 4,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for SEN12MS based DataLoaders.

Просмотреть файл

@ -6,16 +6,16 @@
"""torchgeo model training script."""
import os
from typing import Any, Dict, cast
from typing import Any, Dict, Tuple, Type, cast
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from torchgeo.trainers import (
ChesapeakeCVPRDataModule,
ChesapeakeCVPRSegmentationTask,
CycloneDataModule,
CycloneSimpleRegressionTask,
LandcoverAIDataModule,
@ -26,6 +26,16 @@ from torchgeo.trainers import (
SEN12MSSegmentationTask,
)
TASK_TO_MODULES_MAPPING: Dict[
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
] = {
"chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule),
"cyclone": (CycloneSimpleRegressionTask, CycloneDataModule),
"landcoverai": (LandcoverAISegmentationTask, LandcoverAIDataModule),
"naipchesapeake": (NAIPChesapeakeSegmentationTask, NAIPChesapeakeDataModule),
"sen12ms": (SEN12MSSegmentationTask, SEN12MSDataModule),
}
def set_up_omegaconf() -> DictConfig:
"""Loads program arguments from either YAML config files or command line arguments.
@ -65,17 +75,15 @@ def set_up_omegaconf() -> DictConfig:
# These OmegaConf structured configs enforce a schema at runtime, see:
# https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs
if conf.task.name == "cyclone":
task_conf = OmegaConf.load("conf/task_defaults/cyclone.yaml")
elif conf.task.name == "landcoverai":
task_conf = OmegaConf.load("conf/task_defaults/landcoverai.yaml")
elif conf.task.name == "sen12ms":
task_conf = OmegaConf.load("conf/task_defaults/sen12ms.yaml")
elif conf.task.name == "test":
task_name = conf.experiment.task
task_config_fn = os.path.join("conf", "task_defaults", f"{task_name}.yaml")
if task_name == "test":
task_conf = OmegaConf.create()
elif os.path.exists(task_config_fn):
task_conf = cast(DictConfig, OmegaConf.load(task_config_fn))
else:
raise ValueError(
f"task.name={conf.task.name} is not recognized as a valid task"
f"experiment.task={task_name} is not recognized as a valid task"
)
conf = OmegaConf.merge(task_conf, conf)
@ -90,11 +98,13 @@ def main(conf: DictConfig) -> None:
# Setup output directory
######################################
experiment_name = conf.experiment.name
task_name = conf.experiment.task
if os.path.isfile(conf.program.output_dir):
raise NotADirectoryError("`program.output_dir` must be a directory")
os.makedirs(conf.program.output_dir, exist_ok=True)
experiment_dir = os.path.join(conf.program.output_dir, conf.program.experiment_name)
experiment_dir = os.path.join(conf.program.output_dir, experiment_name)
os.makedirs(experiment_dir, exist_ok=True)
if len(os.listdir(experiment_dir)) > 0:
@ -109,58 +119,33 @@ def main(conf: DictConfig) -> None:
+ "empty. We don't want to overwrite any existing results, exiting..."
)
with open(os.path.join(experiment_dir, "experiment_config.yaml"), "w") as f:
OmegaConf.save(config=conf, f=f)
######################################
# Choose task to run based on arguments or configuration
######################################
# Convert the DictConfig into a dictionary so that we can pass as kwargs. We use
# var() to convert the @dataclass from to_object() to a dictionary and to help mypy
task_args = OmegaConf.to_object(conf.task)
task_args = cast(Dict[str, Any], task_args)
# Convert the DictConfig into a dictionary so that we can pass as kwargs.
task_args = cast(Dict[str, Any], OmegaConf.to_object(conf.experiment.module))
datamodule_args = cast(
Dict[str, Any], OmegaConf.to_object(conf.experiment.datamodule)
)
datamodule: LightningDataModule
task: LightningModule
if conf.task.name == "cyclone":
datamodule = CycloneDataModule(
conf.program.data_dir,
seed=conf.program.seed,
batch_size=conf.program.batch_size,
num_workers=conf.program.num_workers,
)
task = CycloneSimpleRegressionTask(**task_args)
elif conf.task.name == "landcoverai":
datamodule = LandcoverAIDataModule(
conf.program.data_dir,
batch_size=conf.program.batch_size,
num_workers=conf.program.num_workers,
)
task = LandcoverAISegmentationTask(**task_args)
elif conf.task.name == "naipchesapeake":
datamodule = NAIPChesapeakeDataModule(
conf.program.naip_data_dir,
conf.program.chesapeake_data_dir,
batch_size=conf.program.batch_size,
num_workers=conf.program.num_workers,
)
task = NAIPChesapeakeSegmentationTask(**task_args)
elif conf.task.name == "sen12ms":
datamodule = SEN12MSDataModule(
conf.program.data_dir,
seed=conf.program.seed,
batch_size=conf.program.batch_size,
num_workers=conf.program.num_workers,
)
task = SEN12MSSegmentationTask(**task_args)
datamodule: pl.LightningDataModule
task: pl.LightningModule
if task_name in TASK_TO_MODULES_MAPPING:
task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name]
task = task_class(**task_args)
datamodule = datamodule_class(**datamodule_args)
else:
raise ValueError(
f"task.name={conf.task.name} is not recognized as a valid task"
f"experiment.task={task_name} is not recognized as a valid task"
)
######################################
# Setup trainer
######################################
tb_logger = pl_loggers.TensorBoardLogger(
conf.program.log_dir, name=conf.program.experiment_name
)
tb_logger = pl_loggers.TensorBoardLogger(conf.program.log_dir, name=experiment_name)
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
@ -174,8 +159,7 @@ def main(conf: DictConfig) -> None:
patience=10,
)
trainer_args = OmegaConf.to_object(conf.trainer)
trainer_args = cast(Dict[str, Any], trainer_args)
trainer_args = cast(Dict[str, Any], OmegaConf.to_object(conf.trainer))
trainer_args["callbacks"] = [checkpoint_callback, early_stopping_callback]
trainer_args["logger"] = tb_logger