Reformat iou [func] and add IoU class (PL^4704)

* added Iou

* Create iou.py

* Update iou.py

* Update iou.py

* Update CHANGELOG.md

* Update metrics.rst

* Update iou.py

* Update iou.py

* Update __init__.py

* Update iou.py

* Update iou.py

* Update classification.py

* Update classification.py

* Update classification.py

* Update __init__.py

* Update __init__.py

* Update iou.py

* Update classification.py

* Update metrics.rst

* Update CHANGELOG.md

* Update CHANGELOG.md

* add iou

* add test

* add test

* removed iou

* add iou

* add iou test

* add float

* reformat test_iou

* removed test_iou

* updated format

* updated format

* Update CHANGELOG.md

* updated format

* Update metrics.rst

* Apply suggestions from code review

merge suggestions

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* added equations

* reformat init

* change format

* change format

* deprecate iou and test for this

* fix changelog

* delete iou test in test_classification

* format change

* format change

* format

* format

* format

* delete white space

* delete white space

* fix tests

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* better deprecation

* fix docs

* Apply suggestions from code review

* fix todo

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
(cherry picked from commit 238040c9e5f2a9f03cecb171791fd8c4b88ca8b7)
This commit is contained in:
deng-cy 2021-01-08 08:36:08 -05:00 коммит произвёл Jirka Borovec
Родитель bcb82d2a0c
Коммит 92c2e1923b
8 изменённых файлов: 460 добавлений и 156 удалений

Просмотреть файл

