torchmetrics: IoU -> JaccardIndex (#361)

This commit is contained in:
Adam J. Stewart 2022-01-18 14:34:15 -06:00
Родитель c9520aa3f1
Коммит 37d0e8f028
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
5 изменённых файлов: 15 добавлений и 14 удалений

Просмотреть файл

@ -45,5 +45,5 @@ dependencies:
- setuptools>=42 - setuptools>=42
- sphinx>=4 - sphinx>=4
- timm>=0.2.1 - timm>=0.2.1
- torchmetrics - torchmetrics>=0.7
- zipfile-deflate64>=0.2 - zipfile-deflate64>=0.2

Просмотреть файл

@ -12,7 +12,7 @@ from typing import Any, Dict, Union
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torchmetrics import Accuracy, IoU, Metric, MetricCollection from torchmetrics import Accuracy, JaccardIndex, Metric, MetricCollection
from torchgeo.trainers import ClassificationTask, SemanticSegmentationTask from torchgeo.trainers import ClassificationTask, SemanticSegmentationTask
from train import TASK_TO_MODULES_MAPPING from train import TASK_TO_MODULES_MAPPING
@ -185,7 +185,7 @@ def main(args: argparse.Namespace) -> None:
if args.task == "etci2021": # Custom metric setup for testing ETCI2021 if args.task == "etci2021": # Custom metric setup for testing ETCI2021
metrics = MetricCollection( metrics = MetricCollection(
[Accuracy(num_classes=2), IoU(num_classes=2, reduction="none")] [Accuracy(num_classes=2), JaccardIndex(num_classes=2, reduction="none")]
).to(device) ).to(device)
val_results = run_eval_loop(model, dm.val_dataloader(), device, metrics) val_results = run_eval_loop(model, dm.val_dataloader(), device, metrics)
@ -194,13 +194,13 @@ def main(args: argparse.Namespace) -> None:
val_row.update( val_row.update(
{ {
"overall_accuracy": val_results["Accuracy"].item(), "overall_accuracy": val_results["Accuracy"].item(),
"iou": val_results["IoU"][1].item(), "jaccard_index": val_results["JaccardIndex"][1].item(),
} }
) )
test_row.update( test_row.update(
{ {
"overall_accuracy": test_results["Accuracy"].item(), "overall_accuracy": test_results["Accuracy"].item(),
"iou": test_results["IoU"][1].item(), "jaccard_index": test_results["JaccardIndex"][1].item(),
} }
) )
else: # Test with PyTorch Lightning as usual else: # Test with PyTorch Lightning as usual
@ -230,13 +230,13 @@ def main(args: argparse.Namespace) -> None:
val_row.update( val_row.update(
{ {
"overall_accuracy": val_results["val_Accuracy"].item(), "overall_accuracy": val_results["val_Accuracy"].item(),
"iou": val_results["val_IoU"].item(), "jaccard_index": val_results["val_JaccardIndex"].item(),
} }
) )
test_row.update( test_row.update(
{ {
"overall_accuracy": test_results["test_Accuracy"].item(), "overall_accuracy": test_results["test_Accuracy"].item(),
"iou": test_results["test_IoU"].item(), "jaccard_index": test_results["test_JaccardIndex"].item(),
} }
) )

Просмотреть файл

@ -59,7 +59,8 @@ install_requires =
timm>=0.2.1 timm>=0.2.1
# torch 1.7+ required for typing # torch 1.7+ required for typing
torch>=1.7 torch>=1.7
torchmetrics # torchmetrics 0.7+ required for JaccardIndex
torchmetrics>=0.7
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks # torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks
torchvision>=0.10 torchvision>=0.10
python_requires = >= 3.6 python_requires = >= 3.6

Просмотреть файл

@ -14,7 +14,7 @@ from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor from torch import Tensor
from torch.nn.modules import Conv2d, Linear from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import Accuracy, FBeta, IoU, MetricCollection from torchmetrics import Accuracy, FBeta, JaccardIndex, MetricCollection
from ..datasets.utils import unbind_samples from ..datasets.utils import unbind_samples
from . import utils from . import utils
@ -105,7 +105,7 @@ class ClassificationTask(pl.LightningModule):
"AverageAccuracy": Accuracy( "AverageAccuracy": Accuracy(
num_classes=self.hparams["num_classes"], average="macro" num_classes=self.hparams["num_classes"], average="macro"
), ),
"IoU": IoU(num_classes=self.hparams["num_classes"]), "JaccardIndex": JaccardIndex(num_classes=self.hparams["num_classes"]),
"F1Score": FBeta( "F1Score": FBeta(
num_classes=self.hparams["num_classes"], beta=1.0, average="micro" num_classes=self.hparams["num_classes"], beta=1.0, average="micro"
), ),

Просмотреть файл

@ -12,7 +12,7 @@ from pytorch_lightning.core.lightning import LightningModule
from torch import Tensor from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchmetrics import Accuracy, IoU, MetricCollection from torchmetrics import Accuracy, JaccardIndex, MetricCollection
from ..datasets.utils import unbind_samples from ..datasets.utils import unbind_samples
from ..models import FCN from ..models import FCN
@ -96,7 +96,7 @@ class SemanticSegmentationTask(LightningModule):
num_classes=self.hparams["num_classes"], num_classes=self.hparams["num_classes"],
ignore_index=self.ignore_zeros, ignore_index=self.ignore_zeros,
), ),
IoU( JaccardIndex(
num_classes=self.hparams["num_classes"], num_classes=self.hparams["num_classes"],
ignore_index=self.ignore_zeros, ignore_index=self.ignore_zeros,
), ),
@ -120,7 +120,7 @@ class SemanticSegmentationTask(LightningModule):
def training_step( # type: ignore[override] def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int self, batch: Dict[str, Any], batch_idx: int
) -> Tensor: ) -> Tensor:
"""Training step - reports average accuracy and average IoU. """Training step - reports average accuracy and average JaccardIndex.
Args: Args:
batch: Current batch batch: Current batch
@ -155,7 +155,7 @@ class SemanticSegmentationTask(LightningModule):
def validation_step( # type: ignore[override] def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int self, batch: Dict[str, Any], batch_idx: int
) -> None: ) -> None:
"""Validation step - reports average accuracy and average IoU. """Validation step - reports average accuracy and average JaccardIndex.
Logs the first 10 validation samples to tensorboard as images with 3 subplots Logs the first 10 validation samples to tensorboard as images with 3 subplots
showing the image, mask, and predictions. showing the image, mask, and predictions.