(cherry picked from commit cce1be3470a58aa46047d598ff00e9ad23f5f877)
This commit is contained in:
Jirka Borovec 2021-02-06 17:41:40 +01:00 коммит произвёл Jirka Borovec
Родитель 8bc17d8db8
Коммит a718ab3cca
28 изменённых файлов: 621 добавлений и 622 удалений

Просмотреть файл

@ -6,35 +6,31 @@ from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES
Input = namedtuple('Input', ["preds", "target"])
_input_binary_prob = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
)
_binary_prob_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
_input_binary = Input(
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
)
_binary_inputs = Input(
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,))
)
_multilabel_prob_inputs = Input(
_input_multilabel_prob = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
)
_multilabel_multidim_prob_inputs = Input(
_input_multilabel_multidim_prob = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM))
)
_multilabel_inputs = Input(
_input_multilabel = Input(
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
)
_multilabel_multidim_inputs = Input(
_input_multilabel_multidim = Input(
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM))
)
@ -43,21 +39,16 @@ _multilabel_multidim_inputs = Input(
__temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
__temp_target = abs(__temp_preds - 1)
_multilabel_inputs_no_match = Input(
preds=__temp_preds,
target=__temp_target
)
_input_multilabel_no_match = Input(preds=__temp_preds, target=__temp_target)
__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)
__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True)
_multiclass_prob_inputs = Input(
preds=__mc_prob_preds,
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
_input_multiclass_prob = Input(
preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
)
_multiclass_inputs = Input(
_input_multiclass = Input(
preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)),
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
)
@ -65,12 +56,11 @@ _multiclass_inputs = Input(
__mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)
__mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True)
_multidim_multiclass_prob_inputs = Input(
preds=__mdmc_prob_preds,
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))
_input_multidim_multiclass_prob = Input(
preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))
)
_multidim_multiclass_inputs = Input(
_input_multidim_multiclass = Input(
preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)),
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))
)

Просмотреть файл

@ -8,18 +8,15 @@ from sklearn.metrics import accuracy_score as sk_accuracy
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional import accuracy
from tests.metrics.classification.inputs import (
_binary_inputs,
_binary_prob_inputs,
_multiclass_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_inputs,
_multidim_multiclass_prob_inputs,
_multilabel_inputs,
_multilabel_multidim_inputs,
_multilabel_multidim_prob_inputs,
_multilabel_prob_inputs,
)
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
from tests.metrics.classification.inputs import _input_multilabel_multidim as _input_mlmd
from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.metrics.utils import MetricTester, THRESHOLD
torch.manual_seed(42)
@ -43,25 +40,26 @@ def _sk_accuracy(preds, target, subset_accuracy):
@pytest.mark.parametrize(
"preds, target, subset_accuracy",
[
(_binary_prob_inputs.preds, _binary_prob_inputs.target, False),
(_binary_inputs.preds, _binary_inputs.target, False),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, True),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, False),
(_multilabel_inputs.preds, _multilabel_inputs.target, True),
(_multilabel_inputs.preds, _multilabel_inputs.target, False),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, False),
(_multiclass_inputs.preds, _multiclass_inputs.target, False),
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, False),
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, True),
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, False),
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, True),
(_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, True),
(_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, False),
(_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target, True),
(_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target, False),
(_input_binary_prob.preds, _input_binary_prob.target, False),
(_input_binary.preds, _input_binary.target, False),
(_input_mlb_prob.preds, _input_mlb_prob.target, True),
(_input_mlb_prob.preds, _input_mlb_prob.target, False),
(_input_mlb.preds, _input_mlb.target, True),
(_input_mlb.preds, _input_mlb.target, False),
(_input_mcls_prob.preds, _input_mcls_prob.target, False),
(_input_mcls.preds, _input_mcls.target, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, True),
(_input_mdmc.preds, _input_mdmc.target, False),
(_input_mdmc.preds, _input_mdmc.target, True),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, True),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, False),
(_input_mlmd.preds, _input_mlmd.target, True),
(_input_mlmd.preds, _input_mlmd.target, False),
],
)
class TestAccuracies(MetricTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy):
@ -72,7 +70,10 @@ class TestAccuracies(MetricTester):
metric_class=Accuracy,
sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy),
dist_sync_on_step=dist_sync_on_step,
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy},
metric_args={
"threshold": THRESHOLD,
"subset_accuracy": subset_accuracy
},
)
def test_accuracy_fn(self, preds, target, subset_accuracy):
@ -81,21 +82,24 @@ class TestAccuracies(MetricTester):
target,
metric_functional=accuracy,
sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy),
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy},
metric_args={
"threshold": THRESHOLD,
"subset_accuracy": subset_accuracy
},
)
_l1to4 = [0.1, 0.2, 0.3, 0.4]
_l1to4t3 = np.array([_l1to4, _l1to4, _l1to4])
_l1to4t3_mc = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T]
_l1to4t3_mcls = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T]
# The preds in these examples always put highest probability on class 3, second highest on class 2,
# third highest on class 1, and lowest on class 0
_topk_preds_mc = torch.tensor([_l1to4t3, _l1to4t3]).float()
_topk_target_mc = torch.tensor([[1, 2, 3], [2, 1, 0]])
_topk_preds_mcls = torch.tensor([_l1to4t3, _l1to4t3]).float()
_topk_target_mcls = torch.tensor([[1, 2, 3], [2, 1, 0]])
# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :)
_topk_preds_mdmc = torch.tensor([_l1to4t3_mc, _l1to4t3_mc]).float()
_topk_preds_mdmc = torch.tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float()
_topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]])
@ -103,12 +107,12 @@ _topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0],
@pytest.mark.parametrize(
"preds, target, exp_result, k, subset_accuracy",
[
(_topk_preds_mc, _topk_target_mc, 1 / 6, 1, False),
(_topk_preds_mc, _topk_target_mc, 3 / 6, 2, False),
(_topk_preds_mc, _topk_target_mc, 5 / 6, 3, False),
(_topk_preds_mc, _topk_target_mc, 1 / 6, 1, True),
(_topk_preds_mc, _topk_target_mc, 3 / 6, 2, True),
(_topk_preds_mc, _topk_target_mc, 5 / 6, 3, True),
(_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, False),
(_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, False),
(_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, False),
(_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, True),
(_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, True),
(_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, True),
(_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, False),
(_topk_preds_mdmc, _topk_target_mdmc, 8 / 18, 2, False),
(_topk_preds_mdmc, _topk_target_mdmc, 13 / 18, 3, False),
@ -138,14 +142,14 @@ def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy):
@pytest.mark.parametrize(
"preds, target",
[
(_binary_prob_inputs.preds, _binary_prob_inputs.target),
(_binary_inputs.preds, _binary_inputs.target),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target),
(_multilabel_inputs.preds, _multilabel_inputs.target),
(_multiclass_inputs.preds, _multiclass_inputs.target),
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target),
(_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target),
(_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target),
(_input_binary_prob.preds, _input_binary_prob.target),
(_input_binary.preds, _input_binary.target),
(_input_mlb_prob.preds, _input_mlb_prob.target),
(_input_mlb.preds, _input_mlb.target),
(_input_mcls.preds, _input_mcls.target),
(_input_mdmc.preds, _input_mdmc.target),
(_input_mlmd_prob.preds, _input_mlmd_prob.target),
(_input_mlmd.preds, _input_mlmd.target),
],
)
def test_topk_accuracy_wrong_input_types(preds, target):
@ -160,7 +164,7 @@ def test_topk_accuracy_wrong_input_types(preds, target):
@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)])
def test_wrong_params(top_k, threshold):
preds, target = _multiclass_prob_inputs.preds, _multiclass_prob_inputs.target
preds, target = _input_mcls_prob.preds, _input_mcls_prob.target
with pytest.raises(ValueError):
acc = Accuracy(threshold=threshold, top_k=top_k)

