From 37d0e8f02828f0e3b8d9c6401288e3da19383b8f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 18 Jan 2022 14:34:15 -0600 Subject: [PATCH] torchmetrics: IoU -> JaccardIndex (#361) --- environment.yml | 2 +- evaluate.py | 12 ++++++------ setup.cfg | 3 ++- torchgeo/trainers/classification.py | 4 ++-- torchgeo/trainers/segmentation.py | 8 ++++---- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/environment.yml b/environment.yml index 5be325da9..d60866c8c 100644 --- a/environment.yml +++ b/environment.yml @@ -45,5 +45,5 @@ dependencies: - setuptools>=42 - sphinx>=4 - timm>=0.2.1 - - torchmetrics + - torchmetrics>=0.7 - zipfile-deflate64>=0.2 diff --git a/evaluate.py b/evaluate.py index b8a7fa270..d173c7c00 100755 --- a/evaluate.py +++ b/evaluate.py @@ -12,7 +12,7 @@ from typing import Any, Dict, Union import pytorch_lightning as pl import torch -from torchmetrics import Accuracy, IoU, Metric, MetricCollection +from torchmetrics import Accuracy, JaccardIndex, Metric, MetricCollection from torchgeo.trainers import ClassificationTask, SemanticSegmentationTask from train import TASK_TO_MODULES_MAPPING @@ -185,7 +185,7 @@ def main(args: argparse.Namespace) -> None: if args.task == "etci2021": # Custom metric setup for testing ETCI2021 metrics = MetricCollection( - [Accuracy(num_classes=2), IoU(num_classes=2, reduction="none")] + [Accuracy(num_classes=2), JaccardIndex(num_classes=2, reduction="none")] ).to(device) val_results = run_eval_loop(model, dm.val_dataloader(), device, metrics) @@ -194,13 +194,13 @@ def main(args: argparse.Namespace) -> None: val_row.update( { "overall_accuracy": val_results["Accuracy"].item(), - "iou": val_results["IoU"][1].item(), + "jaccard_index": val_results["JaccardIndex"][1].item(), } ) test_row.update( { "overall_accuracy": test_results["Accuracy"].item(), - "iou": test_results["IoU"][1].item(), + "jaccard_index": test_results["JaccardIndex"][1].item(), } ) else: # Test with PyTorch Lightning as usual @@ -230,13 +230,13 @@ def main(args: argparse.Namespace) -> None: val_row.update( { "overall_accuracy": val_results["val_Accuracy"].item(), - "iou": val_results["val_IoU"].item(), + "jaccard_index": val_results["val_JaccardIndex"].item(), } ) test_row.update( { "overall_accuracy": test_results["test_Accuracy"].item(), - "iou": test_results["test_IoU"].item(), + "jaccard_index": test_results["test_JaccardIndex"].item(), } ) diff --git a/setup.cfg b/setup.cfg index b62bd4beb..fc7001efb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -59,7 +59,8 @@ install_requires = timm>=0.2.1 # torch 1.7+ required for typing torch>=1.7 - torchmetrics + # torchmetrics 0.7+ required for JaccardIndex + torchmetrics>=0.7 # torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks torchvision>=0.10 python_requires = >= 3.6 diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 5b69f10b1..002eedc18 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -14,7 +14,7 @@ from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss from torch import Tensor from torch.nn.modules import Conv2d, Linear from torch.optim.lr_scheduler import ReduceLROnPlateau -from torchmetrics import Accuracy, FBeta, IoU, MetricCollection +from torchmetrics import Accuracy, FBeta, JaccardIndex, MetricCollection from ..datasets.utils import unbind_samples from . import utils @@ -105,7 +105,7 @@ class ClassificationTask(pl.LightningModule): "AverageAccuracy": Accuracy( num_classes=self.hparams["num_classes"], average="macro" ), - "IoU": IoU(num_classes=self.hparams["num_classes"]), + "JaccardIndex": JaccardIndex(num_classes=self.hparams["num_classes"]), "F1Score": FBeta( num_classes=self.hparams["num_classes"], beta=1.0, average="micro" ), diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 1da0cf05d..afc5937a6 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -12,7 +12,7 @@ from pytorch_lightning.core.lightning import LightningModule from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader -from torchmetrics import Accuracy, IoU, MetricCollection +from torchmetrics import Accuracy, JaccardIndex, MetricCollection from ..datasets.utils import unbind_samples from ..models import FCN @@ -96,7 +96,7 @@ class SemanticSegmentationTask(LightningModule): num_classes=self.hparams["num_classes"], ignore_index=self.ignore_zeros, ), - IoU( + JaccardIndex( num_classes=self.hparams["num_classes"], ignore_index=self.ignore_zeros, ), @@ -120,7 +120,7 @@ class SemanticSegmentationTask(LightningModule): def training_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> Tensor: - """Training step - reports average accuracy and average IoU. + """Training step - reports average accuracy and average JaccardIndex. Args: batch: Current batch @@ -155,7 +155,7 @@ class SemanticSegmentationTask(LightningModule): def validation_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: - """Validation step - reports average accuracy and average IoU. + """Validation step - reports average accuracy and average JaccardIndex. Logs the first 10 validation samples to tensorboard as images with 3 subplots showing the image, mask, and predictions.