diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0387a44ca..ba1822abf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,5 +34,5 @@ repos: hooks: - id: mypy args: [--strict, --ignore-missing-imports, --show-error-codes] - additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=1.8, pytest>=6.1.2, pyvista>=0.29, omegaconf>=2.0.1, hydra-core>=1, kornia>=0.6.5, numpy>=1.22] + additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=2.0.3, pytest>=6.1.2, pyvista>=0.29, omegaconf>=2.0.1, hydra-core>=1, kornia>=0.6.5, numpy>=1.22] exclude: (build|data|dist|logo|logs|output)/ diff --git a/requirements/required.txt b/requirements/required.txt index 3e0fe7d63..00d6ed181 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -6,7 +6,7 @@ einops==0.6.1 fiona==1.9.4.post1 kornia==0.6.12 lightly==1.4.7 -lightning==2.0.2 +lightning==2.0.3 matplotlib==3.7.1 numpy==1.24.3 pillow==9.5.0 diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index b3f1bad3f..e68f45164 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -97,28 +97,32 @@ class TestGeoDataModule: def test_train(self, datamodule: CustomGeoDataModule) -> None: datamodule.setup("fit") - datamodule.trainer.training = True + if datamodule.trainer: + datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1) batch = datamodule.on_after_batch_transfer(batch, 0) def test_val(self, datamodule: CustomGeoDataModule) -> None: datamodule.setup("validate") - datamodule.trainer.validating = True + if datamodule.trainer: + datamodule.trainer.validating = True batch = next(iter(datamodule.val_dataloader())) batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1) batch = datamodule.on_after_batch_transfer(batch, 0) def test_test(self, datamodule: CustomGeoDataModule) -> None: datamodule.setup("test") - datamodule.trainer.testing = True + if datamodule.trainer: + datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1) batch = datamodule.on_after_batch_transfer(batch, 0) def test_predict(self, datamodule: CustomGeoDataModule) -> None: datamodule.setup("predict") - datamodule.trainer.predicting = True + if datamodule.trainer: + datamodule.trainer.predicting = True batch = next(iter(datamodule.predict_dataloader())) batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1) batch = datamodule.on_after_batch_transfer(batch, 0) @@ -156,25 +160,29 @@ class TestNonGeoDataModule: def test_train(self, datamodule: CustomNonGeoDataModule) -> None: datamodule.setup("fit") - datamodule.trainer.training = True + if datamodule.trainer: + datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) def test_val(self, datamodule: CustomNonGeoDataModule) -> None: datamodule.setup("validate") - datamodule.trainer.validating = True + if datamodule.trainer: + datamodule.trainer.validating = True batch = next(iter(datamodule.val_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) def test_test(self, datamodule: CustomNonGeoDataModule) -> None: datamodule.setup("test") - datamodule.trainer.testing = True + if datamodule.trainer: + datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) def test_predict(self, datamodule: CustomNonGeoDataModule) -> None: datamodule.setup("predict") - datamodule.trainer.predicting = True + if datamodule.trainer: + datamodule.trainer.predicting = True batch = next(iter(datamodule.predict_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 6f6403aa2..10e890d04 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -30,7 +30,8 @@ class TestOSCDDataModule: def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.setup("fit") - datamodule.trainer.training = True + if datamodule.trainer: + datamodule.trainer.training = True batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) @@ -42,7 +43,8 @@ class TestOSCDDataModule: def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.setup("validate") - datamodule.trainer.validating = True + if datamodule.trainer: + datamodule.trainer.validating = True batch = next(iter(datamodule.val_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: @@ -55,7 +57,8 @@ class TestOSCDDataModule: def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.setup("test") - datamodule.trainer.testing = True + if datamodule.trainer: + datamodule.trainer.testing = True batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index dfcc06b87..f733467f0 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -3,7 +3,7 @@ """Base classes for all :mod:`torchgeo` data modules.""" -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast import kornia.augmentation as K import matplotlib.pyplot as plt @@ -23,7 +23,7 @@ from ..transforms import AugmentationSequential from .utils import MisconfigurationException -class BaseDataModule(LightningDataModule): # type: ignore[misc] +class BaseDataModule(LightningDataModule): """Base class for all TorchGeo data modules. .. versionadded:: 0.5 @@ -32,6 +32,51 @@ class BaseDataModule(LightningDataModule): # type: ignore[misc] mean = torch.tensor(0) std = torch.tensor(255) + def __init__( + self, + dataset_class: type[Dataset[dict[str, Tensor]]], + batch_size: int = 1, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new BaseDataModule instance. + + Args: + dataset_class: Class used to instantiate a new dataset. + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to ``dataset_class`` + """ + super().__init__() + + self.dataset_class = dataset_class + self.batch_size = batch_size + self.num_workers = num_workers + self.kwargs = kwargs + + # Datasets + self.dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None + self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None + + # Data loaders + self.train_batch_size: Optional[int] = None + self.val_batch_size: Optional[int] = None + self.test_batch_size: Optional[int] = None + self.predict_batch_size: Optional[int] = None + + # Data augmentation + Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] + self.aug: Transform = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] + ) + self.train_aug: Optional[Transform] = None + self.val_aug: Optional[Transform] = None + self.test_aug: Optional[Transform] = None + self.predict_aug: Optional[Transform] = None + def prepare_data(self) -> None: """Download and prepare data. @@ -112,21 +157,13 @@ class GeoDataModule(BaseDataModule): num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ - super().__init__() + super().__init__(dataset_class, batch_size, num_workers, **kwargs) - self.dataset_class = dataset_class - self.batch_size = batch_size self.patch_size = patch_size self.length = length - self.num_workers = num_workers - self.kwargs = kwargs - # Datasets - self.dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None + # Collation + self.collate_fn = stack_samples # Samplers self.sampler: Optional[GeoSampler] = None @@ -142,25 +179,6 @@ class GeoDataModule(BaseDataModule): self.test_batch_sampler: Optional[BatchGeoSampler] = None self.predict_batch_sampler: Optional[BatchGeoSampler] = None - # Data loaders - self.train_batch_size: Optional[int] = None - self.val_batch_size: Optional[int] = None - self.test_batch_size: Optional[int] = None - self.predict_batch_size: Optional[int] = None - - # Collation - self.collate_fn = stack_samples - - # Data augmentation - Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] - self.aug: Transform = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] - ) - self.train_aug: Optional[Transform] = None - self.val_aug: Optional[Transform] = None - self.test_aug: Optional[Transform] = None - self.predict_aug: Optional[Transform] = None - def setup(self, stage: str) -> None: """Set up datasets and samplers. @@ -172,22 +190,31 @@ class GeoDataModule(BaseDataModule): stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit"]: - self.train_dataset = self.dataset_class( # type: ignore[call-arg] - split="train", **self.kwargs + self.train_dataset = cast( + GeoDataset, + self.dataset_class( # type: ignore[call-arg] + split="train", **self.kwargs + ), ) self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) if stage in ["fit", "validate"]: - self.val_dataset = self.dataset_class( # type: ignore[call-arg] - split="val", **self.kwargs + self.val_dataset = cast( + GeoDataset, + self.dataset_class( # type: ignore[call-arg] + split="val", **self.kwargs + ), ) self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) if stage in ["test"]: - self.test_dataset = self.dataset_class( # type: ignore[call-arg] - split="test", **self.kwargs + self.test_dataset = cast( + GeoDataset, + self.dataset_class( # type: ignore[call-arg] + split="test", **self.kwargs + ), ) self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size @@ -357,39 +384,11 @@ class NonGeoDataModule(BaseDataModule): num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ - super().__init__() - - self.dataset_class = dataset_class - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - - # Datasets - self.dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None - self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None - - # Data loaders - self.train_batch_size: Optional[int] = None - self.val_batch_size: Optional[int] = None - self.test_batch_size: Optional[int] = None - self.predict_batch_size: Optional[int] = None + super().__init__(dataset_class, batch_size, num_workers, **kwargs) # Collation self.collate_fn = default_collate - # Data augmentation - Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] - self.aug: Transform = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] - ) - self.train_aug: Optional[Transform] = None - self.val_aug: Optional[Transform] = None - self.test_aug: Optional[Transform] = None - self.predict_aug: Optional[Transform] = None - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index eb5907841..00315f402 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -278,7 +278,7 @@ class BYOL(nn.Module): pt.data = self.beta * pt.data + (1 - self.beta) * p.data -class BYOLTask(LightningModule): # type: ignore[misc] +class BYOLTask(LightningModule): """Class for pre-training any PyTorch model using BYOL. Supports any available `Timm model diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 99c06dcaa..29f96f2f9 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -29,7 +29,7 @@ from ..models import get_weight from . import utils -class ClassificationTask(LightningModule): # type: ignore[misc] +class ClassificationTask(LightningModule): """LightningModule for image classification. Supports any available `Timm model diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 60c1754d2..dfa2bd924 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -46,7 +46,7 @@ BACKBONE_WEIGHT_MAP = { } -class ObjectDetectionTask(LightningModule): # type: ignore[misc] +class ObjectDetectionTask(LightningModule): """LightningModule for object detection of images. Currently, supports Faster R-CNN, FCOS, and RetinaNet models from diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index 570dec336..f5425390b 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -6,7 +6,7 @@ import os import warnings from collections.abc import Sequence -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast import kornia.augmentation as K import timm @@ -118,7 +118,7 @@ def moco_augmentations( return aug1, aug2 -class MoCoTask(LightningModule): # type: ignore[misc] +class MoCoTask(LightningModule): """MoCo: Momentum Contrast. Reference implementations: @@ -295,12 +295,15 @@ class MoCoTask(LightningModule): # type: ignore[misc] k = self.projection_head_momentum(k) return cast(Tensor, k) - def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: + def training_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: """Compute the training loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: The loss tensor. @@ -359,13 +362,15 @@ class MoCoTask(LightningModule): # type: ignore[misc] return cast(Tensor, loss) - def validation_step(self, batch: dict[str, Tensor], batch_idx: int) -> None: + def validation_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: """No-op, does nothing.""" - def test_step(self, batch: dict[str, Tensor], batch_idx: int) -> None: + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" - def predict_step(self, batch: dict[str, Tensor], batch_idx: int) -> None: + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]: @@ -381,6 +386,9 @@ class MoCoTask(LightningModule): # type: ignore[misc] weight_decay=self.hparams["weight_decay"], ) warmup_epochs = 40 + max_epochs = 200 + if self.trainer and self.trainer.max_epochs: + max_epochs = self.trainer.max_epochs lr_scheduler: LRScheduler = SequentialLR( optimizer, schedulers=[ @@ -389,7 +397,7 @@ class MoCoTask(LightningModule): # type: ignore[misc] start_factor=1 / warmup_epochs, total_iters=warmup_epochs, ), - CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs), + CosineAnnealingLR(optimizer, T_max=max_epochs), ], milestones=[warmup_epochs], ) diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index c8dee0c50..ce62a6494 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -22,7 +22,7 @@ from ..models import FCN, get_weight from . import utils -class RegressionTask(LightningModule): # type: ignore[misc] +class RegressionTask(LightningModule): """LightningModule for training models on regression datasets. Supports any available `Timm model diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 826f18241..af20fb868 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -23,7 +23,7 @@ from ..models import FCN, get_weight from . import utils -class SemanticSegmentationTask(LightningModule): # type: ignore[misc] +class SemanticSegmentationTask(LightningModule): """LightningModule for semantic segmentation of images. Supports `Segmentation Models Pytorch diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index b878b7823..1dd546f7c 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -5,7 +5,7 @@ import os import warnings -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast import kornia.augmentation as K import timm @@ -57,7 +57,7 @@ def simclr_augmentations(size: int, weights: Tensor) -> nn.Module: ) -class SimCLRTask(LightningModule): # type: ignore[misc] +class SimCLRTask(LightningModule): """SimCLR: a simple framework for contrastive learning of visual representations. Reference implementation: @@ -193,12 +193,15 @@ class SimCLRTask(LightningModule): # type: ignore[misc] z = self.projection_head(h) return cast(Tensor, z), cast(Tensor, h) - def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: + def training_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: """Compute the training loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. + dataloader_idx: Index of the current dataloader. Returns: The loss tensor. @@ -237,15 +240,17 @@ class SimCLRTask(LightningModule): # type: ignore[misc] return cast(Tensor, loss) - def validation_step(self, batch: dict[str, Tensor], batch_idx: int) -> None: + def validation_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0 + ) -> None: """No-op, does nothing.""" - def test_step(self, batch: dict[str, Tensor], batch_idx: int) -> None: + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" # TODO # v2: add distillation step - def predict_step(self, batch: dict[str, Tensor], batch_idx: int) -> None: + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing.""" def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]: @@ -260,15 +265,18 @@ class SimCLRTask(LightningModule): # type: ignore[misc] lr=self.hparams["lr"], weight_decay=self.hparams["weight_decay"], ) + max_epochs = 200 + if self.trainer and self.trainer.max_epochs: + max_epochs = self.trainer.max_epochs if self.hparams["version"] == 1: warmup_epochs = 10 else: - warmup_epochs = int(self.trainer.max_epochs * 0.05) + warmup_epochs = int(max_epochs * 0.05) lr_scheduler = SequentialLR( optimizer, schedulers=[ LinearLR(optimizer, total_iters=warmup_epochs), - CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs), + CosineAnnealingLR(optimizer, T_max=max_epochs), ], milestones=[warmup_epochs], )