@ -16,6 +16,7 @@ from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F
from pytorch_lightning.metrics.classification import ( # noqa: F401
Accuracy,
HammingDistance,
IoU,
Precision,
Recall,
ConfusionMatrix,

Просмотреть файл

@ -20,3 +20,4 @@ from pytorch_lightning.metrics.classification.precision_recall import Precision,
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401
from pytorch_lightning.metrics.classification.stat_scores import StatScores # noqa: F401
from pytorch_lightning.metrics.classification.iou import IoU # noqa: F401

Просмотреть файл

@ -0,0 +1,106 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
from pytorch_lightning.metrics.functional.iou import _iou_from_confmat
class IoU(ConfusionMatrix):
r"""
Computes `Intersection over union, or Jaccard index calculation <https://en.wikipedia.org/wiki/Jaccard_index>`_:
.. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|}
Where: :math:`A` and :math:`B` are both tensors of the same size, containing integer class values.
They may be subject to conversion from input data (see description below). Note that it is different from box IoU.
Works with binary, multiclass and multi-label data.
Accepts logits from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.
Forward accepts
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
This is the case for binary and multi-label logits.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
num_classes: Number of classes in the dataset.
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. Has no effect if given an int that is not in the
range [0, num_classes-1]. By default, no index is ignored, and all classes are used.
absent_score: score to use for an individual class, if no instances of the class index were present in
`pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes,
[0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`.
threshold:
Threshold value for binary or multi-label logits.
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example:
>>> from pytorch_lightning.metrics import IoU
>>> target = torch.randint(0, 2, (10, 25, 25))
>>> pred = torch.tensor(target)
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
>>> iou = IoU(num_classes=2)
>>> iou(pred, target)
tensor(0.9660)
"""
def __init__(
self,
num_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
reduction: str = 'elementwise_mean',
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
num_classes=num_classes,
normalize=None,
threshold=threshold,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
self.reduction = reduction
self.ignore_index = ignore_index
self.absent_score = absent_score
def compute(self) -> torch.Tensor:
"""
Computes intersection over union (IoU)
"""
return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction)

Просмотреть файл

@ -17,7 +17,6 @@ from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
auroc,
dice_score,
get_num_classes,
iou,
multiclass_auroc,
precision,
precision_recall,
@ -32,6 +31,8 @@ from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matr
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401
from pytorch_lightning.metrics.functional.iou import iou # noqa: F401
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401
@ -43,4 +44,3 @@ from pytorch_lightning.metrics.functional.roc import roc # noqa: F401
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401
from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401
from pytorch_lightning.metrics.functional.stat_scores import stat_scores # noqa: F401
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401

Просмотреть файл

@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from functools import wraps
from typing import Callable, Optional, Sequence, Tuple
import torch
from distutils.version import LooseVersion
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
from pytorch_lightning.metrics.functional.iou import iou as __iou
from pytorch_lightning.metrics.functional.precision_recall_curve import (
_binary_clf_curve,
precision_recall_curve as __prc
@ -84,7 +84,7 @@ def get_num_classes(
" `from pytorch_lightning.metrics.utils import get_num_classes`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __gnc(pred,target, num_classes)
return __gnc(pred, target, num_classes)
def stat_scores(
@ -162,8 +162,8 @@ def stat_scores_multiple_classes(
raise ValueError("reduction type %s not supported" % reduction)
if reduction == 'none':
pred = pred.view((-1, )).long()
target = target.view((-1, )).long()
pred = pred.view((-1,)).long()
target = target.view((-1,)).long()
tps = torch.zeros((num_classes + 1,), device=pred.device)
fps = torch.zeros((num_classes + 1,), device=pred.device)
@ -687,6 +687,7 @@ def dice_score(
return reduce(scores, reduction=reduction)
# todo: remove in 1.4
def iou(
pred: torch.Tensor,
target: torch.Tensor,
@ -698,6 +699,10 @@ def iou(
"""
Intersection over union, or Jaccard index calculation.
.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.iou.iou`. Will be removed in
v1.4.0.
Args:
pred: Tensor containing integer predictions, with shape [N, d1, d2, ...]
target: Tensor containing integer targets, with shape [N, d1, d2, ...]
@ -729,48 +734,20 @@ def iou(
tensor(0.9660)
"""
if pred.size() != target.size():
raise ValueError(f"'pred' shape ({pred.size()}) must equal 'target' shape ({target.size()})")
if not torch.allclose(pred.float(), pred.int().float()):
raise ValueError("'pred' must contain integer targets.")
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)
scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32)
for class_idx in range(num_classes):
if class_idx == ignore_index:
continue
tp = tps[class_idx]
fp = fps[class_idx]
fn = fns[class_idx]
sup = sups[class_idx]
# If this class is absent in the target (no support) AND absent in the pred (no true or false
# positives), then use the absent_score for this class.
if sup + tp + fp == 0:
scores[class_idx] = absent_score
continue
denom = tp + fp + fn
# Note that we do not need to worry about division-by-zero here since we know (sup + tp + fp != 0) from above,
# which means ((tp+fn) + tp + fp != 0), which means (2tp + fp + fn != 0). Since all vars are non-negative, we
# can conclude (tp + fp + fn > 0), meaning the denominator is non-zero for each class.
score = tp.to(torch.float) / denom
scores[class_idx] = score
# Remove the ignored class index from the scores.
if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
scores = torch.cat([
scores[:ignore_index],
scores[ignore_index + 1:],
])
return reduce(scores, reduction=reduction)
rank_zero_warn(
"This `iou` was deprecated in v1.2.0 in favor of"
" `from pytorch_lightning.metrics.functional.iou import iou`."
" It will be removed in v1.4.0", DeprecationWarning
)
return __iou(
pred=pred,
target=target,
ignore_index=ignore_index,
absent_score=absent_score,
threshold=0.5,
num_classes=num_classes,
reduction=reduction
)
# todo: remove in 1.3

Просмотреть файл

@ -0,0 +1,110 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update
from pytorch_lightning.metrics.functional.reduction import reduce
from pytorch_lightning.metrics.utils import get_num_classes
def _iou_from_confmat(
confmat: torch.Tensor,
num_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
reduction: str = 'elementwise_mean',
):
intersection = torch.diag(confmat)
union = confmat.sum(0) + confmat.sum(1) - intersection
# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = intersection.float() / union.float()
scores[union == 0] = absent_score
# Remove the ignored class index from the scores.
if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
scores = torch.cat([
scores[:ignore_index],
scores[ignore_index + 1:],
])
return reduce(scores, reduction=reduction)
def iou(
pred: torch.Tensor,
target: torch.Tensor,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
r"""
Computes `Intersection over union, or Jaccard index calculation <https://en.wikipedia.org/wiki/Jaccard_index>`_:
.. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|}
Where: :math:`A` and :math:`B` are both tensors of the same size,
containing integer class values. They may be subject to conversion from
input data (see description below).
Note that it is different from box IoU.
If pred and target are the same shape and pred is a float tensor,
we use the ``threshold`` argument. This is the case for binary and multi-label logits.
If pred has an extra dimension as in the case of multi-class scores we
perform an argmax on ``dim=1``.
Args:
pred: Tensor containing integer predictions, with shape [N, d1, d2, ...]
target: Tensor containing integer targets, with shape [N, d1, d2, ...]
ignore_index: optional int specifying a target class to ignore. If given,
this class index does not contribute to the returned score, regardless
of reduction method. Has no effect if given an int that is not in the
range [0, num_classes-1], where num_classes is either given or derived
from pred and target. By default, no index is ignored, and all classes are used.
absent_score: score to use for an individual class, if no instances of
the class index were present in `pred` AND no instances of the class
index were present in `target`. For example, if we have 3 classes,
[0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be
assigned the `absent_score`.
threshold:
Threshold value for binary or multi-label logits. default: 0.5
num_classes:
Optionally specify the number of classes
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
Return:
IoU score : Tensor containing single value if reduction is
'elementwise_mean', or number of classes if reduction is 'none'
Example:
>>> target = torch.randint(0, 2, (10, 25, 25))
>>> pred = torch.tensor(target)
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
>>> iou(pred, target)
tensor(0.9660)
"""
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
confmat = _confusion_matrix_update(pred, target, num_classes, threshold)
return _iou_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)

Просмотреть файл

@ -0,0 +1,214 @@
from functools import partial
import numpy as np
import pytest
import torch
from pytorch_lightning.metrics.classification.iou import IoU
from pytorch_lightning.metrics.functional.iou import iou
from sklearn.metrics import jaccard_score as sk_jaccard_score
from tests.metrics.classification.inputs import (
_binary_inputs,
_binary_prob_inputs,
_multiclass_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_inputs,
_multidim_multiclass_prob_inputs,
_multilabel_inputs,
_multilabel_prob_inputs
)
from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester
def _sk_iou_binary_prob(preds, target, average=None):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
def _sk_iou_binary(preds, target, average=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
def _sk_iou_multilabel_prob(preds, target, average=None):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
def _sk_iou_multilabel(preds, target, average=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
def _sk_iou_multiclass_prob(preds, target, average=None):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
def _sk_iou_multiclass(preds, target, average=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
def _sk_iou_multidim_multiclass_prob(preds, target, average=None):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
def _sk_iou_multidim_multiclass(preds, target, average=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)
@pytest.mark.parametrize("reduction", ['elementwise_mean', 'none'])
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_iou_binary_prob, 2),
(_binary_inputs.preds, _binary_inputs.target, _sk_iou_binary, 2),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_iou_multilabel_prob, 2),
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_iou_multilabel, 2),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_iou_multiclass_prob, NUM_CLASSES),
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_iou_multiclass, NUM_CLASSES),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_sk_iou_multidim_multiclass_prob,
NUM_CLASSES
),
(
_multidim_multiclass_inputs.preds,
_multidim_multiclass_inputs.target,
_sk_iou_multidim_multiclass,
NUM_CLASSES
)
])
class TestIoU(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_confusion_matrix(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
self.run_class_metric_test(ddp=ddp,
preds=preds,
target=target,
metric_class=IoU,
sk_metric=partial(sk_metric, average=average),
dist_sync_on_step=dist_sync_on_step,
metric_args={"num_classes": num_classes,
"threshold": THRESHOLD,
"reduction": reduction}
)
def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric, num_classes):
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
self.run_functional_metric_test(preds,
target,
metric_functional=iou,
sk_metric=partial(sk_metric, average=average),
metric_args={"num_classes": num_classes,
"threshold": THRESHOLD,
"reduction": reduction}
)
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
])
def test_iou(half_ones, reduction, ignore_index, expected):
pred = (torch.arange(120) % 3).view(-1, 1)
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
pred[:60] = 1
iou_val = iou(
pred=pred,
target=target,
ignore_index=ignore_index,
reduction=reduction,
)
assert torch.allclose(iou_val, expected, atol=1e-9)
# test `absent_score`
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
# 2 classes, class 0 is correct everywhere, class 1 is absent.
pytest.param([0], [0], None, -1., 2, [1., -1.]),
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
# absent_score not applied if only class 0 is present and it's the only class.
pytest.param([0], [0], None, -1., 1, [1.]),
# 2 classes, class 1 is correct everywhere, class 0 is absent.
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
pytest.param([1], [1], 0, -1., 2, [1.0]),
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
# 2 is absent.
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
# 2 is absent.
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
# Sanity checks with absent_score of 1.0.
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
])
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
absent_score=absent_score,
num_classes=num_classes,
reduction='none',
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
])
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
num_classes=num_classes,
reduction=reduction,
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))

