Native torch metrics (PL^1488)
* Create metric.py * Create utils.py * Create __init__.py * Create __init__.py * Create __init__.py * add tests for metric utils * add tests for metric utils * add docstrings for metrics utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add function to recursively apply other function to collection * add tests for this function * add tests for this function * add tests for this function * update test * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * fix tests * add metric tests * fix to tensor conversion * fix to tensor conversion * fix apply to collection * fix apply to collection * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * remove tests from init * add missing type annotations * rename utils to convertors * rename utils to convertors * rename utils to convertors * rename utils to convertors * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * rename file and fix imports * added parametrized test * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * rename apply_to_collection to apply_func * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * Add requested changes and add ellipsis for doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * forgot to push these files... * forgot to push these files... * add explicit check for dtype to convert to * add explicit check for dtype to convert to * fix ddp tests * fix ddp tests * fix ddp tests * remove explicit ddp destruction * remove explicit ddp destruction * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * add function to reduce tensors (similar to reduction in torch.nn) * add functionals of reduction metrics * add functionals of reduction metrics * add more metrics * pep8 fixes * rename * rename * add reduction tests * add first classification tests * bugfixes * bugfixes * add more unit tests * fix roc score metric * fix tests * solve tests * fix docs * Update CHANGELOG.md * remove binaries * solve changes from rebase * add eos * test auc independently * fix formatting * docs * docs * chlog * move * function descriptions * Add documentation to native metrics (#2144) * add docs * add docs * Apply suggestions from code review * formatting * add docs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> * Rename tests/metrics/test_classification.py to tests/metrics/functional/test_classification.py * Rename tests/metrics/test_reduction.py to tests/metrics/functional/test_reduction.py * Add module interface for classification metrics * add basic tests for classification metrics' module interface * pep8 * add additional converters * add additional base class * change baseclass for some metrics * update classification tests * update converter tests * update metric tests * Apply suggestions from code review * tests-params * tests-params * imports * pep8 * tests-params * formatting * fix test_metrics * typo * formatting * fix dice tests * fix decorator order * fix tests * seed * dice test * formatting * try freeze test * formatting * fix tests * try spawn * formatting * fix Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Xavier Sumba <c.uent@hotmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
Родитель
21aeabc092
Коммит
1fc078442b
|
@ -23,8 +23,8 @@ inputs to and outputs from numpy as well as automated ddp syncing.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
|
||||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||||
from pytorch_lightning.metrics.sklearn import (
|
from pytorch_lightning.metrics.sklearn import (
|
||||||
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
||||||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
||||||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
|
|
||||||
|
|
|
@ -0,0 +1,652 @@
|
||||||
|
from typing import Any, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_lightning.metrics.functional.classification import (
|
||||||
|
accuracy,
|
||||||
|
confusion_matrix,
|
||||||
|
precision_recall_curve,
|
||||||
|
precision,
|
||||||
|
recall,
|
||||||
|
average_precision,
|
||||||
|
auroc,
|
||||||
|
fbeta_score,
|
||||||
|
f1_score,
|
||||||
|
roc,
|
||||||
|
multiclass_roc,
|
||||||
|
multiclass_precision_recall_curve,
|
||||||
|
dice_score
|
||||||
|
)
|
||||||
|
from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Accuracy',
|
||||||
|
'ConfusionMatrix',
|
||||||
|
'PrecisionRecall',
|
||||||
|
'Precision',
|
||||||
|
'Recall',
|
||||||
|
'AveragePrecision',
|
||||||
|
'AUROC',
|
||||||
|
'FBeta',
|
||||||
|
'F1',
|
||||||
|
'ROC',
|
||||||
|
'MulticlassROC',
|
||||||
|
'MulticlassPrecisionRecall',
|
||||||
|
'DiceCoefficient'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Accuracy(TensorMetric):
|
||||||
|
"""
|
||||||
|
Computes the accuracy classification score
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='accuracy',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: ground truth labels
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A Tensor with the classification score.
|
||||||
|
"""
|
||||||
|
return accuracy(pred=pred, target=target,
|
||||||
|
num_classes=self.num_classes, reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix(TensorMetric):
|
||||||
|
"""
|
||||||
|
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
|
||||||
|
in group i that were predicted in group j.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalize: bool = False,
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
normalize: whether to compute a normalized confusion matrix
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
"""
|
||||||
|
super().__init__(name='confusion_matrix',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
self.normalize = normalize
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: ground truth labels
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A Tensor with the confusion matrix.
|
||||||
|
"""
|
||||||
|
return confusion_matrix(pred=pred, target=target,
|
||||||
|
normalize=self.normalize)
|
||||||
|
|
||||||
|
|
||||||
|
class PrecisionRecall(TensorCollectionMetric):
|
||||||
|
"""
|
||||||
|
Computes the precision recall curve
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pos_label: int = 1,
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pos_label: positive label indicator
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='precision_recall_curve',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.pos_label = pos_label
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: groundtruth labels
|
||||||
|
sample_weight: the weights per sample
|
||||||
|
|
||||||
|
Return:
|
||||||
|
torch.Tensor: precision values
|
||||||
|
torch.Tensor: recall values
|
||||||
|
torch.Tensor: threshold values
|
||||||
|
"""
|
||||||
|
return precision_recall_curve(pred=pred, target=target,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
pos_label=self.pos_label)
|
||||||
|
|
||||||
|
|
||||||
|
class Precision(TensorMetric):
|
||||||
|
"""
|
||||||
|
Computes the precision score
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='precision',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: ground truth labels
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A Tensor with the classification score.
|
||||||
|
"""
|
||||||
|
return precision(pred=pred, target=target,
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
|
class Recall(TensorMetric):
|
||||||
|
"""
|
||||||
|
Computes the recall score
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='recall',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: ground truth labels
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A Tensor with the classification score.
|
||||||
|
"""
|
||||||
|
return recall(pred=pred,
|
||||||
|
target=target,
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
|
class AveragePrecision(TensorMetric):
|
||||||
|
"""
|
||||||
|
Computes the average precision score
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pos_label: int = 1,
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pos_label: positive label indicator
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='AP',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.pos_label = pos_label
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: groundtruth labels
|
||||||
|
sample_weight: the weights per sample
|
||||||
|
|
||||||
|
Return:
|
||||||
|
torch.Tensor: classification score
|
||||||
|
"""
|
||||||
|
return average_precision(pred=pred, target=target,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
pos_label=self.pos_label)
|
||||||
|
|
||||||
|
|
||||||
|
class AUROC(TensorMetric):
|
||||||
|
"""
|
||||||
|
Computes the area under curve (AUC) of the receiver operator characteristic (ROC)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pos_label: int = 1,
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pos_label: positive label indicator
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='auroc',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.pos_label = pos_label
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: groundtruth labels
|
||||||
|
sample_weight: the weights per sample
|
||||||
|
|
||||||
|
Return:
|
||||||
|
torch.Tensor: classification score
|
||||||
|
"""
|
||||||
|
return auroc(pred=pred, target=target,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
pos_label=self.pos_label)
|
||||||
|
|
||||||
|
|
||||||
|
class FBeta(TensorMetric):
|
||||||
|
"""Computes the FBeta Score"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
beta: float,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
beta: determines the weight of recall in the combined score.
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='fbeta',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.beta = beta
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: groundtruth labels
|
||||||
|
|
||||||
|
Return:
|
||||||
|
torch.Tensor: classification score
|
||||||
|
"""
|
||||||
|
return fbeta_score(pred=pred, target=target,
|
||||||
|
beta=self.beta, num_classes=self.num_classes,
|
||||||
|
reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
|
class F1(TensorMetric):
|
||||||
|
"""Computes the F1 score"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='f1',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: groundtruth labels
|
||||||
|
|
||||||
|
Return:
|
||||||
|
torch.Tensor: classification score
|
||||||
|
"""
|
||||||
|
return f1_score(pred=pred, target=target,
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
reduction=self.reduction)
|
||||||
|
|
||||||
|
|
||||||
|
class ROC(TensorCollectionMetric):
|
||||||
|
"""
|
||||||
|
Computes the Receiver Operator Characteristic (ROC)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pos_label: int = 1,
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pos_label: positive label indicator
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='roc',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.pos_label = pos_label
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: groundtruth labels
|
||||||
|
sample_weight: the weights per sample
|
||||||
|
|
||||||
|
Return:
|
||||||
|
torch.Tensor: false positive rate
|
||||||
|
torch.Tensor: true positive rate
|
||||||
|
torch.Tensor: thresholds
|
||||||
|
"""
|
||||||
|
return roc(pred=pred, target=target,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
pos_label=self.pos_label)
|
||||||
|
|
||||||
|
|
||||||
|
class MulticlassROC(TensorCollectionMetric):
|
||||||
|
"""
|
||||||
|
Computes the multiclass ROC
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='multiclass_roc',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: groundtruth labels
|
||||||
|
sample_weight: Weights for each sample defining the sample's impact on the score
|
||||||
|
|
||||||
|
Return:
|
||||||
|
tuple: A tuple consisting of one tuple per class,
|
||||||
|
holding false positive rate, true positive rate and thresholds
|
||||||
|
"""
|
||||||
|
return multiclass_roc(pred=pred,
|
||||||
|
target=target,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
num_classes=self.num_classes)
|
||||||
|
|
||||||
|
|
||||||
|
class MulticlassPrecisionRecall(TensorCollectionMetric):
|
||||||
|
"""Computes the multiclass PR Curve"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(name='multiclass_precision_recall_curve',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: groundtruth labels
|
||||||
|
sample_weight: Weights for each sample defining the sample's impact on the score
|
||||||
|
|
||||||
|
Return:
|
||||||
|
tuple: A tuple consisting of one tuple per class,
|
||||||
|
holding precision, recall and thresholds
|
||||||
|
"""
|
||||||
|
return multiclass_precision_recall_curve(pred=pred,
|
||||||
|
target=target,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
num_classes=self.num_classes)
|
||||||
|
|
||||||
|
|
||||||
|
class DiceCoefficient(TensorMetric):
|
||||||
|
"""Computes the dice coefficient"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
include_background: bool = False,
|
||||||
|
nan_score: float = 0.0, no_fg_score: float = 0.0,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
reduce_group: Any = None,
|
||||||
|
reduce_op: Any = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
include_background: whether to also compute dice for the background
|
||||||
|
nan_score: score to return, if a NaN occurs during computation (denom zero)
|
||||||
|
no_fg_score: score to return, if no foreground pixel was found in target
|
||||||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
reduce_group: the process group to reduce metric results from DDP
|
||||||
|
reduce_op: the operation to perform for ddp reduction
|
||||||
|
"""
|
||||||
|
super().__init__(name='dice',
|
||||||
|
reduce_group=reduce_group,
|
||||||
|
reduce_op=reduce_op)
|
||||||
|
|
||||||
|
self.include_background = include_background
|
||||||
|
self.nan_score = nan_score
|
||||||
|
self.no_fg_score = no_fg_score
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Actual metric computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: groundtruth labels
|
||||||
|
|
||||||
|
Return:
|
||||||
|
torch.Tensor: the calculated dice coefficient
|
||||||
|
"""
|
||||||
|
return dice_score(pred=pred,
|
||||||
|
target=target,
|
||||||
|
bg=self.include_background,
|
||||||
|
nan_score=self.nan_score,
|
||||||
|
no_fg_score=self.no_fg_score,
|
||||||
|
reduction=self.reduction)
|
|
@ -4,7 +4,6 @@ conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as u
|
||||||
sync tensors between different processes in a DDP scenario, when needed.
|
sync tensors between different processes in a DDP scenario, when needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
|
||||||
import numbers
|
import numbers
|
||||||
from typing import Union, Any, Callable, Optional
|
from typing import Union, Any, Callable, Optional
|
||||||
|
|
||||||
|
@ -18,12 +17,13 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
||||||
"""
|
"""
|
||||||
Decorator function to apply a function to all inputs of a function.
|
Decorator function to apply a function to all inputs of a function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func_to_apply: the function to apply to the inputs
|
func_to_apply: the function to apply to the inputs
|
||||||
*dec_args: positional arguments for the function to be applied
|
*dec_args: positional arguments for the function to be applied
|
||||||
**dec_kwargs: keyword arguments for the function to be applied
|
**dec_kwargs: keyword arguments for the function to be applied
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
the decorated function
|
the decorated function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -42,12 +42,13 @@ def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callab
|
||||||
def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
||||||
"""
|
"""
|
||||||
Decorator function to apply a function to all outputs of a function.
|
Decorator function to apply a function to all outputs of a function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func_to_apply: the function to apply to the outputs
|
func_to_apply: the function to apply to the outputs
|
||||||
*dec_args: positional arguments for the function to be applied
|
*dec_args: positional arguments for the function to be applied
|
||||||
**dec_kwargs: keyword arguments for the function to be applied
|
**dec_kwargs: keyword arguments for the function to be applied
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
the decorated function
|
the decorated function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -69,9 +70,8 @@ def _convert_to_tensor(data: Any) -> Any:
|
||||||
Args:
|
Args:
|
||||||
data: the data to convert to tensor
|
data: the data to convert to tensor
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
the converted data
|
the converted data
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(data, numbers.Number):
|
if isinstance(data, numbers.Number):
|
||||||
return torch.tensor([data])
|
return torch.tensor([data])
|
||||||
|
@ -86,12 +86,12 @@ def _convert_to_tensor(data: Any) -> Any:
|
||||||
|
|
||||||
def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
|
def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
|
||||||
"""Convert all tensors and numpy arrays to numpy arrays.
|
"""Convert all tensors and numpy arrays to numpy arrays.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: the tensor or array to convert to numpy
|
data: the tensor or array to convert to numpy
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
the resulting numpy array
|
the resulting numpy array
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data.cpu().detach().numpy()
|
return data.cpu().detach().numpy()
|
||||||
|
@ -103,6 +103,33 @@ def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) ->
|
||||||
raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)
|
raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator converting all inputs of a function to numpy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_to_decorate: the function whose inputs shall be converted
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Callable: the decorated function
|
||||||
|
"""
|
||||||
|
return _apply_to_inputs(
|
||||||
|
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator converting all outputs of a function to tensors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_to_decorate: the function whose outputs shall be converted
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Callable: the decorated function
|
||||||
|
"""
|
||||||
|
return _apply_to_outputs(_convert_to_tensor)(func_to_decorate)
|
||||||
|
|
||||||
|
|
||||||
def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
|
def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||||
"""
|
"""
|
||||||
Decorator handling the argument conversion for metrics working on numpy.
|
Decorator handling the argument conversion for metrics working on numpy.
|
||||||
|
@ -112,19 +139,45 @@ def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||||
Args:
|
Args:
|
||||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
the decorated function
|
the decorated function
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# applies collection conversion from tensor to numpy to all inputs
|
# applies collection conversion from tensor to numpy to all inputs
|
||||||
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
|
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
|
||||||
func_convert_inputs = _apply_to_inputs(
|
func_convert_inputs = _numpy_metric_input_conversion(func_to_decorate)
|
||||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
|
|
||||||
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
|
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
|
||||||
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
func_convert_in_out = _tensor_metric_output_conversion(func_convert_inputs)
|
||||||
return func_convert_in_out
|
return func_convert_in_out
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator converting all inputs of a function to tensors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_to_decorate: the function whose inputs shall be converted
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Callable: the decorated function
|
||||||
|
"""
|
||||||
|
return _apply_to_inputs(
|
||||||
|
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator converting all numpy arrays and numbers occuring in the outputs of a function to tensors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_to_decorate: the function whose outputs shall be converted
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Callable: the decorated function
|
||||||
|
"""
|
||||||
|
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number),
|
||||||
|
_convert_to_tensor)(func_to_decorate)
|
||||||
|
|
||||||
|
|
||||||
def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
|
def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||||
"""
|
"""
|
||||||
Decorator Handling the argument conversion for metrics working on tensors.
|
Decorator Handling the argument conversion for metrics working on tensors.
|
||||||
|
@ -133,16 +186,33 @@ def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||||
Args:
|
Args:
|
||||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
the decorated function
|
the decorated function
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# converts all inputs to tensor if possible
|
# converts all inputs to tensor if possible
|
||||||
# we need to include tensors here, since otherwise they will also be treated as sequences
|
# we need to include tensors here, since otherwise they will also be treated as sequences
|
||||||
func_convert_inputs = _apply_to_inputs(
|
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
|
||||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
|
|
||||||
# convert all outputs to tensor if possible
|
# convert all outputs to tensor if possible
|
||||||
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
return _tensor_metric_output_conversion(func_convert_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator Handling the argument conversion for metrics working on tensors.
|
||||||
|
All inputs of the decorated function and all numpy arrays and numbers in
|
||||||
|
it's outputs will be converted to tensors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||||
|
|
||||||
|
Return:
|
||||||
|
the decorated function
|
||||||
|
"""
|
||||||
|
# converts all inputs to tensor if possible
|
||||||
|
# we need to include tensors here, since otherwise they will also be treated as sequences
|
||||||
|
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
|
||||||
|
# convert all outputs to tensor if possible
|
||||||
|
return _tensor_collection_metric_output_conversion(func_convert_inputs)
|
||||||
|
|
||||||
|
|
||||||
def _sync_ddp_if_available(result: Union[torch.Tensor],
|
def _sync_ddp_if_available(result: Union[torch.Tensor],
|
||||||
|
@ -157,9 +227,8 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
|
||||||
group: the process group to gather results from. Defaults to all processes (world)
|
group: the process group to gather results from. Defaults to all processes (world)
|
||||||
reduce_op: the reduction operation. Defaults to sum.
|
reduce_op: the reduction operation. Defaults to sum.
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
reduced value
|
reduced value
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||||
|
@ -177,11 +246,32 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def sync_ddp(group: Optional[Any] = None,
|
||||||
|
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||||
|
"""
|
||||||
|
This decorator syncs a functions outputs across different processes for DDP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: the process group to gather results from. Defaults to all processes (world)
|
||||||
|
reduce_op: the reduction operation. Defaults to sum
|
||||||
|
|
||||||
|
Return:
|
||||||
|
the decorated function
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator_fn(func_to_decorate):
|
||||||
|
return _apply_to_outputs(apply_to_collection, torch.Tensor,
|
||||||
|
_sync_ddp_if_available, group=group,
|
||||||
|
reduce_op=reduce_op)(func_to_decorate)
|
||||||
|
|
||||||
|
return decorator_fn
|
||||||
|
|
||||||
|
|
||||||
def numpy_metric(group: Optional[Any] = None,
|
def numpy_metric(group: Optional[Any] = None,
|
||||||
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||||
"""
|
"""
|
||||||
This decorator shall be used on all function metrics working on numpy arrays.
|
This decorator shall be used on all function metrics working on numpy arrays.
|
||||||
|
|
||||||
It handles the argument conversion and DDP reduction for metrics working on numpy.
|
It handles the argument conversion and DDP reduction for metrics working on numpy.
|
||||||
All inputs of the decorated function will be converted to numpy and all
|
All inputs of the decorated function will be converted to numpy and all
|
||||||
outputs will be converted to tensors.
|
outputs will be converted to tensors.
|
||||||
|
@ -191,15 +281,12 @@ def numpy_metric(group: Optional[Any] = None,
|
||||||
group: the process group to gather results from. Defaults to all processes (world)
|
group: the process group to gather results from. Defaults to all processes (world)
|
||||||
reduce_op: the reduction operation. Defaults to sum
|
reduce_op: the reduction operation. Defaults to sum
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
the decorated function
|
the decorated function
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator_fn(func_to_decorate):
|
def decorator_fn(func_to_decorate):
|
||||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
|
return sync_ddp(group=group, reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
|
||||||
group=group,
|
|
||||||
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
|
|
||||||
|
|
||||||
return decorator_fn
|
return decorator_fn
|
||||||
|
|
||||||
|
@ -208,7 +295,6 @@ def tensor_metric(group: Optional[Any] = None,
|
||||||
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||||
"""
|
"""
|
||||||
This decorator shall be used on all function metrics working on tensors.
|
This decorator shall be used on all function metrics working on tensors.
|
||||||
|
|
||||||
It handles the argument conversion and DDP reduction for metrics working on tensors.
|
It handles the argument conversion and DDP reduction for metrics working on tensors.
|
||||||
All inputs and outputs of the decorated function will be converted to tensors.
|
All inputs and outputs of the decorated function will be converted to tensors.
|
||||||
In DDP Training all output tensors will be reduced according to the given rules.
|
In DDP Training all output tensors will be reduced according to the given rules.
|
||||||
|
@ -217,14 +303,34 @@ def tensor_metric(group: Optional[Any] = None,
|
||||||
group: the process group to gather results from. Defaults to all processes (world)
|
group: the process group to gather results from. Defaults to all processes (world)
|
||||||
reduce_op: the reduction operation. Defaults to sum
|
reduce_op: the reduction operation. Defaults to sum
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
the decorated function
|
the decorated function
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator_fn(func_to_decorate):
|
def decorator_fn(func_to_decorate):
|
||||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
|
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
|
||||||
group=group,
|
|
||||||
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
|
return decorator_fn
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_collection_metric(group: Optional[Any] = None,
|
||||||
|
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||||
|
"""
|
||||||
|
This decorator shall be used on all function metrics working on tensors and returning collections
|
||||||
|
that cannot be converted to tensors.
|
||||||
|
It handles the argument conversion and DDP reduction for metrics working on tensors.
|
||||||
|
All inputs and outputs of the decorated function will be converted to tensors.
|
||||||
|
In DDP Training all output tensors will be reduced according to the given rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: the process group to gather results from. Defaults to all processes (world)
|
||||||
|
reduce_op: the reduction operation. Defaults to sum
|
||||||
|
|
||||||
|
Return:
|
||||||
|
the decorated function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator_fn(func_to_decorate):
|
||||||
|
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_collection_metric_conversion(func_to_decorate))
|
||||||
|
|
||||||
return decorator_fn
|
return decorator_fn
|
||||||
|
|
|
@ -0,0 +1,693 @@
|
||||||
|
from collections import Sequence
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional, Tuple, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||||
|
|
||||||
|
|
||||||
|
def to_onehot(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
n_classes: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts a dense label tensor to one-hot format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: dense label tensor, with shape [N, d1, d2, ...]
|
||||||
|
|
||||||
|
n_classes: number of classes C
|
||||||
|
|
||||||
|
Output:
|
||||||
|
A sparse label tensor with shape [N, C, d1, d2, ...]
|
||||||
|
"""
|
||||||
|
if n_classes is None:
|
||||||
|
n_classes = int(tensor.max().detach().item() + 1)
|
||||||
|
dtype, device, shape = tensor.dtype, tensor.device, tensor.shape
|
||||||
|
tensor_onehot = torch.zeros(shape[0], n_classes, *shape[1:],
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
index = tensor.long().unsqueeze(1).expand_as(tensor_onehot)
|
||||||
|
return tensor_onehot.scatter_(1, index, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts a tensor of probabilities to a dense label tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: probabilities to get the categorical label [N, d1, d2, ...]
|
||||||
|
argmax_dim: dimension to apply (default: 1)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A tensor with categorical labels [N, d2, ...]
|
||||||
|
"""
|
||||||
|
return torch.argmax(tensor, dim=argmax_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_classes(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
num_classes: Optional[int],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Returns the number of classes for a given prediction and target tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted values
|
||||||
|
target: true labels
|
||||||
|
num_classes: number of classes if known (default: None)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
An integer that represents the number of classes.
|
||||||
|
"""
|
||||||
|
if num_classes is None:
|
||||||
|
if pred.ndim > target.ndim:
|
||||||
|
num_classes = pred.size(1)
|
||||||
|
else:
|
||||||
|
num_classes = int(target.max().detach().item() + 1)
|
||||||
|
return num_classes
|
||||||
|
|
||||||
|
|
||||||
|
def stat_scores(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
class_index: int, argmax_dim: int = 1,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Calculates the number of true positive, falsepositivee, true negative
|
||||||
|
and false negative for a specific class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: prediction tensor
|
||||||
|
|
||||||
|
target: target tensor
|
||||||
|
|
||||||
|
class_index: class to calculate over
|
||||||
|
|
||||||
|
argmax_dim: if pred is a tensor of probabilities, this indicates the
|
||||||
|
axis the argmax transformation will be applied over
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensors in the following order: True Positive, False Positive, True Negative, False Negative
|
||||||
|
|
||||||
|
"""
|
||||||
|
if pred.ndim == target.ndim + 1:
|
||||||
|
pred = to_categorical(pred, argmax_dim=argmax_dim)
|
||||||
|
|
||||||
|
tp = ((pred == class_index) * (target == class_index)).to(torch.long).sum()
|
||||||
|
fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum()
|
||||||
|
tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum()
|
||||||
|
fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum()
|
||||||
|
|
||||||
|
return tp, fp, tn, fn
|
||||||
|
|
||||||
|
|
||||||
|
def stat_scores_multiple_classes(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
argmax_dim: int = 1,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Calls the stat_scores function iteratively for all classes, thus
|
||||||
|
calculating the number of true postive, false postive, true negative
|
||||||
|
and false negative for each class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: prediction tensor
|
||||||
|
target: target tensor
|
||||||
|
class_index: class to calculate over
|
||||||
|
argmax_dim: if pred is a tensor of probabilities, this indicates the
|
||||||
|
axis the argmax transformation will be applied over
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Returns tensors for: tp, fp, tn, fn
|
||||||
|
|
||||||
|
"""
|
||||||
|
num_classes = get_num_classes(pred=pred, target=target,
|
||||||
|
num_classes=num_classes)
|
||||||
|
|
||||||
|
if pred.ndim == target.ndim + 1:
|
||||||
|
pred = to_categorical(pred, argmax_dim=argmax_dim)
|
||||||
|
|
||||||
|
tps = torch.zeros((num_classes,), device=pred.device)
|
||||||
|
fps = torch.zeros((num_classes,), device=pred.device)
|
||||||
|
tns = torch.zeros((num_classes,), device=pred.device)
|
||||||
|
fns = torch.zeros((num_classes,), device=pred.device)
|
||||||
|
|
||||||
|
for c in range(num_classes):
|
||||||
|
tps[c], fps[c], tns[c], fns[c] = stat_scores(pred=pred, target=target,
|
||||||
|
class_index=c)
|
||||||
|
|
||||||
|
return tps, fps, tns, fns
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction='elementwise_mean',
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes the accuracy classification score
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: predicted labels
|
||||||
|
target: ground truth labels
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A Tensor with the classification score.
|
||||||
|
"""
|
||||||
|
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
|
||||||
|
num_classes=num_classes)
|
||||||
|
|
||||||
|
if not (target > 0).any() and num_classes is None:
|
||||||
|
raise RuntimeError("cannot infer num_classes when target is all zero")
|
||||||
|
|
||||||
|
accuracies = (tps + tns) / (tps + tns + fps + fns)
|
||||||
|
|
||||||
|
return reduce(accuracies, reduction=reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def confusion_matrix(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
normalize: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
|
||||||
|
in group i that were predicted in group j.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated targets
|
||||||
|
target: ground truth labels
|
||||||
|
normalize: normalizes confusion matrix
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor, confusion matrix C [num_classes, num_classes ]
|
||||||
|
"""
|
||||||
|
num_classes = get_num_classes(pred, target, None)
|
||||||
|
|
||||||
|
d = target.size(-1)
|
||||||
|
batch_vec = torch.arange(target.size(-1))
|
||||||
|
# this will account for multilabel
|
||||||
|
unique_labels = batch_vec * num_classes ** 2 + target.view(-1) * num_classes + pred.view(-1)
|
||||||
|
|
||||||
|
bins = torch.bincount(unique_labels, minlength=d * num_classes ** 2)
|
||||||
|
cm = bins.reshape(d, num_classes, num_classes).squeeze().float()
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
cm = cm / cm.sum(-1)
|
||||||
|
|
||||||
|
return cm
|
||||||
|
|
||||||
|
|
||||||
|
def precision_recall(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Computes precision and recall for different thresholds
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: method for reducing precision-recall values (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor with precision and recall
|
||||||
|
"""
|
||||||
|
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
|
||||||
|
target=target,
|
||||||
|
num_classes=num_classes)
|
||||||
|
|
||||||
|
tps = tps.to(torch.float)
|
||||||
|
fps = fps.to(torch.float)
|
||||||
|
fns = fns.to(torch.float)
|
||||||
|
|
||||||
|
precision = tps / (tps + fps)
|
||||||
|
recall = tps / (tps + fns)
|
||||||
|
|
||||||
|
precision = reduce(precision, reduction=reduction)
|
||||||
|
recall = reduce(recall, reduction=reduction)
|
||||||
|
return precision, recall
|
||||||
|
|
||||||
|
|
||||||
|
def precision(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes precision score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: method for reducing precision values (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor with precision.
|
||||||
|
"""
|
||||||
|
return precision_recall(pred=pred, target=target,
|
||||||
|
num_classes=num_classes, reduction=reduction)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def recall(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes recall score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: method for reducing recall values (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor with recall.
|
||||||
|
"""
|
||||||
|
return precision_recall(pred=pred, target=target,
|
||||||
|
num_classes=num_classes, reduction=reduction)[1]
|
||||||
|
|
||||||
|
|
||||||
|
def fbeta_score(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
beta: float,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
|
||||||
|
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
beta: weights recall when combining the score.
|
||||||
|
beta < 1: more weight to precision.
|
||||||
|
beta > 1 more weight to recall
|
||||||
|
beta = 0: only precision
|
||||||
|
beta -> inf: only recall
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: method for reducing F-score (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor with the value of F-score. It is a value between 0-1.
|
||||||
|
"""
|
||||||
|
prec, rec = precision_recall(pred=pred, target=target,
|
||||||
|
num_classes=num_classes,
|
||||||
|
reduction='none')
|
||||||
|
|
||||||
|
nom = (1 + beta ** 2) * prec * rec
|
||||||
|
denom = ((beta ** 2) * prec + rec)
|
||||||
|
fbeta = nom / denom
|
||||||
|
|
||||||
|
return reduce(fbeta, reduction=reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def f1_score(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
reduction='elementwise_mean',
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes F1-score a.k.a F-measure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
num_classes: number of classes
|
||||||
|
reduction: method for reducing F1-score (default: takes the mean)
|
||||||
|
Available reduction methods:
|
||||||
|
|
||||||
|
- elementwise_mean: takes the mean
|
||||||
|
- none: pass array
|
||||||
|
- sum: add elements.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tensor containing F1-score
|
||||||
|
"""
|
||||||
|
return fbeta_score(pred=pred, target=target, beta=1.,
|
||||||
|
num_classes=num_classes, reduction=reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def _binary_clf_curve(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
pos_label: int = 1.,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
|
||||||
|
"""
|
||||||
|
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
|
||||||
|
sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float)
|
||||||
|
|
||||||
|
# remove class dimension if necessary
|
||||||
|
if pred.ndim > target.ndim:
|
||||||
|
pred = pred[:, 0]
|
||||||
|
desc_score_indices = torch.argsort(pred, descending=True)
|
||||||
|
|
||||||
|
pred = pred[desc_score_indices]
|
||||||
|
target = target[desc_score_indices]
|
||||||
|
|
||||||
|
if sample_weight is not None:
|
||||||
|
weight = sample_weight[desc_score_indices]
|
||||||
|
else:
|
||||||
|
weight = 1.
|
||||||
|
|
||||||
|
# pred typically has many tied values. Here we extract
|
||||||
|
# the indices associated with the distinct values. We also
|
||||||
|
# concatenate a value for the end of the curve.
|
||||||
|
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
|
||||||
|
threshold_idxs = torch.cat([distinct_value_indices,
|
||||||
|
torch.tensor([target.size(0) - 1])])
|
||||||
|
|
||||||
|
target = (target == pos_label).to(torch.long)
|
||||||
|
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
|
||||||
|
|
||||||
|
if sample_weight is not None:
|
||||||
|
# express fps as a cumsum to ensure fps is increasing even in
|
||||||
|
# the presence of floating point errors
|
||||||
|
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
|
||||||
|
|
||||||
|
else:
|
||||||
|
fps = 1 + threshold_idxs - tps
|
||||||
|
|
||||||
|
return fps, tps, pred[threshold_idxs]
|
||||||
|
|
||||||
|
|
||||||
|
def roc(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
pos_label: int = 1.,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
sample_weight: sample weights
|
||||||
|
pos_label: the label for the positive class (default: 1)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
[Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||||||
|
"""
|
||||||
|
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
pos_label=pos_label)
|
||||||
|
|
||||||
|
# Add an extra threshold position
|
||||||
|
# to make sure that the curve starts at (0, 0)
|
||||||
|
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
|
||||||
|
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
|
||||||
|
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
|
||||||
|
|
||||||
|
if fps[-1] <= 0:
|
||||||
|
raise ValueError("No negative samples in targets, false positive value should be meaningless")
|
||||||
|
|
||||||
|
fpr = fps / fps[-1]
|
||||||
|
|
||||||
|
if tps[-1] <= 0:
|
||||||
|
raise ValueError("No positive samples in targets, true positive value should be meaningless")
|
||||||
|
|
||||||
|
tpr = tps / tps[-1]
|
||||||
|
|
||||||
|
return fpr, tpr, thresholds
|
||||||
|
|
||||||
|
|
||||||
|
def multiclass_roc(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
sample_weight: sample weights
|
||||||
|
num_classes: number of classes (default: None, computes automatically from data)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
[num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
|
||||||
|
number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||||||
|
"""
|
||||||
|
num_classes = get_num_classes(pred, target, num_classes)
|
||||||
|
|
||||||
|
class_roc_vals = []
|
||||||
|
for c in range(num_classes):
|
||||||
|
pred_c = pred[:, c]
|
||||||
|
|
||||||
|
class_roc_vals.append(roc(pred=pred_c, target=target,
|
||||||
|
sample_weight=sample_weight, pos_label=c))
|
||||||
|
|
||||||
|
return tuple(class_roc_vals)
|
||||||
|
|
||||||
|
|
||||||
|
def precision_recall_curve(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
pos_label: int = 1.,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Computes precision-recall pairs for different thresholds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
sample_weight: sample weights
|
||||||
|
pos_label: the label for the positive class (default: 1.)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
[Tensor, Tensor, Tensor]: precision, recall, thresholds
|
||||||
|
"""
|
||||||
|
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
pos_label=pos_label)
|
||||||
|
|
||||||
|
precision = tps / (tps + fps)
|
||||||
|
recall = tps / tps[-1]
|
||||||
|
|
||||||
|
# stop when full recall attained
|
||||||
|
# and reverse the outputs so recall is decreasing
|
||||||
|
last_ind = torch.where(tps == tps[-1])[0][0]
|
||||||
|
sl = slice(0, last_ind.item() + 1)
|
||||||
|
|
||||||
|
# need to call reversed explicitly, since including that to slice would
|
||||||
|
# introduce negative strides that are not yet supported in pytorch
|
||||||
|
precision = torch.cat([reversed(precision[sl]),
|
||||||
|
torch.ones(1, dtype=precision.dtype,
|
||||||
|
device=precision.device)])
|
||||||
|
|
||||||
|
recall = torch.cat([reversed(recall[sl]),
|
||||||
|
torch.zeros(1, dtype=recall.dtype,
|
||||||
|
device=recall.device)])
|
||||||
|
|
||||||
|
thresholds = torch.tensor(reversed(thresholds[sl]))
|
||||||
|
|
||||||
|
return precision, recall, thresholds
|
||||||
|
|
||||||
|
|
||||||
|
def multiclass_precision_recall_curve(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
num_classes: Optional[int] = None,
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Computes precision-recall pairs for different thresholds given a multiclass scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
sample_weight: sample weight
|
||||||
|
num_classes: number of classes
|
||||||
|
|
||||||
|
Return:
|
||||||
|
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
|
||||||
|
"""
|
||||||
|
num_classes = get_num_classes(pred, target, num_classes)
|
||||||
|
|
||||||
|
class_pr_vals = []
|
||||||
|
for c in range(num_classes):
|
||||||
|
pred_c = pred[:, c]
|
||||||
|
|
||||||
|
class_pr_vals.append(precision_recall_curve(
|
||||||
|
pred=pred_c,
|
||||||
|
target=target,
|
||||||
|
sample_weight=sample_weight, pos_label=c))
|
||||||
|
|
||||||
|
return tuple(class_pr_vals)
|
||||||
|
|
||||||
|
|
||||||
|
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
|
||||||
|
"""
|
||||||
|
Computes Area Under the Curve (AUC) using the trapezoidal rule
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: x-coordinates
|
||||||
|
y: y-coordinates
|
||||||
|
reorder: reorder coordinates, so they are increasing.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
AUC score (float)
|
||||||
|
"""
|
||||||
|
direction = 1.
|
||||||
|
|
||||||
|
if reorder:
|
||||||
|
# can't use lexsort here since it is not implemented for torch
|
||||||
|
order = torch.argsort(x)
|
||||||
|
x, y = x[order], y[order]
|
||||||
|
else:
|
||||||
|
dx = x[1:] - x[:-1]
|
||||||
|
if (dx < 0).any():
|
||||||
|
if (dx, 0).all():
|
||||||
|
direction = -1.
|
||||||
|
else:
|
||||||
|
raise ValueError("Reordering is not turned on, and "
|
||||||
|
"the x array is not increasing: %s" % x)
|
||||||
|
|
||||||
|
return direction * torch.trapz(y, x)
|
||||||
|
|
||||||
|
|
||||||
|
def auc_decorator(reorder: bool = True) -> Callable:
|
||||||
|
def wrapper(func_to_decorate: Callable) -> Callable:
|
||||||
|
@wraps(func_to_decorate)
|
||||||
|
def new_func(*args, **kwargs) -> torch.Tensor:
|
||||||
|
x, y = func_to_decorate(*args, **kwargs)[:2]
|
||||||
|
|
||||||
|
return auc(x, y, reorder=reorder)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def multiclass_auc_decorator(reorder: bool = True) -> Callable:
|
||||||
|
def wrapper(func_to_decorate: Callable) -> Callable:
|
||||||
|
def new_func(*args, **kwargs) -> torch.Tensor:
|
||||||
|
results = []
|
||||||
|
for class_result in func_to_decorate(*args, **kwargs):
|
||||||
|
x, y = class_result[:2]
|
||||||
|
results.append(auc(x, y, reorder=reorder))
|
||||||
|
|
||||||
|
return torch.cat(results)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def auroc(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
pos_label: int = 1.,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: estimated probabilities
|
||||||
|
target: ground-truth labels
|
||||||
|
sample_weight: sample weights
|
||||||
|
pos_label: the label for the positive class (default: 1.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@auc_decorator(reorder=True)
|
||||||
|
def _auroc(pred, target, sample_weight, pos_label):
|
||||||
|
return roc(pred, target, sample_weight, pos_label)
|
||||||
|
|
||||||
|
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)
|
||||||
|
|
||||||
|
|
||||||
|
def average_precision(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
sample_weight: Optional[Sequence] = None,
|
||||||
|
pos_label: int = 1.,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
pos_label=pos_label)
|
||||||
|
# Return the step function integral
|
||||||
|
# The following works because the last entry of precision is
|
||||||
|
# guaranteed to be 1, as returned by precision_recall_curve
|
||||||
|
return -torch.sum(recall[1:] - recall[:-1] * precision[:-1])
|
||||||
|
|
||||||
|
|
||||||
|
def dice_score(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
bg: bool = False,
|
||||||
|
nan_score: float = 0.0,
|
||||||
|
no_fg_score: float = 0.0,
|
||||||
|
reduction: str = 'elementwise_mean',
|
||||||
|
) -> torch.Tensor:
|
||||||
|
n_classes = pred.shape[1]
|
||||||
|
bg = (1 - int(bool(bg)))
|
||||||
|
scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32)
|
||||||
|
for i in range(bg, n_classes):
|
||||||
|
if not (target == i).any():
|
||||||
|
# no foreground class
|
||||||
|
scores[i - bg] += no_fg_score
|
||||||
|
continue
|
||||||
|
|
||||||
|
tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i)
|
||||||
|
|
||||||
|
denom = (2 * tp + fp + fn).to(torch.float)
|
||||||
|
|
||||||
|
if torch.isclose(denom, torch.zeros_like(denom)).any():
|
||||||
|
# nan result
|
||||||
|
score_cls = nan_score
|
||||||
|
else:
|
||||||
|
score_cls = (2 * tp).to(torch.float) / denom
|
||||||
|
|
||||||
|
scores[i - bg] += score_cls
|
||||||
|
return reduce(scores, reduction=reduction)
|
|
@ -0,0 +1,24 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Reduces a given tensor by a given reduction method
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_reduce : the tensor, which shall be reduced
|
||||||
|
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
|
||||||
|
|
||||||
|
Return:
|
||||||
|
reduced Tensor
|
||||||
|
|
||||||
|
Raise:
|
||||||
|
ValueError if an invalid reduction parameter was given
|
||||||
|
"""
|
||||||
|
if reduction == 'elementwise_mean':
|
||||||
|
return torch.mean(to_reduce)
|
||||||
|
if reduction == 'none':
|
||||||
|
return to_reduce
|
||||||
|
if reduction == 'sum':
|
||||||
|
return torch.sum(to_reduce)
|
||||||
|
raise ValueError('Reduction parameter unknown.')
|
|
@ -3,16 +3,16 @@ from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch.nn import Module
|
|
||||||
|
|
||||||
from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric
|
from pytorch_lightning.metrics.converters import (
|
||||||
|
tensor_metric, numpy_metric, tensor_collection_metric)
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||||
|
|
||||||
__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']
|
__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']
|
||||||
|
|
||||||
|
|
||||||
class Metric(ABC, DeviceDtypeModuleMixin, Module):
|
class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for metric implementation.
|
Abstract base class for metric implementation.
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ class Metric(ABC, DeviceDtypeModuleMixin, Module):
|
||||||
1. Return multiple Outputs
|
1. Return multiple Outputs
|
||||||
2. Handle their own DDP sync
|
2. Handle their own DDP sync
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -49,6 +50,7 @@ class TensorMetric(Metric):
|
||||||
All inputs and outputs will be casted to tensors if necessary.
|
All inputs and outputs will be casted to tensors if necessary.
|
||||||
Already handles DDP sync and input/output conversions.
|
Already handles DDP sync and input/output conversions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str,
|
def __init__(self, name: str,
|
||||||
reduce_group: Optional[Any] = None,
|
reduce_group: Optional[Any] = None,
|
||||||
reduce_op: Optional[Any] = None):
|
reduce_op: Optional[Any] = None):
|
||||||
|
@ -73,6 +75,47 @@ class TensorMetric(Metric):
|
||||||
_to_device_dtype)
|
_to_device_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorCollectionMetric(Metric):
|
||||||
|
"""
|
||||||
|
Base class for metric implementation operating directly on tensors.
|
||||||
|
All inputs will be casted to tensors if necessary. Outputs won't be casted.
|
||||||
|
Already handles DDP sync and input conversions.
|
||||||
|
|
||||||
|
This class differs from :class:`TensorMetric`, as it assumes all outputs to
|
||||||
|
be collections of tensors and does not explicitly convert them. This is
|
||||||
|
necessary, since some collections (like for ROC, Precision-Recall Curve etc.)
|
||||||
|
cannot be converted to tensors at the highest level.
|
||||||
|
All numpy arrays and numbers occuring in these outputs will still be converted.
|
||||||
|
|
||||||
|
Use this class as a baseclass, whenever you want to ensure inputs are
|
||||||
|
tensors and outputs cannot be converted to tensors automatically
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str,
|
||||||
|
reduce_group: Optional[Any] = None,
|
||||||
|
reduce_op: Optional[Any] = None):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: the metric's name
|
||||||
|
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||||
|
Defaults to all processes (world)
|
||||||
|
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||||
|
Defaults to sum.
|
||||||
|
"""
|
||||||
|
super().__init__(name)
|
||||||
|
self._orig_call = tensor_collection_metric(group=reduce_group,
|
||||||
|
reduce_op=reduce_op)(super().__call__)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||||
|
|
||||||
|
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
|
||||||
|
_to_device_dtype)
|
||||||
|
|
||||||
|
|
||||||
class NumpyMetric(Metric):
|
class NumpyMetric(Metric):
|
||||||
"""
|
"""
|
||||||
Base class for metric implementation operating on numpy arrays.
|
Base class for metric implementation operating on numpy arrays.
|
||||||
|
@ -80,6 +123,7 @@ class NumpyMetric(Metric):
|
||||||
be casted to tensors if necessary.
|
be casted to tensors if necessary.
|
||||||
Already handles DDP sync and input/output conversions.
|
Already handles DDP sync and input/output conversions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str,
|
def __init__(self, name: str,
|
||||||
reduce_group: Optional[Any] = None,
|
reduce_group: Optional[Any] = None,
|
||||||
reduce_op: Optional[Any] = None):
|
reduce_op: Optional[Any] = None):
|
||||||
|
|
|
@ -1,130 +0,0 @@
|
||||||
import numbers
|
|
||||||
from typing import Union, Any, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data._utils.collate import default_convert
|
|
||||||
|
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_to_inputs(func_to_apply, *dec_args, **dec_kwargs):
|
|
||||||
def decorator_fn(func_to_decorate):
|
|
||||||
def new_func(*args, **kwargs):
|
|
||||||
args = func_to_apply(args, *dec_args, **dec_kwargs)
|
|
||||||
kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs)
|
|
||||||
return func_to_decorate(*args, **kwargs)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
return decorator_fn
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_to_outputs(func_to_apply, *dec_args, **dec_kwargs):
|
|
||||||
def decorator_fn(function_to_decorate):
|
|
||||||
def new_func(*args, **kwargs):
|
|
||||||
result = function_to_decorate(*args, **kwargs)
|
|
||||||
return func_to_apply(result, *dec_args, **dec_kwargs)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
return decorator_fn
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_tensor(data: Any) -> Any:
|
|
||||||
"""
|
|
||||||
Maps all kind of collections and numbers to tensors
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: the data to convert to tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
the converted data
|
|
||||||
|
|
||||||
"""
|
|
||||||
if isinstance(data, numbers.Number):
|
|
||||||
return torch.tensor([data])
|
|
||||||
else:
|
|
||||||
return default_convert(data)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
converts all tensors and numpy arrays to numpy arrays
|
|
||||||
Args:
|
|
||||||
data: the tensor or array to convert to numpy
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
the resulting numpy array
|
|
||||||
|
|
||||||
"""
|
|
||||||
if isinstance(data, torch.Tensor):
|
|
||||||
return data.cpu().detach().numpy()
|
|
||||||
elif isinstance(data, numbers.Number):
|
|
||||||
return np.array([data])
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _numpy_metric_conversion(func_to_decorate):
|
|
||||||
# Applies collection conversion from tensor to numpy to all inputs
|
|
||||||
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
|
|
||||||
func_convert_inputs = _apply_to_inputs(
|
|
||||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
|
|
||||||
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
|
|
||||||
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
|
||||||
return func_convert_in_out
|
|
||||||
|
|
||||||
|
|
||||||
def _tensor_metric_conversion(func_to_decorate):
|
|
||||||
# Converts all inputs to tensor if possible
|
|
||||||
func_convert_inputs = _apply_to_inputs(_convert_to_tensor)(func_to_decorate)
|
|
||||||
# convert all outputs to tensor if possible
|
|
||||||
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def _sync_ddp(result: Union[torch.Tensor],
|
|
||||||
group: Any = torch.distributed.group.WORLD,
|
|
||||||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Function to reduce the tensors from several ddp processes to one master process
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: the value to sync and reduce (typically tensor or number)
|
|
||||||
device: the device to put the synced and reduced value to
|
|
||||||
dtype: the datatype to convert the synced and reduced value to
|
|
||||||
group: the process group to gather results from. Defaults to all processes (world)
|
|
||||||
reduce_op: the reduction operation. Defaults to sum
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
reduced value
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
||||||
# sync all processes before reduction
|
|
||||||
torch.distributed.barrier(group=group)
|
|
||||||
torch.distributed.all_reduce(result, op=reduce_op, group=group,
|
|
||||||
async_op=False)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def numpy_metric(group: Any = torch.distributed.group.WORLD,
|
|
||||||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM):
|
|
||||||
def decorator_fn(func_to_decorate):
|
|
||||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
|
|
||||||
group=group,
|
|
||||||
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
|
|
||||||
|
|
||||||
return decorator_fn
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_metric(group: Any = torch.distributed.group.WORLD,
|
|
||||||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM):
|
|
||||||
def decorator_fn(func_to_decorate):
|
|
||||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
|
|
||||||
group=group,
|
|
||||||
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
|
|
||||||
|
|
||||||
return decorator_fn
|
|
|
@ -0,0 +1,309 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
from pytorch_lightning.metrics.functional.classification import (
|
||||||
|
to_onehot,
|
||||||
|
to_categorical,
|
||||||
|
get_num_classes,
|
||||||
|
stat_scores,
|
||||||
|
stat_scores_multiple_classes,
|
||||||
|
accuracy,
|
||||||
|
confusion_matrix,
|
||||||
|
precision,
|
||||||
|
recall,
|
||||||
|
fbeta_score,
|
||||||
|
f1_score,
|
||||||
|
_binary_clf_curve,
|
||||||
|
dice_score,
|
||||||
|
average_precision,
|
||||||
|
auroc,
|
||||||
|
precision_recall_curve,
|
||||||
|
roc,
|
||||||
|
auc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_onehot():
|
||||||
|
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||||
|
expected = torch.tensor([
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0, 0],
|
||||||
|
[0, 0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0]
|
||||||
|
], [
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[1, 0, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0, 0],
|
||||||
|
[0, 0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 0, 1]
|
||||||
|
]
|
||||||
|
])
|
||||||
|
|
||||||
|
assert test_tensor.shape == (2, 5)
|
||||||
|
assert expected.shape == (2, 10, 5)
|
||||||
|
|
||||||
|
onehot_classes = to_onehot(test_tensor, n_classes=10)
|
||||||
|
onehot_no_classes = to_onehot(test_tensor)
|
||||||
|
|
||||||
|
assert torch.allclose(onehot_classes, onehot_no_classes)
|
||||||
|
|
||||||
|
assert onehot_classes.shape == expected.shape
|
||||||
|
assert onehot_no_classes.shape == expected.shape
|
||||||
|
|
||||||
|
assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes)
|
||||||
|
assert torch.allclose(expected.to(onehot_classes), onehot_classes)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_categorical():
|
||||||
|
test_tensor = torch.tensor([
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0, 0],
|
||||||
|
[0, 0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0]
|
||||||
|
], [
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[1, 0, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0, 0],
|
||||||
|
[0, 0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 0, 1]
|
||||||
|
]
|
||||||
|
]).to(torch.float)
|
||||||
|
|
||||||
|
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||||
|
assert expected.shape == (2, 5)
|
||||||
|
assert test_tensor.shape == (2, 10, 5)
|
||||||
|
|
||||||
|
result = to_categorical(test_tensor)
|
||||||
|
|
||||||
|
assert result.shape == expected.shape
|
||||||
|
assert torch.allclose(result, expected.to(result.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [
|
||||||
|
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
|
||||||
|
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||||
|
pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||||
|
])
|
||||||
|
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
|
||||||
|
assert get_num_classes(pred, target, num_classes) == expected_num_classes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
|
||||||
|
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1),
|
||||||
|
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1)
|
||||||
|
])
|
||||||
|
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
|
||||||
|
tp, fp, tn, fn = stat_scores(pred, target, class_index=4)
|
||||||
|
|
||||||
|
assert tp.item() == expected_tp
|
||||||
|
assert fp.item() == expected_fp
|
||||||
|
assert tn.item() == expected_tn
|
||||||
|
assert fn.item() == expected_fn
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
|
||||||
|
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
|
||||||
|
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1]),
|
||||||
|
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
|
||||||
|
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1])
|
||||||
|
])
|
||||||
|
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
|
||||||
|
tp, fp, tn, fn = stat_scores_multiple_classes(pred, target)
|
||||||
|
|
||||||
|
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
|
||||||
|
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
|
||||||
|
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
|
||||||
|
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multilabel_accuracy():
|
||||||
|
# Dense label indicator matrix format
|
||||||
|
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
|
||||||
|
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
|
||||||
|
|
||||||
|
assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([0.8333333134651184] * 2))
|
||||||
|
assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.]))
|
||||||
|
assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.]))
|
||||||
|
assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.]))
|
||||||
|
assert torch.allclose(accuracy(y1, torch.logical_not(y1), reduction='none'), torch.tensor([0., 0.]))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
accuracy(y2, torch.zeros_like(y2), reduction='none')
|
||||||
|
|
||||||
|
|
||||||
|
def test_confusion_matrix():
|
||||||
|
target = (torch.arange(120) % 3).view(-1, 1)
|
||||||
|
pred = target.clone()
|
||||||
|
cm = confusion_matrix(pred, target, normalize=True)
|
||||||
|
|
||||||
|
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))
|
||||||
|
|
||||||
|
pred = torch.zeros_like(pred)
|
||||||
|
cm = confusion_matrix(pred, target, normalize=True)
|
||||||
|
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [
|
||||||
|
pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]),
|
||||||
|
pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5])
|
||||||
|
])
|
||||||
|
def test_precision_recall(pred, target, expected_prec, expected_rec):
|
||||||
|
prec = precision(pred, target, reduction='none')
|
||||||
|
rec = recall(pred, target, reduction='none')
|
||||||
|
|
||||||
|
assert torch.allclose(torch.tensor(expected_prec).to(prec), prec)
|
||||||
|
assert torch.allclose(torch.tensor(expected_rec).to(rec), rec)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [
|
||||||
|
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]),
|
||||||
|
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]),
|
||||||
|
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]),
|
||||||
|
])
|
||||||
|
def test_fbeta_score(pred, target, beta, exp_score):
|
||||||
|
score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, reduction='none')
|
||||||
|
assert torch.allclose(score, torch.tensor(exp_score))
|
||||||
|
|
||||||
|
score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, reduction='none')
|
||||||
|
assert torch.allclose(score, torch.tensor(exp_score))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
|
||||||
|
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]),
|
||||||
|
])
|
||||||
|
def test_f1_score(pred, target, exp_score):
|
||||||
|
score = f1_score(torch.tensor(pred), torch.tensor(target), reduction='none')
|
||||||
|
assert torch.allclose(score, torch.tensor(exp_score))
|
||||||
|
|
||||||
|
score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), reduction='none')
|
||||||
|
assert torch.allclose(score, torch.tensor(exp_score))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
|
||||||
|
pytest.param(1, 1., 42),
|
||||||
|
pytest.param(None, 1., 42),
|
||||||
|
])
|
||||||
|
def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
|
||||||
|
# TODO: move back the pred and target to test func arguments
|
||||||
|
# if you fix the array inside the function, you'd also have fix the shape,
|
||||||
|
# because when the array changes, you also have to fix the shape
|
||||||
|
seed_everything(0)
|
||||||
|
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
|
||||||
|
target = torch.tensor([0, 1] * 50, dtype=torch.int)
|
||||||
|
if sample_weight is not None:
|
||||||
|
sample_weight = torch.ones_like(pred) * sample_weight
|
||||||
|
|
||||||
|
fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
|
||||||
|
|
||||||
|
assert isinstance(tps, torch.Tensor)
|
||||||
|
assert isinstance(fps, torch.Tensor)
|
||||||
|
assert isinstance(thresh, torch.Tensor)
|
||||||
|
assert tps.shape == (exp_shape,)
|
||||||
|
assert fps.shape == (exp_shape,)
|
||||||
|
assert thresh.shape == (exp_shape,)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
|
||||||
|
pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])
|
||||||
|
])
|
||||||
|
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
|
||||||
|
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
|
||||||
|
assert p.size() == r.size()
|
||||||
|
assert p.size(0) == t.size(0) + 1
|
||||||
|
|
||||||
|
assert torch.allclose(p, torch.tensor(expected_p).to(p))
|
||||||
|
assert torch.allclose(r, torch.tensor(expected_r).to(r))
|
||||||
|
assert torch.allclose(t, torch.tensor(expected_t).to(t))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
|
||||||
|
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
|
||||||
|
pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]),
|
||||||
|
pytest.param([1, 1], [1, 0], [0, 1], [0, 1]),
|
||||||
|
pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]),
|
||||||
|
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
|
||||||
|
])
|
||||||
|
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
|
||||||
|
fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target))
|
||||||
|
|
||||||
|
assert fpr.shape == tpr.shape
|
||||||
|
assert fpr.size(0) == thresh.size(0)
|
||||||
|
assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr))
|
||||||
|
assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||||
|
pytest.param([0, 0, 1, 1], [0, 0, 1, 1], 1.),
|
||||||
|
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
|
||||||
|
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
|
||||||
|
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
|
||||||
|
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5),
|
||||||
|
])
|
||||||
|
def test_auroc(pred, target, expected):
|
||||||
|
score = auroc(torch.tensor(pred), torch.tensor(target)).item()
|
||||||
|
assert score == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['x', 'y', 'expected'], [
|
||||||
|
pytest.param([0, 1], [0, 1], 0.5),
|
||||||
|
pytest.param([1, 0], [0, 1], 0.5),
|
||||||
|
pytest.param([1, 0, 0], [0, 1, 1], 0.5),
|
||||||
|
pytest.param([0, 1], [1, 1], 1),
|
||||||
|
pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5),
|
||||||
|
])
|
||||||
|
def test_auc(x, y, expected):
|
||||||
|
# Test Area Under Curve (AUC) computation
|
||||||
|
assert auc(torch.tensor(x), torch.tensor(y)) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_average_precision_constant_values():
|
||||||
|
# Check the average_precision_score of a constant predictor is
|
||||||
|
# the TPR
|
||||||
|
|
||||||
|
# Generate a dataset with 25% of positives
|
||||||
|
target = torch.zeros(100, dtype=torch.float)
|
||||||
|
target[::4] = 1
|
||||||
|
# And a constant score
|
||||||
|
pred = torch.ones(100)
|
||||||
|
# The precision is then the fraction of positive whatever the recall
|
||||||
|
# is, as there is only one threshold:
|
||||||
|
assert average_precision(pred, target).item() == .25
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||||
|
pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.),
|
||||||
|
pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.),
|
||||||
|
pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
|
||||||
|
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.),
|
||||||
|
])
|
||||||
|
def test_dice_score(pred, target, expected):
|
||||||
|
score = dice_score(torch.tensor(pred), torch.tensor(target))
|
||||||
|
assert score == expected
|
||||||
|
|
||||||
|
# example data taken from
|
||||||
|
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
|
|
@ -0,0 +1,15 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||||
|
|
||||||
|
|
||||||
|
def test_reduce():
|
||||||
|
start_tensor = torch.rand(50, 40, 30)
|
||||||
|
|
||||||
|
assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor))
|
||||||
|
assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor))
|
||||||
|
assert torch.allclose(reduce(start_tensor, 'none'), start_tensor)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
reduce(start_tensor, 'error_reduction')
|
|
@ -0,0 +1,227 @@
|
||||||
|
# NOTE: This file only tests if modules with arguments are running fine.
|
||||||
|
# The actual metric implementation is tested in functional/test_classification.py
|
||||||
|
# Especially reduction and reducing across processes won't be tested here!
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_lightning.metrics.classification import (
|
||||||
|
Accuracy,
|
||||||
|
ConfusionMatrix,
|
||||||
|
PrecisionRecall,
|
||||||
|
Precision,
|
||||||
|
Recall,
|
||||||
|
AveragePrecision,
|
||||||
|
AUROC,
|
||||||
|
FBeta,
|
||||||
|
F1,
|
||||||
|
ROC,
|
||||||
|
MulticlassROC,
|
||||||
|
MulticlassPrecisionRecall,
|
||||||
|
DiceCoefficient,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def random():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_classes', [1, None])
|
||||||
|
def test_accuracy(num_classes):
|
||||||
|
acc = Accuracy(num_classes=num_classes)
|
||||||
|
|
||||||
|
assert acc.name == 'accuracy'
|
||||||
|
|
||||||
|
result = acc(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
|
||||||
|
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
|
||||||
|
|
||||||
|
assert isinstance(result, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('normalize', [False, True])
|
||||||
|
def test_confusion_matrix(normalize):
|
||||||
|
conf_matrix = ConfusionMatrix(normalize=normalize)
|
||||||
|
assert conf_matrix.name == 'confusion_matrix'
|
||||||
|
|
||||||
|
target = (torch.arange(120) % 3).view(-1, 1)
|
||||||
|
pred = target.clone()
|
||||||
|
|
||||||
|
cm = conf_matrix(pred, target)
|
||||||
|
|
||||||
|
assert isinstance(cm, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('pos_label', [1, 2.])
|
||||||
|
def test_precision_recall(pos_label):
|
||||||
|
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||||
|
|
||||||
|
pr_curve = PrecisionRecall(pos_label=pos_label)
|
||||||
|
assert pr_curve.name == 'precision_recall_curve'
|
||||||
|
|
||||||
|
pr = pr_curve(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||||
|
|
||||||
|
assert isinstance(pr, tuple)
|
||||||
|
assert len(pr) == 3
|
||||||
|
for tmp in pr:
|
||||||
|
assert isinstance(tmp, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_classes', [1, None])
|
||||||
|
def test_precision(num_classes):
|
||||||
|
precision = Precision(num_classes=num_classes)
|
||||||
|
|
||||||
|
assert precision.name == 'precision'
|
||||||
|
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||||
|
|
||||||
|
prec = precision(pred=pred, target=target)
|
||||||
|
|
||||||
|
assert isinstance(prec, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_classes', [1, None])
|
||||||
|
def test_recall(num_classes):
|
||||||
|
recall = Recall(num_classes=num_classes)
|
||||||
|
|
||||||
|
assert recall.name == 'recall'
|
||||||
|
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||||
|
|
||||||
|
rec = recall(pred=pred, target=target)
|
||||||
|
|
||||||
|
assert isinstance(rec, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('pos_label', [1, 2])
|
||||||
|
def test_average_precision(pos_label):
|
||||||
|
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
|
||||||
|
|
||||||
|
avg_prec = AveragePrecision(pos_label=pos_label)
|
||||||
|
assert avg_prec.name == 'AP'
|
||||||
|
|
||||||
|
ap = avg_prec(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||||
|
|
||||||
|
assert isinstance(ap, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('pos_label', [1, 2])
|
||||||
|
def test_auroc(pos_label):
|
||||||
|
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
|
||||||
|
|
||||||
|
auroc = AUROC(pos_label=pos_label)
|
||||||
|
assert auroc.name == 'auroc'
|
||||||
|
|
||||||
|
area = auroc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||||
|
|
||||||
|
assert isinstance(area, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['beta', 'num_classes'], [
|
||||||
|
pytest.param(0., 1),
|
||||||
|
pytest.param(0.5, 1),
|
||||||
|
pytest.param(1., 1),
|
||||||
|
pytest.param(2., 1),
|
||||||
|
pytest.param(0., None),
|
||||||
|
pytest.param(0.5, None),
|
||||||
|
pytest.param(1., None),
|
||||||
|
pytest.param(2., None)
|
||||||
|
])
|
||||||
|
def test_fbeta(beta, num_classes):
|
||||||
|
fbeta = FBeta(beta=beta, num_classes=num_classes)
|
||||||
|
assert fbeta.name == 'fbeta'
|
||||||
|
|
||||||
|
score = fbeta(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
|
||||||
|
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
|
||||||
|
|
||||||
|
assert isinstance(score, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_classes', [1, None])
|
||||||
|
def test_f1(num_classes):
|
||||||
|
f1 = F1(num_classes=num_classes)
|
||||||
|
assert f1.name == 'f1'
|
||||||
|
|
||||||
|
score = f1(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
|
||||||
|
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
|
||||||
|
|
||||||
|
assert isinstance(score, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('pos_label', [1, 2])
|
||||||
|
def test_roc(pos_label):
|
||||||
|
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 4, 3])
|
||||||
|
|
||||||
|
roc = ROC(pos_label=pos_label)
|
||||||
|
assert roc.name == 'roc'
|
||||||
|
|
||||||
|
res = roc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||||
|
|
||||||
|
assert isinstance(res, tuple)
|
||||||
|
assert len(res) == 3
|
||||||
|
for tmp in res:
|
||||||
|
assert isinstance(tmp, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_classes', [4, None])
|
||||||
|
def test_multiclass_roc(num_classes):
|
||||||
|
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||||
|
[0.05, 0.85, 0.05, 0.05],
|
||||||
|
[0.05, 0.05, 0.85, 0.05],
|
||||||
|
[0.05, 0.05, 0.05, 0.85]])
|
||||||
|
target = torch.tensor([0, 1, 3, 2])
|
||||||
|
|
||||||
|
multi_roc = MulticlassROC(num_classes=num_classes)
|
||||||
|
|
||||||
|
assert multi_roc.name == 'multiclass_roc'
|
||||||
|
|
||||||
|
res = multi_roc(pred, target)
|
||||||
|
|
||||||
|
assert isinstance(res, tuple)
|
||||||
|
|
||||||
|
if num_classes is not None:
|
||||||
|
assert len(res) == num_classes
|
||||||
|
|
||||||
|
for tmp in res:
|
||||||
|
assert isinstance(tmp, tuple)
|
||||||
|
assert len(tmp) == 3
|
||||||
|
|
||||||
|
for _tmp in tmp:
|
||||||
|
assert isinstance(_tmp, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_classes', [4, None])
|
||||||
|
def test_multiclass_pr(num_classes):
|
||||||
|
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||||
|
[0.05, 0.85, 0.05, 0.05],
|
||||||
|
[0.05, 0.05, 0.85, 0.05],
|
||||||
|
[0.05, 0.05, 0.05, 0.85]])
|
||||||
|
target = torch.tensor([0, 1, 3, 2])
|
||||||
|
|
||||||
|
multi_pr = MulticlassPrecisionRecall(num_classes=num_classes)
|
||||||
|
|
||||||
|
assert multi_pr.name == 'multiclass_precision_recall_curve'
|
||||||
|
|
||||||
|
pr = multi_pr(pred, target)
|
||||||
|
|
||||||
|
assert isinstance(pr, tuple)
|
||||||
|
|
||||||
|
if num_classes is not None:
|
||||||
|
assert len(pr) == num_classes
|
||||||
|
|
||||||
|
for tmp in pr:
|
||||||
|
assert isinstance(tmp, tuple)
|
||||||
|
assert len(tmp) == 3
|
||||||
|
|
||||||
|
for _tmp in tmp:
|
||||||
|
assert isinstance(_tmp, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('include_background', [True, False])
|
||||||
|
def test_dice_coefficient(include_background):
|
||||||
|
dice_coeff = DiceCoefficient(include_background=include_background)
|
||||||
|
|
||||||
|
assert dice_coeff.name == 'dice'
|
||||||
|
|
||||||
|
dice = dice_coeff(torch.randint(0, 1, (10, 25, 25)),
|
||||||
|
torch.randint(0, 1, (10, 25, 25)))
|
||||||
|
|
||||||
|
assert isinstance(dice, torch.Tensor)
|
|
@ -6,16 +6,19 @@ import torch.multiprocessing as mp
|
||||||
|
|
||||||
import tests.base.utils as tutils
|
import tests.base.utils as tutils
|
||||||
from pytorch_lightning.metrics.converters import (
|
from pytorch_lightning.metrics.converters import (
|
||||||
_apply_to_inputs, _apply_to_outputs, _convert_to_tensor, _convert_to_numpy,
|
_apply_to_inputs,
|
||||||
_numpy_metric_conversion, _tensor_metric_conversion, _sync_ddp_if_available, tensor_metric, numpy_metric)
|
_apply_to_outputs,
|
||||||
|
_convert_to_tensor,
|
||||||
|
_convert_to_numpy,
|
||||||
|
_numpy_metric_conversion,
|
||||||
|
_tensor_metric_conversion,
|
||||||
|
_sync_ddp_if_available,
|
||||||
|
tensor_metric,
|
||||||
|
numpy_metric
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(['args', 'kwargs'],
|
def test_apply_to_inputs():
|
||||||
[pytest.param([], {}),
|
|
||||||
pytest.param([1., 2.], {}),
|
|
||||||
pytest.param([], {'a': 1., 'b': 2.}),
|
|
||||||
pytest.param([1., 2.], {'a': 1., 'b': 2.})])
|
|
||||||
def test_apply_to_inputs(args, kwargs):
|
|
||||||
def apply_fn(inputs, factor):
|
def apply_fn(inputs, factor):
|
||||||
if isinstance(inputs, (float, int)):
|
if isinstance(inputs, (float, int)):
|
||||||
return inputs * factor
|
return inputs * factor
|
||||||
|
@ -25,22 +28,24 @@ def test_apply_to_inputs(args, kwargs):
|
||||||
return [apply_fn(x, factor) for x in inputs]
|
return [apply_fn(x, factor) for x in inputs]
|
||||||
|
|
||||||
@_apply_to_inputs(apply_fn, factor=2.)
|
@_apply_to_inputs(apply_fn, factor=2.)
|
||||||
def test_fn(*func_args, **func_kwargs):
|
def test_fn(*args, **kwargs):
|
||||||
return func_args, func_kwargs
|
return args, kwargs
|
||||||
|
|
||||||
result_args, result_kwargs = test_fn(*args, **kwargs)
|
for args in [[], [1., 2.]]:
|
||||||
assert isinstance(result_args, (list, tuple))
|
for kwargs in [{}, {'a': 1., 'b': 2.}]:
|
||||||
assert isinstance(result_kwargs, dict)
|
result_args, result_kwargs = test_fn(*args, **kwargs)
|
||||||
assert len(result_args) == len(args)
|
assert isinstance(result_args, (list, tuple))
|
||||||
assert len(result_kwargs) == len(kwargs)
|
assert isinstance(result_kwargs, dict)
|
||||||
assert all([k in result_kwargs for k in kwargs.keys()])
|
assert len(result_args) == len(args)
|
||||||
for arg, result_arg in zip(args, result_args):
|
assert len(result_kwargs) == len(kwargs)
|
||||||
assert arg * 2. == result_arg
|
assert all([k in result_kwargs for k in kwargs.keys()])
|
||||||
|
for arg, result_arg in zip(args, result_args):
|
||||||
|
assert arg * 2. == result_arg
|
||||||
|
|
||||||
for key in kwargs.keys():
|
for key in kwargs.keys():
|
||||||
arg = kwargs[key]
|
arg = kwargs[key]
|
||||||
result_arg = result_kwargs[key]
|
result_arg = result_kwargs[key]
|
||||||
assert arg * 2. == result_arg
|
assert arg * 2. == result_arg
|
||||||
|
|
||||||
|
|
||||||
def test_apply_to_outputs():
|
def test_apply_to_outputs():
|
||||||
|
@ -100,7 +105,7 @@ def test_tensor_metric_conversion():
|
||||||
assert result.item() == 5.
|
assert result.item() == 5.
|
||||||
|
|
||||||
|
|
||||||
def setup_ddp(rank, worldsize, ):
|
def _setup_ddp(rank, worldsize):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
|
@ -109,8 +114,8 @@ def setup_ddp(rank, worldsize, ):
|
||||||
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
||||||
|
|
||||||
|
|
||||||
def ddp_test_fn(rank, worldsize):
|
def _ddp_test_fn(rank, worldsize):
|
||||||
setup_ddp(rank, worldsize)
|
_setup_ddp(rank, worldsize)
|
||||||
tensor = torch.tensor([1.], device='cuda:0')
|
tensor = torch.tensor([1.], device='cuda:0')
|
||||||
|
|
||||||
reduced_tensor = _sync_ddp_if_available(tensor)
|
reduced_tensor = _sync_ddp_if_available(tensor)
|
||||||
|
@ -119,6 +124,7 @@ def ddp_test_fn(rank, worldsize):
|
||||||
'Sync-Reduce does not work properly with DDP and Tensors'
|
'Sync-Reduce does not work properly with DDP and Tensors'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.spawn
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||||
def test_sync_reduce_ddp():
|
def test_sync_reduce_ddp():
|
||||||
"""Make sure sync-reduce works with DDP"""
|
"""Make sure sync-reduce works with DDP"""
|
||||||
|
@ -126,7 +132,9 @@ def test_sync_reduce_ddp():
|
||||||
tutils.set_random_master_port()
|
tutils.set_random_master_port()
|
||||||
|
|
||||||
worldsize = 2
|
worldsize = 2
|
||||||
mp.spawn(ddp_test_fn, args=(worldsize,), nprocs=worldsize)
|
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
|
||||||
|
|
||||||
|
# dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def test_sync_reduce_simple():
|
def test_sync_reduce_simple():
|
||||||
|
@ -161,16 +169,18 @@ def _test_tensor_metric(is_ddp: bool):
|
||||||
|
|
||||||
|
|
||||||
def _ddp_test_tensor_metric(rank, worldsize):
|
def _ddp_test_tensor_metric(rank, worldsize):
|
||||||
setup_ddp(rank, worldsize)
|
_setup_ddp(rank, worldsize)
|
||||||
_test_tensor_metric(True)
|
_test_tensor_metric(True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||||
def test_tensor_metric_ddp():
|
def test_tensor_metric_ddp():
|
||||||
tutils.reset_seed()
|
tutils.reset_seed()
|
||||||
tutils.set_random_master_port()
|
tutils.set_random_master_port()
|
||||||
|
|
||||||
world_size = 2
|
world_size = 2
|
||||||
mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size)
|
mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size)
|
||||||
|
# dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_metric_simple():
|
def test_tensor_metric_simple():
|
||||||
|
@ -199,16 +209,19 @@ def _test_numpy_metric(is_ddp: bool):
|
||||||
|
|
||||||
|
|
||||||
def _ddp_test_numpy_metric(rank, worldsize):
|
def _ddp_test_numpy_metric(rank, worldsize):
|
||||||
setup_ddp(rank, worldsize)
|
_setup_ddp(rank, worldsize)
|
||||||
_test_numpy_metric(True)
|
_test_numpy_metric(True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.spawn
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||||
def test_numpy_metric_ddp():
|
def test_numpy_metric_ddp():
|
||||||
tutils.reset_seed()
|
tutils.reset_seed()
|
||||||
tutils.set_random_master_port()
|
tutils.set_random_master_port()
|
||||||
world_size = 2
|
world_size = 2
|
||||||
mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)
|
mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)
|
||||||
|
# dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def test_numpy_metric_simple():
|
def test_numpy_metric_simple():
|
||||||
_test_tensor_metric(False)
|
_test_numpy_metric(False)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric
|
||||||
|
|
||||||
|
|
||||||
class DummyTensorMetric(TensorMetric):
|
class DummyTensorMetric(TensorMetric):
|
||||||
|
@ -24,7 +24,65 @@ class DummyNumpyMetric(NumpyMetric):
|
||||||
return 1.
|
return 1.
|
||||||
|
|
||||||
|
|
||||||
|
class DummyTensorCollectionMetric(TensorCollectionMetric):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('dummy')
|
||||||
|
|
||||||
|
def forward(self, input1, input2):
|
||||||
|
assert isinstance(input1, torch.Tensor)
|
||||||
|
assert isinstance(input2, torch.Tensor)
|
||||||
|
return 1., 2., 3., 4.
|
||||||
|
|
||||||
|
|
||||||
|
def _test_collection_metric(metric: Metric):
|
||||||
|
""" Test that metric.device, metric.dtype works for metric collection """
|
||||||
|
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
|
||||||
|
|
||||||
|
def change_and_check_device_dtype(device, dtype):
|
||||||
|
metric.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
metric_val = metric(input1, input2)
|
||||||
|
assert not isinstance(metric_val, torch.Tensor)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
assert metric.device in [device, torch.device(device)]
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
assert metric.dtype == dtype
|
||||||
|
|
||||||
|
devices = [None, 'cpu']
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
devices += ['cuda:0']
|
||||||
|
|
||||||
|
for device in devices:
|
||||||
|
for dtype in [None, torch.float32, torch.float64]:
|
||||||
|
change_and_check_device_dtype(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
metric.cuda(0)
|
||||||
|
assert metric.device == torch.device('cuda', index=0)
|
||||||
|
|
||||||
|
metric.cpu()
|
||||||
|
assert metric.device == torch.device('cpu')
|
||||||
|
|
||||||
|
metric.type(torch.int8)
|
||||||
|
assert metric.dtype == torch.int8
|
||||||
|
|
||||||
|
metric.float()
|
||||||
|
assert metric.dtype == torch.float32
|
||||||
|
|
||||||
|
metric.double()
|
||||||
|
assert metric.dtype == torch.float64
|
||||||
|
assert all(out.dtype == torch.float64 for out in metric(input1, input2))
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
metric.cuda()
|
||||||
|
metric.half()
|
||||||
|
assert metric.dtype == torch.float16
|
||||||
|
|
||||||
|
|
||||||
def _test_metric(metric: Metric):
|
def _test_metric(metric: Metric):
|
||||||
|
""" Test that metric.device, metric.dtype works for single metric"""
|
||||||
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
|
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
|
||||||
|
|
||||||
def change_and_check_device_dtype(device, dtype):
|
def change_and_check_device_dtype(device, dtype):
|
||||||
|
@ -83,3 +141,7 @@ def test_tensor_metric():
|
||||||
|
|
||||||
def test_numpy_metric():
|
def test_numpy_metric():
|
||||||
_test_metric(DummyNumpyMetric())
|
_test_metric(DummyNumpyMetric())
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_collection():
|
||||||
|
_test_collection_metric(DummyTensorCollectionMetric())
|
||||||
|
|
|
@ -5,13 +5,24 @@ from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from sklearn.metrics import (accuracy_score, average_precision_score, auc, confusion_matrix, f1_score,
|
from sklearn.metrics import (
|
||||||
fbeta_score, precision_score, recall_score, precision_recall_curve, roc_curve,
|
accuracy_score,
|
||||||
roc_auc_score)
|
average_precision_score,
|
||||||
|
auc,
|
||||||
|
confusion_matrix,
|
||||||
|
f1_score,
|
||||||
|
fbeta_score,
|
||||||
|
precision_score,
|
||||||
|
recall_score,
|
||||||
|
precision_recall_curve,
|
||||||
|
roc_curve,
|
||||||
|
roc_auc_score
|
||||||
|
)
|
||||||
|
|
||||||
from pytorch_lightning.metrics.converters import _convert_to_numpy
|
from pytorch_lightning.metrics.converters import _convert_to_numpy
|
||||||
from pytorch_lightning.metrics.sklearn import (Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
from pytorch_lightning.metrics.sklearn import (
|
||||||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
||||||
|
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,37 +36,38 @@ def xy_only(func):
|
||||||
@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [
|
@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [
|
||||||
pytest.param(Accuracy(), accuracy_score,
|
pytest.param(Accuracy(), accuracy_score,
|
||||||
{'y_pred': torch.randint(low=0, high=10, size=(128,)),
|
{'y_pred': torch.randint(low=0, high=10, size=(128,)),
|
||||||
'y_true': torch.randint(low=0, high=10, size=(128,))}, id='Accuracy'),
|
'y_true': torch.randint(low=0, high=10, size=(128,))},
|
||||||
|
id='Accuracy'),
|
||||||
pytest.param(AUC(), auc, {'x': torch.arange(10, dtype=torch.float) / 10,
|
pytest.param(AUC(), auc, {'x': torch.arange(10, dtype=torch.float) / 10,
|
||||||
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2,
|
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5, 0.6, 0.7])},
|
||||||
0.2, 0.3, 0.5, 0.6, 0.7])}, id='AUC'),
|
id='AUC'),
|
||||||
pytest.param(AveragePrecision(), average_precision_score,
|
pytest.param(AveragePrecision(), average_precision_score,
|
||||||
{'y_score': torch.randint(2, size=(128,)),
|
{'y_score': torch.randint(2, size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||||
'y_true': torch.randint(2, size=(128,))}, id='AveragePrecision'),
|
id='AveragePrecision'),
|
||||||
pytest.param(ConfusionMatrix(), confusion_matrix,
|
pytest.param(ConfusionMatrix(), confusion_matrix,
|
||||||
{'y_pred': torch.randint(10, size=(128,)),
|
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||||
'y_true': torch.randint(10, size=(128,))}, id='ConfusionMatrix'),
|
id='ConfusionMatrix'),
|
||||||
pytest.param(F1(average='macro'), partial(f1_score, average='macro'),
|
pytest.param(F1(average='macro'), partial(f1_score, average='macro'),
|
||||||
{'y_pred': torch.randint(10, size=(128,)),
|
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||||
'y_true': torch.randint(10, size=(128,))}, id='F1'),
|
id='F1'),
|
||||||
pytest.param(FBeta(beta=0.5, average='macro'), partial(fbeta_score, beta=0.5, average='macro'),
|
pytest.param(FBeta(beta=0.5, average='macro'), partial(fbeta_score, beta=0.5, average='macro'),
|
||||||
{'y_pred': torch.randint(10, size=(128,)),
|
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||||
'y_true': torch.randint(10, size=(128,))}, id='FBeta'),
|
id='FBeta'),
|
||||||
pytest.param(Precision(average='macro'), partial(precision_score, average='macro'),
|
pytest.param(Precision(average='macro'), partial(precision_score, average='macro'),
|
||||||
{'y_pred': torch.randint(10, size=(128,)),
|
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||||
'y_true': torch.randint(10, size=(128,))}, id='Precision'),
|
id='Precision'),
|
||||||
pytest.param(Recall(average='macro'), partial(recall_score, average='macro'),
|
pytest.param(Recall(average='macro'), partial(recall_score, average='macro'),
|
||||||
{'y_pred': torch.randint(10, size=(128,)),
|
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||||
'y_true': torch.randint(10, size=(128,))}, id='Recall'),
|
id='Recall'),
|
||||||
pytest.param(PrecisionRecallCurve(), xy_only(precision_recall_curve),
|
pytest.param(PrecisionRecallCurve(), xy_only(precision_recall_curve),
|
||||||
{'probas_pred': torch.rand(size=(128,)),
|
{'probas_pred': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||||
'y_true': torch.randint(2, size=(128,))}, id='PrecisionRecallCurve'),
|
id='PrecisionRecallCurve'),
|
||||||
pytest.param(ROC(), xy_only(roc_curve),
|
pytest.param(ROC(), xy_only(roc_curve),
|
||||||
{'y_score': torch.rand(size=(128,)),
|
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||||
'y_true': torch.randint(2, size=(128,))}, id='ROC'),
|
id='ROC'),
|
||||||
pytest.param(AUROC(), roc_auc_score,
|
pytest.param(AUROC(), roc_auc_score,
|
||||||
{'y_score': torch.rand(size=(128,)),
|
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||||
'y_true': torch.randint(2, size=(128,))}, id='AUROC'),
|
id='AUROC'),
|
||||||
])
|
])
|
||||||
def test_sklearn_metric(metric_class, sklearn_func, inputs: dict):
|
def test_sklearn_metric(metric_class, sklearn_func, inputs: dict):
|
||||||
numpy_inputs = apply_to_collection(
|
numpy_inputs = apply_to_collection(
|
||||||
|
|
Загрузка…
Ссылка в новой задаче