Просмотреть файл

@ -35,6 +35,7 @@ for i in range(4):
@pytest.mark.parametrize("x, y", _examples)
class TestAUC(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_auc(self, x, y, ddp, dist_sync_on_step):
@ -48,13 +49,7 @@ class TestAUC(MetricTester):
)
def test_auc_functional(self, x, y):
self.run_functional_metric_test(
x,
y,
metric_functional=auc,
sk_metric=sk_auc,
metric_args={"reorder": False}
)
self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, metric_args={"reorder": False})
@pytest.mark.parametrize(['x', 'y', 'expected'], [

Просмотреть файл

@ -7,25 +7,23 @@ from sklearn.metrics import roc_auc_score as sk_roc_auc_score
from pytorch_lightning.metrics.classification.auroc import AUROC
from pytorch_lightning.metrics.functional.auroc import auroc
from tests.metrics.classification.inputs import (
_binary_prob_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_prob_inputs,
_multilabel_multidim_prob_inputs,
_multilabel_prob_inputs,
)
from tests.metrics.classification.inputs import _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES
torch.manual_seed(42)
def _binary_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
def _sk_auroc_binary_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr)
def _multiclass_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
def _sk_auroc_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1).numpy()
return sk_roc_auc_score(
@ -33,11 +31,11 @@ def _multiclass_prob_sk_metric(preds, target, num_classes, average='macro', max_
y_score=sk_preds,
average=average,
max_fpr=max_fpr,
multi_class=multi_class
multi_class=multi_class,
)
def _multidim_multiclass_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
def _sk_auroc_multidim_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.view(-1).numpy()
return sk_roc_auc_score(
@ -45,11 +43,11 @@ def _multidim_multiclass_prob_sk_metric(preds, target, num_classes, average='mac
y_score=sk_preds,
average=average,
max_fpr=max_fpr,
multi_class=multi_class
multi_class=multi_class,
)
def _multilabel_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
def _sk_auroc_multilabel_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.reshape(-1, num_classes).numpy()
return sk_roc_auc_score(
@ -57,11 +55,11 @@ def _multilabel_prob_sk_metric(preds, target, num_classes, average='macro', max_
y_score=sk_preds,
average=average,
max_fpr=max_fpr,
multi_class=multi_class
multi_class=multi_class,
)
def _multilabel_multidim_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
return sk_roc_auc_score(
@ -69,40 +67,22 @@ def _multilabel_multidim_prob_sk_metric(preds, target, num_classes, average='mac
y_score=sk_preds,
average=average,
max_fpr=max_fpr,
multi_class=multi_class
multi_class=multi_class,
)
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1),
(
_multiclass_prob_inputs.preds,
_multiclass_prob_inputs.target,
_multiclass_prob_sk_metric,
NUM_CLASSES
),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_multidim_multiclass_prob_sk_metric,
NUM_CLASSES
),
(
_multilabel_prob_inputs.preds,
_multilabel_prob_inputs.target,
_multilabel_prob_sk_metric,
NUM_CLASSES
),
(
_multilabel_multidim_prob_inputs.preds,
_multilabel_multidim_prob_inputs.target,
_multilabel_multidim_prob_sk_metric,
NUM_CLASSES
)
])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES)]
)
@pytest.mark.parametrize("average", ['macro', 'weighted'])
@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5])
class TestAUROC(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step):
@ -121,9 +101,11 @@ class TestAUROC(MetricTester):
metric_class=AUROC,
sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr),
dist_sync_on_step=dist_sync_on_step,
metric_args={"num_classes": num_classes,
"average": average,
"max_fpr": max_fpr},
metric_args={
"num_classes": num_classes,
"average": average,
"max_fpr": max_fpr
},
)
def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr):
@ -140,9 +122,11 @@ class TestAUROC(MetricTester):
target,
metric_functional=auroc,
sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr),
metric_args={"num_classes": num_classes,
"average": average,
"max_fpr": max_fpr},
metric_args={
"num_classes": num_classes,
"average": average,
"max_fpr": max_fpr
},
)
@ -152,10 +136,7 @@ def test_error_on_different_mode():
"""
metric = AUROC()
# pass in multi-class data
metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10,)))
with pytest.raises(
ValueError,
match=r"The mode of data.* should be constant.*"
):
metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10, )))
with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"):
# pass in multi-label data
metric.update(torch.rand(10, 5), torch.randint(0, 2, (10,5)))
metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5)))

Просмотреть файл

@ -3,67 +3,59 @@ from functools import partial
import numpy as np
import pytest
import torch
from sklearn.metrics import average_precision_score as _sk_average_precision_score
from sklearn.metrics import average_precision_score as sk_average_precision_score
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision
from pytorch_lightning.metrics.functional.average_precision import average_precision
from tests.metrics.classification.inputs import (
_binary_prob_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_prob_inputs,
)
from tests.metrics.classification.inputs import _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES
torch.manual_seed(42)
def sk_average_precision_score(y_true, probas_pred, num_classes=1):
def _sk_average_precision_score(y_true, probas_pred, num_classes=1):
if num_classes == 1:
return _sk_average_precision_score(y_true, probas_pred)
return sk_average_precision_score(y_true, probas_pred)
res = []
for i in range(num_classes):
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
res.append(_sk_average_precision_score(y_true_temp, probas_pred[:, i]))
res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i]))
return res
def _binary_prob_sk_metric(preds, target, num_classes=1):
def _sk_avg_prec_binary_prob(preds, target, num_classes=1):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
def _multiclass_prob_sk_metric(preds, target, num_classes=1):
def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1).numpy()
return sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
def _multidim_multiclass_prob_sk_metric(preds, target, num_classes=1):
def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.view(-1).numpy()
return sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1),
(
_multiclass_prob_inputs.preds,
_multiclass_prob_inputs.target,
_multiclass_prob_sk_metric,
NUM_CLASSES),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_multidim_multiclass_prob_sk_metric,
NUM_CLASSES
),
])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES),
]
)
class TestAveragePrecision(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
@ -87,16 +79,19 @@ class TestAveragePrecision(MetricTester):
)
@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [
# Check the average_precision_score of a constant predictor is
# the TPR
# Generate a dataset with 25% of positives
# And a constant score
# The precision is then the fraction of positive whatever the recall
# is, as there is only one threshold:
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
# With threshold 0.8 : 1 TP and 2 TN and one FN
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
])
@pytest.mark.parametrize(
['scores', 'target', 'expected_score'],
[
# Check the average_precision_score of a constant predictor is
# the TPR
# Generate a dataset with 25% of positives
# And a constant score
# The precision is then the fraction of positive whatever the recall
# is, as there is only one threshold:
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
# With threshold 0.8 : 1 TP and 2 TN and one FN
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
]
)
def test_average_precision(scores, target, expected_score):
assert average_precision(scores, target) == expected_score

Просмотреть файл

@ -7,71 +7,68 @@ from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix
from tests.metrics.classification.inputs import (
_binary_inputs,
_binary_prob_inputs,
_multiclass_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_inputs,
_multidim_multiclass_prob_inputs,
_multilabel_inputs,
_multilabel_prob_inputs,
)
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
torch.manual_seed(42)
def _binary_prob_sk_metric(preds, target, normalize=None):
def _sk_cm_binary_prob(preds, target, normalize=None):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
def _binary_sk_metric(preds, target, normalize=None):
def _sk_cm_binary(preds, target, normalize=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
def _multilabel_prob_sk_metric(preds, target, normalize=None):
def _sk_cm_multilabel_prob(preds, target, normalize=None):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
def _multilabel_sk_metric(preds, target, normalize=None):
def _sk_cm_multilabel(preds, target, normalize=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
def _multiclass_prob_sk_metric(preds, target, normalize=None):
def _sk_cm_multiclass_prob(preds, target, normalize=None):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
def _multiclass_sk_metric(preds, target, normalize=None):
def _sk_cm_multiclass(preds, target, normalize=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
def _multidim_multiclass_prob_sk_metric(preds, target, normalize=None):
def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
def _multidim_multiclass_sk_metric(preds, target, normalize=None):
def _sk_cm_multidim_multiclass(preds, target, normalize=None):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
@ -79,55 +76,53 @@ def _multidim_multiclass_sk_metric(preds, target, normalize=None):
@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 2),
(_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric, 2),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric, 2),
(_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric, 2),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric, NUM_CLASSES),
(_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric, NUM_CLASSES),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_multidim_multiclass_prob_sk_metric,
NUM_CLASSES
),
(
_multidim_multiclass_inputs.preds,
_multidim_multiclass_inputs.target,
_multidim_multiclass_sk_metric,
NUM_CLASSES
)
])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2),
(_input_binary.preds, _input_binary.target, _sk_cm_binary, 2),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, 2),
(_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES),
(_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES),
(_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES)]
)
class TestConfusionMatrix(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
self.run_class_metric_test(ddp=ddp,
preds=preds,
target=target,
metric_class=ConfusionMatrix,
sk_metric=partial(sk_metric, normalize=normalize),
dist_sync_on_step=dist_sync_on_step,
metric_args={"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize}
)
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=ConfusionMatrix,
sk_metric=partial(sk_metric, normalize=normalize),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize
}
)
def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes):
self.run_functional_metric_test(preds,
target,
metric_functional=confusion_matrix,
sk_metric=partial(sk_metric, normalize=normalize),
metric_args={"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize}
)
self.run_functional_metric_test(
preds,
target,
metric_functional=confusion_matrix,
sk_metric=partial(sk_metric, normalize=normalize),
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"normalize": normalize
}
)
def test_warning_on_nan(tmpdir):
preds = torch.randint(3, size=(20,))
target = torch.randint(3, size=(20,))
preds = torch.randint(3, size=(20, ))
target = torch.randint(3, size=(20, ))
with pytest.warns(UserWarning, match='.* nan values found in confusion matrix have been replaced with zeros.'):
confusion_matrix(preds, target, num_classes=5, normalize='true')

Просмотреть файл

@ -7,17 +7,14 @@ from sklearn.metrics import fbeta_score
from pytorch_lightning.metrics import F1, FBeta
from pytorch_lightning.metrics.functional import f1, fbeta
from tests.metrics.classification.inputs import (
_binary_inputs,
_binary_prob_inputs,
_multiclass_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_inputs,
_multidim_multiclass_prob_inputs,
_multilabel_inputs,
_multilabel_inputs_no_match,
_multilabel_prob_inputs,
)
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
from tests.metrics.classification.inputs import _input_multilabel_no_match as _input_mlb_nomatch
from tests.metrics.classification.inputs import _input_multilabel_prob as _mlb_prob_inputs
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
torch.manual_seed(42)
@ -82,28 +79,24 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0):
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes, multilabel",
[
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_fbeta_binary_prob, 1, False),
(_binary_inputs.preds, _binary_inputs.target, _sk_fbeta_binary, 1, False),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True),
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
(_multilabel_inputs_no_match.preds, _multilabel_inputs_no_match.target,
_sk_fbeta_multilabel, NUM_CLASSES, True),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False),
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_fbeta_multiclass, NUM_CLASSES, False),
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target,
_sk_fbeta_multidim_multiclass_prob, NUM_CLASSES, False),
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target,
_sk_fbeta_multidim_multiclass, NUM_CLASSES, False),
(_input_binary_prob.preds, _input_binary_prob.target, _sk_fbeta_binary_prob, 1, False),
(_input_binary.preds, _input_binary.target, _sk_fbeta_binary, 1, False),
(_mlb_prob_inputs.preds, _mlb_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True),
(_input_mlb.preds, _input_mlb.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
(_input_mlb_nomatch.preds, _input_mlb_nomatch.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False),
(_input_mcls.preds, _input_mcls.target, _sk_fbeta_multiclass, NUM_CLASSES, False),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_fbeta_multidim_multiclass_prob, NUM_CLASSES, False),
(_input_mdmc.preds, _input_mdmc.target, _sk_fbeta_multidim_multiclass, NUM_CLASSES, False),
],
)
@pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None])
@pytest.mark.parametrize("beta", [0.5, 1.0, 2.0])
class TestFBeta(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_fbeta(
self, preds, target, sk_metric, num_classes, multilabel, average, beta, ddp, dist_sync_on_step
):
def test_fbeta(self, preds, target, sk_metric, num_classes, multilabel, average, beta, ddp, dist_sync_on_step):
metric_class = F1 if beta == 1.0 else partial(FBeta, beta=beta)
self.run_class_metric_test(
@ -123,21 +116,21 @@ class TestFBeta(MetricTester):
check_batch=False,
)
def test_fbeta_functional(
self, preds, target, sk_metric, num_classes, multilabel, average, beta
):
def test_fbeta_functional(self, preds, target, sk_metric, num_classes, multilabel, average, beta):
metric_functional = f1 if beta == 1.0 else partial(fbeta, beta=beta)
self.run_functional_metric_test(preds=preds,
target=target,
metric_functional=metric_functional,
sk_metric=partial(sk_metric, average=average, beta=beta),
metric_args={
"num_classes": num_classes,
"average": average,
"multilabel": multilabel,
"threshold": THRESHOLD}
)
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=metric_functional,
sk_metric=partial(sk_metric, average=average, beta=beta),
metric_args={
"num_classes": num_classes,
"average": average,
"multilabel": multilabel,
"threshold": THRESHOLD
}
)
@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [

Просмотреть файл

@ -5,18 +5,15 @@ from sklearn.metrics import hamming_loss as sk_hamming_loss
from pytorch_lightning.metrics import HammingDistance
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import hamming_distance
from tests.metrics.classification.inputs import (
_binary_inputs,
_binary_prob_inputs,
_multiclass_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_inputs,
_multidim_multiclass_prob_inputs,
_multilabel_inputs,
_multilabel_multidim_inputs,
_multilabel_multidim_prob_inputs,
_multilabel_prob_inputs,
)
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
from tests.metrics.classification.inputs import _input_multilabel_multidim as _input_mlmd
from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.metrics.utils import MetricTester, THRESHOLD
torch.manual_seed(42)
@ -33,19 +30,20 @@ def _sk_hamming_loss(preds, target):
@pytest.mark.parametrize(
"preds, target",
[
(_binary_prob_inputs.preds, _binary_prob_inputs.target),
(_binary_inputs.preds, _binary_inputs.target),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target),
(_multilabel_inputs.preds, _multilabel_inputs.target),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target),
(_multiclass_inputs.preds, _multiclass_inputs.target),
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target),
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target),
(_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target),
(_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target),
(_input_binary_prob.preds, _input_binary_prob.target),
(_input_binary.preds, _input_binary.target),
(_input_mlb_prob.preds, _input_mlb_prob.target),
(_input_mlb.preds, _input_mlb.target),
(_input_mcls_prob.preds, _input_mcls_prob.target),
(_input_mcls.preds, _input_mcls.target),
(_input_mdmc_prob.preds, _input_mdmc_prob.target),
(_input_mdmc.preds, _input_mdmc.target),
(_input_mlmd_prob.preds, _input_mlmd_prob.target),
(_input_mlmd.preds, _input_mlmd.target),
],
)
class TestHammingDistance(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target):
@ -71,7 +69,7 @@ class TestHammingDistance(MetricTester):
@pytest.mark.parametrize("threshold", [1.5])
def test_wrong_params(threshold):
preds, target = _multiclass_prob_inputs.preds, _multiclass_prob_inputs.target
preds, target = _input_mcls_prob.preds, _input_mcls_prob.target
with pytest.raises(ValueError):
ham_dist = HammingDistance(threshold=threshold)

Просмотреть файл

@ -4,16 +4,16 @@ from torch import rand, randint
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.utils import select_topk, to_onehot
from tests.metrics.classification.inputs import _binary_inputs as _bin
from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob
from tests.metrics.classification.inputs import _multiclass_inputs as _mc
from tests.metrics.classification.inputs import _multiclass_prob_inputs as _mc_prob
from tests.metrics.classification.inputs import _multidim_multiclass_inputs as _mdmc
from tests.metrics.classification.inputs import _multidim_multiclass_prob_inputs as _mdmc_prob
from tests.metrics.classification.inputs import _multilabel_inputs as _ml
from tests.metrics.classification.inputs import _multilabel_multidim_inputs as _mlmd
from tests.metrics.classification.inputs import _multilabel_multidim_prob_inputs as _mlmd_prob
from tests.metrics.classification.inputs import _multilabel_prob_inputs as _ml_prob
from tests.metrics.classification.inputs import _input_binary as _bin
from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob
from tests.metrics.classification.inputs import _input_multiclass as _mc
from tests.metrics.classification.inputs import _input_multiclass_prob as _mc_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _ml
from tests.metrics.classification.inputs import _input_multilabel_multidim as _mlmd
from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _mlmd_prob
from tests.metrics.classification.inputs import _input_multilabel_prob as _ml_prob
from tests.metrics.classification.inputs import Input
from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, THRESHOLD
@ -155,6 +155,7 @@ def _mlmd_prob_to_mc_preds_tr(x):
],
)
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
def __get_data_type_enum(str_exp_mode):
return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)
@ -204,7 +205,7 @@ def test_threshold():
@pytest.mark.parametrize("threshold", [-0.5, 0.0, 1.0, 1.5])
def test_incorrect_threshold(threshold):
preds, target = rand(size=(7,)), randint(high=2, size=(7,))
preds, target = rand(size=(7, )), randint(high=2, size=(7, ))
with pytest.raises(ValueError):
_input_format_classification(preds, target, threshold=threshold)
@ -213,21 +214,21 @@ def test_incorrect_threshold(threshold):
"preds, target, num_classes, is_multiclass",
[
# Target not integer
(randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), None, None),
(randint(high=2, size=(7, )), randint(high=2, size=(7, )).float(), None, None),
# Target negative
(randint(high=2, size=(7,)), -randint(high=2, size=(7,)), None, None),
(randint(high=2, size=(7, )), -randint(high=2, size=(7, )), None, None),
# Preds negative integers
(-randint(high=2, size=(7,)), randint(high=2, size=(7,)), None, None),
(-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None),
# Negative probabilities
(-rand(size=(7,)), randint(high=2, size=(7,)), None, None),
(-rand(size=(7, )), randint(high=2, size=(7, )), None, None),
# is_multiclass=False and target > 1
(rand(size=(7,)), randint(low=2, high=4, size=(7,)), None, False),
(rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False),
# is_multiclass=False and preds integers with > 1
(randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), None, False),
(randint(low=2, high=4, size=(7, )), randint(high=2, size=(7, )), None, False),
# Wrong batch size
(randint(high=2, size=(8,)), randint(high=2, size=(7,)), None, None),
(randint(high=2, size=(8, )), randint(high=2, size=(7, )), None, None),
# Completely wrong shape
(randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), None, None),
(randint(high=2, size=(7, )), randint(high=2, size=(7, 4)), None, None),
# Same #dims, different shape
(randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), None, None),
# Same shape and preds floats, target not binary
@ -237,11 +238,11 @@ def test_incorrect_threshold(threshold):
# #dims in preds = 1 + #dims in target, preds not float
(randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None),
# is_multiclass=False, with C dimension > 2
(_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE,)), None, False),
(_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False),
# Probs of multiclass preds do not sum up to 1
(rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None),
# Max target larger or equal to C dimension
(_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE,)), None, None),
(_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, )), None, None),
# C dimension not equal to num_classes
(_mc_prob.preds[0], _mc_prob.target[0], NUM_CLASSES + 1, None),
# Max target larger than num_classes (with #dim preds = 1 + #dims target)
@ -251,7 +252,7 @@ def test_incorrect_threshold(threshold):
# Max preds larger than num_classes (with #dim preds = #dims target)
(randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None),
# Num_classes=1, but is_multiclass not false
(randint(high=2, size=(7,)), randint(high=2, size=(7,)), 1, None),
(randint(high=2, size=(7, )), randint(high=2, size=(7, )), 1, None),
# is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes
(randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False),
# Multilabel input with implied class dimension != num_classes
@ -259,12 +260,12 @@ def test_incorrect_threshold(threshold):
# Multilabel input with is_multiclass=True, but num_classes != 2 (or None)
(rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True),
# Binary input, num_classes > 2
(rand(size=(7,)), randint(high=2, size=(7,)), 4, None),
(rand(size=(7, )), randint(high=2, size=(7, )), 4, None),
# Binary input, num_classes == 2 and is_multiclass not True
(rand(size=(7,)), randint(high=2, size=(7,)), 2, None),
(rand(size=(7,)), randint(high=2, size=(7,)), 2, False),
(rand(size=(7, )), randint(high=2, size=(7, )), 2, None),
(rand(size=(7, )), randint(high=2, size=(7, )), 2, False),
# Binary input, num_classes == 1 and is_multiclass=True
(rand(size=(7,)), randint(high=2, size=(7,)), 1, True),
(rand(size=(7, )), randint(high=2, size=(7, )), 1, True),
],
)
def test_incorrect_inputs(preds, target, num_classes, is_multiclass):

Просмотреть файл

@ -7,16 +7,13 @@ from sklearn.metrics import jaccard_score as sk_jaccard_score
from pytorch_lightning.metrics.classification.iou import IoU
from pytorch_lightning.metrics.functional.iou import iou
from tests.metrics.classification.inputs import (
_binary_inputs,
_binary_prob_inputs,
_multiclass_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_inputs,
_multidim_multiclass_prob_inputs,
_multilabel_inputs,
_multilabel_prob_inputs,
)
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
@ -77,52 +74,50 @@ def _sk_iou_multidim_multiclass(preds, target, average=None):
@pytest.mark.parametrize("reduction", ['elementwise_mean', 'none'])
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_iou_binary_prob, 2),
(_binary_inputs.preds, _binary_inputs.target, _sk_iou_binary, 2),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_iou_multilabel_prob, 2),
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_iou_multilabel, 2),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_iou_multiclass_prob, NUM_CLASSES),
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_iou_multiclass, NUM_CLASSES),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_sk_iou_multidim_multiclass_prob,
NUM_CLASSES
),
(
_multidim_multiclass_inputs.preds,
_multidim_multiclass_inputs.target,
_sk_iou_multidim_multiclass,
NUM_CLASSES
)
])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_iou_binary_prob, 2),
(_input_binary.preds, _input_binary.target, _sk_iou_binary, 2),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_iou_multilabel_prob, 2),
(_input_mlb.preds, _input_mlb.target, _sk_iou_multilabel, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_iou_multiclass_prob, NUM_CLASSES),
(_input_mcls.preds, _input_mcls.target, _sk_iou_multiclass, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_iou_multidim_multiclass_prob, NUM_CLASSES),
(_input_mdmc.preds, _input_mdmc.target, _sk_iou_multidim_multiclass, NUM_CLASSES)]
)
class TestIoU(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_confusion_matrix(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
self.run_class_metric_test(ddp=ddp,
preds=preds,
target=target,
metric_class=IoU,
sk_metric=partial(sk_metric, average=average),
dist_sync_on_step=dist_sync_on_step,
metric_args={"num_classes": num_classes,
"threshold": THRESHOLD,
"reduction": reduction}
)
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=IoU,
sk_metric=partial(sk_metric, average=average),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"reduction": reduction
}
)
def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric, num_classes):
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
self.run_functional_metric_test(preds,
target,
metric_functional=iou,
sk_metric=partial(sk_metric, average=average),
metric_args={"num_classes": num_classes,
"threshold": THRESHOLD,
"reduction": reduction}
)
self.run_functional_metric_test(
preds,
target,
metric_functional=iou,
sk_metric=partial(sk_metric, average=average),
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
"reduction": reduction
}
)
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
@ -148,35 +143,38 @@ def test_iou(half_ones, reduction, ignore_index, expected):
# test `absent_score`
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
# 2 classes, class 0 is correct everywhere, class 1 is absent.
pytest.param([0], [0], None, -1., 2, [1., -1.]),
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
# absent_score not applied if only class 0 is present and it's the only class.
pytest.param([0], [0], None, -1., 1, [1.]),
# 2 classes, class 1 is correct everywhere, class 0 is absent.
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
pytest.param([1], [1], 0, -1., 2, [1.0]),
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
# 2 is absent.
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
# 2 is absent.
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
# Sanity checks with absent_score of 1.0.
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
])
@pytest.mark.parametrize(
['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'],
[
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
# 2 classes, class 0 is correct everywhere, class 1 is absent.
pytest.param([0], [0], None, -1., 2, [1., -1.]),
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
# absent_score not applied if only class 0 is present and it's the only class.
pytest.param([0], [0], None, -1., 1, [1.]),
# 2 classes, class 1 is correct everywhere, class 0 is absent.
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
pytest.param([1], [1], 0, -1., 2, [1.0]),
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
# 2 is absent.
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
# 2 is absent.
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
# Sanity checks with absent_score of 1.0.
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
]
)
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
iou_val = iou(
pred=torch.tensor(pred),
@ -191,19 +189,22 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
])
@pytest.mark.parametrize(
['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'],
[
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
]
)
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
iou_val = iou(
pred=torch.tensor(pred),

Просмотреть файл

@ -9,12 +9,13 @@ from sklearn.metrics import precision_score, recall_score
from pytorch_lightning.metrics import Metric, Precision, Recall
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import precision, precision_recall, recall
from tests.metrics.classification.inputs import _binary_inputs, _binary_prob_inputs, _multiclass_inputs
from tests.metrics.classification.inputs import _multiclass_prob_inputs as _mc_prob
from tests.metrics.classification.inputs import _multidim_multiclass_inputs as _mdmc
from tests.metrics.classification.inputs import _multidim_multiclass_prob_inputs as _mdmc_prob
from tests.metrics.classification.inputs import _multilabel_inputs as _ml
from tests.metrics.classification.inputs import _multilabel_prob_inputs as _ml_prob
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
torch.manual_seed(42)
@ -45,7 +46,9 @@ def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, i
return sk_scores
def _sk_prec_recall_mdmc(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average):
def _sk_prec_recall_multidim_multiclass(
preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average
):
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
)
@ -89,8 +92,8 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
with pytest.raises(ValueError, match=match_str):
fn_metric(
_binary_inputs.preds[0],
_binary_inputs.target[0],
_input_binary.preds[0],
_input_binary.target[0],
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
@ -99,8 +102,8 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
with pytest.raises(ValueError, match=match_str):
precision_recall(
_binary_inputs.preds[0],
_binary_inputs.target[0],
_input_binary.preds[0],
_input_binary.target[0],
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
@ -156,19 +159,26 @@ def test_no_support(metric_class, metric_fn):
@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper",
[
(_binary_prob_inputs.preds, _binary_prob_inputs.target, 1, None, None, _sk_prec_recall),
(_binary_inputs.preds, _binary_inputs.target, 1, False, None, _sk_prec_recall),
(_ml_prob.preds, _ml_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_ml.preds, _ml.target, NUM_CLASSES, False, None, _sk_prec_recall),
(_mc_prob.preds, _mc_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_multiclass_inputs.preds, _multiclass_inputs.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc),
(_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc),
(_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc),
(_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc),
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall),
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall),
(_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall),
(_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global",
_sk_prec_recall_multidim_multiclass
),
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise",
_sk_prec_recall_multidim_multiclass
),
],
)
class TestPrecisionRecall(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_precision_recall_class(
@ -278,11 +288,15 @@ def test_precision_recall_joint(average):
which are already tested thoroughly.
"""
precision_result = precision(_mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES)
recall_result = recall(_mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES)
precision_result = precision(
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
)
recall_result = recall(
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
)
prec_recall_result = precision_recall(
_mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
)
assert torch.equal(precision_result, prec_recall_result[0])

