* multilabel_roc_supp

* formatting

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Nicki Skafte 2021-03-23 21:41:31 +01:00 коммит произвёл GitHub
Родитель fc6c8ef124
Коммит 2af13fb1b3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 141 добавлений и 31 удалений

Просмотреть файл

@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `CohenKappa` metric ([#69](https://github.com/PyTorchLightning/metrics/pull/69))
@ -18,12 +18,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RetrievalMAP` metric for Information Retrieval ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))
- Added `average='micro'` as an option in auroc for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
- Added `MatthewsCorrcoef` metric ([#98](https://github.com/PyTorchLightning/metrics/pull/98))
- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))
### Changed
- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))

Просмотреть файл

@ -22,6 +22,7 @@ from torch import tensor
from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel_multidim_prob, _input_multilabel_prob
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.roc import ROC
from torchmetrics.functional import roc
@ -29,15 +30,19 @@ from torchmetrics.functional import roc
torch.manual_seed(42)
def _sk_roc_curve(y_true, probas_pred, num_classes=1):
def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False):
""" Adjusted comparison function that can also handles multiclass """
if num_classes == 1:
return sk_roc_curve(y_true, probas_pred, drop_intermediate=False)
fpr, tpr, thresholds = [], [], []
for i in range(num_classes):
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
if multilabel:
y_true_temp = y_true[:, i]
else:
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False)
fpr.append(res[0])
tpr.append(res[1])
@ -65,11 +70,40 @@ def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1):
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
def _sk_roc_multilabel_prob(preds, target, num_classes=1):
sk_preds = preds.numpy()
sk_target = target.numpy()
return _sk_roc_curve(
y_true=sk_target,
probas_pred=sk_preds,
num_classes=num_classes,
multilabel=True
)
def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
return _sk_roc_curve(
y_true=sk_target,
probas_pred=sk_preds,
num_classes=num_classes,
multilabel=True
)
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES),
(_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES),
(
_input_multilabel_multidim_prob.preds,
_input_multilabel_multidim_prob.target,
_sk_roc_multilabel_multidim_prob,
NUM_CLASSES
)
]
)
class TestROC(MetricTester):

Просмотреть файл

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
@ -24,13 +24,13 @@ from torchmetrics.utilities import rank_zero_warn
class ROC(Metric):
"""
Computes the Receiver Operating Characteristic (ROC). Works for both
binary and multiclass problems. In the case of multiclass, the values will
binary, multiclass and multilabel problems. In the case of multiclass, the values will
be calculated based on a one-vs-the-rest approach.
Forward accepts
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
with probabilities, where C is the number of classes.
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass/multilabel) tensor
with probabilities, where C is the number of classes/labels.
- ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels
@ -48,9 +48,12 @@ class ROC(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Example (binary case):
Example:
>>> # binary case
>>> from torchmetrics import ROC
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
@ -63,7 +66,9 @@ class ROC(Metric):
>>> thresholds
tensor([4, 3, 2, 1, 0])
>>> # multiclass case
Example (multiclass case):
>>> from torchmetrics import ROC
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05],
@ -81,8 +86,30 @@ class ROC(Metric):
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500])]
"""
Example (multilabel case):
>>> from torchmetrics import ROC
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
... [0.3584, 0.7576, 0.1183],
... [0.2286, 0.3468, 0.1338],
... [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> roc = ROC(num_classes=3, pos_label=1)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr # doctest: +NORMALIZE_WHITESPACE
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr # doctest: +NORMALIZE_WHITESPACE
[tensor([0., 0., 1., 1., 1.]),
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
tensor([0., 1., 1., 1., 1.])]
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]
"""
def __init__(
self,
num_classes: Optional[int] = None,
@ -90,11 +117,13 @@ class ROC(Metric):
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.num_classes = num_classes

Просмотреть файл

@ -75,7 +75,7 @@ def _auroc_compute(
# calculate fpr, tpr
if mode == 'multi-label':
if average == AverageMethod.MICRO:
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), num_classes, pos_label, sample_weights)
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights)
else:
# for multilabel we iteratively evaluate roc in a binary fashion
output = [

Просмотреть файл

@ -71,16 +71,28 @@ def _precision_recall_curve_update(
) -> Tuple[Tensor, Tensor, int, int]:
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
# single class evaluation
if len(preds.shape) == len(target.shape):
num_classes = 1
if pos_label is None:
rank_zero_warn('`pos_label` automatically set 1.')
pos_label = 1
preds = preds.flatten()
target = target.flatten()
if num_classes is not None and num_classes != 1:
# multilabel problem
if num_classes != preds.shape[1]:
raise ValueError(
f'Argument `num_classes` was set to {num_classes} in'
f' metric `precision_recall_curve` but detected {preds.shape[1]}'
' number of classes from predictions'
)
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
else:
# binary problem
preds = preds.flatten()
target = target.flatten()
num_classes = 1
# multi class evaluation
# multi class problem
if len(preds.shape) == len(target.shape) + 1:
if pos_label is not None:
rank_zero_warn(

Просмотреть файл

@ -27,8 +27,9 @@ def _roc_update(
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[Tensor, Tensor, int, int]:
return _precision_recall_curve_update(preds, target, num_classes, pos_label)
) -> Tuple[Tensor, Tensor, int, int, str]:
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label)
return preds, target, num_classes, pos_label
def _roc_compute(
@ -39,7 +40,7 @@ def _roc_compute(
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
if num_classes == 1:
if num_classes == 1 and preds.ndim == 1: # binary
fps, tps, thresholds = _binary_clf_curve(
preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label
)
@ -62,12 +63,19 @@ def _roc_compute(
# Recursively call per class
fpr, tpr, thresholds = [], [], []
for c in range(num_classes):
preds_c = preds[:, c]
if preds.shape == target.shape:
preds_c = preds[:, c]
target_c = target[:, c]
pos_label = 1
else:
preds_c = preds[:, c]
target_c = target
pos_label = c
res = roc(
preds=preds_c,
target=target,
target=target_c,
num_classes=1,
pos_label=c,
pos_label=pos_label,
sample_weights=sample_weights,
)
fpr.append(res[0])
@ -86,6 +94,7 @@ def roc(
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""
Computes the Receiver Operating Characteristic (ROC).
Works with both binary, multiclass and multilabel input.
Args:
preds: predictions from model (logits or probabilities)
@ -103,15 +112,16 @@ def roc(
fpr:
tensor with false positive rates.
If multiclass, this is a list of such tensors, one for each class.
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
tpr:
tensor with true positive rates.
If multiclass, this is a list of such tensors, one for each class.
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
thresholds:
thresholds used for computing false- and true postive rates
tensor with thresholds used for computing false- and true postive rates
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
Example (binary case):
Example:
>>> # binary case
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
@ -123,7 +133,9 @@ def roc(
>>> thresholds
tensor([4, 3, 2, 1, 0])
>>> # multiclass case
Example (multiclass case):
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05],
@ -139,6 +151,27 @@ def roc(
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500])]
Example (multilabel case):
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
... [0.3584, 0.7576, 0.1183],
... [0.2286, 0.3468, 0.1338],
... [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> fpr, tpr, thresholds = roc(pred, target, num_classes=3, pos_label=1)
>>> fpr # doctest: +NORMALIZE_WHITESPACE
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr
[tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])]
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]
"""
preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label)
return _roc_compute(preds, target, num_classes, pos_label, sample_weights)