diff --git a/benchmark.py b/benchmark.py index fb8c30066..e6205ee67 100755 --- a/benchmark.py +++ b/benchmark.py @@ -7,7 +7,7 @@ import csv import os import time -import lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn import torch.optim as optim diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb index 54443734d..67e6f2344 100644 --- a/docs/tutorials/pretrained_weights.ipynb +++ b/docs/tutorials/pretrained_weights.ipynb @@ -65,9 +65,9 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import timm\n", - "from lightning import Trainer\n", - "from lightning.callbacks import EarlyStopping, ModelCheckpoint\n", - "from lightning.loggers import CSVLogger\n", + "from lightning.pytorch import Trainer\n", + "from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n", + "from lightning.pytorch.loggers import CSVLogger\n", "\n", "from torchgeo.datamodules import EuroSATDataModule\n", "from torchgeo.trainers import ClassificationTask\n", diff --git a/docs/tutorials/trainers.ipynb b/docs/tutorials/trainers.ipynb index 59ddd2d6e..c95bf85cc 100644 --- a/docs/tutorials/trainers.ipynb +++ b/docs/tutorials/trainers.ipynb @@ -74,9 +74,9 @@ "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from lightning import Trainer\n", - "from lightning.callbacks import EarlyStopping, ModelCheckpoint\n", - "from lightning.loggers import CSVLogger\n", + "from lightning.pytorch import Trainer\n", + "from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n", + "from lightning.pytorch.loggers import CSVLogger\n", "\n", "from torchgeo.datamodules import TropicalCycloneDataModule\n", "from torchgeo.trainers import RegressionTask" diff --git a/evaluate.py b/evaluate.py index e1cf15945..f8d9cfcc7 100755 --- a/evaluate.py +++ b/evaluate.py @@ -10,7 +10,7 @@ import csv import os from typing import Any, Dict, Union, cast -import lightning as pl +import lightning.pytorch as pl import torch from torchmetrics import MetricCollection from torchmetrics.classification import BinaryAccuracy, BinaryJaccardIndex diff --git a/experiments/test_chesapeakecvpr_models.py b/experiments/test_chesapeakecvpr_models.py index 3347835f5..de2e0ef4e 100755 --- a/experiments/test_chesapeakecvpr_models.py +++ b/experiments/test_chesapeakecvpr_models.py @@ -8,7 +8,7 @@ import argparse import csv import os -from lightning import Trainer +from lightning.pytorch import Trainer from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index ba7ad20cd..2daf98e3f 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -5,7 +5,7 @@ import os import pytest from _pytest.fixtures import SubRequest -from lightning import Trainer +from lightning.pytorch import Trainer from torchgeo.datamodules import OSCDDataModule diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 2689ba099..6b007e3e2 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -12,7 +12,7 @@ import torch.nn as nn import torchvision from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from lightning import LightningDataModule, Trainer +from lightning.pytorch import LightningDataModule, Trainer from omegaconf import OmegaConf from torchvision.models import resnet18 from torchvision.models._api import WeightsEnum diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 154a2abf1..fe5fcb6ab 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -11,7 +11,7 @@ import torch import torchvision from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from lightning import LightningDataModule, Trainer +from lightning.pytorch import LightningDataModule, Trainer from omegaconf import OmegaConf from torch.nn.modules import Module from torchvision.models._api import WeightsEnum diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 87b650eb4..73a9563e3 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn import torchvision.models.detection from _pytest.monkeypatch import MonkeyPatch -from lightning import LightningDataModule, Trainer +from lightning.pytorch import LightningDataModule, Trainer from omegaconf import OmegaConf from torch.nn.modules import Module diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index b5ff02f9b..fa890acd7 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -11,7 +11,7 @@ import torch import torchvision from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from lightning import LightningDataModule, Trainer +from lightning.pytorch import LightningDataModule, Trainer from omegaconf import OmegaConf from torchvision.models._api import WeightsEnum diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 1d2f9aba9..f972c7e73 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Type, cast import pytest import segmentation_models_pytorch as smp from _pytest.monkeypatch import MonkeyPatch -from lightning import LightningDataModule, Trainer +from lightning.pytorch import LightningDataModule, Trainer from omegaconf import OmegaConf from torch.nn.modules import Module diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 08f8c5575..0f5711ec1 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import kornia.augmentation as K import matplotlib.pyplot as plt import torch -from lightning import LightningDataModule +from lightning.pytorch import LightningDataModule from torch import Tensor from torch.utils.data import DataLoader, Dataset, default_collate diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 8128bce65..219e4e73d 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from kornia import augmentation as K -from lightning import LightningModule +from lightning.pytorch import LightningModule from torch import Tensor, optim from torch.optim.lr_scheduler import ReduceLROnPlateau from torchvision.models._api import WeightsEnum diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index acc418a77..7974c4ed8 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -10,7 +10,7 @@ import matplotlib.pyplot as plt import timm import torch import torch.nn as nn -from lightning import LightningModule +from lightning.pytorch import LightningModule from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 99f45b5ae..e61cbfa41 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, cast import matplotlib.pyplot as plt import torch import torchvision.models.detection -from lightning import LightningModule +from lightning.pytorch import LightningModule from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics.detection.mean_ap import MeanAveragePrecision diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 3d29f274f..84a4fc201 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -10,7 +10,7 @@ import matplotlib.pyplot as plt import timm import torch import torch.nn.functional as F -from lightning import LightningModule +from lightning.pytorch import LightningModule from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index c2ffac13f..3396ff7b5 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -10,7 +10,7 @@ import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import torch import torch.nn as nn -from lightning import LightningModule +from lightning.pytorch import LightningModule from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MetricCollection diff --git a/train.py b/train.py index 998af333e..68aa6624a 100755 --- a/train.py +++ b/train.py @@ -8,9 +8,10 @@ import os from typing import Any, Dict, Tuple, Type, cast -import lightning as L -from lightning import loggers -from lightning.callbacks import EarlyStopping, ModelCheckpoint +import lightning.pytorch as pl +from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from omegaconf import DictConfig, OmegaConf from torchgeo.datamodules import ( @@ -45,7 +46,7 @@ from torchgeo.trainers import ( ) TASK_TO_MODULES_MAPPING: Dict[ - str, Tuple[Type[L.LightningModule], Type[L.LightningDataModule]] + str, Tuple[Type[LightningModule], Type[LightningDataModule]] ] = { "bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule), "byol": (BYOLTask, ChesapeakeCVPRDataModule), @@ -165,8 +166,8 @@ def main(conf: DictConfig) -> None: Dict[str, Any], OmegaConf.to_object(conf.experiment.datamodule) ) - datamodule: L.LightningDataModule - task: L.LightningModule + datamodule: LightningDataModule + task: LightningModule if task_name in TASK_TO_MODULES_MAPPING: task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name] task = task_class(**task_args) @@ -179,8 +180,8 @@ def main(conf: DictConfig) -> None: ###################################### # Setup trainer ###################################### - tb_logger = loggers.TensorBoardLogger(conf.program.log_dir, name=experiment_name) - csv_logger = loggers.CSVLogger(conf.program.log_dir, name=experiment_name) + tb_logger = TensorBoardLogger(conf.program.log_dir, name=experiment_name) + csv_logger = CSVLogger(conf.program.log_dir, name=experiment_name) if isinstance(task, ObjectDetectionTask): monitor_metric = "val_map" @@ -205,7 +206,7 @@ def main(conf: DictConfig) -> None: trainer_args["callbacks"] = [checkpoint_callback, early_stopping_callback] trainer_args["logger"] = [tb_logger, csv_logger] trainer_args["default_root_dir"] = experiment_dir - trainer = L.Trainer(**trainer_args) + trainer = Trainer(**trainer_args) ###################################### # Run experiment @@ -229,7 +230,7 @@ if __name__ == "__main__": # Set random seed for reproducibility # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.utilities.seed.html#pytorch_lightning.utilities.seed.seed_everything - L.seed_everything(conf.program.seed) + pl.seed_everything(conf.program.seed) # Main training procedure main(conf)