Просмотреть файл

@ -3,71 +3,63 @@ from functools import partial
import numpy as np
import pytest
import torch
from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve
from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve
from tests.metrics.classification.inputs import (
_binary_prob_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_prob_inputs,
)
from tests.metrics.classification.inputs import _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES
torch.manual_seed(42)
def sk_precision_recall_curve(y_true, probas_pred, num_classes=1):
def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1):
""" Adjusted comparison function that can also handles multiclass """
if num_classes == 1:
return _sk_precision_recall_curve(y_true, probas_pred)
return sk_precision_recall_curve(y_true, probas_pred)
precision, recall, thresholds = [], [], []
for i in range(num_classes):
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
res = _sk_precision_recall_curve(y_true_temp, probas_pred[:, i])
res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i])
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
return precision, recall, thresholds
def _binary_prob_sk_metric(preds, target, num_classes=1):
def _sk_prec_rc_binary_prob(preds, target, num_classes=1):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
def _multiclass_prob_sk_metric(preds, target, num_classes=1):
def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1).numpy()
return sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
def _multidim_multiclass_prob_sk_metric(preds, target, num_classes=1):
def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.view(-1).numpy()
return sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1),
(
_multiclass_prob_inputs.preds,
_multiclass_prob_inputs.target,
_multiclass_prob_sk_metric,
NUM_CLASSES),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_multidim_multiclass_prob_sk_metric,
NUM_CLASSES
),
])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES),
]
)
class TestPrecisionRecallCurve(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
@ -91,9 +83,10 @@ class TestPrecisionRecallCurve(MetricTester):
)
@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])
])
@pytest.mark.parametrize(
['pred', 'target', 'expected_p', 'expected_r', 'expected_t'],
[pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])]
)
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
assert p.size() == r.size()

