Fix `num_classes` arg in F1 metric (PL^5663)
* fix f1 metric * Apply suggestions from code review * chlog Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> (cherry picked from commit ef8efc611c12209ebf158aab7f444756ba35927e)
This commit is contained in:
Родитель
3dcd3cb037
Коммит
8bc17d8db8
|
@ -185,7 +185,7 @@ class F1(FBeta):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = 1,
|
||||
num_classes: int,
|
||||
threshold: float = 0.5,
|
||||
average: str = "micro",
|
||||
multilabel: bool = False,
|
||||
|
@ -201,6 +201,7 @@ class F1(FBeta):
|
|||
beta=1.0,
|
||||
threshold=threshold,
|
||||
average=average,
|
||||
multilabel=multilabel,
|
||||
compute_on_step=compute_on_step,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
process_group=process_group,
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
import torch
|
||||
from sklearn.metrics import fbeta_score
|
||||
|
||||
from pytorch_lightning.metrics import FBeta
|
||||
from pytorch_lightning.metrics import F1, FBeta
|
||||
from pytorch_lightning.metrics.functional import f1, fbeta
|
||||
from tests.metrics.classification.inputs import (
|
||||
_binary_inputs,
|
||||
|
@ -97,22 +97,23 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0):
|
|||
],
|
||||
)
|
||||
@pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None])
|
||||
@pytest.mark.parametrize("beta", [0.5, 1.0])
|
||||
@pytest.mark.parametrize("beta", [0.5, 1.0, 2.0])
|
||||
class TestFBeta(MetricTester):
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
||||
def test_fbeta(
|
||||
self, preds, target, sk_metric, num_classes, multilabel, average, beta, ddp, dist_sync_on_step
|
||||
):
|
||||
metric_class = F1 if beta == 1.0 else partial(FBeta, beta=beta)
|
||||
|
||||
self.run_class_metric_test(
|
||||
ddp=ddp,
|
||||
preds=preds,
|
||||
target=target,
|
||||
metric_class=FBeta,
|
||||
metric_class=metric_class,
|
||||
sk_metric=partial(sk_metric, average=average, beta=beta),
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
metric_args={
|
||||
"beta": beta,
|
||||
"num_classes": num_classes,
|
||||
"average": average,
|
||||
"multilabel": multilabel,
|
||||
|
@ -125,12 +126,13 @@ class TestFBeta(MetricTester):
|
|||
def test_fbeta_functional(
|
||||
self, preds, target, sk_metric, num_classes, multilabel, average, beta
|
||||
):
|
||||
metric_functional = f1 if beta == 1.0 else partial(fbeta, beta=beta)
|
||||
|
||||
self.run_functional_metric_test(preds=preds,
|
||||
target=target,
|
||||
metric_functional=fbeta,
|
||||
metric_functional=metric_functional,
|
||||
sk_metric=partial(sk_metric, average=average, beta=beta),
|
||||
metric_args={
|
||||
"beta": beta,
|
||||
"num_classes": num_classes,
|
||||
"average": average,
|
||||
"multilabel": multilabel,
|
||||
|
|
Загрузка…
Ссылка в новой задаче