Multilabel support in ROC (#114)
* multilabel_roc_supp * formatting Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Родитель
fc6c8ef124
Коммит
2af13fb1b3
|
@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
### Added
|
||||
|
||||
- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
|
||||
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
|
||||
|
||||
|
||||
- Added `CohenKappa` metric ([#69](https://github.com/PyTorchLightning/metrics/pull/69))
|
||||
|
@ -18,12 +18,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Added `RetrievalMAP` metric for Information Retrieval ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))
|
||||
|
||||
|
||||
- Added `average='micro'` as an option in auroc for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
|
||||
- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
|
||||
|
||||
|
||||
- Added `MatthewsCorrcoef` metric ([#98](https://github.com/PyTorchLightning/metrics/pull/98))
|
||||
|
||||
|
||||
- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
|
||||
|
|
|
@ -22,6 +22,7 @@ from torch import tensor
|
|||
from tests.classification.inputs import _input_binary_prob
|
||||
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.classification.inputs import _input_multilabel_multidim_prob, _input_multilabel_prob
|
||||
from tests.helpers.testers import NUM_CLASSES, MetricTester
|
||||
from torchmetrics.classification.roc import ROC
|
||||
from torchmetrics.functional import roc
|
||||
|
@ -29,15 +30,19 @@ from torchmetrics.functional import roc
|
|||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def _sk_roc_curve(y_true, probas_pred, num_classes=1):
|
||||
def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False):
|
||||
""" Adjusted comparison function that can also handles multiclass """
|
||||
if num_classes == 1:
|
||||
return sk_roc_curve(y_true, probas_pred, drop_intermediate=False)
|
||||
|
||||
fpr, tpr, thresholds = [], [], []
|
||||
for i in range(num_classes):
|
||||
y_true_temp = np.zeros_like(y_true)
|
||||
y_true_temp[y_true == i] = 1
|
||||
if multilabel:
|
||||
y_true_temp = y_true[:, i]
|
||||
else:
|
||||
y_true_temp = np.zeros_like(y_true)
|
||||
y_true_temp[y_true == i] = 1
|
||||
|
||||
res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False)
|
||||
fpr.append(res[0])
|
||||
tpr.append(res[1])
|
||||
|
@ -65,11 +70,40 @@ def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1):
|
|||
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
def _sk_roc_multilabel_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.numpy()
|
||||
sk_target = target.numpy()
|
||||
return _sk_roc_curve(
|
||||
y_true=sk_target,
|
||||
probas_pred=sk_preds,
|
||||
num_classes=num_classes,
|
||||
multilabel=True
|
||||
)
|
||||
|
||||
|
||||
def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
|
||||
sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
|
||||
return _sk_roc_curve(
|
||||
y_true=sk_target,
|
||||
probas_pred=sk_preds,
|
||||
num_classes=num_classes,
|
||||
multilabel=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric, num_classes", [
|
||||
(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES),
|
||||
(_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES),
|
||||
(
|
||||
_input_multilabel_multidim_prob.preds,
|
||||
_input_multilabel_multidim_prob.target,
|
||||
_sk_roc_multilabel_multidim_prob,
|
||||
NUM_CLASSES
|
||||
)
|
||||
]
|
||||
)
|
||||
class TestROC(MetricTester):
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -24,13 +24,13 @@ from torchmetrics.utilities import rank_zero_warn
|
|||
class ROC(Metric):
|
||||
"""
|
||||
Computes the Receiver Operating Characteristic (ROC). Works for both
|
||||
binary and multiclass problems. In the case of multiclass, the values will
|
||||
binary, multiclass and multilabel problems. In the case of multiclass, the values will
|
||||
be calculated based on a one-vs-the-rest approach.
|
||||
|
||||
Forward accepts
|
||||
|
||||
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
|
||||
with probabilities, where C is the number of classes.
|
||||
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass/multilabel) tensor
|
||||
with probabilities, where C is the number of classes/labels.
|
||||
|
||||
- ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels
|
||||
|
||||
|
@ -48,9 +48,12 @@ class ROC(Metric):
|
|||
before returning the value at the step. default: False
|
||||
process_group:
|
||||
Specify the process group on which synchronization is called. default: None (which selects the entire world)
|
||||
dist_sync_fn:
|
||||
Callback that performs the allgather operation on the metric state. When ``None``, DDP
|
||||
will be used to perform the allgather
|
||||
|
||||
Example (binary case):
|
||||
|
||||
Example:
|
||||
>>> # binary case
|
||||
>>> from torchmetrics import ROC
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 1, 1])
|
||||
|
@ -63,7 +66,9 @@ class ROC(Metric):
|
|||
>>> thresholds
|
||||
tensor([4, 3, 2, 1, 0])
|
||||
|
||||
>>> # multiclass case
|
||||
Example (multiclass case):
|
||||
|
||||
>>> from torchmetrics import ROC
|
||||
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.75, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.75, 0.05],
|
||||
|
@ -81,8 +86,30 @@ class ROC(Metric):
|
|||
tensor([1.7500, 0.7500, 0.0500]),
|
||||
tensor([1.7500, 0.7500, 0.0500])]
|
||||
|
||||
"""
|
||||
Example (multilabel case):
|
||||
|
||||
>>> from torchmetrics import ROC
|
||||
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
|
||||
... [0.3584, 0.7576, 0.1183],
|
||||
... [0.2286, 0.3468, 0.1338],
|
||||
... [0.8603, 0.0745, 0.1837]])
|
||||
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
|
||||
>>> roc = ROC(num_classes=3, pos_label=1)
|
||||
>>> fpr, tpr, thresholds = roc(pred, target)
|
||||
>>> fpr # doctest: +NORMALIZE_WHITESPACE
|
||||
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
|
||||
tensor([0., 0., 0., 1., 1.]),
|
||||
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
|
||||
>>> tpr # doctest: +NORMALIZE_WHITESPACE
|
||||
[tensor([0., 0., 1., 1., 1.]),
|
||||
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
|
||||
tensor([0., 1., 1., 1., 1.])]
|
||||
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
|
||||
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
|
||||
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
|
||||
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
|
@ -90,11 +117,13 @@ class ROC(Metric):
|
|||
compute_on_step: bool = True,
|
||||
dist_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
dist_sync_fn: Callable = None,
|
||||
):
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
process_group=process_group,
|
||||
dist_sync_fn=dist_sync_fn,
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
|
|
@ -75,7 +75,7 @@ def _auroc_compute(
|
|||
# calculate fpr, tpr
|
||||
if mode == 'multi-label':
|
||||
if average == AverageMethod.MICRO:
|
||||
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), num_classes, pos_label, sample_weights)
|
||||
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights)
|
||||
else:
|
||||
# for multilabel we iteratively evaluate roc in a binary fashion
|
||||
output = [
|
||||
|
|
|
@ -71,16 +71,28 @@ def _precision_recall_curve_update(
|
|||
) -> Tuple[Tensor, Tensor, int, int]:
|
||||
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
|
||||
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
|
||||
# single class evaluation
|
||||
|
||||
if len(preds.shape) == len(target.shape):
|
||||
num_classes = 1
|
||||
if pos_label is None:
|
||||
rank_zero_warn('`pos_label` automatically set 1.')
|
||||
pos_label = 1
|
||||
preds = preds.flatten()
|
||||
target = target.flatten()
|
||||
if num_classes is not None and num_classes != 1:
|
||||
# multilabel problem
|
||||
if num_classes != preds.shape[1]:
|
||||
raise ValueError(
|
||||
f'Argument `num_classes` was set to {num_classes} in'
|
||||
f' metric `precision_recall_curve` but detected {preds.shape[1]}'
|
||||
' number of classes from predictions'
|
||||
)
|
||||
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
|
||||
target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
|
||||
else:
|
||||
# binary problem
|
||||
preds = preds.flatten()
|
||||
target = target.flatten()
|
||||
num_classes = 1
|
||||
|
||||
# multi class evaluation
|
||||
# multi class problem
|
||||
if len(preds.shape) == len(target.shape) + 1:
|
||||
if pos_label is not None:
|
||||
rank_zero_warn(
|
||||
|
|
|
@ -27,8 +27,9 @@ def _roc_update(
|
|||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
pos_label: Optional[int] = None,
|
||||
) -> Tuple[Tensor, Tensor, int, int]:
|
||||
return _precision_recall_curve_update(preds, target, num_classes, pos_label)
|
||||
) -> Tuple[Tensor, Tensor, int, int, str]:
|
||||
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label)
|
||||
return preds, target, num_classes, pos_label
|
||||
|
||||
|
||||
def _roc_compute(
|
||||
|
@ -39,7 +40,7 @@ def _roc_compute(
|
|||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
|
||||
|
||||
if num_classes == 1:
|
||||
if num_classes == 1 and preds.ndim == 1: # binary
|
||||
fps, tps, thresholds = _binary_clf_curve(
|
||||
preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label
|
||||
)
|
||||
|
@ -62,12 +63,19 @@ def _roc_compute(
|
|||
# Recursively call per class
|
||||
fpr, tpr, thresholds = [], [], []
|
||||
for c in range(num_classes):
|
||||
preds_c = preds[:, c]
|
||||
if preds.shape == target.shape:
|
||||
preds_c = preds[:, c]
|
||||
target_c = target[:, c]
|
||||
pos_label = 1
|
||||
else:
|
||||
preds_c = preds[:, c]
|
||||
target_c = target
|
||||
pos_label = c
|
||||
res = roc(
|
||||
preds=preds_c,
|
||||
target=target,
|
||||
target=target_c,
|
||||
num_classes=1,
|
||||
pos_label=c,
|
||||
pos_label=pos_label,
|
||||
sample_weights=sample_weights,
|
||||
)
|
||||
fpr.append(res[0])
|
||||
|
@ -86,6 +94,7 @@ def roc(
|
|||
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
|
||||
"""
|
||||
Computes the Receiver Operating Characteristic (ROC).
|
||||
Works with both binary, multiclass and multilabel input.
|
||||
|
||||
Args:
|
||||
preds: predictions from model (logits or probabilities)
|
||||
|
@ -103,15 +112,16 @@ def roc(
|
|||
|
||||
fpr:
|
||||
tensor with false positive rates.
|
||||
If multiclass, this is a list of such tensors, one for each class.
|
||||
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
|
||||
tpr:
|
||||
tensor with true positive rates.
|
||||
If multiclass, this is a list of such tensors, one for each class.
|
||||
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
|
||||
thresholds:
|
||||
thresholds used for computing false- and true postive rates
|
||||
tensor with thresholds used for computing false- and true postive rates
|
||||
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
|
||||
|
||||
Example (binary case):
|
||||
|
||||
Example:
|
||||
>>> # binary case
|
||||
>>> from torchmetrics.functional import roc
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 1, 1])
|
||||
|
@ -123,7 +133,9 @@ def roc(
|
|||
>>> thresholds
|
||||
tensor([4, 3, 2, 1, 0])
|
||||
|
||||
>>> # multiclass case
|
||||
Example (multiclass case):
|
||||
|
||||
>>> from torchmetrics.functional import roc
|
||||
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.75, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.75, 0.05],
|
||||
|
@ -139,6 +151,27 @@ def roc(
|
|||
tensor([1.7500, 0.7500, 0.0500]),
|
||||
tensor([1.7500, 0.7500, 0.0500]),
|
||||
tensor([1.7500, 0.7500, 0.0500])]
|
||||
|
||||
Example (multilabel case):
|
||||
|
||||
>>> from torchmetrics.functional import roc
|
||||
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
|
||||
... [0.3584, 0.7576, 0.1183],
|
||||
... [0.2286, 0.3468, 0.1338],
|
||||
... [0.8603, 0.0745, 0.1837]])
|
||||
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
|
||||
>>> fpr, tpr, thresholds = roc(pred, target, num_classes=3, pos_label=1)
|
||||
>>> fpr # doctest: +NORMALIZE_WHITESPACE
|
||||
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
|
||||
tensor([0., 0., 0., 1., 1.]),
|
||||
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
|
||||
>>> tpr
|
||||
[tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])]
|
||||
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
|
||||
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
|
||||
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
|
||||
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]
|
||||
|
||||
"""
|
||||
preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label)
|
||||
return _roc_compute(preds, target, num_classes, pos_label, sample_weights)
|
||||
|
|
Загрузка…
Ссылка в новой задаче