Implement partial auroc metric (PL^3790)
* Implement partial auroc metric * Add pycodestyle changes * Added tests for max_fpr * changelog * version for tests * fix imports * fix tests * fix tests * Added more thresholds in (0,1] to test max_fpr * Removed deprecated 'reorder' param from auroc * changelog * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * remove old structure * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix test error Co-authored-by: Nicki Skafte <nugginea@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> (cherry picked from commit 36ea27be9d02ed4a702f0d5711808fb51bf13976)
This commit is contained in:
Родитель
3ac52f16b5
Коммит
527fa67a7b
|
@ -15,6 +15,7 @@ from functools import wraps
|
|||
from typing import Callable, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
|
||||
from pytorch_lightning.metrics.functional.f_beta import fbeta as __fb, f1 as __f1
|
||||
|
@ -544,6 +545,7 @@ def auroc(
|
|||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
max_fpr: float = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
|
||||
|
@ -553,6 +555,8 @@ def auroc(
|
|||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class
|
||||
max_fpr: If not ``None``, calculates standardized partial AUC over the
|
||||
range [0, max_fpr]. Should be a float between 0 and 1.
|
||||
|
||||
Return:
|
||||
Tensor containing ROCAUC score
|
||||
|
@ -569,11 +573,32 @@ def auroc(
|
|||
' target tensor contains value different from 0 and 1.'
|
||||
' Use `multiclass_auroc` for multi class classification.')
|
||||
|
||||
@auc_decorator()
|
||||
def _auroc(pred, target, sample_weight, pos_label):
|
||||
return _roc(pred, target, sample_weight, pos_label)
|
||||
if max_fpr is None or max_fpr == 1:
|
||||
fpr, tpr, _ = __roc(pred, target, sample_weight, pos_label)
|
||||
return auc(fpr, tpr)
|
||||
if not (isinstance(max_fpr, float) and 0 < max_fpr <= 1):
|
||||
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
|
||||
if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
|
||||
raise RuntimeError('`max_fpr` argument requires `torch.bucketize` which'
|
||||
' is not available below PyTorch version 1.6')
|
||||
|
||||
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)
|
||||
fpr, tpr, _ = __roc(pred, target, sample_weight, pos_label)
|
||||
max_fpr = torch.tensor(max_fpr, device=fpr.device)
|
||||
# Add a single point at max_fpr and interpolate its tpr value
|
||||
stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True)
|
||||
weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
|
||||
interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight)
|
||||
tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])
|
||||
fpr = torch.cat([fpr[:stop], max_fpr.view(1)])
|
||||
|
||||
# Compute partial AUC
|
||||
partial_auc = auc(fpr, tpr)
|
||||
|
||||
# McClish correction: standardize result to be 0.5 if non-discriminant
|
||||
# and 1 if maximal
|
||||
min_area = 0.5 * max_fpr ** 2
|
||||
max_area = max_fpr
|
||||
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))
|
||||
|
||||
|
||||
def multiclass_auroc(
|
||||
|
|
|
@ -2,6 +2,7 @@ from functools import partial
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from distutils.version import LooseVersion
|
||||
from sklearn.metrics import (
|
||||
jaccard_score as sk_jaccard_score,
|
||||
precision_score as sk_precision,
|
||||
|
@ -197,18 +198,41 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
|
|||
assert thresh.shape == (exp_shape,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.),
|
||||
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
|
||||
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
|
||||
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
|
||||
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5),
|
||||
@pytest.mark.parametrize(['pred', 'target', 'max_fpr', 'expected'], [
|
||||
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], None, 1.),
|
||||
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], None, 0.),
|
||||
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.8, 0.5),
|
||||
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.2, 0.5),
|
||||
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 0.5, 1.),
|
||||
])
|
||||
def test_auroc(pred, target, expected):
|
||||
score = auroc(torch.tensor(pred), torch.tensor(target)).item()
|
||||
def test_auroc(pred, target, max_fpr, expected):
|
||||
if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
|
||||
pytest.skip('requires torch v1.6 or higher to test max_fpr argument')
|
||||
|
||||
score = auroc(torch.tensor(pred), torch.tensor(target), max_fpr=max_fpr).item()
|
||||
assert score == expected
|
||||
|
||||
|
||||
@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion('1.6.0'),
|
||||
reason='requires torch v1.6 or higher to test max_fpr argument')
|
||||
@pytest.mark.parametrize('max_fpr', [
|
||||
None, 1, 0.99, 0.9, 0.75, 0.5, 0.25, 0.1, 0.01, 0.001,
|
||||
])
|
||||
def test_auroc_with_max_fpr_against_sklearn(max_fpr):
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
pred = torch.rand((300,), device=device)
|
||||
# Supports only binary classification
|
||||
target = torch.randint(2, (300,), dtype=torch.float64, device=device)
|
||||
sk_score = sk_roc_auc_score(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy(),
|
||||
max_fpr=max_fpr)
|
||||
pl_score = auroc(pred, target, max_fpr=max_fpr)
|
||||
|
||||
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
|
||||
assert torch.allclose(sk_score, pl_score)
|
||||
|
||||
|
||||
def test_multiclass_auroc():
|
||||
with pytest.raises(ValueError,
|
||||
match=r".*probabilities, i.e. they should sum up to 1.0 over classes"):
|
||||
|
|
Загрузка…
Ссылка в новой задаче