yapf tests metrics (PL^5845)
(cherry picked from commit cce1be3470a58aa46047d598ff00e9ad23f5f877)
This commit is contained in:
Родитель
8bc17d8db8
Коммит
a718ab3cca
|
@ -6,35 +6,31 @@ from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES
|
|||
|
||||
Input = namedtuple('Input', ["preds", "target"])
|
||||
|
||||
_input_binary_prob = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
|
||||
)
|
||||
|
||||
_binary_prob_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
_input_binary = Input(
|
||||
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
|
||||
)
|
||||
|
||||
|
||||
_binary_inputs = Input(
|
||||
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,))
|
||||
)
|
||||
|
||||
|
||||
_multilabel_prob_inputs = Input(
|
||||
_input_multilabel_prob = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
|
||||
)
|
||||
|
||||
_multilabel_multidim_prob_inputs = Input(
|
||||
_input_multilabel_multidim_prob = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM))
|
||||
)
|
||||
|
||||
_multilabel_inputs = Input(
|
||||
_input_multilabel = Input(
|
||||
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
|
||||
)
|
||||
|
||||
_multilabel_multidim_inputs = Input(
|
||||
_input_multilabel_multidim = Input(
|
||||
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM))
|
||||
)
|
||||
|
@ -43,21 +39,16 @@ _multilabel_multidim_inputs = Input(
|
|||
__temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
|
||||
__temp_target = abs(__temp_preds - 1)
|
||||
|
||||
_multilabel_inputs_no_match = Input(
|
||||
preds=__temp_preds,
|
||||
target=__temp_target
|
||||
)
|
||||
_input_multilabel_no_match = Input(preds=__temp_preds, target=__temp_target)
|
||||
|
||||
__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)
|
||||
__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True)
|
||||
|
||||
_multiclass_prob_inputs = Input(
|
||||
preds=__mc_prob_preds,
|
||||
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
|
||||
_input_multiclass_prob = Input(
|
||||
preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
|
||||
)
|
||||
|
||||
|
||||
_multiclass_inputs = Input(
|
||||
_input_multiclass = Input(
|
||||
preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)),
|
||||
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
|
||||
)
|
||||
|
@ -65,12 +56,11 @@ _multiclass_inputs = Input(
|
|||
__mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)
|
||||
__mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True)
|
||||
|
||||
_multidim_multiclass_prob_inputs = Input(
|
||||
preds=__mdmc_prob_preds,
|
||||
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))
|
||||
_input_multidim_multiclass_prob = Input(
|
||||
preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))
|
||||
)
|
||||
|
||||
_multidim_multiclass_inputs = Input(
|
||||
_input_multidim_multiclass = Input(
|
||||
preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)),
|
||||
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))
|
||||
)
|
||||
|
|
|
@ -8,18 +8,15 @@ from sklearn.metrics import accuracy_score as sk_accuracy
|
|||
from pytorch_lightning.metrics import Accuracy
|
||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
|
||||
from pytorch_lightning.metrics.functional import accuracy
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_inputs,
|
||||
_binary_prob_inputs,
|
||||
_multiclass_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
_multilabel_inputs,
|
||||
_multilabel_multidim_inputs,
|
||||
_multilabel_multidim_prob_inputs,
|
||||
_multilabel_prob_inputs,
|
||||
)
|
||||
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
|
||||
from tests.metrics.classification.inputs import _input_multilabel_multidim as _input_mlmd
|
||||
from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
|
||||
from tests.metrics.utils import MetricTester, THRESHOLD
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
@ -43,25 +40,26 @@ def _sk_accuracy(preds, target, subset_accuracy):
|
|||
@pytest.mark.parametrize(
|
||||
"preds, target, subset_accuracy",
|
||||
[
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, False),
|
||||
(_binary_inputs.preds, _binary_inputs.target, False),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, True),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, False),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, True),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, False),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, False),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, False),
|
||||
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, False),
|
||||
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, True),
|
||||
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, False),
|
||||
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, True),
|
||||
(_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, True),
|
||||
(_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, False),
|
||||
(_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target, True),
|
||||
(_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target, False),
|
||||
(_input_binary_prob.preds, _input_binary_prob.target, False),
|
||||
(_input_binary.preds, _input_binary.target, False),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target, True),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target, False),
|
||||
(_input_mlb.preds, _input_mlb.target, True),
|
||||
(_input_mlb.preds, _input_mlb.target, False),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, False),
|
||||
(_input_mcls.preds, _input_mcls.target, False),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, False),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, True),
|
||||
(_input_mdmc.preds, _input_mdmc.target, False),
|
||||
(_input_mdmc.preds, _input_mdmc.target, True),
|
||||
(_input_mlmd_prob.preds, _input_mlmd_prob.target, True),
|
||||
(_input_mlmd_prob.preds, _input_mlmd_prob.target, False),
|
||||
(_input_mlmd.preds, _input_mlmd.target, True),
|
||||
(_input_mlmd.preds, _input_mlmd.target, False),
|
||||
],
|
||||
)
|
||||
class TestAccuracies(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [False, True])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
|
||||
def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy):
|
||||
|
@ -72,7 +70,10 @@ class TestAccuracies(MetricTester):
|
|||
metric_class=Accuracy,
|
||||
sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy),
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy},
|
||||
metric_args={
|
||||
"threshold": THRESHOLD,
|
||||
"subset_accuracy": subset_accuracy
|
||||
},
|
||||
)
|
||||
|
||||
def test_accuracy_fn(self, preds, target, subset_accuracy):
|
||||
|
@ -81,21 +82,24 @@ class TestAccuracies(MetricTester):
|
|||
target,
|
||||
metric_functional=accuracy,
|
||||
sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy),
|
||||
metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy},
|
||||
metric_args={
|
||||
"threshold": THRESHOLD,
|
||||
"subset_accuracy": subset_accuracy
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
_l1to4 = [0.1, 0.2, 0.3, 0.4]
|
||||
_l1to4t3 = np.array([_l1to4, _l1to4, _l1to4])
|
||||
_l1to4t3_mc = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T]
|
||||
_l1to4t3_mcls = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T]
|
||||
|
||||
# The preds in these examples always put highest probability on class 3, second highest on class 2,
|
||||
# third highest on class 1, and lowest on class 0
|
||||
_topk_preds_mc = torch.tensor([_l1to4t3, _l1to4t3]).float()
|
||||
_topk_target_mc = torch.tensor([[1, 2, 3], [2, 1, 0]])
|
||||
_topk_preds_mcls = torch.tensor([_l1to4t3, _l1to4t3]).float()
|
||||
_topk_target_mcls = torch.tensor([[1, 2, 3], [2, 1, 0]])
|
||||
|
||||
# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :)
|
||||
_topk_preds_mdmc = torch.tensor([_l1to4t3_mc, _l1to4t3_mc]).float()
|
||||
_topk_preds_mdmc = torch.tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float()
|
||||
_topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]])
|
||||
|
||||
|
||||
|
@ -103,12 +107,12 @@ _topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0],
|
|||
@pytest.mark.parametrize(
|
||||
"preds, target, exp_result, k, subset_accuracy",
|
||||
[
|
||||
(_topk_preds_mc, _topk_target_mc, 1 / 6, 1, False),
|
||||
(_topk_preds_mc, _topk_target_mc, 3 / 6, 2, False),
|
||||
(_topk_preds_mc, _topk_target_mc, 5 / 6, 3, False),
|
||||
(_topk_preds_mc, _topk_target_mc, 1 / 6, 1, True),
|
||||
(_topk_preds_mc, _topk_target_mc, 3 / 6, 2, True),
|
||||
(_topk_preds_mc, _topk_target_mc, 5 / 6, 3, True),
|
||||
(_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, False),
|
||||
(_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, False),
|
||||
(_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, False),
|
||||
(_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, True),
|
||||
(_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, True),
|
||||
(_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, True),
|
||||
(_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, False),
|
||||
(_topk_preds_mdmc, _topk_target_mdmc, 8 / 18, 2, False),
|
||||
(_topk_preds_mdmc, _topk_target_mdmc, 13 / 18, 3, False),
|
||||
|
@ -138,14 +142,14 @@ def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy):
|
|||
@pytest.mark.parametrize(
|
||||
"preds, target",
|
||||
[
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target),
|
||||
(_binary_inputs.preds, _binary_inputs.target),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target),
|
||||
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target),
|
||||
(_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target),
|
||||
(_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target),
|
||||
(_input_binary_prob.preds, _input_binary_prob.target),
|
||||
(_input_binary.preds, _input_binary.target),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target),
|
||||
(_input_mlb.preds, _input_mlb.target),
|
||||
(_input_mcls.preds, _input_mcls.target),
|
||||
(_input_mdmc.preds, _input_mdmc.target),
|
||||
(_input_mlmd_prob.preds, _input_mlmd_prob.target),
|
||||
(_input_mlmd.preds, _input_mlmd.target),
|
||||
],
|
||||
)
|
||||
def test_topk_accuracy_wrong_input_types(preds, target):
|
||||
|
@ -160,7 +164,7 @@ def test_topk_accuracy_wrong_input_types(preds, target):
|
|||
|
||||
@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)])
|
||||
def test_wrong_params(top_k, threshold):
|
||||
preds, target = _multiclass_prob_inputs.preds, _multiclass_prob_inputs.target
|
||||
preds, target = _input_mcls_prob.preds, _input_mcls_prob.target
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
acc = Accuracy(threshold=threshold, top_k=top_k)
|
||||
|
|
|
@ -35,6 +35,7 @@ for i in range(4):
|
|||
|
||||
@pytest.mark.parametrize("x, y", _examples)
|
||||
class TestAUC(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_auc(self, x, y, ddp, dist_sync_on_step):
|
||||
|
@ -48,13 +49,7 @@ class TestAUC(MetricTester):
|
|||
)
|
||||
|
||||
def test_auc_functional(self, x, y):
|
||||
self.run_functional_metric_test(
|
||||
x,
|
||||
y,
|
||||
metric_functional=auc,
|
||||
sk_metric=sk_auc,
|
||||
metric_args={"reorder": False}
|
||||
)
|
||||
self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, metric_args={"reorder": False})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['x', 'y', 'expected'], [
|
||||
|
|
|
@ -7,25 +7,23 @@ from sklearn.metrics import roc_auc_score as sk_roc_auc_score
|
|||
|
||||
from pytorch_lightning.metrics.classification.auroc import AUROC
|
||||
from pytorch_lightning.metrics.functional.auroc import auroc
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_prob_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
_multilabel_multidim_prob_inputs,
|
||||
_multilabel_prob_inputs,
|
||||
)
|
||||
from tests.metrics.classification.inputs import _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
|
||||
from tests.metrics.utils import MetricTester, NUM_CLASSES
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def _binary_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
def _sk_auroc_binary_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr)
|
||||
|
||||
|
||||
def _multiclass_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
def _sk_auroc_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
sk_preds = preds.reshape(-1, num_classes).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return sk_roc_auc_score(
|
||||
|
@ -33,11 +31,11 @@ def _multiclass_prob_sk_metric(preds, target, num_classes, average='macro', max_
|
|||
y_score=sk_preds,
|
||||
average=average,
|
||||
max_fpr=max_fpr,
|
||||
multi_class=multi_class
|
||||
multi_class=multi_class,
|
||||
)
|
||||
|
||||
|
||||
def _multidim_multiclass_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
def _sk_auroc_multidim_multiclass_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return sk_roc_auc_score(
|
||||
|
@ -45,11 +43,11 @@ def _multidim_multiclass_prob_sk_metric(preds, target, num_classes, average='mac
|
|||
y_score=sk_preds,
|
||||
average=average,
|
||||
max_fpr=max_fpr,
|
||||
multi_class=multi_class
|
||||
multi_class=multi_class,
|
||||
)
|
||||
|
||||
|
||||
def _multilabel_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
def _sk_auroc_multilabel_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
sk_preds = preds.reshape(-1, num_classes).numpy()
|
||||
sk_target = target.reshape(-1, num_classes).numpy()
|
||||
return sk_roc_auc_score(
|
||||
|
@ -57,11 +55,11 @@ def _multilabel_prob_sk_metric(preds, target, num_classes, average='macro', max_
|
|||
y_score=sk_preds,
|
||||
average=average,
|
||||
max_fpr=max_fpr,
|
||||
multi_class=multi_class
|
||||
multi_class=multi_class,
|
||||
)
|
||||
|
||||
|
||||
def _multilabel_multidim_prob_sk_metric(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
|
||||
sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
|
||||
return sk_roc_auc_score(
|
||||
|
@ -69,40 +67,22 @@ def _multilabel_multidim_prob_sk_metric(preds, target, num_classes, average='mac
|
|||
y_score=sk_preds,
|
||||
average=average,
|
||||
max_fpr=max_fpr,
|
||||
multi_class=multi_class
|
||||
multi_class=multi_class,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1),
|
||||
(
|
||||
_multiclass_prob_inputs.preds,
|
||||
_multiclass_prob_inputs.target,
|
||||
_multiclass_prob_sk_metric,
|
||||
NUM_CLASSES
|
||||
),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_multidim_multiclass_prob_sk_metric,
|
||||
NUM_CLASSES
|
||||
),
|
||||
(
|
||||
_multilabel_prob_inputs.preds,
|
||||
_multilabel_prob_inputs.target,
|
||||
_multilabel_prob_sk_metric,
|
||||
NUM_CLASSES
|
||||
),
|
||||
(
|
||||
_multilabel_multidim_prob_inputs.preds,
|
||||
_multilabel_multidim_prob_inputs.target,
|
||||
_multilabel_multidim_prob_sk_metric,
|
||||
NUM_CLASSES
|
||||
)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric, num_classes",
|
||||
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES),
|
||||
(_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES)]
|
||||
)
|
||||
@pytest.mark.parametrize("average", ['macro', 'weighted'])
|
||||
@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5])
|
||||
class TestAUROC(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step):
|
||||
|
@ -121,9 +101,11 @@ class TestAUROC(MetricTester):
|
|||
metric_class=AUROC,
|
||||
sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr),
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_args={"num_classes": num_classes,
|
||||
"average": average,
|
||||
"max_fpr": max_fpr},
|
||||
metric_args={
|
||||
"num_classes": num_classes,
|
||||
"average": average,
|
||||
"max_fpr": max_fpr
|
||||
},
|
||||
)
|
||||
|
||||
def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr):
|
||||
|
@ -140,9 +122,11 @@ class TestAUROC(MetricTester):
|
|||
target,
|
||||
metric_functional=auroc,
|
||||
sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr),
|
||||
metric_args={"num_classes": num_classes,
|
||||
"average": average,
|
||||
"max_fpr": max_fpr},
|
||||
metric_args={
|
||||
"num_classes": num_classes,
|
||||
"average": average,
|
||||
"max_fpr": max_fpr
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
@ -152,10 +136,7 @@ def test_error_on_different_mode():
|
|||
"""
|
||||
metric = AUROC()
|
||||
# pass in multi-class data
|
||||
metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10,)))
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"The mode of data.* should be constant.*"
|
||||
):
|
||||
metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10, )))
|
||||
with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"):
|
||||
# pass in multi-label data
|
||||
metric.update(torch.rand(10, 5), torch.randint(0, 2, (10,5)))
|
||||
metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5)))
|
||||
|
|
|
@ -3,67 +3,59 @@ from functools import partial
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import average_precision_score as _sk_average_precision_score
|
||||
from sklearn.metrics import average_precision_score as sk_average_precision_score
|
||||
|
||||
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision
|
||||
from pytorch_lightning.metrics.functional.average_precision import average_precision
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_prob_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
)
|
||||
from tests.metrics.classification.inputs import _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.utils import MetricTester, NUM_CLASSES
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def sk_average_precision_score(y_true, probas_pred, num_classes=1):
|
||||
def _sk_average_precision_score(y_true, probas_pred, num_classes=1):
|
||||
if num_classes == 1:
|
||||
return _sk_average_precision_score(y_true, probas_pred)
|
||||
return sk_average_precision_score(y_true, probas_pred)
|
||||
|
||||
res = []
|
||||
for i in range(num_classes):
|
||||
y_true_temp = np.zeros_like(y_true)
|
||||
y_true_temp[y_true == i] = 1
|
||||
res.append(_sk_average_precision_score(y_true_temp, probas_pred[:, i]))
|
||||
res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i]))
|
||||
return res
|
||||
|
||||
|
||||
def _binary_prob_sk_metric(preds, target, num_classes=1):
|
||||
def _sk_avg_prec_binary_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
def _multiclass_prob_sk_metric(preds, target, num_classes=1):
|
||||
def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.reshape(-1, num_classes).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
def _multidim_multiclass_prob_sk_metric(preds, target, num_classes=1):
|
||||
def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1),
|
||||
(
|
||||
_multiclass_prob_inputs.preds,
|
||||
_multiclass_prob_inputs.target,
|
||||
_multiclass_prob_sk_metric,
|
||||
NUM_CLASSES),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_multidim_multiclass_prob_sk_metric,
|
||||
NUM_CLASSES
|
||||
),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric, num_classes", [
|
||||
(_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES),
|
||||
]
|
||||
)
|
||||
class TestAveragePrecision(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
|
||||
|
@ -87,16 +79,19 @@ class TestAveragePrecision(MetricTester):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [
|
||||
# Check the average_precision_score of a constant predictor is
|
||||
# the TPR
|
||||
# Generate a dataset with 25% of positives
|
||||
# And a constant score
|
||||
# The precision is then the fraction of positive whatever the recall
|
||||
# is, as there is only one threshold:
|
||||
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
|
||||
# With threshold 0.8 : 1 TP and 2 TN and one FN
|
||||
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
['scores', 'target', 'expected_score'],
|
||||
[
|
||||
# Check the average_precision_score of a constant predictor is
|
||||
# the TPR
|
||||
# Generate a dataset with 25% of positives
|
||||
# And a constant score
|
||||
# The precision is then the fraction of positive whatever the recall
|
||||
# is, as there is only one threshold:
|
||||
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
|
||||
# With threshold 0.8 : 1 TP and 2 TN and one FN
|
||||
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
|
||||
]
|
||||
)
|
||||
def test_average_precision(scores, target, expected_score):
|
||||
assert average_precision(scores, target) == expected_score
|
||||
|
|
|
@ -7,71 +7,68 @@ from sklearn.metrics import confusion_matrix as sk_confusion_matrix
|
|||
|
||||
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
|
||||
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_inputs,
|
||||
_binary_prob_inputs,
|
||||
_multiclass_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
_multilabel_inputs,
|
||||
_multilabel_prob_inputs,
|
||||
)
|
||||
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
|
||||
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
|
||||
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def _binary_prob_sk_metric(preds, target, normalize=None):
|
||||
def _sk_cm_binary_prob(preds, target, normalize=None):
|
||||
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
|
||||
|
||||
|
||||
def _binary_sk_metric(preds, target, normalize=None):
|
||||
def _sk_cm_binary(preds, target, normalize=None):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
|
||||
|
||||
|
||||
def _multilabel_prob_sk_metric(preds, target, normalize=None):
|
||||
def _sk_cm_multilabel_prob(preds, target, normalize=None):
|
||||
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
|
||||
|
||||
|
||||
def _multilabel_sk_metric(preds, target, normalize=None):
|
||||
def _sk_cm_multilabel(preds, target, normalize=None):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
|
||||
|
||||
|
||||
def _multiclass_prob_sk_metric(preds, target, normalize=None):
|
||||
def _sk_cm_multiclass_prob(preds, target, normalize=None):
|
||||
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
|
||||
|
||||
|
||||
def _multiclass_sk_metric(preds, target, normalize=None):
|
||||
def _sk_cm_multiclass(preds, target, normalize=None):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
|
||||
|
||||
|
||||
def _multidim_multiclass_prob_sk_metric(preds, target, normalize=None):
|
||||
def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None):
|
||||
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize)
|
||||
|
||||
|
||||
def _multidim_multiclass_sk_metric(preds, target, normalize=None):
|
||||
def _sk_cm_multidim_multiclass(preds, target, normalize=None):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
|
@ -79,55 +76,53 @@ def _multidim_multiclass_sk_metric(preds, target, normalize=None):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("normalize", ['true', 'pred', 'all', None])
|
||||
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 2),
|
||||
(_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric, 2),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric, 2),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric, 2),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric, NUM_CLASSES),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric, NUM_CLASSES),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_multidim_multiclass_prob_sk_metric,
|
||||
NUM_CLASSES
|
||||
),
|
||||
(
|
||||
_multidim_multiclass_inputs.preds,
|
||||
_multidim_multiclass_inputs.target,
|
||||
_multidim_multiclass_sk_metric,
|
||||
NUM_CLASSES
|
||||
)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric, num_classes",
|
||||
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2),
|
||||
(_input_binary.preds, _input_binary.target, _sk_cm_binary, 2),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, 2),
|
||||
(_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, 2),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES)]
|
||||
)
|
||||
class TestConfusionMatrix(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_confusion_matrix(self, normalize, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
|
||||
self.run_class_metric_test(ddp=ddp,
|
||||
preds=preds,
|
||||
target=target,
|
||||
metric_class=ConfusionMatrix,
|
||||
sk_metric=partial(sk_metric, normalize=normalize),
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_args={"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"normalize": normalize}
|
||||
)
|
||||
self.run_class_metric_test(
|
||||
ddp=ddp,
|
||||
preds=preds,
|
||||
target=target,
|
||||
metric_class=ConfusionMatrix,
|
||||
sk_metric=partial(sk_metric, normalize=normalize),
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_args={
|
||||
"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"normalize": normalize
|
||||
}
|
||||
)
|
||||
|
||||
def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes):
|
||||
self.run_functional_metric_test(preds,
|
||||
target,
|
||||
metric_functional=confusion_matrix,
|
||||
sk_metric=partial(sk_metric, normalize=normalize),
|
||||
metric_args={"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"normalize": normalize}
|
||||
)
|
||||
self.run_functional_metric_test(
|
||||
preds,
|
||||
target,
|
||||
metric_functional=confusion_matrix,
|
||||
sk_metric=partial(sk_metric, normalize=normalize),
|
||||
metric_args={
|
||||
"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"normalize": normalize
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_warning_on_nan(tmpdir):
|
||||
preds = torch.randint(3, size=(20,))
|
||||
target = torch.randint(3, size=(20,))
|
||||
preds = torch.randint(3, size=(20, ))
|
||||
target = torch.randint(3, size=(20, ))
|
||||
|
||||
with pytest.warns(UserWarning, match='.* nan values found in confusion matrix have been replaced with zeros.'):
|
||||
confusion_matrix(preds, target, num_classes=5, normalize='true')
|
||||
|
|
|
@ -7,17 +7,14 @@ from sklearn.metrics import fbeta_score
|
|||
|
||||
from pytorch_lightning.metrics import F1, FBeta
|
||||
from pytorch_lightning.metrics.functional import f1, fbeta
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_inputs,
|
||||
_binary_prob_inputs,
|
||||
_multiclass_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
_multilabel_inputs,
|
||||
_multilabel_inputs_no_match,
|
||||
_multilabel_prob_inputs,
|
||||
)
|
||||
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
|
||||
from tests.metrics.classification.inputs import _input_multilabel_no_match as _input_mlb_nomatch
|
||||
from tests.metrics.classification.inputs import _input_multilabel_prob as _mlb_prob_inputs
|
||||
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
@ -82,28 +79,24 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0):
|
|||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric, num_classes, multilabel",
|
||||
[
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_fbeta_binary_prob, 1, False),
|
||||
(_binary_inputs.preds, _binary_inputs.target, _sk_fbeta_binary, 1, False),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
|
||||
(_multilabel_inputs_no_match.preds, _multilabel_inputs_no_match.target,
|
||||
_sk_fbeta_multilabel, NUM_CLASSES, True),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_fbeta_multiclass, NUM_CLASSES, False),
|
||||
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target,
|
||||
_sk_fbeta_multidim_multiclass_prob, NUM_CLASSES, False),
|
||||
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target,
|
||||
_sk_fbeta_multidim_multiclass, NUM_CLASSES, False),
|
||||
(_input_binary_prob.preds, _input_binary_prob.target, _sk_fbeta_binary_prob, 1, False),
|
||||
(_input_binary.preds, _input_binary.target, _sk_fbeta_binary, 1, False),
|
||||
(_mlb_prob_inputs.preds, _mlb_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True),
|
||||
(_input_mlb.preds, _input_mlb.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
|
||||
(_input_mlb_nomatch.preds, _input_mlb_nomatch.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False),
|
||||
(_input_mcls.preds, _input_mcls.target, _sk_fbeta_multiclass, NUM_CLASSES, False),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_fbeta_multidim_multiclass_prob, NUM_CLASSES, False),
|
||||
(_input_mdmc.preds, _input_mdmc.target, _sk_fbeta_multidim_multiclass, NUM_CLASSES, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None])
|
||||
@pytest.mark.parametrize("beta", [0.5, 1.0, 2.0])
|
||||
class TestFBeta(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_fbeta(
|
||||
self, preds, target, sk_metric, num_classes, multilabel, average, beta, ddp, dist_sync_on_step
|
||||
):
|
||||
def test_fbeta(self, preds, target, sk_metric, num_classes, multilabel, average, beta, ddp, dist_sync_on_step):
|
||||
metric_class = F1 if beta == 1.0 else partial(FBeta, beta=beta)
|
||||
|
||||
self.run_class_metric_test(
|
||||
|
@ -123,21 +116,21 @@ class TestFBeta(MetricTester):
|
|||
check_batch=False,
|
||||
)
|
||||
|
||||
def test_fbeta_functional(
|
||||
self, preds, target, sk_metric, num_classes, multilabel, average, beta
|
||||
):
|
||||
def test_fbeta_functional(self, preds, target, sk_metric, num_classes, multilabel, average, beta):
|
||||
metric_functional = f1 if beta == 1.0 else partial(fbeta, beta=beta)
|
||||
|
||||
self.run_functional_metric_test(preds=preds,
|
||||
target=target,
|
||||
metric_functional=metric_functional,
|
||||
sk_metric=partial(sk_metric, average=average, beta=beta),
|
||||
metric_args={
|
||||
"num_classes": num_classes,
|
||||
"average": average,
|
||||
"multilabel": multilabel,
|
||||
"threshold": THRESHOLD}
|
||||
)
|
||||
self.run_functional_metric_test(
|
||||
preds=preds,
|
||||
target=target,
|
||||
metric_functional=metric_functional,
|
||||
sk_metric=partial(sk_metric, average=average, beta=beta),
|
||||
metric_args={
|
||||
"num_classes": num_classes,
|
||||
"average": average,
|
||||
"multilabel": multilabel,
|
||||
"threshold": THRESHOLD
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [
|
||||
|
|
|
@ -5,18 +5,15 @@ from sklearn.metrics import hamming_loss as sk_hamming_loss
|
|||
from pytorch_lightning.metrics import HammingDistance
|
||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
||||
from pytorch_lightning.metrics.functional import hamming_distance
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_inputs,
|
||||
_binary_prob_inputs,
|
||||
_multiclass_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
_multilabel_inputs,
|
||||
_multilabel_multidim_inputs,
|
||||
_multilabel_multidim_prob_inputs,
|
||||
_multilabel_prob_inputs,
|
||||
)
|
||||
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
|
||||
from tests.metrics.classification.inputs import _input_multilabel_multidim as _input_mlmd
|
||||
from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
|
||||
from tests.metrics.utils import MetricTester, THRESHOLD
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
@ -33,19 +30,20 @@ def _sk_hamming_loss(preds, target):
|
|||
@pytest.mark.parametrize(
|
||||
"preds, target",
|
||||
[
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target),
|
||||
(_binary_inputs.preds, _binary_inputs.target),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target),
|
||||
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target),
|
||||
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target),
|
||||
(_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target),
|
||||
(_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target),
|
||||
(_input_binary_prob.preds, _input_binary_prob.target),
|
||||
(_input_binary.preds, _input_binary.target),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target),
|
||||
(_input_mlb.preds, _input_mlb.target),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target),
|
||||
(_input_mcls.preds, _input_mcls.target),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target),
|
||||
(_input_mdmc.preds, _input_mdmc.target),
|
||||
(_input_mlmd_prob.preds, _input_mlmd_prob.target),
|
||||
(_input_mlmd.preds, _input_mlmd.target),
|
||||
],
|
||||
)
|
||||
class TestHammingDistance(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
|
||||
def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target):
|
||||
|
@ -71,7 +69,7 @@ class TestHammingDistance(MetricTester):
|
|||
|
||||
@pytest.mark.parametrize("threshold", [1.5])
|
||||
def test_wrong_params(threshold):
|
||||
preds, target = _multiclass_prob_inputs.preds, _multiclass_prob_inputs.target
|
||||
preds, target = _input_mcls_prob.preds, _input_mcls_prob.target
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ham_dist = HammingDistance(threshold=threshold)
|
||||
|
|
|
@ -4,16 +4,16 @@ from torch import rand, randint
|
|||
|
||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
|
||||
from pytorch_lightning.metrics.utils import select_topk, to_onehot
|
||||
from tests.metrics.classification.inputs import _binary_inputs as _bin
|
||||
from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob
|
||||
from tests.metrics.classification.inputs import _multiclass_inputs as _mc
|
||||
from tests.metrics.classification.inputs import _multiclass_prob_inputs as _mc_prob
|
||||
from tests.metrics.classification.inputs import _multidim_multiclass_inputs as _mdmc
|
||||
from tests.metrics.classification.inputs import _multidim_multiclass_prob_inputs as _mdmc_prob
|
||||
from tests.metrics.classification.inputs import _multilabel_inputs as _ml
|
||||
from tests.metrics.classification.inputs import _multilabel_multidim_inputs as _mlmd
|
||||
from tests.metrics.classification.inputs import _multilabel_multidim_prob_inputs as _mlmd_prob
|
||||
from tests.metrics.classification.inputs import _multilabel_prob_inputs as _ml_prob
|
||||
from tests.metrics.classification.inputs import _input_binary as _bin
|
||||
from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass as _mc
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _mc_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass as _mdmc
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _mdmc_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel as _ml
|
||||
from tests.metrics.classification.inputs import _input_multilabel_multidim as _mlmd
|
||||
from tests.metrics.classification.inputs import _input_multilabel_multidim_prob as _mlmd_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel_prob as _ml_prob
|
||||
from tests.metrics.classification.inputs import Input
|
||||
from tests.metrics.utils import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, THRESHOLD
|
||||
|
||||
|
@ -155,6 +155,7 @@ def _mlmd_prob_to_mc_preds_tr(x):
|
|||
],
|
||||
)
|
||||
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
|
||||
|
||||
def __get_data_type_enum(str_exp_mode):
|
||||
return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)
|
||||
|
||||
|
@ -204,7 +205,7 @@ def test_threshold():
|
|||
|
||||
@pytest.mark.parametrize("threshold", [-0.5, 0.0, 1.0, 1.5])
|
||||
def test_incorrect_threshold(threshold):
|
||||
preds, target = rand(size=(7,)), randint(high=2, size=(7,))
|
||||
preds, target = rand(size=(7, )), randint(high=2, size=(7, ))
|
||||
with pytest.raises(ValueError):
|
||||
_input_format_classification(preds, target, threshold=threshold)
|
||||
|
||||
|
@ -213,21 +214,21 @@ def test_incorrect_threshold(threshold):
|
|||
"preds, target, num_classes, is_multiclass",
|
||||
[
|
||||
# Target not integer
|
||||
(randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), None, None),
|
||||
(randint(high=2, size=(7, )), randint(high=2, size=(7, )).float(), None, None),
|
||||
# Target negative
|
||||
(randint(high=2, size=(7,)), -randint(high=2, size=(7,)), None, None),
|
||||
(randint(high=2, size=(7, )), -randint(high=2, size=(7, )), None, None),
|
||||
# Preds negative integers
|
||||
(-randint(high=2, size=(7,)), randint(high=2, size=(7,)), None, None),
|
||||
(-randint(high=2, size=(7, )), randint(high=2, size=(7, )), None, None),
|
||||
# Negative probabilities
|
||||
(-rand(size=(7,)), randint(high=2, size=(7,)), None, None),
|
||||
(-rand(size=(7, )), randint(high=2, size=(7, )), None, None),
|
||||
# is_multiclass=False and target > 1
|
||||
(rand(size=(7,)), randint(low=2, high=4, size=(7,)), None, False),
|
||||
(rand(size=(7, )), randint(low=2, high=4, size=(7, )), None, False),
|
||||
# is_multiclass=False and preds integers with > 1
|
||||
(randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), None, False),
|
||||
(randint(low=2, high=4, size=(7, )), randint(high=2, size=(7, )), None, False),
|
||||
# Wrong batch size
|
||||
(randint(high=2, size=(8,)), randint(high=2, size=(7,)), None, None),
|
||||
(randint(high=2, size=(8, )), randint(high=2, size=(7, )), None, None),
|
||||
# Completely wrong shape
|
||||
(randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), None, None),
|
||||
(randint(high=2, size=(7, )), randint(high=2, size=(7, 4)), None, None),
|
||||
# Same #dims, different shape
|
||||
(randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), None, None),
|
||||
# Same shape and preds floats, target not binary
|
||||
|
@ -237,11 +238,11 @@ def test_incorrect_threshold(threshold):
|
|||
# #dims in preds = 1 + #dims in target, preds not float
|
||||
(randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None),
|
||||
# is_multiclass=False, with C dimension > 2
|
||||
(_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE,)), None, False),
|
||||
(_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE, )), None, False),
|
||||
# Probs of multiclass preds do not sum up to 1
|
||||
(rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None),
|
||||
# Max target larger or equal to C dimension
|
||||
(_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE,)), None, None),
|
||||
(_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, )), None, None),
|
||||
# C dimension not equal to num_classes
|
||||
(_mc_prob.preds[0], _mc_prob.target[0], NUM_CLASSES + 1, None),
|
||||
# Max target larger than num_classes (with #dim preds = 1 + #dims target)
|
||||
|
@ -251,7 +252,7 @@ def test_incorrect_threshold(threshold):
|
|||
# Max preds larger than num_classes (with #dim preds = #dims target)
|
||||
(randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None),
|
||||
# Num_classes=1, but is_multiclass not false
|
||||
(randint(high=2, size=(7,)), randint(high=2, size=(7,)), 1, None),
|
||||
(randint(high=2, size=(7, )), randint(high=2, size=(7, )), 1, None),
|
||||
# is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes
|
||||
(randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False),
|
||||
# Multilabel input with implied class dimension != num_classes
|
||||
|
@ -259,12 +260,12 @@ def test_incorrect_threshold(threshold):
|
|||
# Multilabel input with is_multiclass=True, but num_classes != 2 (or None)
|
||||
(rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True),
|
||||
# Binary input, num_classes > 2
|
||||
(rand(size=(7,)), randint(high=2, size=(7,)), 4, None),
|
||||
(rand(size=(7, )), randint(high=2, size=(7, )), 4, None),
|
||||
# Binary input, num_classes == 2 and is_multiclass not True
|
||||
(rand(size=(7,)), randint(high=2, size=(7,)), 2, None),
|
||||
(rand(size=(7,)), randint(high=2, size=(7,)), 2, False),
|
||||
(rand(size=(7, )), randint(high=2, size=(7, )), 2, None),
|
||||
(rand(size=(7, )), randint(high=2, size=(7, )), 2, False),
|
||||
# Binary input, num_classes == 1 and is_multiclass=True
|
||||
(rand(size=(7,)), randint(high=2, size=(7,)), 1, True),
|
||||
(rand(size=(7, )), randint(high=2, size=(7, )), 1, True),
|
||||
],
|
||||
)
|
||||
def test_incorrect_inputs(preds, target, num_classes, is_multiclass):
|
||||
|
|
|
@ -7,16 +7,13 @@ from sklearn.metrics import jaccard_score as sk_jaccard_score
|
|||
|
||||
from pytorch_lightning.metrics.classification.iou import IoU
|
||||
from pytorch_lightning.metrics.functional.iou import iou
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_inputs,
|
||||
_binary_prob_inputs,
|
||||
_multiclass_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
_multilabel_inputs,
|
||||
_multilabel_prob_inputs,
|
||||
)
|
||||
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
|
||||
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
|
||||
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
|
||||
|
||||
|
||||
|
@ -77,52 +74,50 @@ def _sk_iou_multidim_multiclass(preds, target, average=None):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("reduction", ['elementwise_mean', 'none'])
|
||||
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_iou_binary_prob, 2),
|
||||
(_binary_inputs.preds, _binary_inputs.target, _sk_iou_binary, 2),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_iou_multilabel_prob, 2),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_iou_multilabel, 2),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_iou_multiclass_prob, NUM_CLASSES),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_iou_multiclass, NUM_CLASSES),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_sk_iou_multidim_multiclass_prob,
|
||||
NUM_CLASSES
|
||||
),
|
||||
(
|
||||
_multidim_multiclass_inputs.preds,
|
||||
_multidim_multiclass_inputs.target,
|
||||
_sk_iou_multidim_multiclass,
|
||||
NUM_CLASSES
|
||||
)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric, num_classes",
|
||||
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_iou_binary_prob, 2),
|
||||
(_input_binary.preds, _input_binary.target, _sk_iou_binary, 2),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_iou_multilabel_prob, 2),
|
||||
(_input_mlb.preds, _input_mlb.target, _sk_iou_multilabel, 2),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_iou_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mcls.preds, _input_mcls.target, _sk_iou_multiclass, NUM_CLASSES),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_iou_multidim_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mdmc.preds, _input_mdmc.target, _sk_iou_multidim_multiclass, NUM_CLASSES)]
|
||||
)
|
||||
class TestIoU(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_confusion_matrix(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
|
||||
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
|
||||
self.run_class_metric_test(ddp=ddp,
|
||||
preds=preds,
|
||||
target=target,
|
||||
metric_class=IoU,
|
||||
sk_metric=partial(sk_metric, average=average),
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_args={"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"reduction": reduction}
|
||||
)
|
||||
self.run_class_metric_test(
|
||||
ddp=ddp,
|
||||
preds=preds,
|
||||
target=target,
|
||||
metric_class=IoU,
|
||||
sk_metric=partial(sk_metric, average=average),
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_args={
|
||||
"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"reduction": reduction
|
||||
}
|
||||
)
|
||||
|
||||
def test_confusion_matrix_functional(self, reduction, preds, target, sk_metric, num_classes):
|
||||
average = 'macro' if reduction == 'elementwise_mean' else None # convert tags
|
||||
self.run_functional_metric_test(preds,
|
||||
target,
|
||||
metric_functional=iou,
|
||||
sk_metric=partial(sk_metric, average=average),
|
||||
metric_args={"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"reduction": reduction}
|
||||
)
|
||||
self.run_functional_metric_test(
|
||||
preds,
|
||||
target,
|
||||
metric_functional=iou,
|
||||
sk_metric=partial(sk_metric, average=average),
|
||||
metric_args={
|
||||
"num_classes": num_classes,
|
||||
"threshold": THRESHOLD,
|
||||
"reduction": reduction
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
|
||||
|
@ -148,35 +143,38 @@ def test_iou(half_ones, reduction, ignore_index, expected):
|
|||
|
||||
|
||||
# test `absent_score`
|
||||
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [
|
||||
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
|
||||
# scores the function can return ([0., 1.] range, inclusive).
|
||||
# 2 classes, class 0 is correct everywhere, class 1 is absent.
|
||||
pytest.param([0], [0], None, -1., 2, [1., -1.]),
|
||||
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
|
||||
# absent_score not applied if only class 0 is present and it's the only class.
|
||||
pytest.param([0], [0], None, -1., 1, [1.]),
|
||||
# 2 classes, class 1 is correct everywhere, class 0 is absent.
|
||||
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
|
||||
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
|
||||
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
|
||||
pytest.param([1], [1], 0, -1., 2, [1.0]),
|
||||
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
|
||||
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
|
||||
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
|
||||
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
|
||||
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
|
||||
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
|
||||
# Sanity checks with absent_score of 1.0.
|
||||
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
|
||||
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'],
|
||||
[
|
||||
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
|
||||
# scores the function can return ([0., 1.] range, inclusive).
|
||||
# 2 classes, class 0 is correct everywhere, class 1 is absent.
|
||||
pytest.param([0], [0], None, -1., 2, [1., -1.]),
|
||||
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
|
||||
# absent_score not applied if only class 0 is present and it's the only class.
|
||||
pytest.param([0], [0], None, -1., 1, [1.]),
|
||||
# 2 classes, class 1 is correct everywhere, class 0 is absent.
|
||||
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
|
||||
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
|
||||
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
|
||||
pytest.param([1], [1], 0, -1., 2, [1.0]),
|
||||
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
|
||||
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
|
||||
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
|
||||
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
|
||||
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
|
||||
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
|
||||
# Sanity checks with absent_score of 1.0.
|
||||
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
|
||||
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
|
||||
]
|
||||
)
|
||||
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
|
@ -191,19 +189,22 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
|
|||
|
||||
# example data taken from
|
||||
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
|
||||
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
|
||||
# Ignoring an index outside of [0, num_classes-1] should have no effect.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
# Ignoring a valid index drops only that index from the result.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
|
||||
# When reducing to mean or sum, the ignored index does not contribute to the output.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'],
|
||||
[
|
||||
# Ignoring an index outside of [0, num_classes-1] should have no effect.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
# Ignoring a valid index drops only that index from the result.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
|
||||
# When reducing to mean or sum, the ignored index does not contribute to the output.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
|
||||
]
|
||||
)
|
||||
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
|
|
|
@ -9,12 +9,13 @@ from sklearn.metrics import precision_score, recall_score
|
|||
from pytorch_lightning.metrics import Metric, Precision, Recall
|
||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
||||
from pytorch_lightning.metrics.functional import precision, precision_recall, recall
|
||||
from tests.metrics.classification.inputs import _binary_inputs, _binary_prob_inputs, _multiclass_inputs
|
||||
from tests.metrics.classification.inputs import _multiclass_prob_inputs as _mc_prob
|
||||
from tests.metrics.classification.inputs import _multidim_multiclass_inputs as _mdmc
|
||||
from tests.metrics.classification.inputs import _multidim_multiclass_prob_inputs as _mdmc_prob
|
||||
from tests.metrics.classification.inputs import _multilabel_inputs as _ml
|
||||
from tests.metrics.classification.inputs import _multilabel_prob_inputs as _ml_prob
|
||||
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
|
||||
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
|
||||
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
@ -45,7 +46,9 @@ def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, i
|
|||
return sk_scores
|
||||
|
||||
|
||||
def _sk_prec_recall_mdmc(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average):
|
||||
def _sk_prec_recall_multidim_multiclass(
|
||||
preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average
|
||||
):
|
||||
preds, target, _ = _input_format_classification(
|
||||
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
|
||||
)
|
||||
|
@ -89,8 +92,8 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
|
|||
|
||||
with pytest.raises(ValueError, match=match_str):
|
||||
fn_metric(
|
||||
_binary_inputs.preds[0],
|
||||
_binary_inputs.target[0],
|
||||
_input_binary.preds[0],
|
||||
_input_binary.target[0],
|
||||
average=average,
|
||||
mdmc_average=mdmc_average,
|
||||
num_classes=num_classes,
|
||||
|
@ -99,8 +102,8 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
|
|||
|
||||
with pytest.raises(ValueError, match=match_str):
|
||||
precision_recall(
|
||||
_binary_inputs.preds[0],
|
||||
_binary_inputs.target[0],
|
||||
_input_binary.preds[0],
|
||||
_input_binary.target[0],
|
||||
average=average,
|
||||
mdmc_average=mdmc_average,
|
||||
num_classes=num_classes,
|
||||
|
@ -156,19 +159,26 @@ def test_no_support(metric_class, metric_fn):
|
|||
@pytest.mark.parametrize(
|
||||
"preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper",
|
||||
[
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, 1, None, None, _sk_prec_recall),
|
||||
(_binary_inputs.preds, _binary_inputs.target, 1, False, None, _sk_prec_recall),
|
||||
(_ml_prob.preds, _ml_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
|
||||
(_ml.preds, _ml.target, NUM_CLASSES, False, None, _sk_prec_recall),
|
||||
(_mc_prob.preds, _mc_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, NUM_CLASSES, None, None, _sk_prec_recall),
|
||||
(_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc),
|
||||
(_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "global", _sk_prec_recall_mdmc),
|
||||
(_mdmc.preds, _mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc),
|
||||
(_mdmc_prob.preds, _mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_mdmc),
|
||||
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall),
|
||||
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
|
||||
(_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
|
||||
(_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall),
|
||||
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass),
|
||||
(
|
||||
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global",
|
||||
_sk_prec_recall_multidim_multiclass
|
||||
),
|
||||
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass),
|
||||
(
|
||||
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise",
|
||||
_sk_prec_recall_multidim_multiclass
|
||||
),
|
||||
],
|
||||
)
|
||||
class TestPrecisionRecall(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_precision_recall_class(
|
||||
|
@ -278,11 +288,15 @@ def test_precision_recall_joint(average):
|
|||
which are already tested thoroughly.
|
||||
"""
|
||||
|
||||
precision_result = precision(_mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES)
|
||||
recall_result = recall(_mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES)
|
||||
precision_result = precision(
|
||||
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
|
||||
)
|
||||
recall_result = recall(
|
||||
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
|
||||
)
|
||||
|
||||
prec_recall_result = precision_recall(
|
||||
_mc_prob.preds[0], _mc_prob.target[0], average=average, num_classes=NUM_CLASSES
|
||||
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
|
||||
)
|
||||
|
||||
assert torch.equal(precision_result, prec_recall_result[0])
|
||||
|
|
|
@ -3,71 +3,63 @@ from functools import partial
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve
|
||||
from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve
|
||||
|
||||
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_prob_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
)
|
||||
from tests.metrics.classification.inputs import _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.utils import MetricTester, NUM_CLASSES
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def sk_precision_recall_curve(y_true, probas_pred, num_classes=1):
|
||||
def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1):
|
||||
""" Adjusted comparison function that can also handles multiclass """
|
||||
if num_classes == 1:
|
||||
return _sk_precision_recall_curve(y_true, probas_pred)
|
||||
return sk_precision_recall_curve(y_true, probas_pred)
|
||||
|
||||
precision, recall, thresholds = [], [], []
|
||||
for i in range(num_classes):
|
||||
y_true_temp = np.zeros_like(y_true)
|
||||
y_true_temp[y_true == i] = 1
|
||||
res = _sk_precision_recall_curve(y_true_temp, probas_pred[:, i])
|
||||
res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i])
|
||||
precision.append(res[0])
|
||||
recall.append(res[1])
|
||||
thresholds.append(res[2])
|
||||
return precision, recall, thresholds
|
||||
|
||||
|
||||
def _binary_prob_sk_metric(preds, target, num_classes=1):
|
||||
def _sk_prec_rc_binary_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
def _multiclass_prob_sk_metric(preds, target, num_classes=1):
|
||||
def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.reshape(-1, num_classes).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
def _multidim_multiclass_prob_sk_metric(preds, target, num_classes=1):
|
||||
def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1),
|
||||
(
|
||||
_multiclass_prob_inputs.preds,
|
||||
_multiclass_prob_inputs.target,
|
||||
_multiclass_prob_sk_metric,
|
||||
NUM_CLASSES),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_multidim_multiclass_prob_sk_metric,
|
||||
NUM_CLASSES
|
||||
),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric, num_classes", [
|
||||
(_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES),
|
||||
]
|
||||
)
|
||||
class TestPrecisionRecallCurve(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
|
||||
|
@ -91,9 +83,10 @@ class TestPrecisionRecallCurve(MetricTester):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
|
||||
pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
['pred', 'target', 'expected_p', 'expected_r', 'expected_t'],
|
||||
[pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])]
|
||||
)
|
||||
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
|
||||
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
|
||||
assert p.size() == r.size()
|
||||
|
|
|
@ -3,71 +3,63 @@ from functools import partial
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import roc_curve as _sk_roc_curve
|
||||
from sklearn.metrics import roc_curve as sk_roc_curve
|
||||
|
||||
from pytorch_lightning.metrics.classification.roc import ROC
|
||||
from pytorch_lightning.metrics.functional.roc import roc
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_prob_inputs,
|
||||
_multiclass_prob_inputs,
|
||||
_multidim_multiclass_prob_inputs,
|
||||
)
|
||||
from tests.metrics.classification.inputs import _input_binary_prob
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.utils import MetricTester, NUM_CLASSES
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def sk_roc_curve(y_true, probas_pred, num_classes=1):
|
||||
def _sk_roc_curve(y_true, probas_pred, num_classes=1):
|
||||
""" Adjusted comparison function that can also handles multiclass """
|
||||
if num_classes == 1:
|
||||
return _sk_roc_curve(y_true, probas_pred, drop_intermediate=False)
|
||||
return sk_roc_curve(y_true, probas_pred, drop_intermediate=False)
|
||||
|
||||
fpr, tpr, thresholds = [], [], []
|
||||
for i in range(num_classes):
|
||||
y_true_temp = np.zeros_like(y_true)
|
||||
y_true_temp[y_true == i] = 1
|
||||
res = _sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False)
|
||||
res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False)
|
||||
fpr.append(res[0])
|
||||
tpr.append(res[1])
|
||||
thresholds.append(res[2])
|
||||
return fpr, tpr, thresholds
|
||||
|
||||
|
||||
def _binary_prob_sk_metric(preds, target, num_classes=1):
|
||||
def _sk_roc_binary_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
def _multiclass_prob_sk_metric(preds, target, num_classes=1):
|
||||
def _sk_roc_multiclass_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.reshape(-1, num_classes).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
def _multidim_multiclass_prob_sk_metric(preds, target, num_classes=1):
|
||||
def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1):
|
||||
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1),
|
||||
(
|
||||
_multiclass_prob_inputs.preds,
|
||||
_multiclass_prob_inputs.target,
|
||||
_multiclass_prob_sk_metric,
|
||||
NUM_CLASSES),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_multidim_multiclass_prob_sk_metric,
|
||||
NUM_CLASSES
|
||||
),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric, num_classes", [
|
||||
(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES),
|
||||
]
|
||||
)
|
||||
class TestROC(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
|
||||
|
|
|
@ -9,12 +9,12 @@ from sklearn.metrics import multilabel_confusion_matrix
|
|||
from pytorch_lightning.metrics import StatScores
|
||||
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
||||
from pytorch_lightning.metrics.functional import stat_scores
|
||||
from tests.metrics.classification.inputs import _binary_inputs, _binary_prob_inputs, _multiclass_inputs
|
||||
from tests.metrics.classification.inputs import _multiclass_prob_inputs as _mc_prob
|
||||
from tests.metrics.classification.inputs import _multidim_multiclass_inputs as _mdmc
|
||||
from tests.metrics.classification.inputs import _multidim_multiclass_prob_inputs as _mdmc_prob
|
||||
from tests.metrics.classification.inputs import _multilabel_inputs
|
||||
from tests.metrics.classification.inputs import _multilabel_prob_inputs as _ml_prob
|
||||
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
|
||||
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mccls_prob
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
|
||||
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.metrics.classification.inputs import _input_multilabel as _input_mcls
|
||||
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
|
||||
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
@ -57,7 +57,7 @@ def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_in
|
|||
return sk_stats
|
||||
|
||||
|
||||
def _sk_stat_scores_mdmc(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index, top_k):
|
||||
def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index, top_k):
|
||||
preds, target, _ = _input_format_classification(
|
||||
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k
|
||||
)
|
||||
|
@ -83,13 +83,13 @@ def _sk_stat_scores_mdmc(preds, target, reduce, mdmc_reduce, num_classes, is_mul
|
|||
@pytest.mark.parametrize(
|
||||
"reduce, mdmc_reduce, num_classes, inputs, ignore_index",
|
||||
[
|
||||
["unknown", None, None, _binary_inputs, None],
|
||||
["micro", "unknown", None, _binary_inputs, None],
|
||||
["macro", None, None, _binary_inputs, None],
|
||||
["micro", None, None, _mdmc_prob, None],
|
||||
["micro", None, None, _binary_prob_inputs, 0],
|
||||
["micro", None, None, _mc_prob, NUM_CLASSES],
|
||||
["micro", None, NUM_CLASSES, _mc_prob, NUM_CLASSES],
|
||||
["unknown", None, None, _input_binary, None],
|
||||
["micro", "unknown", None, _input_binary, None],
|
||||
["macro", None, None, _input_binary, None],
|
||||
["micro", None, None, _input_mdmc_prob, None],
|
||||
["micro", None, None, _input_binary_prob, 0],
|
||||
["micro", None, None, _input_mccls_prob, NUM_CLASSES],
|
||||
["micro", None, NUM_CLASSES, _input_mccls_prob, NUM_CLASSES],
|
||||
],
|
||||
)
|
||||
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
|
||||
|
@ -120,18 +120,21 @@ def test_wrong_threshold():
|
|||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_fn, mdmc_reduce, num_classes, is_multiclass, top_k",
|
||||
[
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_stat_scores, None, 1, None, None),
|
||||
(_binary_inputs.preds, _binary_inputs.target, _sk_stat_scores, None, 1, False, None),
|
||||
(_ml_prob.preds, _ml_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_ml_prob.preds, _ml_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
|
||||
(_mc_prob.preds, _mc_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_mc_prob.preds, _mc_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_mdmc.preds, _mdmc.target, _sk_stat_scores_mdmc, "samplewise", NUM_CLASSES, None, None),
|
||||
(_mdmc_prob.preds, _mdmc_prob.target, _sk_stat_scores_mdmc, "samplewise", NUM_CLASSES, None, None),
|
||||
(_mdmc.preds, _mdmc.target, _sk_stat_scores_mdmc, "global", NUM_CLASSES, None, None),
|
||||
(_mdmc_prob.preds, _mdmc_prob.target, _sk_stat_scores_mdmc, "global", NUM_CLASSES, None, None),
|
||||
(_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None),
|
||||
(_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
|
||||
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
|
||||
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
|
||||
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None),
|
||||
(
|
||||
_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None,
|
||||
None
|
||||
),
|
||||
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None),
|
||||
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None),
|
||||
],
|
||||
)
|
||||
class TestStatScores(MetricTester):
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
|
|||
# if you fix the array inside the function, you'd also have fix the shape,
|
||||
# because when the array changes, you also have to fix the shape
|
||||
seed_everything(0)
|
||||
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
|
||||
pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100
|
||||
target = torch.tensor([0, 1] * 50, dtype=torch.int)
|
||||
if sample_weight is not None:
|
||||
sample_weight = torch.ones_like(pred) * sample_weight
|
||||
|
@ -73,9 +73,9 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
|
|||
assert isinstance(tps, torch.Tensor)
|
||||
assert isinstance(fps, torch.Tensor)
|
||||
assert isinstance(thresh, torch.Tensor)
|
||||
assert tps.shape == (exp_shape,)
|
||||
assert fps.shape == (exp_shape,)
|
||||
assert thresh.shape == (exp_shape,)
|
||||
assert tps.shape == (exp_shape, )
|
||||
assert fps.shape == (exp_shape, )
|
||||
assert thresh.shape == (exp_shape, )
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
|
|
|
@ -46,19 +46,19 @@ def test_multi_batch_image_gradients():
|
|||
image = torch.stack([single_channel_img for _ in range(BATCH_SIZE)], dim=0)
|
||||
|
||||
true_dy = [
|
||||
[5., 5., 5., 5., 5., ],
|
||||
[5., 5., 5., 5., 5., ],
|
||||
[5., 5., 5., 5., 5., ],
|
||||
[5., 5., 5., 5., 5., ],
|
||||
[0., 0., 0., 0., 0., ]
|
||||
[5., 5., 5., 5., 5.],
|
||||
[5., 5., 5., 5., 5.],
|
||||
[5., 5., 5., 5., 5.],
|
||||
[5., 5., 5., 5., 5.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
]
|
||||
|
||||
true_dx = [
|
||||
[1., 1., 1., 1., 0., ],
|
||||
[1., 1., 1., 1., 0., ],
|
||||
[1., 1., 1., 1., 0., ],
|
||||
[1., 1., 1., 1., 0., ],
|
||||
[1., 1., 1., 1., 0., ]
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
]
|
||||
true_dy = torch.Tensor(true_dy)
|
||||
true_dx = torch.Tensor(true_dx)
|
||||
|
@ -85,19 +85,19 @@ def test_image_gradients():
|
|||
image = torch.reshape(image, (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH))
|
||||
|
||||
true_dy = [
|
||||
[5., 5., 5., 5., 5., ],
|
||||
[5., 5., 5., 5., 5., ],
|
||||
[5., 5., 5., 5., 5., ],
|
||||
[5., 5., 5., 5., 5., ],
|
||||
[0., 0., 0., 0., 0., ]
|
||||
[5., 5., 5., 5., 5.],
|
||||
[5., 5., 5., 5., 5.],
|
||||
[5., 5., 5., 5., 5.],
|
||||
[5., 5., 5., 5., 5.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
]
|
||||
|
||||
true_dx = [
|
||||
[1., 1., 1., 1., 0., ],
|
||||
[1., 1., 1., 1., 0., ],
|
||||
[1., 1., 1., 1., 0., ],
|
||||
[1., 1., 1., 1., 0., ],
|
||||
[1., 1., 1., 1., 0., ]
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
]
|
||||
|
||||
true_dy = torch.Tensor(true_dy)
|
||||
|
|
|
@ -15,7 +15,6 @@ REFERENCE2 = tuple(
|
|||
)
|
||||
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split())
|
||||
|
||||
|
||||
# example taken from
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
|
||||
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
|
||||
|
@ -44,7 +43,10 @@ smooth_func = SmoothingFunction().method2
|
|||
)
|
||||
def test_bleu_score(weights, n_gram, smooth_func, smooth):
|
||||
nltk_output = sentence_bleu(
|
||||
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func
|
||||
[REFERENCE1, REFERENCE2, REFERENCE3],
|
||||
HYPOTHESIS1,
|
||||
weights=weights,
|
||||
smoothing_function=smooth_func,
|
||||
)
|
||||
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
|
||||
assert torch.allclose(pl_output, torch.tensor(nltk_output))
|
||||
|
|
|
@ -16,15 +16,13 @@ def test_reduce():
|
|||
|
||||
|
||||
def test_class_reduce():
|
||||
num = torch.randint(1, 10, (100,)).float()
|
||||
denom = torch.randint(10, 20, (100,)).float()
|
||||
weights = torch.randint(1, 100, (100,)).float()
|
||||
num = torch.randint(1, 10, (100, )).float()
|
||||
denom = torch.randint(10, 20, (100, )).float()
|
||||
weights = torch.randint(1, 100, (100, )).float()
|
||||
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'micro'),
|
||||
torch.sum(num) / torch.sum(denom))
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'macro'),
|
||||
torch.mean(num / denom))
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'weighted'),
|
||||
torch.sum(num / denom * (weights / torch.sum(weights))))
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'none'),
|
||||
num / denom)
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'micro'), torch.sum(num) / torch.sum(denom))
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'macro'), torch.mean(num / denom))
|
||||
assert torch.allclose(
|
||||
class_reduce(num, denom, weights, 'weighted'), torch.sum(num / denom * (weights / torch.sum(weights)))
|
||||
)
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'none'), num / denom)
|
||||
|
|
|
@ -13,13 +13,11 @@ def test_against_sklearn(similarity, reduction):
|
|||
|
||||
batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions
|
||||
|
||||
pl_dist = embedding_similarity(batch, similarity=similarity,
|
||||
reduction=reduction, zero_diagonal=False)
|
||||
pl_dist = embedding_similarity(batch, similarity=similarity, reduction=reduction, zero_diagonal=False)
|
||||
|
||||
def sklearn_embedding_distance(batch, similarity, reduction):
|
||||
|
||||
metric_func = {'cosine': pairwise.cosine_similarity,
|
||||
'dot': pairwise.linear_kernel}[similarity]
|
||||
metric_func = {'cosine': pairwise.cosine_similarity, 'dot': pairwise.linear_kernel}[similarity]
|
||||
|
||||
dist = metric_func(batch, batch)
|
||||
if reduction == 'mean':
|
||||
|
@ -28,8 +26,7 @@ def test_against_sklearn(similarity, reduction):
|
|||
return dist.sum(axis=-1)
|
||||
return dist
|
||||
|
||||
sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(),
|
||||
similarity=similarity, reduction=reduction)
|
||||
sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), similarity=similarity, reduction=reduction)
|
||||
sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device)
|
||||
|
||||
assert torch.allclose(sk_dist, pl_dist)
|
||||
|
|
|
@ -15,10 +15,14 @@ num_targets = 5
|
|||
|
||||
Input = namedtuple('Input', ["preds", "target"])
|
||||
|
||||
_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),)
|
||||
_single_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
)
|
||||
|
||||
_multi_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
)
|
||||
|
||||
|
||||
|
@ -43,6 +47,7 @@ def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score):
|
|||
],
|
||||
)
|
||||
class TestExplainedVariance(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step):
|
||||
|
@ -69,4 +74,4 @@ class TestExplainedVariance(MetricTester):
|
|||
def test_error_on_different_shape(metric_class=ExplainedVariance):
|
||||
metric = metric_class()
|
||||
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
|
||||
metric(torch.randn(100,), torch.randn(50,))
|
||||
metric(torch.randn(100, ), torch.randn(50, ))
|
||||
|
|
|
@ -17,10 +17,14 @@ num_targets = 5
|
|||
|
||||
Input = namedtuple('Input', ["preds", "target"])
|
||||
|
||||
_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),)
|
||||
_single_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
)
|
||||
|
||||
_multi_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
)
|
||||
|
||||
|
||||
|
@ -52,10 +56,12 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error):
|
|||
],
|
||||
)
|
||||
class TestMeanError(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_mean_error_class(self, preds, target, sk_metric, metric_class,
|
||||
metric_functional, sk_fn, ddp, dist_sync_on_step):
|
||||
def test_mean_error_class(
|
||||
self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step
|
||||
):
|
||||
self.run_class_metric_test(
|
||||
ddp=ddp,
|
||||
preds=preds,
|
||||
|
@ -78,4 +84,4 @@ class TestMeanError(MetricTester):
|
|||
def test_error_on_different_shape(metric_class):
|
||||
metric = metric_class()
|
||||
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
|
||||
metric(torch.randn(100,), torch.randn(50,))
|
||||
metric(torch.randn(100, ), torch.randn(50, ))
|
||||
|
|
|
@ -12,15 +12,13 @@ from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES
|
|||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
Input = namedtuple('Input', ["preds", "target"])
|
||||
|
||||
_inputs = [
|
||||
Input(
|
||||
preds=torch.randint(n_cls_pred, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
|
||||
target=torch.randint(n_cls_target, (NUM_BATCHES, BATCH_SIZE), dtype=torch.float),
|
||||
)
|
||||
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]
|
||||
) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]
|
||||
]
|
||||
|
||||
|
||||
|
@ -52,6 +50,7 @@ def _base_e_sk_metric(preds, target, data_range):
|
|||
],
|
||||
)
|
||||
class TestPSNR(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_psnr(self, preds, target, data_range, base, sk_metric, ddp, dist_sync_on_step):
|
||||
|
@ -61,7 +60,10 @@ class TestPSNR(MetricTester):
|
|||
target,
|
||||
PSNR,
|
||||
partial(sk_metric, data_range=data_range),
|
||||
metric_args={"data_range": data_range, "base": base},
|
||||
metric_args={
|
||||
"data_range": data_range,
|
||||
"base": base
|
||||
},
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
)
|
||||
|
||||
|
@ -71,5 +73,8 @@ class TestPSNR(MetricTester):
|
|||
target,
|
||||
psnr,
|
||||
partial(sk_metric, data_range=data_range),
|
||||
metric_args={"data_range": data_range, "base": base},
|
||||
metric_args={
|
||||
"data_range": data_range,
|
||||
"base": base
|
||||
},
|
||||
)
|
||||
|
|
|
@ -15,10 +15,14 @@ num_targets = 5
|
|||
|
||||
Input = namedtuple('Input', ["preds", "target"])
|
||||
|
||||
_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),)
|
||||
_single_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
)
|
||||
|
||||
_multi_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
)
|
||||
|
||||
|
||||
|
@ -50,6 +54,7 @@ def _multi_target_sk_metric(preds, target, adjusted, multioutput):
|
|||
],
|
||||
)
|
||||
class TestR2Score(MetricTester):
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step):
|
||||
|
@ -60,9 +65,7 @@ class TestR2Score(MetricTester):
|
|||
R2Score,
|
||||
partial(sk_metric, adjusted=adjusted, multioutput=multioutput),
|
||||
dist_sync_on_step,
|
||||
metric_args=dict(adjusted=adjusted,
|
||||
multioutput=multioutput,
|
||||
num_outputs=num_outputs),
|
||||
metric_args=dict(adjusted=adjusted, multioutput=multioutput, num_outputs=num_outputs),
|
||||
)
|
||||
|
||||
def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
|
||||
|
@ -71,39 +74,41 @@ class TestR2Score(MetricTester):
|
|||
target,
|
||||
r2score,
|
||||
partial(sk_metric, adjusted=adjusted, multioutput=multioutput),
|
||||
metric_args=dict(adjusted=adjusted,
|
||||
multioutput=multioutput),
|
||||
metric_args=dict(adjusted=adjusted, multioutput=multioutput),
|
||||
)
|
||||
|
||||
|
||||
def test_error_on_different_shape(metric_class=R2Score):
|
||||
metric = metric_class()
|
||||
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
|
||||
metric(torch.randn(100,), torch.randn(50,))
|
||||
metric(torch.randn(100, ), torch.randn(50, ))
|
||||
|
||||
|
||||
def test_error_on_multidim_tensors(metric_class=R2Score):
|
||||
metric = metric_class()
|
||||
with pytest.raises(ValueError, match=r'Expected both prediction and target to be 1D or 2D tensors,'
|
||||
r' but recevied tensors with dimension .'):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r'Expected both prediction and target to be 1D or 2D tensors,'
|
||||
r' but recevied tensors with dimension .'
|
||||
):
|
||||
metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5))
|
||||
|
||||
|
||||
def test_error_on_too_few_samples(metric_class=R2Score):
|
||||
metric = metric_class()
|
||||
with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'):
|
||||
metric(torch.randn(1,), torch.randn(1,))
|
||||
metric(torch.randn(1, ), torch.randn(1, ))
|
||||
|
||||
|
||||
def test_warning_on_too_large_adjusted(metric_class=R2Score):
|
||||
metric = metric_class(adjusted=10)
|
||||
|
||||
with pytest.warns(UserWarning,
|
||||
match="More independent regressions than datapoints in"
|
||||
" adjusted r2 score. Falls back to standard r2 score."):
|
||||
metric(torch.randn(10,), torch.randn(10,))
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="More independent regressions than datapoints in"
|
||||
" adjusted r2 score. Falls back to standard r2 score."
|
||||
):
|
||||
metric(torch.randn(10, ), torch.randn(10, ))
|
||||
|
||||
with pytest.warns(UserWarning,
|
||||
match="Division by zero in adjusted r2 score. Falls back to"
|
||||
" standard r2 score."):
|
||||
metric(torch.randn(11,), torch.randn(11,))
|
||||
with pytest.warns(UserWarning, match="Division by zero in adjusted r2 score. Falls back to" " standard r2 score."):
|
||||
metric(torch.randn(11, ), torch.randn(11, ))
|
||||
|
|
|
@ -11,10 +11,8 @@ from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES
|
|||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
Input = namedtuple('Input', ["preds", "target", "multichannel"])
|
||||
|
||||
|
||||
_inputs = []
|
||||
for size, channel, coef, multichannel, dtype in [
|
||||
(12, 3, 0.9, True, torch.float),
|
||||
|
@ -23,13 +21,11 @@ for size, channel, coef, multichannel, dtype in [
|
|||
(15, 3, 0.6, True, torch.float64),
|
||||
]:
|
||||
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
|
||||
_inputs.append(
|
||||
Input(
|
||||
preds=preds,
|
||||
target=preds * coef,
|
||||
multichannel=multichannel,
|
||||
)
|
||||
)
|
||||
_inputs.append(Input(
|
||||
preds=preds,
|
||||
target=preds * coef,
|
||||
multichannel=multichannel,
|
||||
))
|
||||
|
||||
|
||||
def _sk_metric(preds, target, data_range, multichannel):
|
||||
|
@ -41,8 +37,14 @@ def _sk_metric(preds, target, data_range, multichannel):
|
|||
sk_target = sk_target[:, :, :, 0]
|
||||
|
||||
return structural_similarity(
|
||||
sk_target, sk_preds, data_range=data_range, multichannel=multichannel,
|
||||
gaussian_weights=True, win_size=11, sigma=1.5, use_sample_covariance=False
|
||||
sk_target,
|
||||
sk_preds,
|
||||
data_range=data_range,
|
||||
multichannel=multichannel,
|
||||
gaussian_weights=True,
|
||||
win_size=11,
|
||||
sigma=1.5,
|
||||
use_sample_covariance=False
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -7,13 +7,16 @@ import torch
|
|||
from pytorch_lightning.metrics.compositional import CompositionalMetric
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
_MARK_TORCH_LOWER_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
|
||||
reason='required PT >= 1.5')
|
||||
_MARK_TORCH_LOWER_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
|
||||
reason='required PT >= 1.6')
|
||||
_MARK_TORCH_LOWER_1_4 = dict(
|
||||
condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"), reason='required PT >= 1.5'
|
||||
)
|
||||
_MARK_TORCH_LOWER_1_5 = dict(
|
||||
condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason='required PT >= 1.6'
|
||||
)
|
||||
|
||||
|
||||
class DummyMetric(Metric):
|
||||
|
||||
def __init__(self, val_to_return):
|
||||
super().__init__()
|
||||
self._num_updates = 0
|
||||
|
@ -295,7 +298,7 @@ def test_metrics_or(second_operand, expected_result):
|
|||
def test_metrics_pow(second_operand, expected_result):
|
||||
first_metric = DummyMetric(2)
|
||||
|
||||
final_pow = first_metric ** second_operand
|
||||
final_pow = first_metric**second_operand
|
||||
|
||||
assert isinstance(final_pow, CompositionalMetric)
|
||||
|
||||
|
@ -349,7 +352,7 @@ def test_metrics_rmod(first_operand, expected_result):
|
|||
def test_metrics_rpow(first_operand, expected_result):
|
||||
second_operand = DummyMetric(2)
|
||||
|
||||
final_rpow = first_operand ** second_operand
|
||||
final_rpow = first_operand**second_operand
|
||||
|
||||
assert isinstance(final_rpow, CompositionalMetric)
|
||||
|
||||
|
|
|
@ -43,13 +43,14 @@ def _test_ddp_sum_cat(rank, worldsize):
|
|||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat])
|
||||
def test_ddp(process):
|
||||
torch.multiprocessing.spawn(process, args=(2,), nprocs=2)
|
||||
torch.multiprocessing.spawn(process, args=(2, ), nprocs=2)
|
||||
|
||||
|
||||
def _test_non_contiguous_tensors(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
|
||||
class DummyMetric(Metric):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", default=[], dist_reduce_fx=None)
|
||||
|
@ -68,4 +69,4 @@ def _test_non_contiguous_tensors(rank, worldsize):
|
|||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
def test_non_contiguous_tensors():
|
||||
""" Test that gather_all operation works for non contiguous tensors """
|
||||
torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2,), nprocs=2)
|
||||
torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2, ), nprocs=2)
|
||||
|
|
|
@ -55,7 +55,7 @@ def test_add_state():
|
|||
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
|
||||
|
||||
a.add_state("c", torch.tensor(0), "cat")
|
||||
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2,)
|
||||
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2, )
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.add_state("d1", torch.tensor(0), 'xyz')
|
||||
|
@ -89,6 +89,7 @@ def test_add_state_persistent():
|
|||
|
||||
|
||||
def test_reset():
|
||||
|
||||
class A(Dummy):
|
||||
pass
|
||||
|
||||
|
@ -109,7 +110,9 @@ def test_reset():
|
|||
|
||||
|
||||
def test_update():
|
||||
|
||||
class A(Dummy):
|
||||
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
|
@ -125,7 +128,9 @@ def test_update():
|
|||
|
||||
|
||||
def test_compute():
|
||||
|
||||
class A(Dummy):
|
||||
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
|
@ -150,7 +155,9 @@ def test_compute():
|
|||
|
||||
|
||||
def test_forward():
|
||||
|
||||
class A(Dummy):
|
||||
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
|
@ -168,6 +175,7 @@ def test_forward():
|
|||
|
||||
|
||||
class DummyMetric1(Dummy):
|
||||
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
|
@ -176,6 +184,7 @@ class DummyMetric1(Dummy):
|
|||
|
||||
|
||||
class DummyMetric2(Dummy):
|
||||
|
||||
def update(self, y):
|
||||
self.x -= y
|
||||
|
||||
|
@ -214,7 +223,9 @@ def test_state_dict(tmpdir):
|
|||
|
||||
def test_child_metric_state_dict():
|
||||
""" test that child metric states will be added to parent state dict """
|
||||
|
||||
class TestModule(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = Dummy()
|
||||
|
@ -226,7 +237,7 @@ def test_child_metric_state_dict():
|
|||
expected_state_dict = {
|
||||
'metric.a': torch.tensor(0),
|
||||
'metric.b': [],
|
||||
'metric.c': torch.tensor(0)
|
||||
'metric.c': torch.tensor(0),
|
||||
}
|
||||
assert module.state_dict() == expected_state_dict
|
||||
|
||||
|
@ -317,8 +328,7 @@ def test_metric_collection_wrong_input(tmpdir):
|
|||
|
||||
# Not all input are metrics (dict)
|
||||
with pytest.raises(ValueError):
|
||||
_ = MetricCollection({'metric1': m1,
|
||||
'metric2': 5})
|
||||
_ = MetricCollection({'metric1': m1, 'metric2': 5})
|
||||
|
||||
# Same metric passed in multiple times
|
||||
with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
|
||||
|
|
|
@ -6,6 +6,7 @@ from tests.base.boring_model import BoringModel
|
|||
|
||||
|
||||
class SumMetric(Metric):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
|
@ -18,6 +19,7 @@ class SumMetric(Metric):
|
|||
|
||||
|
||||
class DiffMetric(Metric):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
|
@ -30,7 +32,9 @@ class DiffMetric(Metric):
|
|||
|
||||
|
||||
def test_metric_lightning(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = SumMetric()
|
||||
|
@ -64,7 +68,9 @@ def test_metric_lightning(tmpdir):
|
|||
|
||||
def test_metric_lightning_log(tmpdir):
|
||||
""" Test logging a metric object and that the metric state gets reset after each epoch."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric_step = SumMetric()
|
||||
|
@ -103,7 +109,9 @@ def test_metric_lightning_log(tmpdir):
|
|||
|
||||
|
||||
def test_scriptable(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# the metric is not used in the module's `forward`
|
||||
|
@ -141,7 +149,9 @@ def test_scriptable(tmpdir):
|
|||
|
||||
|
||||
def test_metric_collection_lightning_log(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = MetricCollection([SumMetric(), DiffMetric()])
|
||||
|
|
Загрузка…
Ссылка в новой задаче