Просмотреть файл

@ -3,71 +3,63 @@ from functools import partial
import numpy as np
import pytest
import torch
from sklearn.metrics import roc_curve as _sk_roc_curve
from sklearn.metrics import roc_curve as sk_roc_curve
from pytorch_lightning.metrics.classification.roc import ROC
from pytorch_lightning.metrics.functional.roc import roc
from tests.metrics.classification.inputs import (
_binary_prob_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_prob_inputs,
)
from tests.metrics.classification.inputs import _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES
torch.manual_seed(42)
def sk_roc_curve(y_true, probas_pred, num_classes=1):
def _sk_roc_curve(y_true, probas_pred, num_classes=1):
""" Adjusted comparison function that can also handles multiclass """
if num_classes == 1:
return _sk_roc_curve(y_true, probas_pred, drop_intermediate=False)
return sk_roc_curve(y_true, probas_pred, drop_intermediate=False)
fpr, tpr, thresholds = [], [], []
for i in range(num_classes):
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
res = _sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False)
res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False)
fpr.append(res[0])
tpr.append(res[1])
thresholds.append(res[2])
return fpr, tpr, thresholds
def _binary_prob_sk_metric(preds, target, num_classes=1):
def _sk_roc_binary_prob(preds, target, num_classes=1):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
def _multiclass_prob_sk_metric(preds, target, num_classes=1):
def _sk_roc_multiclass_prob(preds, target, num_classes=1):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1).numpy()
return sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
def _multidim_multiclass_prob_sk_metric(preds, target, num_classes=1):
def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.view(-1).numpy()
return sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1),
(
_multiclass_prob_inputs.preds,
_multiclass_prob_inputs.target,
_multiclass_prob_sk_metric,
NUM_CLASSES),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_multidim_multiclass_prob_sk_metric,
NUM_CLASSES
),
])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES),
]
)
class TestROC(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):

