зеркало из https://github.com/microsoft/torchgeo.git
Bump pytorch-lightning from 1.6.5 to 1.7.0 in /requirements (#697)
* Bump pytorch-lightning from 1.6.5 to 1.7.0 in /requirements Bumps [pytorch-lightning](https://github.com/Lightning-AI/lightning) from 1.6.5 to 1.7.0. - [Release notes](https://github.com/Lightning-AI/lightning/releases) - [Commits](https://github.com/Lightning-AI/lightning/compare/1.6.5...pl/1.7.0) --- updated-dependencies: - dependency-name: pytorch-lightning dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Remove protobuf restrictions * LightningModule was moved * Mypy fixes * Ensure same behavior * Fix docs * Silence warnings * Change error message location Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
577142435a
Коммит
044d901dca
|
@ -7,7 +7,5 @@ updates:
|
||||||
# Allow up to 2 open pull requests at a time
|
# Allow up to 2 open pull requests at a time
|
||||||
open-pull-requests-limit: 2
|
open-pull-requests-limit: 2
|
||||||
ignore:
|
ignore:
|
||||||
# torch, tensorboard require protobuf < 4
|
|
||||||
- dependency-name: "protobuf"
|
|
||||||
# segmentation-models-pytorch requires older timm, can't update
|
# segmentation-models-pytorch requires older timm, can't update
|
||||||
- dependency-name: "timm"
|
- dependency-name: "timm"
|
||||||
|
|
|
@ -60,6 +60,7 @@ nitpick_ignore = [
|
||||||
("py:class", ".."),
|
("py:class", ".."),
|
||||||
# TODO: can't figure out why this isn't found
|
# TODO: can't figure out why this isn't found
|
||||||
("py:class", "LightningDataModule"),
|
("py:class", "LightningDataModule"),
|
||||||
|
("py:class", "pytorch_lightning.core.module.LightningModule"),
|
||||||
# Undocumented class
|
# Undocumented class
|
||||||
("py:class", "torchvision.models.resnet.ResNet"),
|
("py:class", "torchvision.models.resnet.ResNet"),
|
||||||
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
|
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
|
||||||
|
|
|
@ -89,8 +89,8 @@ def main(args: argparse.Namespace) -> None:
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
gpus=[args.device] if torch.cuda.is_available() else None,
|
gpus=[args.device] if torch.cuda.is_available() else None,
|
||||||
logger=False,
|
logger=False,
|
||||||
progress_bar_refresh_rate=0,
|
enable_progress_bar=False,
|
||||||
checkpoint_callback=False,
|
enable_checkpointing=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
for experiment_dir in os.listdir(args.input_dir):
|
for experiment_dir in os.listdir(args.input_dir):
|
||||||
|
|
|
@ -71,6 +71,8 @@ filterwarnings = [
|
||||||
# https://github.com/PyTorchLightning/pytorch-lightning/issues/13256
|
# https://github.com/PyTorchLightning/pytorch-lightning/issues/13256
|
||||||
# https://github.com/PyTorchLightning/pytorch-lightning/pull/13261
|
# https://github.com/PyTorchLightning/pytorch-lightning/pull/13261
|
||||||
"ignore:torch.distributed._sharded_tensor will be deprecated:DeprecationWarning:torch.distributed._sharded_tensor",
|
"ignore:torch.distributed._sharded_tensor will be deprecated:DeprecationWarning:torch.distributed._sharded_tensor",
|
||||||
|
# https://github.com/Lightning-AI/lightning/issues/13989
|
||||||
|
"ignore:SelectableGroups dict interface is deprecated. Use select.:DeprecationWarning:pytorch_lightning.trainer.connectors.callback_connector",
|
||||||
# https://github.com/rasterio/rasterio/issues/1742
|
# https://github.com/rasterio/rasterio/issues/1742
|
||||||
# https://github.com/rasterio/rasterio/pull/1753
|
# https://github.com/rasterio/rasterio/pull/1753
|
||||||
"ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated:DeprecationWarning:rasterio.crs",
|
"ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated:DeprecationWarning:rasterio.crs",
|
||||||
|
|
|
@ -11,7 +11,6 @@ numpy==1.21.6;python_version=='3.7'
|
||||||
omegaconf==2.2.2
|
omegaconf==2.2.2
|
||||||
packaging==21.3
|
packaging==21.3
|
||||||
pillow==9.2.0
|
pillow==9.2.0
|
||||||
protobuf==3.20.1
|
|
||||||
pyproj==3.3.1;python_version>='3.8'
|
pyproj==3.3.1;python_version>='3.8'
|
||||||
pyproj==3.2.0;python_version=='3.7'
|
pyproj==3.2.0;python_version=='3.7'
|
||||||
pytorch-lightning==1.6.4
|
pytorch-lightning==1.6.4
|
||||||
|
|
|
@ -10,9 +10,8 @@ numpy==1.23.1;python_version>='3.8'
|
||||||
omegaconf==2.2.2
|
omegaconf==2.2.2
|
||||||
packaging==21.3
|
packaging==21.3
|
||||||
pillow==9.2.0
|
pillow==9.2.0
|
||||||
protobuf==3.20.1
|
|
||||||
pyproj==3.3.1;python_version>='3.8'
|
pyproj==3.3.1;python_version>='3.8'
|
||||||
pytorch-lightning==1.6.5
|
pytorch-lightning==1.7.0
|
||||||
rasterio==1.3.0;python_version>='3.8'
|
rasterio==1.3.0;python_version>='3.8'
|
||||||
rtree==1.0.0
|
rtree==1.0.0
|
||||||
scikit-learn==1.1.1;python_version>='3.8'
|
scikit-learn==1.1.1;python_version>='3.8'
|
||||||
|
|
|
@ -90,7 +90,7 @@ class BigEarthNetDataModule(pl.LightningDataModule):
|
||||||
batch_size: The batch size to use in all created DataLoaders
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.bands = bands
|
self.bands = bands
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
|
|
@ -64,7 +64,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if ``use_prior_labels`` is used with ``class_set==7``
|
ValueError: if ``use_prior_labels`` is used with ``class_set==7``
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
for state in train_splits + val_splits + test_splits:
|
for state in train_splits + val_splits + test_splits:
|
||||||
assert state in ChesapeakeCVPR.splits
|
assert state in ChesapeakeCVPR.splits
|
||||||
assert class_set in [5, 7]
|
assert class_set in [5, 7]
|
||||||
|
|
|
@ -36,7 +36,7 @@ class COWCCountingDataModule(pl.LightningDataModule):
|
||||||
batch_size: The batch size to use in all created DataLoaders
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
|
@ -44,7 +44,7 @@ class CycloneDataModule(pl.LightningDataModule):
|
||||||
api_key: The RadiantEarth MLHub API key to use if the dataset needs to be
|
api_key: The RadiantEarth MLHub API key to use if the dataset needs to be
|
||||||
downloaded
|
downloaded
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
|
@ -36,7 +36,7 @@ class DeepGlobeLandCoverDataModule(pl.LightningDataModule):
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
val_split_pct: What percentage of the dataset to use as a validation set
|
val_split_pct: What percentage of the dataset to use as a validation set
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -48,7 +48,7 @@ class ETCI2021DataModule(pl.LightningDataModule):
|
||||||
batch_size: The batch size to use in all created DataLoaders
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
|
@ -68,7 +68,7 @@ class EuroSATDataModule(pl.LightningDataModule):
|
||||||
batch_size: The batch size to use in all created DataLoaders
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -56,7 +56,7 @@ class FAIR1MDataModule(pl.LightningDataModule):
|
||||||
val_split_pct: What percentage of the dataset to use as a validation set
|
val_split_pct: What percentage of the dataset to use as a validation set
|
||||||
test_split_pct: What percentage of the dataset to use as a test set
|
test_split_pct: What percentage of the dataset to use as a test set
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -65,7 +65,7 @@ class InriaAerialImageLabelingDataModule(pl.LightningDataModule):
|
||||||
num_patches_per_tile: Number of random patches per sample
|
num_patches_per_tile: Number of random patches per sample
|
||||||
predict_on: Directory/Dataset of images to run inference on
|
predict_on: Directory/Dataset of images to run inference on
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -33,7 +33,7 @@ class LandCoverAIDataModule(pl.LightningDataModule):
|
||||||
batch_size: The batch size to use in all created DataLoaders
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -39,7 +39,7 @@ class LoveDADataModule(pl.LightningDataModule):
|
||||||
batch_size: The batch size to use in all created DataLoaders
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.scene = scene
|
self.scene = scene
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
|
@ -46,7 +46,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
patch_size: size of patches to sample
|
patch_size: size of patches to sample
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.naip_root_dir = naip_root_dir
|
self.naip_root_dir = naip_root_dir
|
||||||
self.chesapeake_root_dir = chesapeake_root_dir
|
self.chesapeake_root_dir = chesapeake_root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
|
@ -57,7 +57,7 @@ class NASAMarineDebrisDataModule(pl.LightningDataModule):
|
||||||
val_split_pct: What percentage of the dataset to use as a validation set
|
val_split_pct: What percentage of the dataset to use as a validation set
|
||||||
test_split_pct: What percentage of the dataset to use as a test set
|
test_split_pct: What percentage of the dataset to use as a test set
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -87,7 +87,7 @@ class OSCDDataModule(pl.LightningDataModule):
|
||||||
num_patches_per_tile: number of random patches per sample
|
num_patches_per_tile: number of random patches per sample
|
||||||
pad_size: size to pad images to during val/test steps
|
pad_size: size to pad images to during val/test steps
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.bands = bands
|
self.bands = bands
|
||||||
self.train_batch_size = train_batch_size
|
self.train_batch_size = train_batch_size
|
||||||
|
|
|
@ -37,7 +37,7 @@ class Potsdam2DDataModule(pl.LightningDataModule):
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
val_split_pct: What percentage of the dataset to use as a validation set
|
val_split_pct: What percentage of the dataset to use as a validation set
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -38,7 +38,7 @@ class RESISC45DataModule(pl.LightningDataModule):
|
||||||
batch_size: The batch size to use in all created DataLoaders
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -72,7 +72,7 @@ class SEN12MSDataModule(pl.LightningDataModule):
|
||||||
batch_size: The batch size to use in all created DataLoaders
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
assert band_set in SEN12MS.BAND_SETS.keys()
|
assert band_set in SEN12MS.BAND_SETS.keys()
|
||||||
|
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
|
|
|
@ -75,7 +75,7 @@ class So2SatDataModule(pl.LightningDataModule):
|
||||||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||||
val, and test sets
|
val, and test sets
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -34,7 +34,7 @@ class UCMercedDataModule(pl.LightningDataModule):
|
||||||
batch_size: The batch size to use in all created DataLoaders
|
batch_size: The batch size to use in all created DataLoaders
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -37,7 +37,7 @@ class Vaihingen2DDataModule(pl.LightningDataModule):
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
val_split_pct: What percentage of the dataset to use as a validation set
|
val_split_pct: What percentage of the dataset to use as a validation set
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -37,7 +37,7 @@ class XView2DataModule(pl.LightningDataModule):
|
||||||
num_workers: The number of workers to use in all created DataLoaders
|
num_workers: The number of workers to use in all created DataLoaders
|
||||||
val_split_pct: What percentage of the dataset to use as a validation set
|
val_split_pct: What percentage of the dataset to use as a validation set
|
||||||
"""
|
"""
|
||||||
super().__init__() # type: ignore[no-untyped-call]
|
super().__init__()
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import random
|
import random
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, cast
|
from typing import Any, Callable, Dict, Optional, Tuple, cast
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
|
@ -13,7 +14,6 @@ from kornia import augmentation as K
|
||||||
from kornia import filters
|
from kornia import filters
|
||||||
from kornia.geometry import transform as KorniaTransform
|
from kornia.geometry import transform as KorniaTransform
|
||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
|
||||||
from torch import Tensor, optim
|
from torch import Tensor, optim
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch.nn.modules import BatchNorm1d, Conv2d, Linear, Module, ReLU, Sequential
|
from torch.nn.modules import BatchNorm1d, Conv2d, Linear, Module, ReLU, Sequential
|
||||||
|
@ -304,7 +304,7 @@ class BYOL(Module):
|
||||||
pt.data = self.beta * pt.data + (1 - self.beta) * p.data
|
pt.data = self.beta * pt.data + (1 - self.beta) * p.data
|
||||||
|
|
||||||
|
|
||||||
class BYOLTask(LightningModule):
|
class BYOLTask(pl.LightningModule):
|
||||||
"""Class for pre-training any PyTorch model using BYOL."""
|
"""Class for pre-training any PyTorch model using BYOL."""
|
||||||
|
|
||||||
def config_task(self) -> None:
|
def config_task(self) -> None:
|
||||||
|
|
|
@ -189,7 +189,7 @@ class ClassificationTask(pl.LightningModule):
|
||||||
|
|
||||||
if batch_idx < 10:
|
if batch_idx < 10:
|
||||||
try:
|
try:
|
||||||
datamodule = self.trainer.datamodule # type: ignore[union-attr]
|
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
|
||||||
batch["prediction"] = y_hat_hard
|
batch["prediction"] = y_hat_hard
|
||||||
for key in ["image", "label", "prediction"]:
|
for key in ["image", "label", "prediction"]:
|
||||||
batch[key] = batch[key].cpu()
|
batch[key] = batch[key].cpu()
|
||||||
|
@ -358,7 +358,7 @@ class MultiLabelClassificationTask(ClassificationTask):
|
||||||
|
|
||||||
if batch_idx < 10:
|
if batch_idx < 10:
|
||||||
try:
|
try:
|
||||||
datamodule = self.trainer.datamodule # type: ignore[union-attr]
|
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
|
||||||
batch["prediction"] = y_hat_hard
|
batch["prediction"] = y_hat_hard
|
||||||
for key in ["image", "label", "prediction"]:
|
for key in ["image", "label", "prediction"]:
|
||||||
batch[key] = batch[key].cpu()
|
batch[key] = batch[key].cpu()
|
||||||
|
|
|
@ -127,7 +127,7 @@ class RegressionTask(pl.LightningModule):
|
||||||
|
|
||||||
if batch_idx < 10:
|
if batch_idx < 10:
|
||||||
try:
|
try:
|
||||||
datamodule = self.trainer.datamodule # type: ignore[union-attr]
|
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
|
||||||
batch["prediction"] = y_hat
|
batch["prediction"] = y_hat
|
||||||
for key in ["image", "label", "prediction"]:
|
for key in ["image", "label", "prediction"]:
|
||||||
batch[key] = batch[key].cpu()
|
batch[key] = batch[key].cpu()
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, cast
|
from typing import Any, Dict, cast
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
import segmentation_models_pytorch as smp
|
import segmentation_models_pytorch as smp
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -23,7 +23,7 @@ from ..models import FCN
|
||||||
DataLoader.__module__ = "torch.utils.data"
|
DataLoader.__module__ = "torch.utils.data"
|
||||||
|
|
||||||
|
|
||||||
class SemanticSegmentationTask(LightningModule):
|
class SemanticSegmentationTask(pl.LightningModule):
|
||||||
"""LightningModule for semantic segmentation of images."""
|
"""LightningModule for semantic segmentation of images."""
|
||||||
|
|
||||||
def config_task(self) -> None:
|
def config_task(self) -> None:
|
||||||
|
@ -184,7 +184,7 @@ class SemanticSegmentationTask(LightningModule):
|
||||||
|
|
||||||
if batch_idx < 10:
|
if batch_idx < 10:
|
||||||
try:
|
try:
|
||||||
datamodule = self.trainer.datamodule # type: ignore[union-attr]
|
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
|
||||||
batch["prediction"] = y_hat_hard
|
batch["prediction"] = y_hat_hard
|
||||||
for key in ["image", "mask", "prediction"]:
|
for key in ["image", "mask", "prediction"]:
|
||||||
batch[key] = batch[key].cpu()
|
batch[key] = batch[key].cpu()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче