torchmetrics: IoU -> JaccardIndex (#361)

This commit is contained in:
Adam J. Stewart 2022-01-18 14:34:15 -06:00
Родитель c9520aa3f1
Коммит 37d0e8f028
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
5 изменённых файлов: 15 добавлений и 14 удалений

Просмотреть файл

@ -45,5 +45,5 @@ dependencies:
- setuptools>=42
- sphinx>=4
- timm>=0.2.1
- torchmetrics
- torchmetrics>=0.7
- zipfile-deflate64>=0.2

Просмотреть файл

@ -12,7 +12,7 @@ from typing import Any, Dict, Union
import pytorch_lightning as pl
import torch
from torchmetrics import Accuracy, IoU, Metric, MetricCollection
from torchmetrics import Accuracy, JaccardIndex, Metric, MetricCollection
from torchgeo.trainers import ClassificationTask, SemanticSegmentationTask
from train import TASK_TO_MODULES_MAPPING
@ -185,7 +185,7 @@ def main(args: argparse.Namespace) -> None:
if args.task == "etci2021": # Custom metric setup for testing ETCI2021
metrics = MetricCollection(
[Accuracy(num_classes=2), IoU(num_classes=2, reduction="none")]
[Accuracy(num_classes=2), JaccardIndex(num_classes=2, reduction="none")]
).to(device)
val_results = run_eval_loop(model, dm.val_dataloader(), device, metrics)
@ -194,13 +194,13 @@ def main(args: argparse.Namespace) -> None:
val_row.update(
{
"overall_accuracy": val_results["Accuracy"].item(),
"iou": val_results["IoU"][1].item(),
"jaccard_index": val_results["JaccardIndex"][1].item(),
}
)
test_row.update(
{
"overall_accuracy": test_results["Accuracy"].item(),
"iou": test_results["IoU"][1].item(),
"jaccard_index": test_results["JaccardIndex"][1].item(),
}
)
else: # Test with PyTorch Lightning as usual
@ -230,13 +230,13 @@ def main(args: argparse.Namespace) -> None:
val_row.update(
{
"overall_accuracy": val_results["val_Accuracy"].item(),
"iou": val_results["val_IoU"].item(),
"jaccard_index": val_results["val_JaccardIndex"].item(),
}
)
test_row.update(
{
"overall_accuracy": test_results["test_Accuracy"].item(),
"iou": test_results["test_IoU"].item(),
"jaccard_index": test_results["test_JaccardIndex"].item(),
}
)

Просмотреть файл

@ -59,7 +59,8 @@ install_requires =
timm>=0.2.1
# torch 1.7+ required for typing
torch>=1.7
torchmetrics
# torchmetrics 0.7+ required for JaccardIndex
torchmetrics>=0.7
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks
torchvision>=0.10
python_requires = >= 3.6

Просмотреть файл

@ -14,7 +14,7 @@ from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import Accuracy, FBeta, IoU, MetricCollection
from torchmetrics import Accuracy, FBeta, JaccardIndex, MetricCollection
from ..datasets.utils import unbind_samples
from . import utils
@ -105,7 +105,7 @@ class ClassificationTask(pl.LightningModule):
"AverageAccuracy": Accuracy(
num_classes=self.hparams["num_classes"], average="macro"
),
"IoU": IoU(num_classes=self.hparams["num_classes"]),
"JaccardIndex": JaccardIndex(num_classes=self.hparams["num_classes"]),
"F1Score": FBeta(
num_classes=self.hparams["num_classes"], beta=1.0, average="micro"
),

Просмотреть файл

@ -12,7 +12,7 @@ from pytorch_lightning.core.lightning import LightningModule
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, IoU, MetricCollection
from torchmetrics import Accuracy, JaccardIndex, MetricCollection
from ..datasets.utils import unbind_samples
from ..models import FCN
@ -96,7 +96,7 @@ class SemanticSegmentationTask(LightningModule):
num_classes=self.hparams["num_classes"],
ignore_index=self.ignore_zeros,
),
IoU(
JaccardIndex(
num_classes=self.hparams["num_classes"],
ignore_index=self.ignore_zeros,
),
@ -120,7 +120,7 @@ class SemanticSegmentationTask(LightningModule):
def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step - reports average accuracy and average IoU.
"""Training step - reports average accuracy and average JaccardIndex.
Args:
batch: Current batch
@ -155,7 +155,7 @@ class SemanticSegmentationTask(LightningModule):
def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Validation step - reports average accuracy and average IoU.
"""Validation step - reports average accuracy and average JaccardIndex.
Logs the first 10 validation samples to tensorboard as images with 3 subplots
showing the image, mask, and predictions.