Просмотреть файл

@ -9,12 +9,12 @@ from sklearn.metrics import multilabel_confusion_matrix
from pytorch_lightning.metrics import StatScores
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import stat_scores
from tests.metrics.classification.inputs import _binary_inputs, _binary_prob_inputs, _multiclass_inputs
from tests.metrics.classification.inputs import _multiclass_prob_inputs as _mc_prob
from tests.metrics.classification.inputs import _multidim_multiclass_inputs as _mdmc
from tests.metrics.classification.inputs import _multidim_multiclass_prob_inputs as _mdmc_prob
from tests.metrics.classification.inputs import _multilabel_inputs
from tests.metrics.classification.inputs import _multilabel_prob_inputs as _ml_prob
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mccls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _input_mcls
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
torch.manual_seed(42)
@ -57,7 +57,7 @@ def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_in
return sk_stats
def _sk_stat_scores_mdmc(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index, top_k):
def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index, top_k):
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k
)
@ -83,13 +83,13 @@ def _sk_stat_scores_mdmc(preds, target, reduce, mdmc_reduce, num_classes, is_mul
@pytest.mark.parametrize(
"reduce, mdmc_reduce, num_classes, inputs, ignore_index",
[
["unknown", None, None, _binary_inputs, None],
["micro", "unknown", None, _binary_inputs, None],
["macro", None, None, _binary_inputs, None],
["micro", None, None, _mdmc_prob, None],
["micro", None, None, _binary_prob_inputs, 0],
["micro", None, None, _mc_prob, NUM_CLASSES],
["micro", None, NUM_CLASSES, _mc_prob, NUM_CLASSES],
["unknown", None, None, _input_binary, None],
["micro", "unknown", None, _input_binary, None],
["macro", None, None, _input_binary, None],
["micro", None, None, _input_mdmc_prob, None],
["micro", None, None, _input_binary_prob, 0],
["micro", None, None, _input_mccls_prob, NUM_CLASSES],
["micro", None, NUM_CLASSES, _input_mccls_prob, NUM_CLASSES],
],
)
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
@ -120,18 +120,21 @@ def test_wrong_threshold():
@pytest.mark.parametrize(
"preds, target, sk_fn, mdmc_reduce, num_classes, is_multiclass, top_k",
[
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_stat_scores, None, 1, None, None),
(_binary_inputs.preds, _binary_inputs.target, _sk_stat_scores, None, 1, False, None),
(_ml_prob.preds, _ml_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_ml_prob.preds, _ml_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
(_mc_prob.preds, _mc_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_mc_prob.preds, _mc_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_mdmc.preds, _mdmc.target, _sk_stat_scores_mdmc, "samplewise", NUM_CLASSES, None, None),
(_mdmc_prob.preds, _mdmc_prob.target, _sk_stat_scores_mdmc, "samplewise", NUM_CLASSES, None, None),
(_mdmc.preds, _mdmc.target, _sk_stat_scores_mdmc, "global", NUM_CLASSES, None, None),
(_mdmc_prob.preds, _mdmc_prob.target, _sk_stat_scores_mdmc, "global", NUM_CLASSES, None, None),
(_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None),
(_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None,
None
),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None),
],
)
class TestStatScores(MetricTester):

