зеркало из https://github.com/microsoft/torchgeo.git
Bump lightning from 2.0.2 to 2.0.3 in /requirements (#1406)
* Bump lightning from 2.0.2 to 2.0.3 in /requirements Bumps [lightning](https://github.com/Lightning-AI/lightning) from 2.0.2 to 2.0.3. - [Release notes](https://github.com/Lightning-AI/lightning/releases) - [Commits](https://github.com/Lightning-AI/lightning/compare/2.0.2...2.0.3) --- updated-dependencies: - dependency-name: lightning dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> * Update datamodule base class * Check attributes first * Fix remaining type hints * More fixes * Try again * Try casting --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
072839410b
Коммит
2f0458db44
|
@ -34,5 +34,5 @@ repos:
|
|||
hooks:
|
||||
- id: mypy
|
||||
args: [--strict, --ignore-missing-imports, --show-error-codes]
|
||||
additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=1.8, pytest>=6.1.2, pyvista>=0.29, omegaconf>=2.0.1, hydra-core>=1, kornia>=0.6.5, numpy>=1.22]
|
||||
additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=2.0.3, pytest>=6.1.2, pyvista>=0.29, omegaconf>=2.0.1, hydra-core>=1, kornia>=0.6.5, numpy>=1.22]
|
||||
exclude: (build|data|dist|logo|logs|output)/
|
||||
|
|
|
@ -6,7 +6,7 @@ einops==0.6.1
|
|||
fiona==1.9.4.post1
|
||||
kornia==0.6.12
|
||||
lightly==1.4.7
|
||||
lightning==2.0.2
|
||||
lightning==2.0.3
|
||||
matplotlib==3.7.1
|
||||
numpy==1.24.3
|
||||
pillow==9.5.0
|
||||
|
|
|
@ -97,28 +97,32 @@ class TestGeoDataModule:
|
|||
|
||||
def test_train(self, datamodule: CustomGeoDataModule) -> None:
|
||||
datamodule.setup("fit")
|
||||
datamodule.trainer.training = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.training = True
|
||||
batch = next(iter(datamodule.train_dataloader()))
|
||||
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
|
||||
def test_val(self, datamodule: CustomGeoDataModule) -> None:
|
||||
datamodule.setup("validate")
|
||||
datamodule.trainer.validating = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.validating = True
|
||||
batch = next(iter(datamodule.val_dataloader()))
|
||||
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
|
||||
def test_test(self, datamodule: CustomGeoDataModule) -> None:
|
||||
datamodule.setup("test")
|
||||
datamodule.trainer.testing = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.testing = True
|
||||
batch = next(iter(datamodule.test_dataloader()))
|
||||
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
|
||||
def test_predict(self, datamodule: CustomGeoDataModule) -> None:
|
||||
datamodule.setup("predict")
|
||||
datamodule.trainer.predicting = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.predicting = True
|
||||
batch = next(iter(datamodule.predict_dataloader()))
|
||||
batch = datamodule.transfer_batch_to_device(batch, torch.device("cpu"), 1)
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
|
@ -156,25 +160,29 @@ class TestNonGeoDataModule:
|
|||
|
||||
def test_train(self, datamodule: CustomNonGeoDataModule) -> None:
|
||||
datamodule.setup("fit")
|
||||
datamodule.trainer.training = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.training = True
|
||||
batch = next(iter(datamodule.train_dataloader()))
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
|
||||
def test_val(self, datamodule: CustomNonGeoDataModule) -> None:
|
||||
datamodule.setup("validate")
|
||||
datamodule.trainer.validating = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.validating = True
|
||||
batch = next(iter(datamodule.val_dataloader()))
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
|
||||
def test_test(self, datamodule: CustomNonGeoDataModule) -> None:
|
||||
datamodule.setup("test")
|
||||
datamodule.trainer.testing = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.testing = True
|
||||
batch = next(iter(datamodule.test_dataloader()))
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
|
||||
def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
|
||||
datamodule.setup("predict")
|
||||
datamodule.trainer.predicting = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.predicting = True
|
||||
batch = next(iter(datamodule.predict_dataloader()))
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TestOSCDDataModule:
|
|||
|
||||
def test_train_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
datamodule.setup("fit")
|
||||
datamodule.trainer.training = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.training = True
|
||||
batch = next(iter(datamodule.train_dataloader()))
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
|
||||
|
@ -42,7 +43,8 @@ class TestOSCDDataModule:
|
|||
|
||||
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
datamodule.setup("validate")
|
||||
datamodule.trainer.validating = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.validating = True
|
||||
batch = next(iter(datamodule.val_dataloader()))
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
if datamodule.val_split_pct > 0.0:
|
||||
|
@ -55,7 +57,8 @@ class TestOSCDDataModule:
|
|||
|
||||
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
|
||||
datamodule.setup("test")
|
||||
datamodule.trainer.testing = True
|
||||
if datamodule.trainer:
|
||||
datamodule.trainer.testing = True
|
||||
batch = next(iter(datamodule.test_dataloader()))
|
||||
batch = datamodule.on_after_batch_transfer(batch, 0)
|
||||
assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
"""Base classes for all :mod:`torchgeo` data modules."""
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -23,7 +23,7 @@ from ..transforms import AugmentationSequential
|
|||
from .utils import MisconfigurationException
|
||||
|
||||
|
||||
class BaseDataModule(LightningDataModule): # type: ignore[misc]
|
||||
class BaseDataModule(LightningDataModule):
|
||||
"""Base class for all TorchGeo data modules.
|
||||
|
||||
.. versionadded:: 0.5
|
||||
|
@ -32,6 +32,51 @@ class BaseDataModule(LightningDataModule): # type: ignore[misc]
|
|||
mean = torch.tensor(0)
|
||||
std = torch.tensor(255)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_class: type[Dataset[dict[str, Tensor]]],
|
||||
batch_size: int = 1,
|
||||
num_workers: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new BaseDataModule instance.
|
||||
|
||||
Args:
|
||||
dataset_class: Class used to instantiate a new dataset.
|
||||
batch_size: Size of each mini-batch.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
**kwargs: Additional keyword arguments passed to ``dataset_class``
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.dataset_class = dataset_class
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Datasets
|
||||
self.dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
|
||||
# Data loaders
|
||||
self.train_batch_size: Optional[int] = None
|
||||
self.val_batch_size: Optional[int] = None
|
||||
self.test_batch_size: Optional[int] = None
|
||||
self.predict_batch_size: Optional[int] = None
|
||||
|
||||
# Data augmentation
|
||||
Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]]
|
||||
self.aug: Transform = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std), data_keys=["image"]
|
||||
)
|
||||
self.train_aug: Optional[Transform] = None
|
||||
self.val_aug: Optional[Transform] = None
|
||||
self.test_aug: Optional[Transform] = None
|
||||
self.predict_aug: Optional[Transform] = None
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""Download and prepare data.
|
||||
|
||||
|
@ -112,21 +157,13 @@ class GeoDataModule(BaseDataModule):
|
|||
num_workers: Number of workers for parallel data loading.
|
||||
**kwargs: Additional keyword arguments passed to ``dataset_class``
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(dataset_class, batch_size, num_workers, **kwargs)
|
||||
|
||||
self.dataset_class = dataset_class
|
||||
self.batch_size = batch_size
|
||||
self.patch_size = patch_size
|
||||
self.length = length
|
||||
self.num_workers = num_workers
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Datasets
|
||||
self.dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
# Collation
|
||||
self.collate_fn = stack_samples
|
||||
|
||||
# Samplers
|
||||
self.sampler: Optional[GeoSampler] = None
|
||||
|
@ -142,25 +179,6 @@ class GeoDataModule(BaseDataModule):
|
|||
self.test_batch_sampler: Optional[BatchGeoSampler] = None
|
||||
self.predict_batch_sampler: Optional[BatchGeoSampler] = None
|
||||
|
||||
# Data loaders
|
||||
self.train_batch_size: Optional[int] = None
|
||||
self.val_batch_size: Optional[int] = None
|
||||
self.test_batch_size: Optional[int] = None
|
||||
self.predict_batch_size: Optional[int] = None
|
||||
|
||||
# Collation
|
||||
self.collate_fn = stack_samples
|
||||
|
||||
# Data augmentation
|
||||
Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]]
|
||||
self.aug: Transform = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std), data_keys=["image"]
|
||||
)
|
||||
self.train_aug: Optional[Transform] = None
|
||||
self.val_aug: Optional[Transform] = None
|
||||
self.test_aug: Optional[Transform] = None
|
||||
self.predict_aug: Optional[Transform] = None
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Set up datasets and samplers.
|
||||
|
||||
|
@ -172,22 +190,31 @@ class GeoDataModule(BaseDataModule):
|
|||
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
||||
"""
|
||||
if stage in ["fit"]:
|
||||
self.train_dataset = self.dataset_class( # type: ignore[call-arg]
|
||||
split="train", **self.kwargs
|
||||
self.train_dataset = cast(
|
||||
GeoDataset,
|
||||
self.dataset_class( # type: ignore[call-arg]
|
||||
split="train", **self.kwargs
|
||||
),
|
||||
)
|
||||
self.train_batch_sampler = RandomBatchGeoSampler(
|
||||
self.train_dataset, self.patch_size, self.batch_size, self.length
|
||||
)
|
||||
if stage in ["fit", "validate"]:
|
||||
self.val_dataset = self.dataset_class( # type: ignore[call-arg]
|
||||
split="val", **self.kwargs
|
||||
self.val_dataset = cast(
|
||||
GeoDataset,
|
||||
self.dataset_class( # type: ignore[call-arg]
|
||||
split="val", **self.kwargs
|
||||
),
|
||||
)
|
||||
self.val_sampler = GridGeoSampler(
|
||||
self.val_dataset, self.patch_size, self.patch_size
|
||||
)
|
||||
if stage in ["test"]:
|
||||
self.test_dataset = self.dataset_class( # type: ignore[call-arg]
|
||||
split="test", **self.kwargs
|
||||
self.test_dataset = cast(
|
||||
GeoDataset,
|
||||
self.dataset_class( # type: ignore[call-arg]
|
||||
split="test", **self.kwargs
|
||||
),
|
||||
)
|
||||
self.test_sampler = GridGeoSampler(
|
||||
self.test_dataset, self.patch_size, self.patch_size
|
||||
|
@ -357,39 +384,11 @@ class NonGeoDataModule(BaseDataModule):
|
|||
num_workers: Number of workers for parallel data loading.
|
||||
**kwargs: Additional keyword arguments passed to ``dataset_class``
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.dataset_class = dataset_class
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Datasets
|
||||
self.dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.train_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.val_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.test_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
self.predict_dataset: Optional[Dataset[dict[str, Tensor]]] = None
|
||||
|
||||
# Data loaders
|
||||
self.train_batch_size: Optional[int] = None
|
||||
self.val_batch_size: Optional[int] = None
|
||||
self.test_batch_size: Optional[int] = None
|
||||
self.predict_batch_size: Optional[int] = None
|
||||
super().__init__(dataset_class, batch_size, num_workers, **kwargs)
|
||||
|
||||
# Collation
|
||||
self.collate_fn = default_collate
|
||||
|
||||
# Data augmentation
|
||||
Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]]
|
||||
self.aug: Transform = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std), data_keys=["image"]
|
||||
)
|
||||
self.train_aug: Optional[Transform] = None
|
||||
self.val_aug: Optional[Transform] = None
|
||||
self.test_aug: Optional[Transform] = None
|
||||
self.predict_aug: Optional[Transform] = None
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Set up datasets.
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ class BYOL(nn.Module):
|
|||
pt.data = self.beta * pt.data + (1 - self.beta) * p.data
|
||||
|
||||
|
||||
class BYOLTask(LightningModule): # type: ignore[misc]
|
||||
class BYOLTask(LightningModule):
|
||||
"""Class for pre-training any PyTorch model using BYOL.
|
||||
|
||||
Supports any available `Timm model
|
||||
|
|
|
@ -29,7 +29,7 @@ from ..models import get_weight
|
|||
from . import utils
|
||||
|
||||
|
||||
class ClassificationTask(LightningModule): # type: ignore[misc]
|
||||
class ClassificationTask(LightningModule):
|
||||
"""LightningModule for image classification.
|
||||
|
||||
Supports any available `Timm model
|
||||
|
|
|
@ -46,7 +46,7 @@ BACKBONE_WEIGHT_MAP = {
|
|||
}
|
||||
|
||||
|
||||
class ObjectDetectionTask(LightningModule): # type: ignore[misc]
|
||||
class ObjectDetectionTask(LightningModule):
|
||||
"""LightningModule for object detection of images.
|
||||
|
||||
Currently, supports Faster R-CNN, FCOS, and RetinaNet models from
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import os
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
import timm
|
||||
|
@ -118,7 +118,7 @@ def moco_augmentations(
|
|||
return aug1, aug2
|
||||
|
||||
|
||||
class MoCoTask(LightningModule): # type: ignore[misc]
|
||||
class MoCoTask(LightningModule):
|
||||
"""MoCo: Momentum Contrast.
|
||||
|
||||
Reference implementations:
|
||||
|
@ -295,12 +295,15 @@ class MoCoTask(LightningModule): # type: ignore[misc]
|
|||
k = self.projection_head_momentum(k)
|
||||
return cast(Tensor, k)
|
||||
|
||||
def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
|
||||
def training_step(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
||||
) -> Tensor:
|
||||
"""Compute the training loss and additional metrics.
|
||||
|
||||
Args:
|
||||
batch: The output of your DataLoader.
|
||||
batch_idx: Integer displaying index of this batch.
|
||||
dataloader_idx: Index of the current dataloader.
|
||||
|
||||
Returns:
|
||||
The loss tensor.
|
||||
|
@ -359,13 +362,15 @@ class MoCoTask(LightningModule): # type: ignore[misc]
|
|||
|
||||
return cast(Tensor, loss)
|
||||
|
||||
def validation_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
|
||||
def validation_step(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
||||
) -> None:
|
||||
"""No-op, does nothing."""
|
||||
|
||||
def test_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
|
||||
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
||||
"""No-op, does nothing."""
|
||||
|
||||
def predict_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
|
||||
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
||||
"""No-op, does nothing."""
|
||||
|
||||
def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]:
|
||||
|
@ -381,6 +386,9 @@ class MoCoTask(LightningModule): # type: ignore[misc]
|
|||
weight_decay=self.hparams["weight_decay"],
|
||||
)
|
||||
warmup_epochs = 40
|
||||
max_epochs = 200
|
||||
if self.trainer and self.trainer.max_epochs:
|
||||
max_epochs = self.trainer.max_epochs
|
||||
lr_scheduler: LRScheduler = SequentialLR(
|
||||
optimizer,
|
||||
schedulers=[
|
||||
|
@ -389,7 +397,7 @@ class MoCoTask(LightningModule): # type: ignore[misc]
|
|||
start_factor=1 / warmup_epochs,
|
||||
total_iters=warmup_epochs,
|
||||
),
|
||||
CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs),
|
||||
CosineAnnealingLR(optimizer, T_max=max_epochs),
|
||||
],
|
||||
milestones=[warmup_epochs],
|
||||
)
|
||||
|
|
|
@ -22,7 +22,7 @@ from ..models import FCN, get_weight
|
|||
from . import utils
|
||||
|
||||
|
||||
class RegressionTask(LightningModule): # type: ignore[misc]
|
||||
class RegressionTask(LightningModule):
|
||||
"""LightningModule for training models on regression datasets.
|
||||
|
||||
Supports any available `Timm model
|
||||
|
|
|
@ -23,7 +23,7 @@ from ..models import FCN, get_weight
|
|||
from . import utils
|
||||
|
||||
|
||||
class SemanticSegmentationTask(LightningModule): # type: ignore[misc]
|
||||
class SemanticSegmentationTask(LightningModule):
|
||||
"""LightningModule for semantic segmentation of images.
|
||||
|
||||
Supports `Segmentation Models Pytorch
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
import timm
|
||||
|
@ -57,7 +57,7 @@ def simclr_augmentations(size: int, weights: Tensor) -> nn.Module:
|
|||
)
|
||||
|
||||
|
||||
class SimCLRTask(LightningModule): # type: ignore[misc]
|
||||
class SimCLRTask(LightningModule):
|
||||
"""SimCLR: a simple framework for contrastive learning of visual representations.
|
||||
|
||||
Reference implementation:
|
||||
|
@ -193,12 +193,15 @@ class SimCLRTask(LightningModule): # type: ignore[misc]
|
|||
z = self.projection_head(h)
|
||||
return cast(Tensor, z), cast(Tensor, h)
|
||||
|
||||
def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
|
||||
def training_step(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
||||
) -> Tensor:
|
||||
"""Compute the training loss and additional metrics.
|
||||
|
||||
Args:
|
||||
batch: The output of your DataLoader.
|
||||
batch_idx: Integer displaying index of this batch.
|
||||
dataloader_idx: Index of the current dataloader.
|
||||
|
||||
Returns:
|
||||
The loss tensor.
|
||||
|
@ -237,15 +240,17 @@ class SimCLRTask(LightningModule): # type: ignore[misc]
|
|||
|
||||
return cast(Tensor, loss)
|
||||
|
||||
def validation_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
|
||||
def validation_step(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
||||
) -> None:
|
||||
"""No-op, does nothing."""
|
||||
|
||||
def test_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
|
||||
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
||||
"""No-op, does nothing."""
|
||||
# TODO
|
||||
# v2: add distillation step
|
||||
|
||||
def predict_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
|
||||
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
||||
"""No-op, does nothing."""
|
||||
|
||||
def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]:
|
||||
|
@ -260,15 +265,18 @@ class SimCLRTask(LightningModule): # type: ignore[misc]
|
|||
lr=self.hparams["lr"],
|
||||
weight_decay=self.hparams["weight_decay"],
|
||||
)
|
||||
max_epochs = 200
|
||||
if self.trainer and self.trainer.max_epochs:
|
||||
max_epochs = self.trainer.max_epochs
|
||||
if self.hparams["version"] == 1:
|
||||
warmup_epochs = 10
|
||||
else:
|
||||
warmup_epochs = int(self.trainer.max_epochs * 0.05)
|
||||
warmup_epochs = int(max_epochs * 0.05)
|
||||
lr_scheduler = SequentialLR(
|
||||
optimizer,
|
||||
schedulers=[
|
||||
LinearLR(optimizer, total_iters=warmup_epochs),
|
||||
CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs),
|
||||
CosineAnnealingLR(optimizer, T_max=max_epochs),
|
||||
],
|
||||
milestones=[warmup_epochs],
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче