* annotate all unused vars

* rank_zero_warn

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* f1 fixed

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
(cherry picked from commit dfbad656c493bd392c59e7d563125006b95d225c)
This commit is contained in:
Jirka Borovec 2020-12-19 13:53:06 +01:00 коммит произвёл Jirka Borovec
Родитель b79e64dff0
Коммит 048a415b26
5 изменённых файлов: 8 добавлений и 2 удалений

Просмотреть файл

@ -219,6 +219,9 @@ class F1(FBeta):
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
if multilabel is not False:
rank_zero_warn(f'The `multilabel={multilabel}` parameter is unused and will not have any effect.')
super().__init__(
num_classes=num_classes,
beta=1.0,

Просмотреть файл

@ -188,7 +188,6 @@ def stat_scores_multiple_classes(
tps = torch.zeros((num_classes + 1,), device=pred.device)
fps = torch.zeros((num_classes + 1,), device=pred.device)
tns = torch.zeros((num_classes + 1,), device=pred.device)
fns = torch.zeros((num_classes + 1,), device=pred.device)
sups = torch.zeros((num_classes + 1,), device=pred.device)

Просмотреть файл

@ -16,6 +16,7 @@ from typing import Tuple
import torch
from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
from pytorch_lightning.utilities import rank_zero_warn
def _fbeta_update(

Просмотреть файл

@ -74,7 +74,6 @@ def bleu_score(
assert len(translate_corpus) == len(reference_corpus)
numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
precision_scores = torch.zeros(n_gram)
c = 0.0
r = 0.0

Просмотреть файл

@ -3,6 +3,8 @@ from typing import Tuple, Optional
import torch
from pytorch_lightning.utilities import rank_zero_warn
def _psnr_compute(
sum_squared_error: torch.Tensor,
@ -11,6 +13,8 @@ def _psnr_compute(
base: float = 10.0,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
if reduction != 'elementwise_mean':
rank_zero_warn(f'The `reduction={reduction}` parameter is unused and will not have any effect.')
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
return psnr