зеркало из https://github.com/microsoft/torchgeo.git
torchmetrics: IoU -> JaccardIndex (#361)
This commit is contained in:
Родитель
c9520aa3f1
Коммит
37d0e8f028
|
@ -45,5 +45,5 @@ dependencies:
|
|||
- setuptools>=42
|
||||
- sphinx>=4
|
||||
- timm>=0.2.1
|
||||
- torchmetrics
|
||||
- torchmetrics>=0.7
|
||||
- zipfile-deflate64>=0.2
|
||||
|
|
12
evaluate.py
12
evaluate.py
|
@ -12,7 +12,7 @@ from typing import Any, Dict, Union
|
|||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchmetrics import Accuracy, IoU, Metric, MetricCollection
|
||||
from torchmetrics import Accuracy, JaccardIndex, Metric, MetricCollection
|
||||
|
||||
from torchgeo.trainers import ClassificationTask, SemanticSegmentationTask
|
||||
from train import TASK_TO_MODULES_MAPPING
|
||||
|
@ -185,7 +185,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
if args.task == "etci2021": # Custom metric setup for testing ETCI2021
|
||||
|
||||
metrics = MetricCollection(
|
||||
[Accuracy(num_classes=2), IoU(num_classes=2, reduction="none")]
|
||||
[Accuracy(num_classes=2), JaccardIndex(num_classes=2, reduction="none")]
|
||||
).to(device)
|
||||
|
||||
val_results = run_eval_loop(model, dm.val_dataloader(), device, metrics)
|
||||
|
@ -194,13 +194,13 @@ def main(args: argparse.Namespace) -> None:
|
|||
val_row.update(
|
||||
{
|
||||
"overall_accuracy": val_results["Accuracy"].item(),
|
||||
"iou": val_results["IoU"][1].item(),
|
||||
"jaccard_index": val_results["JaccardIndex"][1].item(),
|
||||
}
|
||||
)
|
||||
test_row.update(
|
||||
{
|
||||
"overall_accuracy": test_results["Accuracy"].item(),
|
||||
"iou": test_results["IoU"][1].item(),
|
||||
"jaccard_index": test_results["JaccardIndex"][1].item(),
|
||||
}
|
||||
)
|
||||
else: # Test with PyTorch Lightning as usual
|
||||
|
@ -230,13 +230,13 @@ def main(args: argparse.Namespace) -> None:
|
|||
val_row.update(
|
||||
{
|
||||
"overall_accuracy": val_results["val_Accuracy"].item(),
|
||||
"iou": val_results["val_IoU"].item(),
|
||||
"jaccard_index": val_results["val_JaccardIndex"].item(),
|
||||
}
|
||||
)
|
||||
test_row.update(
|
||||
{
|
||||
"overall_accuracy": test_results["test_Accuracy"].item(),
|
||||
"iou": test_results["test_IoU"].item(),
|
||||
"jaccard_index": test_results["test_JaccardIndex"].item(),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -59,7 +59,8 @@ install_requires =
|
|||
timm>=0.2.1
|
||||
# torch 1.7+ required for typing
|
||||
torch>=1.7
|
||||
torchmetrics
|
||||
# torchmetrics 0.7+ required for JaccardIndex
|
||||
torchmetrics>=0.7
|
||||
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks
|
||||
torchvision>=0.10
|
||||
python_requires = >= 3.6
|
||||
|
|
|
@ -14,7 +14,7 @@ from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
|
|||
from torch import Tensor
|
||||
from torch.nn.modules import Conv2d, Linear
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchmetrics import Accuracy, FBeta, IoU, MetricCollection
|
||||
from torchmetrics import Accuracy, FBeta, JaccardIndex, MetricCollection
|
||||
|
||||
from ..datasets.utils import unbind_samples
|
||||
from . import utils
|
||||
|
@ -105,7 +105,7 @@ class ClassificationTask(pl.LightningModule):
|
|||
"AverageAccuracy": Accuracy(
|
||||
num_classes=self.hparams["num_classes"], average="macro"
|
||||
),
|
||||
"IoU": IoU(num_classes=self.hparams["num_classes"]),
|
||||
"JaccardIndex": JaccardIndex(num_classes=self.hparams["num_classes"]),
|
||||
"F1Score": FBeta(
|
||||
num_classes=self.hparams["num_classes"], beta=1.0, average="micro"
|
||||
),
|
||||
|
|
|
@ -12,7 +12,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from torch import Tensor
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics import Accuracy, IoU, MetricCollection
|
||||
from torchmetrics import Accuracy, JaccardIndex, MetricCollection
|
||||
|
||||
from ..datasets.utils import unbind_samples
|
||||
from ..models import FCN
|
||||
|
@ -96,7 +96,7 @@ class SemanticSegmentationTask(LightningModule):
|
|||
num_classes=self.hparams["num_classes"],
|
||||
ignore_index=self.ignore_zeros,
|
||||
),
|
||||
IoU(
|
||||
JaccardIndex(
|
||||
num_classes=self.hparams["num_classes"],
|
||||
ignore_index=self.ignore_zeros,
|
||||
),
|
||||
|
@ -120,7 +120,7 @@ class SemanticSegmentationTask(LightningModule):
|
|||
def training_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> Tensor:
|
||||
"""Training step - reports average accuracy and average IoU.
|
||||
"""Training step - reports average accuracy and average JaccardIndex.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
|
@ -155,7 +155,7 @@ class SemanticSegmentationTask(LightningModule):
|
|||
def validation_step( # type: ignore[override]
|
||||
self, batch: Dict[str, Any], batch_idx: int
|
||||
) -> None:
|
||||
"""Validation step - reports average accuracy and average IoU.
|
||||
"""Validation step - reports average accuracy and average JaccardIndex.
|
||||
|
||||
Logs the first 10 validation samples to tensorboard as images with 3 subplots
|
||||
showing the image, mask, and predictions.
|
||||
|
|
Загрузка…
Ссылка в новой задаче