зеркало из https://github.com/microsoft/torchgeo.git
Bump pytorch-lightning from 1.6.5 to 1.7.0 in /requirements (#697)
* Bump pytorch-lightning from 1.6.5 to 1.7.0 in /requirements Bumps [pytorch-lightning](https://github.com/Lightning-AI/lightning) from 1.6.5 to 1.7.0. - [Release notes](https://github.com/Lightning-AI/lightning/releases) - [Commits](https://github.com/Lightning-AI/lightning/compare/1.6.5...pl/1.7.0) --- updated-dependencies: - dependency-name: pytorch-lightning dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Remove protobuf restrictions * LightningModule was moved * Mypy fixes * Ensure same behavior * Fix docs * Silence warnings * Change error message location Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
6d40b08795
Коммит
bc23e21009
|
@ -7,7 +7,5 @@ updates:
|
|||
# Allow up to 2 open pull requests at a time
|
||||
open-pull-requests-limit: 2
|
||||
ignore:
|
||||
# torch, tensorboard require protobuf < 4
|
||||
- dependency-name: "protobuf"
|
||||
# segmentation-models-pytorch requires older timm, can't update
|
||||
- dependency-name: "timm"
|
||||
|
|
|
@ -60,6 +60,7 @@ nitpick_ignore = [
|
|||
("py:class", ".."),
|
||||
# TODO: can't figure out why this isn't found
|
||||
("py:class", "LightningDataModule"),
|
||||
("py:class", "pytorch_lightning.core.module.LightningModule"),
|
||||
# Undocumented class
|
||||
("py:class", "torchvision.models.resnet.ResNet"),
|
||||
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
|
||||
|
|
|
@ -89,8 +89,8 @@ def main(args: argparse.Namespace) -> None:
|
|||
trainer = pl.Trainer(
|
||||
gpus=[args.device] if torch.cuda.is_available() else None,
|
||||
logger=False,
|
||||
progress_bar_refresh_rate=0,
|
||||
checkpoint_callback=False,
|
||||
enable_progress_bar=False,
|
||||
enable_checkpointing=False,
|
||||
)
|
||||
|
||||
for experiment_dir in os.listdir(args.input_dir):
|
||||
|
|
|
@ -71,6 +71,8 @@ filterwarnings = [
|
|||
# https://github.com/PyTorchLightning/pytorch-lightning/issues/13256
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/pull/13261
|
||||
"ignore:torch.distributed._sharded_tensor will be deprecated:DeprecationWarning:torch.distributed._sharded_tensor",
|
||||
# https://github.com/Lightning-AI/lightning/issues/13989
|
||||
"ignore:SelectableGroups dict interface is deprecated. Use select.:DeprecationWarning:pytorch_lightning.trainer.connectors.callback_connector",
|
||||
# https://github.com/rasterio/rasterio/issues/1742
|
||||
# https://github.com/rasterio/rasterio/pull/1753
|
||||
"ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated:DeprecationWarning:rasterio.crs",
|
||||
|
|
|
@ -11,7 +11,6 @@ numpy==1.21.6;python_version=='3.7'
|
|||
omegaconf==2.2.2
|
||||
packaging==21.3
|
||||
pillow==9.2.0
|
||||
protobuf==3.20.1
|
||||
pyproj==3.3.1;python_version>='3.8'
|
||||
pyproj==3.2.0;python_version=='3.7'
|
||||
pytorch-lightning==1.6.4
|
||||
|
|
|
@ -10,9 +10,8 @@ numpy==1.23.0;python_version>='3.8'
|
|||
omegaconf==2.2.2
|
||||
packaging==21.3
|
||||
pillow==9.2.0
|
||||
protobuf==3.20.1
|
||||
pyproj==3.3.1;python_version>='3.8'
|
||||
pytorch-lightning==1.6.4
|
||||
pytorch-lightning==1.7.0
|
||||
rasterio==1.3.0;python_version>='3.8'
|
||||
rtree==1.0.0
|
||||
scikit-learn==1.1.1;python_version>='3.8'
|
||||
|
|
|
@ -90,7 +90,7 @@ class BigEarthNetDataModule(pl.LightningDataModule):
|
|||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.bands = bands
|
||||
self.num_classes = num_classes
|
||||
|
|
|
@ -64,7 +64,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
|
|||
Raises:
|
||||
ValueError: if ``use_prior_labels`` is used with ``class_set==7``
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
for state in train_splits + val_splits + test_splits:
|
||||
assert state in ChesapeakeCVPR.splits
|
||||
assert class_set in [5, 7]
|
||||
|
|
|
@ -36,7 +36,7 @@ class COWCCountingDataModule(pl.LightningDataModule):
|
|||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
|
|
|
@ -44,7 +44,7 @@ class CycloneDataModule(pl.LightningDataModule):
|
|||
api_key: The RadiantEarth MLHub API key to use if the dataset needs to be
|
||||
downloaded
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
|
|
|
@ -36,7 +36,7 @@ class DeepGlobeLandCoverDataModule(pl.LightningDataModule):
|
|||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -48,7 +48,7 @@ class ETCI2021DataModule(pl.LightningDataModule):
|
|||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
|
|
|
@ -68,7 +68,7 @@ class EuroSATDataModule(pl.LightningDataModule):
|
|||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -56,7 +56,7 @@ class FAIR1MDataModule(pl.LightningDataModule):
|
|||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
test_split_pct: What percentage of the dataset to use as a test set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -65,7 +65,7 @@ class InriaAerialImageLabelingDataModule(pl.LightningDataModule):
|
|||
num_patches_per_tile: Number of random patches per sample
|
||||
predict_on: Directory/Dataset of images to run inference on
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -33,7 +33,7 @@ class LandCoverAIDataModule(pl.LightningDataModule):
|
|||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -39,7 +39,7 @@ class LoveDADataModule(pl.LightningDataModule):
|
|||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.scene = scene
|
||||
self.batch_size = batch_size
|
||||
|
|
|
@ -46,7 +46,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
num_workers: The number of workers to use in all created DataLoaders
|
||||
patch_size: size of patches to sample
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.naip_root_dir = naip_root_dir
|
||||
self.chesapeake_root_dir = chesapeake_root_dir
|
||||
self.batch_size = batch_size
|
||||
|
|
|
@ -57,7 +57,7 @@ class NASAMarineDebrisDataModule(pl.LightningDataModule):
|
|||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
test_split_pct: What percentage of the dataset to use as a test set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -87,7 +87,7 @@ class OSCDDataModule(pl.LightningDataModule):
|
|||
num_patches_per_tile: number of random patches per sample
|
||||
pad_size: size to pad images to during val/test steps
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.bands = bands
|
||||
self.train_batch_size = train_batch_size
|
||||
|
|
|
@ -37,7 +37,7 @@ class Potsdam2DDataModule(pl.LightningDataModule):
|
|||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -38,7 +38,7 @@ class RESISC45DataModule(pl.LightningDataModule):
|
|||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -72,7 +72,7 @@ class SEN12MSDataModule(pl.LightningDataModule):
|
|||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
assert band_set in SEN12MS.BAND_SETS.keys()
|
||||
|
||||
self.root_dir = root_dir
|
||||
|
|
|
@ -75,7 +75,7 @@ class So2SatDataModule(pl.LightningDataModule):
|
|||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||
val, and test sets
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -34,7 +34,7 @@ class UCMercedDataModule(pl.LightningDataModule):
|
|||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -37,7 +37,7 @@ class Vaihingen2DDataModule(pl.LightningDataModule):
|
|||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -37,7 +37,7 @@ class XView2DataModule(pl.LightningDataModule):
|
|||
num_workers: The number of workers to use in all created DataLoaders
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
super().__init__()
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
import random
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, cast
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
@ -13,7 +14,6 @@ from kornia import augmentation as K
|
|||
from kornia import filters
|
||||
from kornia.geometry import transform as KorniaTransform
|
||||
from packaging.version import parse
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from torch import Tensor, optim
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.modules import BatchNorm1d, Conv2d, Linear, Module, ReLU, Sequential
|
||||
|
@ -304,7 +304,7 @@ class BYOL(Module):
|
|||
pt.data = self.beta * pt.data + (1 - self.beta) * p.data
|
||||
|
||||
|
||||
class BYOLTask(LightningModule):
|
||||
class BYOLTask(pl.LightningModule):
|
||||
"""Class for pre-training any PyTorch model using BYOL."""
|
||||
|
||||
def config_task(self) -> None:
|
||||
|
|
|
@ -189,7 +189,7 @@ class ClassificationTask(pl.LightningModule):
|
|||
|
||||
if batch_idx < 10:
|
||||
try:
|
||||
datamodule = self.trainer.datamodule # type: ignore[union-attr]
|
||||
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
|
||||
batch["prediction"] = y_hat_hard
|
||||
for key in ["image", "label", "prediction"]:
|
||||
batch[key] = batch[key].cpu()
|
||||
|
@ -358,7 +358,7 @@ class MultiLabelClassificationTask(ClassificationTask):
|
|||
|
||||
if batch_idx < 10:
|
||||
try:
|
||||
datamodule = self.trainer.datamodule # type: ignore[union-attr]
|
||||
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
|
||||
batch["prediction"] = y_hat_hard
|
||||
for key in ["image", "label", "prediction"]:
|
||||
batch[key] = batch[key].cpu()
|
||||
|
|
|
@ -127,7 +127,7 @@ class RegressionTask(pl.LightningModule):
|
|||
|
||||
if batch_idx < 10:
|
||||
try:
|
||||
datamodule = self.trainer.datamodule # type: ignore[union-attr]
|
||||
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
|
||||
batch["prediction"] = y_hat
|
||||
for key in ["image", "label", "prediction"]:
|
||||
batch[key] = batch[key].cpu()
|
||||
|
|
|
@ -6,10 +6,10 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from torch import Tensor
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -23,7 +23,7 @@ from ..models import FCN
|
|||
DataLoader.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class SemanticSegmentationTask(LightningModule):
|
||||
class SemanticSegmentationTask(pl.LightningModule):
|
||||
"""LightningModule for semantic segmentation of images."""
|
||||
|
||||
def config_task(self) -> None:
|
||||
|
@ -184,7 +184,7 @@ class SemanticSegmentationTask(LightningModule):
|
|||
|
||||
if batch_idx < 10:
|
||||
try:
|
||||
datamodule = self.trainer.datamodule # type: ignore[union-attr]
|
||||
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
|
||||
batch["prediction"] = y_hat_hard
|
||||
for key in ["image", "mask", "prediction"]:
|
||||
batch[key] = batch[key].cpu()
|
||||
|
|
Загрузка…
Ссылка в новой задаче