* Fix lightning imports

* Style fixes
This commit is contained in:
Adam J. Stewart 2023-03-19 10:48:55 -06:00 коммит произвёл Caleb Robinson
Родитель 8efda7845c
Коммит 3b4436f7e7
18 изменённых файлов: 32 добавлений и 31 удалений

Просмотреть файл

@ -7,7 +7,7 @@ import csv
import os
import time
import lightning as pl
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.optim as optim

Просмотреть файл

@ -65,9 +65,9 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import timm\n",
"from lightning import Trainer\n",
"from lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
"from lightning.loggers import CSVLogger\n",
"from lightning.pytorch import Trainer\n",
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
"from lightning.pytorch.loggers import CSVLogger\n",
"\n",
"from torchgeo.datamodules import EuroSATDataModule\n",
"from torchgeo.trainers import ClassificationTask\n",

Просмотреть файл

@ -74,9 +74,9 @@
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from lightning import Trainer\n",
"from lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
"from lightning.loggers import CSVLogger\n",
"from lightning.pytorch import Trainer\n",
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
"from lightning.pytorch.loggers import CSVLogger\n",
"\n",
"from torchgeo.datamodules import TropicalCycloneDataModule\n",
"from torchgeo.trainers import RegressionTask"

Просмотреть файл

@ -10,7 +10,7 @@ import csv
import os
from typing import Any, Dict, Union, cast
import lightning as pl
import lightning.pytorch as pl
import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy, BinaryJaccardIndex

Просмотреть файл

@ -8,7 +8,7 @@ import argparse
import csv
import os
from lightning import Trainer
from lightning.pytorch import Trainer
from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask

Просмотреть файл

@ -5,7 +5,7 @@ import os
import pytest
from _pytest.fixtures import SubRequest
from lightning import Trainer
from lightning.pytorch import Trainer
from torchgeo.datamodules import OSCDDataModule

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from lightning import LightningDataModule, Trainer
from lightning.pytorch import LightningDataModule, Trainer
from omegaconf import OmegaConf
from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum

Просмотреть файл

@ -11,7 +11,7 @@ import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from lightning import LightningDataModule, Trainer
from lightning.pytorch import LightningDataModule, Trainer
from omegaconf import OmegaConf
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

Просмотреть файл

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
import torchvision.models.detection
from _pytest.monkeypatch import MonkeyPatch
from lightning import LightningDataModule, Trainer
from lightning.pytorch import LightningDataModule, Trainer
from omegaconf import OmegaConf
from torch.nn.modules import Module

Просмотреть файл

@ -11,7 +11,7 @@ import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from lightning import LightningDataModule, Trainer
from lightning.pytorch import LightningDataModule, Trainer
from omegaconf import OmegaConf
from torchvision.models._api import WeightsEnum

Просмотреть файл

@ -7,7 +7,7 @@ from typing import Any, Dict, Type, cast
import pytest
import segmentation_models_pytorch as smp
from _pytest.monkeypatch import MonkeyPatch
from lightning import LightningDataModule, Trainer
from lightning.pytorch import LightningDataModule, Trainer
from omegaconf import OmegaConf
from torch.nn.modules import Module

Просмотреть файл

@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import kornia.augmentation as K
import matplotlib.pyplot as plt
import torch
from lightning import LightningDataModule
from lightning.pytorch import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, default_collate

Просмотреть файл

@ -11,7 +11,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia import augmentation as K
from lightning import LightningModule
from lightning.pytorch import LightningModule
from torch import Tensor, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.models._api import WeightsEnum

Просмотреть файл

@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
import timm
import torch
import torch.nn as nn
from lightning import LightningModule
from lightning.pytorch import LightningModule
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau

Просмотреть файл

@ -9,7 +9,7 @@ from typing import Any, Dict, List, cast
import matplotlib.pyplot as plt
import torch
import torchvision.models.detection
from lightning import LightningModule
from lightning.pytorch import LightningModule
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.detection.mean_ap import MeanAveragePrecision

Просмотреть файл

@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
import timm
import torch
import torch.nn.functional as F
from lightning import LightningModule
from lightning.pytorch import LightningModule
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection

Просмотреть файл

@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from lightning import LightningModule
from lightning.pytorch import LightningModule
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection

Просмотреть файл

@ -8,9 +8,10 @@
import os
from typing import Any, Dict, Tuple, Type, cast
import lightning as L
from lightning import loggers
from lightning.callbacks import EarlyStopping, ModelCheckpoint
import lightning.pytorch as pl
from lightning.pytorch import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from omegaconf import DictConfig, OmegaConf
from torchgeo.datamodules import (
@ -45,7 +46,7 @@ from torchgeo.trainers import (
)
TASK_TO_MODULES_MAPPING: Dict[
str, Tuple[Type[L.LightningModule], Type[L.LightningDataModule]]
str, Tuple[Type[LightningModule], Type[LightningDataModule]]
] = {
"bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule),
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
@ -165,8 +166,8 @@ def main(conf: DictConfig) -> None:
Dict[str, Any], OmegaConf.to_object(conf.experiment.datamodule)
)
datamodule: L.LightningDataModule
task: L.LightningModule
datamodule: LightningDataModule
task: LightningModule
if task_name in TASK_TO_MODULES_MAPPING:
task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name]
task = task_class(**task_args)
@ -179,8 +180,8 @@ def main(conf: DictConfig) -> None:
######################################
# Setup trainer
######################################
tb_logger = loggers.TensorBoardLogger(conf.program.log_dir, name=experiment_name)
csv_logger = loggers.CSVLogger(conf.program.log_dir, name=experiment_name)
tb_logger = TensorBoardLogger(conf.program.log_dir, name=experiment_name)
csv_logger = CSVLogger(conf.program.log_dir, name=experiment_name)
if isinstance(task, ObjectDetectionTask):
monitor_metric = "val_map"
@ -205,7 +206,7 @@ def main(conf: DictConfig) -> None:
trainer_args["callbacks"] = [checkpoint_callback, early_stopping_callback]
trainer_args["logger"] = [tb_logger, csv_logger]
trainer_args["default_root_dir"] = experiment_dir
trainer = L.Trainer(**trainer_args)
trainer = Trainer(**trainer_args)
######################################
# Run experiment
@ -229,7 +230,7 @@ if __name__ == "__main__":
# Set random seed for reproducibility
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.utilities.seed.html#pytorch_lightning.utilities.seed.seed_everything
L.seed_everything(conf.program.seed)
pl.seed_everything(conf.program.seed)
# Main training procedure
main(conf)