Просмотреть файл

@ -63,7 +63,7 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
# if you fix the array inside the function, you'd also have fix the shape,
# because when the array changes, you also have to fix the shape
seed_everything(0)
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100
target = torch.tensor([0, 1] * 50, dtype=torch.int)
if sample_weight is not None:
sample_weight = torch.ones_like(pred) * sample_weight
@ -73,9 +73,9 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
assert isinstance(tps, torch.Tensor)
assert isinstance(fps, torch.Tensor)
assert isinstance(thresh, torch.Tensor)
assert tps.shape == (exp_shape,)
assert fps.shape == (exp_shape,)
assert thresh.shape == (exp_shape,)
assert tps.shape == (exp_shape, )
assert fps.shape == (exp_shape, )
assert thresh.shape == (exp_shape, )
@pytest.mark.parametrize(['pred', 'target', 'expected'], [

Просмотреть файл

@ -46,19 +46,19 @@ def test_multi_batch_image_gradients():
image = torch.stack([single_channel_img for _ in range(BATCH_SIZE)], dim=0)
true_dy = [
[5., 5., 5., 5., 5., ],
[5., 5., 5., 5., 5., ],
[5., 5., 5., 5., 5., ],
[5., 5., 5., 5., 5., ],
[0., 0., 0., 0., 0., ]
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[0., 0., 0., 0., 0.],
]
true_dx = [
[1., 1., 1., 1., 0., ],
[1., 1., 1., 1., 0., ],
[1., 1., 1., 1., 0., ],
[1., 1., 1., 1., 0., ],
[1., 1., 1., 1., 0., ]
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
]
true_dy = torch.Tensor(true_dy)
true_dx = torch.Tensor(true_dx)
@ -85,19 +85,19 @@ def test_image_gradients():
image = torch.reshape(image, (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH))
true_dy = [
[5., 5., 5., 5., 5., ],
[5., 5., 5., 5., 5., ],
[5., 5., 5., 5., 5., ],
[5., 5., 5., 5., 5., ],
[0., 0., 0., 0., 0., ]
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[0., 0., 0., 0., 0.],
]
true_dx = [
[1., 1., 1., 1., 0., ],
[1., 1., 1., 1., 0., ],
[1., 1., 1., 1., 0., ],
[1., 1., 1., 1., 0., ],
[1., 1., 1., 1., 0., ]
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
]
true_dy = torch.Tensor(true_dy)

Просмотреть файл

@ -15,7 +15,6 @@ REFERENCE2 = tuple(
)
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split())
# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
@ -44,7 +43,10 @@ smooth_func = SmoothingFunction().method2
)
def test_bleu_score(weights, n_gram, smooth_func, smooth):
nltk_output = sentence_bleu(
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func
[REFERENCE1, REFERENCE2, REFERENCE3],
HYPOTHESIS1,
weights=weights,
smoothing_function=smooth_func,
)
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, torch.tensor(nltk_output))

