зеркало из https://github.com/microsoft/torchgeo.git
torchmetrics: IoU -> JaccardIndex (#361)
This commit is contained in:
Родитель
c9520aa3f1
Коммит
37d0e8f028
|
@ -45,5 +45,5 @@ dependencies:
|
||||||
- setuptools>=42
|
- setuptools>=42
|
||||||
- sphinx>=4
|
- sphinx>=4
|
||||||
- timm>=0.2.1
|
- timm>=0.2.1
|
||||||
- torchmetrics
|
- torchmetrics>=0.7
|
||||||
- zipfile-deflate64>=0.2
|
- zipfile-deflate64>=0.2
|
||||||
|
|
12
evaluate.py
12
evaluate.py
|
@ -12,7 +12,7 @@ from typing import Any, Dict, Union
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torchmetrics import Accuracy, IoU, Metric, MetricCollection
|
from torchmetrics import Accuracy, JaccardIndex, Metric, MetricCollection
|
||||||
|
|
||||||
from torchgeo.trainers import ClassificationTask, SemanticSegmentationTask
|
from torchgeo.trainers import ClassificationTask, SemanticSegmentationTask
|
||||||
from train import TASK_TO_MODULES_MAPPING
|
from train import TASK_TO_MODULES_MAPPING
|
||||||
|
@ -185,7 +185,7 @@ def main(args: argparse.Namespace) -> None:
|
||||||
if args.task == "etci2021": # Custom metric setup for testing ETCI2021
|
if args.task == "etci2021": # Custom metric setup for testing ETCI2021
|
||||||
|
|
||||||
metrics = MetricCollection(
|
metrics = MetricCollection(
|
||||||
[Accuracy(num_classes=2), IoU(num_classes=2, reduction="none")]
|
[Accuracy(num_classes=2), JaccardIndex(num_classes=2, reduction="none")]
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
val_results = run_eval_loop(model, dm.val_dataloader(), device, metrics)
|
val_results = run_eval_loop(model, dm.val_dataloader(), device, metrics)
|
||||||
|
@ -194,13 +194,13 @@ def main(args: argparse.Namespace) -> None:
|
||||||
val_row.update(
|
val_row.update(
|
||||||
{
|
{
|
||||||
"overall_accuracy": val_results["Accuracy"].item(),
|
"overall_accuracy": val_results["Accuracy"].item(),
|
||||||
"iou": val_results["IoU"][1].item(),
|
"jaccard_index": val_results["JaccardIndex"][1].item(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
test_row.update(
|
test_row.update(
|
||||||
{
|
{
|
||||||
"overall_accuracy": test_results["Accuracy"].item(),
|
"overall_accuracy": test_results["Accuracy"].item(),
|
||||||
"iou": test_results["IoU"][1].item(),
|
"jaccard_index": test_results["JaccardIndex"][1].item(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else: # Test with PyTorch Lightning as usual
|
else: # Test with PyTorch Lightning as usual
|
||||||
|
@ -230,13 +230,13 @@ def main(args: argparse.Namespace) -> None:
|
||||||
val_row.update(
|
val_row.update(
|
||||||
{
|
{
|
||||||
"overall_accuracy": val_results["val_Accuracy"].item(),
|
"overall_accuracy": val_results["val_Accuracy"].item(),
|
||||||
"iou": val_results["val_IoU"].item(),
|
"jaccard_index": val_results["val_JaccardIndex"].item(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
test_row.update(
|
test_row.update(
|
||||||
{
|
{
|
||||||
"overall_accuracy": test_results["test_Accuracy"].item(),
|
"overall_accuracy": test_results["test_Accuracy"].item(),
|
||||||
"iou": test_results["test_IoU"].item(),
|
"jaccard_index": test_results["test_JaccardIndex"].item(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,8 @@ install_requires =
|
||||||
timm>=0.2.1
|
timm>=0.2.1
|
||||||
# torch 1.7+ required for typing
|
# torch 1.7+ required for typing
|
||||||
torch>=1.7
|
torch>=1.7
|
||||||
torchmetrics
|
# torchmetrics 0.7+ required for JaccardIndex
|
||||||
|
torchmetrics>=0.7
|
||||||
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks
|
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks
|
||||||
torchvision>=0.10
|
torchvision>=0.10
|
||||||
python_requires = >= 3.6
|
python_requires = >= 3.6
|
||||||
|
|
|
@ -14,7 +14,7 @@ from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.modules import Conv2d, Linear
|
from torch.nn.modules import Conv2d, Linear
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from torchmetrics import Accuracy, FBeta, IoU, MetricCollection
|
from torchmetrics import Accuracy, FBeta, JaccardIndex, MetricCollection
|
||||||
|
|
||||||
from ..datasets.utils import unbind_samples
|
from ..datasets.utils import unbind_samples
|
||||||
from . import utils
|
from . import utils
|
||||||
|
@ -105,7 +105,7 @@ class ClassificationTask(pl.LightningModule):
|
||||||
"AverageAccuracy": Accuracy(
|
"AverageAccuracy": Accuracy(
|
||||||
num_classes=self.hparams["num_classes"], average="macro"
|
num_classes=self.hparams["num_classes"], average="macro"
|
||||||
),
|
),
|
||||||
"IoU": IoU(num_classes=self.hparams["num_classes"]),
|
"JaccardIndex": JaccardIndex(num_classes=self.hparams["num_classes"]),
|
||||||
"F1Score": FBeta(
|
"F1Score": FBeta(
|
||||||
num_classes=self.hparams["num_classes"], beta=1.0, average="micro"
|
num_classes=self.hparams["num_classes"], beta=1.0, average="micro"
|
||||||
),
|
),
|
||||||
|
|
|
@ -12,7 +12,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchmetrics import Accuracy, IoU, MetricCollection
|
from torchmetrics import Accuracy, JaccardIndex, MetricCollection
|
||||||
|
|
||||||
from ..datasets.utils import unbind_samples
|
from ..datasets.utils import unbind_samples
|
||||||
from ..models import FCN
|
from ..models import FCN
|
||||||
|
@ -96,7 +96,7 @@ class SemanticSegmentationTask(LightningModule):
|
||||||
num_classes=self.hparams["num_classes"],
|
num_classes=self.hparams["num_classes"],
|
||||||
ignore_index=self.ignore_zeros,
|
ignore_index=self.ignore_zeros,
|
||||||
),
|
),
|
||||||
IoU(
|
JaccardIndex(
|
||||||
num_classes=self.hparams["num_classes"],
|
num_classes=self.hparams["num_classes"],
|
||||||
ignore_index=self.ignore_zeros,
|
ignore_index=self.ignore_zeros,
|
||||||
),
|
),
|
||||||
|
@ -120,7 +120,7 @@ class SemanticSegmentationTask(LightningModule):
|
||||||
def training_step( # type: ignore[override]
|
def training_step( # type: ignore[override]
|
||||||
self, batch: Dict[str, Any], batch_idx: int
|
self, batch: Dict[str, Any], batch_idx: int
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Training step - reports average accuracy and average IoU.
|
"""Training step - reports average accuracy and average JaccardIndex.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: Current batch
|
batch: Current batch
|
||||||
|
@ -155,7 +155,7 @@ class SemanticSegmentationTask(LightningModule):
|
||||||
def validation_step( # type: ignore[override]
|
def validation_step( # type: ignore[override]
|
||||||
self, batch: Dict[str, Any], batch_idx: int
|
self, batch: Dict[str, Any], batch_idx: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validation step - reports average accuracy and average IoU.
|
"""Validation step - reports average accuracy and average JaccardIndex.
|
||||||
|
|
||||||
Logs the first 10 validation samples to tensorboard as images with 3 subplots
|
Logs the first 10 validation samples to tensorboard as images with 3 subplots
|
||||||
showing the image, mask, and predictions.
|
showing the image, mask, and predictions.
|
||||||
|
|
Загрузка…
Ссылка в новой задаче