drop duplicated metric helper (PL^5366)
* drop duplicated metric helper * . * fix tests Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> (cherry picked from commit 220dfaf7dcd12ad230ea2938d8d67a119a22e143)
This commit is contained in:
Родитель
261177efa2
Коммит
1cccdc73d9
|
@ -15,7 +15,7 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.utils import _input_format_classification
|
||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
|
@ -23,7 +23,10 @@ def _confusion_matrix_update(preds: torch.Tensor,
|
|||
target: torch.Tensor,
|
||||
num_classes: int,
|
||||
threshold: float = 0.5) -> torch.Tensor:
|
||||
preds, target = _input_format_classification(preds, target, threshold)
|
||||
preds, target, mode = _input_format_classification(preds, target, threshold)
|
||||
if mode not in ('binary', 'multi-label'):
|
||||
preds = preds.argmax(dim=1)
|
||||
target = target.argmax(dim=1)
|
||||
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
|
||||
bins = torch.bincount(unique_mapping, minlength=num_classes ** 2)
|
||||
confmat = bins.reshape(num_classes, num_classes)
|
||||
|
|
|
@ -43,34 +43,6 @@ def _check_same_shape(pred: torch.Tensor, target: torch.Tensor):
|
|||
raise RuntimeError("Predictions and targets are expected to have the same shape")
|
||||
|
||||
|
||||
def _input_format_classification(
|
||||
preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert preds and target tensors into label tensors
|
||||
|
||||
Args:
|
||||
preds: either tensor with labels, tensor with probabilities/logits or
|
||||
multilabel tensor
|
||||
target: tensor with ground true labels
|
||||
threshold: float used for thresholding multilabel input
|
||||
|
||||
Returns:
|
||||
preds: tensor with labels
|
||||
target: tensor with labels
|
||||
"""
|
||||
if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1):
|
||||
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
|
||||
|
||||
if preds.ndim == target.ndim + 1:
|
||||
# multi class probabilites
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
|
||||
if preds.ndim == target.ndim and preds.is_floating_point():
|
||||
# binary or multilabel probablities
|
||||
preds = (preds >= threshold).long()
|
||||
return preds, target
|
||||
|
||||
|
||||
def _input_format_classification_one_hot(
|
||||
num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, multilabel: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
|
Загрузка…
Ссылка в новой задаче