Fix `num_classes` arg in F1 metric (PL^5663)

* fix f1 metric

* Apply suggestions from code review

* chlog

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
(cherry picked from commit ef8efc611c12209ebf158aab7f444756ba35927e)
This commit is contained in:
Nicki Skafte 2021-01-27 12:16:04 +01:00 коммит произвёл Jirka Borovec
Родитель 3dcd3cb037
Коммит 8bc17d8db8
2 изменённых файлов: 10 добавлений и 7 удалений

Просмотреть файл

@ -185,7 +185,7 @@ class F1(FBeta):
def __init__(
self,
num_classes: int = 1,
num_classes: int,
threshold: float = 0.5,
average: str = "micro",
multilabel: bool = False,
@ -201,6 +201,7 @@ class F1(FBeta):
beta=1.0,
threshold=threshold,
average=average,
multilabel=multilabel,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,

Просмотреть файл

@ -5,7 +5,7 @@ import pytest
import torch
from sklearn.metrics import fbeta_score
from pytorch_lightning.metrics import FBeta
from pytorch_lightning.metrics import F1, FBeta
from pytorch_lightning.metrics.functional import f1, fbeta
from tests.metrics.classification.inputs import (
_binary_inputs,
@ -97,22 +97,23 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0):
],
)
@pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None])
@pytest.mark.parametrize("beta", [0.5, 1.0])
@pytest.mark.parametrize("beta", [0.5, 1.0, 2.0])
class TestFBeta(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_fbeta(
self, preds, target, sk_metric, num_classes, multilabel, average, beta, ddp, dist_sync_on_step
):
metric_class = F1 if beta == 1.0 else partial(FBeta, beta=beta)
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=FBeta,
metric_class=metric_class,
sk_metric=partial(sk_metric, average=average, beta=beta),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"beta": beta,
"num_classes": num_classes,
"average": average,
"multilabel": multilabel,
@ -125,12 +126,13 @@ class TestFBeta(MetricTester):
def test_fbeta_functional(
self, preds, target, sk_metric, num_classes, multilabel, average, beta
):
metric_functional = f1 if beta == 1.0 else partial(fbeta, beta=beta)
self.run_functional_metric_test(preds=preds,
target=target,
metric_functional=fbeta,
metric_functional=metric_functional,
sk_metric=partial(sk_metric, average=average, beta=beta),
metric_args={
"beta": beta,
"num_classes": num_classes,
"average": average,
"multilabel": multilabel,