Просмотреть файл

@ -4,7 +4,6 @@ import pytest
import torch
from distutils.version import LooseVersion
from sklearn.metrics import (
jaccard_score as sk_jaccard_score,
precision_score as sk_precision,
recall_score as sk_recall,
roc_auc_score as sk_roc_auc_score,
@ -20,14 +19,12 @@ from pytorch_lightning.metrics.functional.classification import (
auroc,
multiclass_auroc,
auc,
iou,
)
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [
pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'),
pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'),
pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'),
pytest.param(sk_roc_auc_score, auroc, True, id='auroc')
@ -297,112 +294,10 @@ def test_dice_score(pred, target, expected):
assert score == expected
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
])
def test_iou(half_ones, reduction, ignore_index, expected):
pred = (torch.arange(120) % 3).view(-1, 1)
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
pred[:60] = 1
iou_val = iou(
pred=pred,
target=target,
ignore_index=ignore_index,
reduction=reduction,
)
assert torch.allclose(iou_val, expected, atol=1e-9)
def test_iou_input_check():
with pytest.raises(ValueError, match=r"'pred' shape (.*) must equal 'target' shape (.*)"):
_ = iou(pred=torch.randint(0, 2, (3, 4, 3)),
target=torch.randint(0, 2, (3, 3)))
with pytest.raises(ValueError, match="'pred' must contain integer targets."):
_ = iou(pred=torch.rand((3, 3)),
target=torch.randint(0, 2, (3, 3)))
@pytest.mark.parametrize('metric', [auroc])
def test_error_on_multiclass_input(metric):
""" check that these metrics raise an error if they are used for multiclass problems """
pred = torch.randint(0, 10, (100, ))
target = torch.randint(0, 10, (100, ))
pred = torch.randint(0, 10, (100,))
target = torch.randint(0, 10, (100,))
with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"):
_ = metric(pred, target)
# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see
# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our
# `absent_score`.
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
# 2 classes, class 0 is correct everywhere, class 1 is absent.
pytest.param([0], [0], None, -1., 2, [1., -1.]),
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
# absent_score not applied if only class 0 is present and it's the only class.
pytest.param([0], [0], None, -1., 1, [1.]),
# 2 classes, class 1 is correct everywhere, class 0 is absent.
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
pytest.param([1], [1], 0, -1., 2, [1.0]),
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
# 2 is absent.
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
# 2 is absent.
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
# Sanity checks with absent_score of 1.0.
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
])
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
absent_score=absent_score,
num_classes=num_classes,
reduction='none',
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
])
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
num_classes=num_classes,
reduction=reduction,
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))