Просмотреть файл

@ -16,15 +16,13 @@ def test_reduce():
def test_class_reduce():
num = torch.randint(1, 10, (100,)).float()
denom = torch.randint(10, 20, (100,)).float()
weights = torch.randint(1, 100, (100,)).float()
num = torch.randint(1, 10, (100, )).float()
denom = torch.randint(10, 20, (100, )).float()
weights = torch.randint(1, 100, (100, )).float()
assert torch.allclose(class_reduce(num, denom, weights, 'micro'),
torch.sum(num) / torch.sum(denom))
assert torch.allclose(class_reduce(num, denom, weights, 'macro'),
torch.mean(num / denom))
assert torch.allclose(class_reduce(num, denom, weights, 'weighted'),
torch.sum(num / denom * (weights / torch.sum(weights))))
assert torch.allclose(class_reduce(num, denom, weights, 'none'),
num / denom)
assert torch.allclose(class_reduce(num, denom, weights, 'micro'), torch.sum(num) / torch.sum(denom))
assert torch.allclose(class_reduce(num, denom, weights, 'macro'), torch.mean(num / denom))
assert torch.allclose(
class_reduce(num, denom, weights, 'weighted'), torch.sum(num / denom * (weights / torch.sum(weights)))
)
assert torch.allclose(class_reduce(num, denom, weights, 'none'), num / denom)

Просмотреть файл

@ -13,13 +13,11 @@ def test_against_sklearn(similarity, reduction):
batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions
pl_dist = embedding_similarity(batch, similarity=similarity,
reduction=reduction, zero_diagonal=False)
pl_dist = embedding_similarity(batch, similarity=similarity, reduction=reduction, zero_diagonal=False)
def sklearn_embedding_distance(batch, similarity, reduction):
metric_func = {'cosine': pairwise.cosine_similarity,
'dot': pairwise.linear_kernel}[similarity]
metric_func = {'cosine': pairwise.cosine_similarity, 'dot': pairwise.linear_kernel}[similarity]
dist = metric_func(batch, batch)
if reduction == 'mean':
@ -28,8 +26,7 @@ def test_against_sklearn(similarity, reduction):
return dist.sum(axis=-1)
return dist
sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(),
similarity=similarity, reduction=reduction)
sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), similarity=similarity, reduction=reduction)
sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device)
assert torch.allclose(sk_dist, pl_dist)

Просмотреть файл

@ -15,10 +15,14 @@ num_targets = 5
Input = namedtuple('Input', ["preds", "target"])
_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),)
_single_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)
_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)
@ -43,6 +47,7 @@ def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score):
],
)
class TestExplainedVariance(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step):
@ -69,4 +74,4 @@ class TestExplainedVariance(MetricTester):
def test_error_on_different_shape(metric_class=ExplainedVariance):
metric = metric_class()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100,), torch.randn(50,))
metric(torch.randn(100, ), torch.randn(50, ))

Просмотреть файл

@ -17,10 +17,14 @@ num_targets = 5
Input = namedtuple('Input', ["preds", "target"])
_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),)
_single_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)
_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)
@ -52,10 +56,12 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error):
],
)
class TestMeanError(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_mean_error_class(self, preds, target, sk_metric, metric_class,
metric_functional, sk_fn, ddp, dist_sync_on_step):
def test_mean_error_class(
self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step
):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
@ -78,4 +84,4 @@ class TestMeanError(MetricTester):
def test_error_on_different_shape(metric_class):
metric = metric_class()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100,), torch.randn(50,))
metric(torch.randn(100, ), torch.randn(50, ))

Просмотреть файл

@ -12,15 +12,13 @@ from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES
torch.manual_seed(42)
Input = namedtuple('Input', ["preds", "target"])
_inputs = [
Input(
preds=torch.randint(n_cls_pred, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
target=torch.randint(n_cls_target, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
)
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]
) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]
]
@ -52,6 +50,7 @@ def _base_e_sk_metric(preds, target, data_range):
],
)
class TestPSNR(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_psnr(self, preds, target, data_range, base, sk_metric, ddp, dist_sync_on_step):
@ -61,7 +60,10 @@ class TestPSNR(MetricTester):
target,
PSNR,
partial(sk_metric, data_range=data_range),
metric_args={"data_range": data_range, "base": base},
metric_args={
"data_range": data_range,
"base": base
},
dist_sync_on_step=dist_sync_on_step,
)
@ -71,5 +73,8 @@ class TestPSNR(MetricTester):
target,
psnr,
partial(sk_metric, data_range=data_range),
metric_args={"data_range": data_range, "base": base},
metric_args={
"data_range": data_range,
"base": base
},
)

Просмотреть файл

