drop duplicated metric helper (PL^5366)

* drop duplicated metric helper

* .

* fix tests

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
(cherry picked from commit 220dfaf7dcd12ad230ea2938d8d67a119a22e143)
This commit is contained in:
Jirka Borovec 2021-01-07 00:49:59 +01:00 коммит произвёл Jirka Borovec
Родитель 261177efa2
Коммит 1cccdc73d9
2 изменённых файлов: 5 добавлений и 30 удалений

Просмотреть файл

@ -15,7 +15,7 @@ from typing import Optional
import torch
from pytorch_lightning.metrics.utils import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.utilities import rank_zero_warn
@ -23,7 +23,10 @@ def _confusion_matrix_update(preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
threshold: float = 0.5) -> torch.Tensor:
preds, target = _input_format_classification(preds, target, threshold)
preds, target, mode = _input_format_classification(preds, target, threshold)
if mode not in ('binary', 'multi-label'):
preds = preds.argmax(dim=1)
target = target.argmax(dim=1)
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
bins = torch.bincount(unique_mapping, minlength=num_classes ** 2)
confmat = bins.reshape(num_classes, num_classes)

Просмотреть файл

@ -43,34 +43,6 @@ def _check_same_shape(pred: torch.Tensor, target: torch.Tensor):
raise RuntimeError("Predictions and targets are expected to have the same shape")
def _input_format_classification(
preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert preds and target tensors into label tensors
Args:
preds: either tensor with labels, tensor with probabilities/logits or
multilabel tensor
target: tensor with ground true labels
threshold: float used for thresholding multilabel input
Returns:
preds: tensor with labels
target: tensor with labels
"""
if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1):
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
if preds.ndim == target.ndim + 1:
# multi class probabilites
preds = torch.argmax(preds, dim=1)
if preds.ndim == target.ndim and preds.is_floating_point():
# binary or multilabel probablities
preds = (preds >= threshold).long()
return preds, target
def _input_format_classification_one_hot(
num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, multilabel: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]: