Reformat iou [func] and add IoU class (PL^4704)
* added Iou * Create iou.py * Update iou.py * Update iou.py * Update CHANGELOG.md * Update metrics.rst * Update iou.py * Update iou.py * Update __init__.py * Update iou.py * Update iou.py * Update classification.py * Update classification.py * Update classification.py * Update __init__.py * Update __init__.py * Update iou.py * Update classification.py * Update metrics.rst * Update CHANGELOG.md * Update CHANGELOG.md * add iou * add test * add test * removed iou * add iou * add iou test * add float * reformat test_iou * removed test_iou * updated format * updated format * Update CHANGELOG.md * updated format * Update metrics.rst * Apply suggestions from code review merge suggestions Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * added equations * reformat init * change format * change format * deprecate iou and test for this * fix changelog * delete iou test in test_classification * format change * format change * format * format * format * delete white space * delete white space * fix tests * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * better deprecation * fix docs * Apply suggestions from code review * fix todo Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> (cherry picked from commit 238040c9e5f2a9f03cecb171791fd8c4b88ca8b7)
This commit is contained in:
Родитель
bcb82d2a0c
Коммит
92c2e1923b
|
@ -16,6 +16,7 @@ from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F
|
|||
from pytorch_lightning.metrics.classification import ( # noqa: F401
|
||||
Accuracy,
|
||||
HammingDistance,
|
||||
IoU,
|
||||
Precision,
|
||||
Recall,
|
||||
ConfusionMatrix,
|
||||
|
|
|
@ -20,3 +20,4 @@ from pytorch_lightning.metrics.classification.precision_recall import Precision,
|
|||
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.stat_scores import StatScores # noqa: F401
|
||||
from pytorch_lightning.metrics.classification.iou import IoU # noqa: F401
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
|
||||
from pytorch_lightning.metrics.functional.iou import _iou_from_confmat
|
||||
|
||||
|
||||
class IoU(ConfusionMatrix):
|
||||
r"""
|
||||
Computes `Intersection over union, or Jaccard index calculation <https://en.wikipedia.org/wiki/Jaccard_index>`_:
|
||||
|
||||
.. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|}
|
||||
|
||||
Where: :math:`A` and :math:`B` are both tensors of the same size, containing integer class values.
|
||||
They may be subject to conversion from input data (see description below). Note that it is different from box IoU.
|
||||
|
||||
Works with binary, multiclass and multi-label data.
|
||||
Accepts logits from a model output or integer class values in prediction.
|
||||
Works with multi-dimensional preds and target.
|
||||
|
||||
Forward accepts
|
||||
|
||||
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
|
||||
- ``target`` (long tensor): ``(N, ...)``
|
||||
|
||||
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
|
||||
This is the case for binary and multi-label logits.
|
||||
|
||||
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
|
||||
|
||||
Args:
|
||||
num_classes: Number of classes in the dataset.
|
||||
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
|
||||
to the returned score, regardless of reduction method. Has no effect if given an int that is not in the
|
||||
range [0, num_classes-1]. By default, no index is ignored, and all classes are used.
|
||||
absent_score: score to use for an individual class, if no instances of the class index were present in
|
||||
`pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes,
|
||||
[0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`.
|
||||
threshold:
|
||||
Threshold value for binary or multi-label logits.
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and return None if this is set to False.
|
||||
dist_sync_on_step:
|
||||
Synchronize metric state across processes at each ``forward()``
|
||||
before returning the value at the step.
|
||||
process_group:
|
||||
Specify the process group on which synchronization is called. default: None (which selects the entire world)
|
||||
|
||||
Example:
|
||||
>>> from pytorch_lightning.metrics import IoU
|
||||
>>> target = torch.randint(0, 2, (10, 25, 25))
|
||||
>>> pred = torch.tensor(target)
|
||||
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
|
||||
>>> iou = IoU(num_classes=2)
|
||||
>>> iou(pred, target)
|
||||
tensor(0.9660)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
ignore_index: Optional[int] = None,
|
||||
absent_score: float = 0.0,
|
||||
threshold: float = 0.5,
|
||||
reduction: str = 'elementwise_mean',
|
||||
compute_on_step: bool = True,
|
||||
dist_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_classes=num_classes,
|
||||
normalize=None,
|
||||
threshold=threshold,
|
||||
compute_on_step=compute_on_step,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.reduction = reduction
|
||||
self.ignore_index = ignore_index
|
||||
self.absent_score = absent_score
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
"""
|
||||
Computes intersection over union (IoU)
|
||||
"""
|
||||
return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction)
|
|
@ -17,7 +17,6 @@ from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
|
|||
auroc,
|
||||
dice_score,
|
||||
get_num_classes,
|
||||
iou,
|
||||
multiclass_auroc,
|
||||
precision,
|
||||
precision_recall,
|
||||
|
@ -32,6 +31,8 @@ from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matr
|
|||
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.iou import iou # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401
|
||||
|
@ -43,4 +44,3 @@ from pytorch_lightning.metrics.functional.roc import roc # noqa: F401
|
|||
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.stat_scores import stat_scores # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401
|
||||
|
|
|
@ -11,13 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from distutils.version import LooseVersion
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
|
||||
from pytorch_lightning.metrics.functional.iou import iou as __iou
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import (
|
||||
_binary_clf_curve,
|
||||
precision_recall_curve as __prc
|
||||
|
@ -84,7 +84,7 @@ def get_num_classes(
|
|||
" `from pytorch_lightning.metrics.utils import get_num_classes`."
|
||||
" It will be removed in v1.3.0", DeprecationWarning
|
||||
)
|
||||
return __gnc(pred,target, num_classes)
|
||||
return __gnc(pred, target, num_classes)
|
||||
|
||||
|
||||
def stat_scores(
|
||||
|
@ -162,8 +162,8 @@ def stat_scores_multiple_classes(
|
|||
raise ValueError("reduction type %s not supported" % reduction)
|
||||
|
||||
if reduction == 'none':
|
||||
pred = pred.view((-1, )).long()
|
||||
target = target.view((-1, )).long()
|
||||
pred = pred.view((-1,)).long()
|
||||
target = target.view((-1,)).long()
|
||||
|
||||
tps = torch.zeros((num_classes + 1,), device=pred.device)
|
||||
fps = torch.zeros((num_classes + 1,), device=pred.device)
|
||||
|
@ -687,6 +687,7 @@ def dice_score(
|
|||
return reduce(scores, reduction=reduction)
|
||||
|
||||
|
||||
# todo: remove in 1.4
|
||||
def iou(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
|
@ -698,6 +699,10 @@ def iou(
|
|||
"""
|
||||
Intersection over union, or Jaccard index calculation.
|
||||
|
||||
.. warning :: Deprecated in favor of
|
||||
:func:`~pytorch_lightning.metrics.functional.iou.iou`. Will be removed in
|
||||
v1.4.0.
|
||||
|
||||
Args:
|
||||
pred: Tensor containing integer predictions, with shape [N, d1, d2, ...]
|
||||
target: Tensor containing integer targets, with shape [N, d1, d2, ...]
|
||||
|
@ -729,48 +734,20 @@ def iou(
|
|||
tensor(0.9660)
|
||||
|
||||
"""
|
||||
if pred.size() != target.size():
|
||||
raise ValueError(f"'pred' shape ({pred.size()}) must equal 'target' shape ({target.size()})")
|
||||
|
||||
if not torch.allclose(pred.float(), pred.int().float()):
|
||||
raise ValueError("'pred' must contain integer targets.")
|
||||
|
||||
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
|
||||
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)
|
||||
|
||||
scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32)
|
||||
|
||||
for class_idx in range(num_classes):
|
||||
if class_idx == ignore_index:
|
||||
continue
|
||||
|
||||
tp = tps[class_idx]
|
||||
fp = fps[class_idx]
|
||||
fn = fns[class_idx]
|
||||
sup = sups[class_idx]
|
||||
|
||||
# If this class is absent in the target (no support) AND absent in the pred (no true or false
|
||||
# positives), then use the absent_score for this class.
|
||||
if sup + tp + fp == 0:
|
||||
scores[class_idx] = absent_score
|
||||
continue
|
||||
|
||||
denom = tp + fp + fn
|
||||
# Note that we do not need to worry about division-by-zero here since we know (sup + tp + fp != 0) from above,
|
||||
# which means ((tp+fn) + tp + fp != 0), which means (2tp + fp + fn != 0). Since all vars are non-negative, we
|
||||
# can conclude (tp + fp + fn > 0), meaning the denominator is non-zero for each class.
|
||||
score = tp.to(torch.float) / denom
|
||||
scores[class_idx] = score
|
||||
|
||||
# Remove the ignored class index from the scores.
|
||||
if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
|
||||
scores = torch.cat([
|
||||
scores[:ignore_index],
|
||||
scores[ignore_index + 1:],
|
||||
])
|
||||
|
||||
return reduce(scores, reduction=reduction)
|
||||
rank_zero_warn(
|
||||
"This `iou` was deprecated in v1.2.0 in favor of"
|
||||
" `from pytorch_lightning.metrics.functional.iou import iou`."
|
||||
" It will be removed in v1.4.0", DeprecationWarning
|
||||
)
|
||||
return __iou(
|
||||
pred=pred,
|
||||
target=target,
|
||||
ignore_index=ignore_index,
|
||||
absent_score=absent_score,
|
||||
threshold=0.5,
|
||||
num_classes=num_classes,
|
||||
reduction=reduction
|
||||
)
|
||||
|
||||
|
||||
# todo: remove in 1.3
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
from pytorch_lightning.metrics.utils import get_num_classes
|
||||
|
||||
|
||||
def _iou_from_confmat(
|
||||
confmat: torch.Tensor,
|
||||
num_classes: int,
|
||||
ignore_index: Optional[int] = None,
|
||||
absent_score: float = 0.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
):
|
||||
intersection = torch.diag(confmat)
|
||||
union = confmat.sum(0) + confmat.sum(1) - intersection
|
||||
|
||||
# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
|
||||
scores = intersection.float() / union.float()
|
||||
scores[union == 0] = absent_score
|
||||
|
||||
# Remove the ignored class index from the scores.
|
||||
if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
|
||||
scores = torch.cat([
|
||||
scores[:ignore_index],
|
||||
scores[ignore_index + 1:],
|
||||
])
|
||||
return reduce(scores, reduction=reduction)
|
||||
|
||||
|
||||
def iou(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
ignore_index: Optional[int] = None,
|
||||
absent_score: float = 0.0,
|
||||
threshold: float = 0.5,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Computes `Intersection over union, or Jaccard index calculation <https://en.wikipedia.org/wiki/Jaccard_index>`_:
|
||||
|
||||
.. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|}
|
||||
|
||||
Where: :math:`A` and :math:`B` are both tensors of the same size,
|
||||
containing integer class values. They may be subject to conversion from
|
||||
input data (see description below).
|
||||
|
||||
Note that it is different from box IoU.
|
||||
|
||||
If pred and target are the same shape and pred is a float tensor,
|
||||
we use the ``threshold`` argument. This is the case for binary and multi-label logits.
|
||||
|
||||
If pred has an extra dimension as in the case of multi-class scores we
|
||||
perform an argmax on ``dim=1``.
|
||||
|
||||
Args:
|
||||
pred: Tensor containing integer predictions, with shape [N, d1, d2, ...]
|
||||
target: Tensor containing integer targets, with shape [N, d1, d2, ...]
|
||||
ignore_index: optional int specifying a target class to ignore. If given,
|
||||
this class index does not contribute to the returned score, regardless
|
||||
of reduction method. Has no effect if given an int that is not in the
|
||||
range [0, num_classes-1], where num_classes is either given or derived
|
||||
from pred and target. By default, no index is ignored, and all classes are used.
|
||||
absent_score: score to use for an individual class, if no instances of
|
||||
the class index were present in `pred` AND no instances of the class
|
||||
index were present in `target`. For example, if we have 3 classes,
|
||||
[0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be
|
||||
assigned the `absent_score`.
|
||||
threshold:
|
||||
Threshold value for binary or multi-label logits. default: 0.5
|
||||
num_classes:
|
||||
Optionally specify the number of classes
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
|
||||
Return:
|
||||
IoU score : Tensor containing single value if reduction is
|
||||
'elementwise_mean', or number of classes if reduction is 'none'
|
||||
|
||||
Example:
|
||||
|
||||
>>> target = torch.randint(0, 2, (10, 25, 25))
|
||||
>>> pred = torch.tensor(target)
|
||||
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
|
||||
>>> iou(pred, target)
|
||||
tensor(0.9660)
|
||||
|
||||
"""
|
||||
|
||||
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
|
||||
confmat = _confusion_matrix_update(pred, target, num_classes, threshold)
|
||||
return _iou_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)
|
|
@ -0,0 +1,214 @@
|
|||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from pytorch_lightning.metrics.classification.iou import IoU
|
||||
from pytorch_lightning.metrics.functional.iou import iou
|
||||
from sklearn.metrics import jaccard_score as sk_jaccard_score
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_inputs,
|
||||
_binary_prob_inputs,
|
||||
_multiclass_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
_multilabel_inputs,
|
||||
_multilabel_prob_inputs
|
||||
)
|
||||
from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester
|
||||
|
||||
|
||||
def _sk_iou_binary_prob(preds, target, average=None):
|
||||
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
|
||||
|
||||
|
||||
def _sk_iou_binary(preds, target, average=None):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
|
||||
|
||||
|
||||
def _sk_iou_multilabel_prob(preds, target, average=None):
|
||||
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
|
||||
|
||||
|
||||
def _sk_iou_multilabel(preds, target, average=None):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
|
||||
|
||||
|
||||
def _sk_iou_multiclass_prob(preds, target, average=None):
|
||||
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
|
||||
|
||||
|
||||
def _sk_iou_multiclass(preds, target, average=None):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
|
||||
|
||||
|
||||
def _sk_iou_multidim_multiclass_prob(preds, target, average=None):
|
||||
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
|
||||
|
||||
|
||||
def _sk_iou_multidim_multiclass(preds, target, average=None):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("reduction", ['elementwise_mean', 'none'])
|
||||
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_iou_binary_prob, 2),
|
||||
(_binary_inputs.preds, _binary_inputs.target, _sk_iou_binary, 2),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_iou_multilabel_prob, 2),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_iou_multilabel, 2),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_iou_multiclass_prob, NUM_CLASSES),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_iou_multiclass, NUM_CLASSES),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_sk_iou_multidim_multiclass_prob,
|
||||
NUM_CLASSES
|
||||
),
|
||||
(
|
||||
_multidim_multiclass_inputs.preds,
|
||||
_multidim_multiclass_inputs.target,
|
||||
_sk_iou_multidim_multiclass,
|
||||
NUM_CLASSES
|
||||
)
|
||||
])
|
||||
class TestIoU(MetricTester):
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_confusion_matrix(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
|
||||
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
|
||||
self.run_class_metric_test(ddp=ddp,
|
||||
preds=preds,
|
||||
target=target,
|
||||
metric_class=IoU,
|
||||
sk_metric=partial(sk_metric, average=average),
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_args={"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"reduction": reduction}
|
||||
)
|
||||
|
||||
def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric, num_classes):
|
||||
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
|
||||
self.run_functional_metric_test(preds,
|
||||
target,
|
||||
metric_functional=iou,
|
||||
sk_metric=partial(sk_metric, average=average),
|
||||
metric_args={"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"reduction": reduction}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
|
||||
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
|
||||
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
|
||||
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
|
||||
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
|
||||
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
|
||||
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
|
||||
])
|
||||
def test_iou(half_ones, reduction, ignore_index, expected):
|
||||
pred = (torch.arange(120) % 3).view(-1, 1)
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
if half_ones:
|
||||
pred[:60] = 1
|
||||
iou_val = iou(
|
||||
pred=pred,
|
||||
target=target,
|
||||
ignore_index=ignore_index,
|
||||
reduction=reduction,
|
||||
)
|
||||
assert torch.allclose(iou_val, expected, atol=1e-9)
|
||||
|
||||
|
||||
# test `absent_score`
|
||||
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [
|
||||
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
|
||||
# scores the function can return ([0., 1.] range, inclusive).
|
||||
# 2 classes, class 0 is correct everywhere, class 1 is absent.
|
||||
pytest.param([0], [0], None, -1., 2, [1., -1.]),
|
||||
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
|
||||
# absent_score not applied if only class 0 is present and it's the only class.
|
||||
pytest.param([0], [0], None, -1., 1, [1.]),
|
||||
# 2 classes, class 1 is correct everywhere, class 0 is absent.
|
||||
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
|
||||
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
|
||||
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
|
||||
pytest.param([1], [1], 0, -1., 2, [1.0]),
|
||||
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
|
||||
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
|
||||
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
|
||||
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
|
||||
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
|
||||
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
|
||||
# Sanity checks with absent_score of 1.0.
|
||||
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
|
||||
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
|
||||
])
|
||||
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
absent_score=absent_score,
|
||||
num_classes=num_classes,
|
||||
reduction='none',
|
||||
)
|
||||
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
|
||||
|
||||
|
||||
# example data taken from
|
||||
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
|
||||
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
|
||||
# Ignoring an index outside of [0, num_classes-1] should have no effect.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
# Ignoring a valid index drops only that index from the result.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
|
||||
# When reducing to mean or sum, the ignored index does not contribute to the output.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
|
||||
])
|
||||
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
num_classes=num_classes,
|
||||
reduction=reduction,
|
||||
)
|
||||
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
|
|
@ -4,7 +4,6 @@ import pytest
|
|||
import torch
|
||||
from distutils.version import LooseVersion
|
||||
from sklearn.metrics import (
|
||||
jaccard_score as sk_jaccard_score,
|
||||
precision_score as sk_precision,
|
||||
recall_score as sk_recall,
|
||||
roc_auc_score as sk_roc_auc_score,
|
||||
|
@ -20,14 +19,12 @@ from pytorch_lightning.metrics.functional.classification import (
|
|||
auroc,
|
||||
multiclass_auroc,
|
||||
auc,
|
||||
iou,
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
|
||||
from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [
|
||||
pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'),
|
||||
pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'),
|
||||
pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'),
|
||||
pytest.param(sk_roc_auc_score, auroc, True, id='auroc')
|
||||
|
@ -297,112 +294,10 @@ def test_dice_score(pred, target, expected):
|
|||
assert score == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
|
||||
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
|
||||
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
|
||||
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
|
||||
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
|
||||
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
|
||||
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
|
||||
])
|
||||
def test_iou(half_ones, reduction, ignore_index, expected):
|
||||
pred = (torch.arange(120) % 3).view(-1, 1)
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
if half_ones:
|
||||
pred[:60] = 1
|
||||
iou_val = iou(
|
||||
pred=pred,
|
||||
target=target,
|
||||
ignore_index=ignore_index,
|
||||
reduction=reduction,
|
||||
)
|
||||
assert torch.allclose(iou_val, expected, atol=1e-9)
|
||||
|
||||
|
||||
def test_iou_input_check():
|
||||
with pytest.raises(ValueError, match=r"'pred' shape (.*) must equal 'target' shape (.*)"):
|
||||
_ = iou(pred=torch.randint(0, 2, (3, 4, 3)),
|
||||
target=torch.randint(0, 2, (3, 3)))
|
||||
|
||||
with pytest.raises(ValueError, match="'pred' must contain integer targets."):
|
||||
_ = iou(pred=torch.rand((3, 3)),
|
||||
target=torch.randint(0, 2, (3, 3)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('metric', [auroc])
|
||||
def test_error_on_multiclass_input(metric):
|
||||
""" check that these metrics raise an error if they are used for multiclass problems """
|
||||
pred = torch.randint(0, 10, (100, ))
|
||||
target = torch.randint(0, 10, (100, ))
|
||||
pred = torch.randint(0, 10, (100,))
|
||||
target = torch.randint(0, 10, (100,))
|
||||
with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"):
|
||||
_ = metric(pred, target)
|
||||
|
||||
|
||||
# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see
|
||||
# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our
|
||||
# `absent_score`.
|
||||
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [
|
||||
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
|
||||
# scores the function can return ([0., 1.] range, inclusive).
|
||||
# 2 classes, class 0 is correct everywhere, class 1 is absent.
|
||||
pytest.param([0], [0], None, -1., 2, [1., -1.]),
|
||||
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
|
||||
# absent_score not applied if only class 0 is present and it's the only class.
|
||||
pytest.param([0], [0], None, -1., 1, [1.]),
|
||||
# 2 classes, class 1 is correct everywhere, class 0 is absent.
|
||||
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
|
||||
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
|
||||
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
|
||||
pytest.param([1], [1], 0, -1., 2, [1.0]),
|
||||
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
|
||||
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
|
||||
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
|
||||
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
|
||||
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
|
||||
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
|
||||
# Sanity checks with absent_score of 1.0.
|
||||
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
|
||||
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
|
||||
])
|
||||
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
absent_score=absent_score,
|
||||
num_classes=num_classes,
|
||||
reduction='none',
|
||||
)
|
||||
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
|
||||
|
||||
|
||||
# example data taken from
|
||||
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
|
||||
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
|
||||
# Ignoring an index outside of [0, num_classes-1] should have no effect.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
# Ignoring a valid index drops only that index from the result.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
|
||||
# When reducing to mean or sum, the ignored index does not contribute to the output.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
|
||||
])
|
||||
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
num_classes=num_classes,
|
||||
reduction=reduction,
|
||||
)
|
||||
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
|
||||
|
|
Загрузка…
Ссылка в новой задаче