@ -15,10 +15,14 @@ num_targets = 5
Input = namedtuple('Input', ["preds", "target"])
_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),)
_single_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)
_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)
@ -50,6 +54,7 @@ def _multi_target_sk_metric(preds, target, adjusted, multioutput):
],
)
class TestR2Score(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step):
@ -60,9 +65,7 @@ class TestR2Score(MetricTester):
R2Score,
partial(sk_metric, adjusted=adjusted, multioutput=multioutput),
dist_sync_on_step,
metric_args=dict(adjusted=adjusted,
multioutput=multioutput,
num_outputs=num_outputs),
metric_args=dict(adjusted=adjusted, multioutput=multioutput, num_outputs=num_outputs),
)
def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
@ -71,39 +74,41 @@ class TestR2Score(MetricTester):
target,
r2score,
partial(sk_metric, adjusted=adjusted, multioutput=multioutput),
metric_args=dict(adjusted=adjusted,
multioutput=multioutput),
metric_args=dict(adjusted=adjusted, multioutput=multioutput),
)
def test_error_on_different_shape(metric_class=R2Score):
metric = metric_class()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100,), torch.randn(50,))
metric(torch.randn(100, ), torch.randn(50, ))
def test_error_on_multidim_tensors(metric_class=R2Score):
metric = metric_class()
with pytest.raises(ValueError, match=r'Expected both prediction and target to be 1D or 2D tensors,'
r' but recevied tensors with dimension .'):
with pytest.raises(
ValueError,
match=r'Expected both prediction and target to be 1D or 2D tensors,'
r' but recevied tensors with dimension .'
):
metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5))
def test_error_on_too_few_samples(metric_class=R2Score):
metric = metric_class()
with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'):
metric(torch.randn(1,), torch.randn(1,))
metric(torch.randn(1, ), torch.randn(1, ))
def test_warning_on_too_large_adjusted(metric_class=R2Score):
metric = metric_class(adjusted=10)
with pytest.warns(UserWarning,
match="More independent regressions than datapoints in"
" adjusted r2 score. Falls back to standard r2 score."):
metric(torch.randn(10,), torch.randn(10,))
with pytest.warns(
UserWarning,
match="More independent regressions than datapoints in"
" adjusted r2 score. Falls back to standard r2 score."
):
metric(torch.randn(10, ), torch.randn(10, ))
with pytest.warns(UserWarning,
match="Division by zero in adjusted r2 score. Falls back to"
" standard r2 score."):
metric(torch.randn(11,), torch.randn(11,))
with pytest.warns(UserWarning, match="Division by zero in adjusted r2 score. Falls back to" " standard r2 score."):
metric(torch.randn(11, ), torch.randn(11, ))

Просмотреть файл

@ -11,10 +11,8 @@ from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES
torch.manual_seed(42)
Input = namedtuple('Input', ["preds", "target", "multichannel"])
_inputs = []
for size, channel, coef, multichannel, dtype in [
(12, 3, 0.9, True, torch.float),
@ -23,13 +21,11 @@ for size, channel, coef, multichannel, dtype in [
(15, 3, 0.6, True, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(
Input(
preds=preds,
target=preds * coef,
multichannel=multichannel,
)
)
_inputs.append(Input(
preds=preds,
target=preds * coef,
multichannel=multichannel,
))
def _sk_metric(preds, target, data_range, multichannel):
@ -41,8 +37,14 @@ def _sk_metric(preds, target, data_range, multichannel):
sk_target = sk_target[:, :, :, 0]
return structural_similarity(
sk_target, sk_preds, data_range=data_range, multichannel=multichannel,
gaussian_weights=True, win_size=11, sigma=1.5, use_sample_covariance=False
sk_target,
sk_preds,
data_range=data_range,
multichannel=multichannel,
gaussian_weights=True,
win_size=11,
sigma=1.5,
use_sample_covariance=False
)

Просмотреть файл

@ -7,13 +7,16 @@ import torch
from pytorch_lightning.metrics.compositional import CompositionalMetric
from pytorch_lightning.metrics.metric import Metric
_MARK_TORCH_LOWER_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
reason='required PT >= 1.5')
_MARK_TORCH_LOWER_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason='required PT >= 1.6')
_MARK_TORCH_LOWER_1_4 = dict(
condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"), reason='required PT >= 1.5'
)
_MARK_TORCH_LOWER_1_5 = dict(
condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason='required PT >= 1.6'
)
class DummyMetric(Metric):
def __init__(self, val_to_return):
super().__init__()
self._num_updates = 0
@ -295,7 +298,7 @@ def test_metrics_or(second_operand, expected_result):
def test_metrics_pow(second_operand, expected_result):
first_metric = DummyMetric(2)
final_pow = first_metric ** second_operand
final_pow = first_metric**second_operand
assert isinstance(final_pow, CompositionalMetric)
@ -349,7 +352,7 @@ def test_metrics_rmod(first_operand, expected_result):
def test_metrics_rpow(first_operand, expected_result):
second_operand = DummyMetric(2)
final_rpow = first_operand ** second_operand
final_rpow = first_operand**second_operand
assert isinstance(final_rpow, CompositionalMetric)

Просмотреть файл

@ -43,13 +43,14 @@ def _test_ddp_sum_cat(rank, worldsize):
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat])
def test_ddp(process):
torch.multiprocessing.spawn(process, args=(2,), nprocs=2)
torch.multiprocessing.spawn(process, args=(2, ), nprocs=2)
def _test_non_contiguous_tensors(rank, worldsize):
setup_ddp(rank, worldsize)
class DummyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", default=[], dist_reduce_fx=None)
@ -68,4 +69,4 @@ def _test_non_contiguous_tensors(rank, worldsize):
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_non_contiguous_tensors():
""" Test that gather_all operation works for non contiguous tensors """
torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2,), nprocs=2)
torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2, ), nprocs=2)

Просмотреть файл

@ -55,7 +55,7 @@ def test_add_state():
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
a.add_state("c", torch.tensor(0), "cat")
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2,)
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2, )
with pytest.raises(ValueError):
a.add_state("d1", torch.tensor(0), 'xyz')
@ -89,6 +89,7 @@ def test_add_state_persistent():
def test_reset():
class A(Dummy):
pass
@ -109,7 +110,9 @@ def test_reset():
def test_update():
class A(Dummy):
def update(self, x):
self.x += x
@ -125,7 +128,9 @@ def test_update():
def test_compute():
class A(Dummy):
def update(self, x):
self.x += x
@ -150,7 +155,9 @@ def test_compute():
def test_forward():
class A(Dummy):
def update(self, x):
self.x += x
@ -168,6 +175,7 @@ def test_forward():
class DummyMetric1(Dummy):
def update(self, x):
self.x += x
@ -176,6 +184,7 @@ class DummyMetric1(Dummy):
class DummyMetric2(Dummy):
def update(self, y):
self.x -= y
@ -214,7 +223,9 @@ def test_state_dict(tmpdir):
def test_child_metric_state_dict():
""" test that child metric states will be added to parent state dict """
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.metric = Dummy()
@ -226,7 +237,7 @@ def test_child_metric_state_dict():
expected_state_dict = {
'metric.a': torch.tensor(0),
'metric.b': [],
'metric.c': torch.tensor(0)
'metric.c': torch.tensor(0),
}
assert module.state_dict() == expected_state_dict
@ -317,8 +328,7 @@ def test_metric_collection_wrong_input(tmpdir):
# Not all input are metrics (dict)
with pytest.raises(ValueError):
_ = MetricCollection({'metric1': m1,
'metric2': 5})
_ = MetricCollection({'metric1': m1, 'metric2': 5})
# Same metric passed in multiple times
with pytest.raises(ValueError, match='Encountered two metrics both named *.'):

Просмотреть файл

@ -6,6 +6,7 @@ from tests.base.boring_model import BoringModel
class SumMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
@ -18,6 +19,7 @@ class SumMetric(Metric):
class DiffMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
@ -30,7 +32,9 @@ class DiffMetric(Metric):
def test_metric_lightning(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = SumMetric()
@ -64,7 +68,9 @@ def test_metric_lightning(tmpdir):
def test_metric_lightning_log(tmpdir):
""" Test logging a metric object and that the metric state gets reset after each epoch."""
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric_step = SumMetric()
@ -103,7 +109,9 @@ def test_metric_lightning_log(tmpdir):
def test_scriptable(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
# the metric is not used in the module's `forward`
@ -141,7 +149,9 @@ def test_scriptable(tmpdir):
def test_metric_collection_lightning_log(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = MetricCollection([SumMetric(), DiffMetric()])