зеркало из https://github.com/microsoft/torchgeo.git
Родитель
8efda7845c
Коммит
3b4436f7e7
|
@ -7,7 +7,7 @@ import csv
|
|||
import os
|
||||
import time
|
||||
|
||||
import lightning as pl
|
||||
import lightning.pytorch as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
|
|
@ -65,9 +65,9 @@
|
|||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import timm\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
|
||||
"from lightning.loggers import CSVLogger\n",
|
||||
"from lightning.pytorch import Trainer\n",
|
||||
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"\n",
|
||||
"from torchgeo.datamodules import EuroSATDataModule\n",
|
||||
"from torchgeo.trainers import ClassificationTask\n",
|
||||
|
|
|
@ -74,9 +74,9 @@
|
|||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
|
||||
"from lightning.loggers import CSVLogger\n",
|
||||
"from lightning.pytorch import Trainer\n",
|
||||
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"\n",
|
||||
"from torchgeo.datamodules import TropicalCycloneDataModule\n",
|
||||
"from torchgeo.trainers import RegressionTask"
|
||||
|
|
|
@ -10,7 +10,7 @@ import csv
|
|||
import os
|
||||
from typing import Any, Dict, Union, cast
|
||||
|
||||
import lightning as pl
|
||||
import lightning.pytorch as pl
|
||||
import torch
|
||||
from torchmetrics import MetricCollection
|
||||
from torchmetrics.classification import BinaryAccuracy, BinaryJaccardIndex
|
||||
|
|
|
@ -8,7 +8,7 @@ import argparse
|
|||
import csv
|
||||
import os
|
||||
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch import Trainer
|
||||
|
||||
from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
||||
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch import Trainer
|
||||
|
||||
from torchgeo.datamodules import OSCDDataModule
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from lightning import LightningDataModule, Trainer
|
||||
from lightning.pytorch import LightningDataModule, Trainer
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch
|
|||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from lightning import LightningDataModule, Trainer
|
||||
from lightning.pytorch import LightningDataModule, Trainer
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn.modules import Module
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torchvision.models.detection
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from lightning import LightningDataModule, Trainer
|
||||
from lightning.pytorch import LightningDataModule, Trainer
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn.modules import Module
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch
|
|||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from lightning import LightningDataModule, Trainer
|
||||
from lightning.pytorch import LightningDataModule, Trainer
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Type, cast
|
|||
import pytest
|
||||
import segmentation_models_pytorch as smp
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from lightning import LightningDataModule, Trainer
|
||||
from lightning.pytorch import LightningDataModule, Trainer
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn.modules import Module
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
|||
import kornia.augmentation as K
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from lightning import LightningDataModule
|
||||
from lightning.pytorch import LightningDataModule
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset, default_collate
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from kornia import augmentation as K
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch import LightningModule
|
||||
from torch import Tensor, optim
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
|
|
@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
|
|||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch import LightningModule
|
||||
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
|
||||
from torch import Tensor
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, cast
|
|||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torchvision.models.detection
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch import LightningModule
|
||||
from torch import Tensor
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||
|
|
|
@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
|
|||
import timm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch import LightningModule
|
||||
from torch import Tensor
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
|
||||
|
|
|
@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
|
|||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch import LightningModule
|
||||
from torch import Tensor
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchmetrics import MetricCollection
|
||||
|
|
21
train.py
21
train.py
|
@ -8,9 +8,10 @@
|
|||
import os
|
||||
from typing import Any, Dict, Tuple, Type, cast
|
||||
|
||||
import lightning as L
|
||||
from lightning import loggers
|
||||
from lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch import LightningDataModule, LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from torchgeo.datamodules import (
|
||||
|
@ -45,7 +46,7 @@ from torchgeo.trainers import (
|
|||
)
|
||||
|
||||
TASK_TO_MODULES_MAPPING: Dict[
|
||||
str, Tuple[Type[L.LightningModule], Type[L.LightningDataModule]]
|
||||
str, Tuple[Type[LightningModule], Type[LightningDataModule]]
|
||||
] = {
|
||||
"bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule),
|
||||
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
|
||||
|
@ -165,8 +166,8 @@ def main(conf: DictConfig) -> None:
|
|||
Dict[str, Any], OmegaConf.to_object(conf.experiment.datamodule)
|
||||
)
|
||||
|
||||
datamodule: L.LightningDataModule
|
||||
task: L.LightningModule
|
||||
datamodule: LightningDataModule
|
||||
task: LightningModule
|
||||
if task_name in TASK_TO_MODULES_MAPPING:
|
||||
task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name]
|
||||
task = task_class(**task_args)
|
||||
|
@ -179,8 +180,8 @@ def main(conf: DictConfig) -> None:
|
|||
######################################
|
||||
# Setup trainer
|
||||
######################################
|
||||
tb_logger = loggers.TensorBoardLogger(conf.program.log_dir, name=experiment_name)
|
||||
csv_logger = loggers.CSVLogger(conf.program.log_dir, name=experiment_name)
|
||||
tb_logger = TensorBoardLogger(conf.program.log_dir, name=experiment_name)
|
||||
csv_logger = CSVLogger(conf.program.log_dir, name=experiment_name)
|
||||
|
||||
if isinstance(task, ObjectDetectionTask):
|
||||
monitor_metric = "val_map"
|
||||
|
@ -205,7 +206,7 @@ def main(conf: DictConfig) -> None:
|
|||
trainer_args["callbacks"] = [checkpoint_callback, early_stopping_callback]
|
||||
trainer_args["logger"] = [tb_logger, csv_logger]
|
||||
trainer_args["default_root_dir"] = experiment_dir
|
||||
trainer = L.Trainer(**trainer_args)
|
||||
trainer = Trainer(**trainer_args)
|
||||
|
||||
######################################
|
||||
# Run experiment
|
||||
|
@ -229,7 +230,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Set random seed for reproducibility
|
||||
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.utilities.seed.html#pytorch_lightning.utilities.seed.seed_everything
|
||||
L.seed_everything(conf.program.seed)
|
||||
pl.seed_everything(conf.program.seed)
|
||||
|
||||
# Main training procedure
|
||||
main(conf)
|
||||
|
|
Загрузка…
Ссылка в новой задаче