Refactor unnecessary `else` / `elif` when `if` block has a `return` statement (#285)
* Refactor unnecessary `else` / `elif` when `if` block has a `return` statement * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Родитель
51a37b5239
Коммит
a150ab1acd
|
@ -75,7 +75,7 @@ def _sk_fbeta_f1_multidim_multiclass(
|
|||
target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])
|
||||
|
||||
return _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, False, ignore_index)
|
||||
elif mdmc_average == "samplewise":
|
||||
if mdmc_average == "samplewise":
|
||||
scores = []
|
||||
|
||||
for i in range(preds.shape[0]):
|
||||
|
|
|
@ -65,14 +65,13 @@ def _sk_hinge(preds, target, squared, multiclass_mode):
|
|||
if squared:
|
||||
measures = measures**2
|
||||
return measures.mean(axis=0)
|
||||
else:
|
||||
if multiclass_mode == MulticlassMode.ONE_VS_ALL:
|
||||
result = np.zeros(sk_preds.shape[1])
|
||||
for i in range(result.shape[0]):
|
||||
result[i] = sk_hinge(y_true=sk_target[:, i], pred_decision=sk_preds[:, i])
|
||||
return result
|
||||
if multiclass_mode == MulticlassMode.ONE_VS_ALL:
|
||||
result = np.zeros(sk_preds.shape[1])
|
||||
for i in range(result.shape[0]):
|
||||
result[i] = sk_hinge(y_true=sk_target[:, i], pred_decision=sk_preds[:, i])
|
||||
return result
|
||||
|
||||
return sk_hinge(y_true=sk_target, pred_decision=sk_preds)
|
||||
return sk_hinge(y_true=sk_target, pred_decision=sk_preds)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -76,7 +76,7 @@ def _sk_prec_recall_multidim_multiclass(
|
|||
target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])
|
||||
|
||||
return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index)
|
||||
elif mdmc_average == "samplewise":
|
||||
if mdmc_average == "samplewise":
|
||||
scores = []
|
||||
|
||||
for i in range(preds.shape[0]):
|
||||
|
|
|
@ -112,28 +112,19 @@ def _sk_spec_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multicla
|
|||
preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
|
||||
target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])
|
||||
return _sk_spec(preds, target, reduce, num_classes, False, ignore_index, top_k, mdmc_reduce)
|
||||
else:
|
||||
fp, tn = [], []
|
||||
stats = []
|
||||
fp, tn = [], []
|
||||
stats = []
|
||||
|
||||
for i in range(preds.shape[0]):
|
||||
pred_i = preds[i, ...].T
|
||||
target_i = target[i, ...].T
|
||||
fp_i, tn_i = _sk_stats_score(
|
||||
pred_i,
|
||||
target_i,
|
||||
reduce,
|
||||
num_classes,
|
||||
False,
|
||||
ignore_index,
|
||||
top_k,
|
||||
)
|
||||
fp.append(fp_i)
|
||||
tn.append(tn_i)
|
||||
for i in range(preds.shape[0]):
|
||||
pred_i = preds[i, ...].T
|
||||
target_i = target[i, ...].T
|
||||
fp_i, tn_i = _sk_stats_score(pred_i, target_i, reduce, num_classes, False, ignore_index, top_k)
|
||||
fp.append(fp_i)
|
||||
tn.append(tn_i)
|
||||
|
||||
stats.append(fp)
|
||||
stats.append(tn)
|
||||
return _sk_spec(preds[0], target[0], reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce, stats)
|
||||
stats.append(fp)
|
||||
stats.append(tn)
|
||||
return _sk_spec(preds[0], target[0], reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce, stats)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric, fn_metric", [(Specificity, specificity)])
|
||||
|
|
|
@ -85,7 +85,7 @@ def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, m
|
|||
target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])
|
||||
|
||||
return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k)
|
||||
elif mdmc_reduce == "samplewise":
|
||||
if mdmc_reduce == "samplewise":
|
||||
scores = []
|
||||
|
||||
for i in range(preds.shape[0]):
|
||||
|
|
|
@ -50,8 +50,7 @@ def _fallout_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
|
|||
order_indexes = np.argsort(preds, axis=0)[::-1]
|
||||
relevant = np.sum(target[order_indexes][:k])
|
||||
return relevant * 1.0 / target.sum()
|
||||
else:
|
||||
return np.NaN
|
||||
return np.NaN
|
||||
|
||||
|
||||
class TestFallOut(RetrievalMetricTester):
|
||||
|
|
|
@ -50,8 +50,7 @@ def _reciprocal_rank(target: np.ndarray, preds: np.ndarray):
|
|||
if target.sum() > 0:
|
||||
# sklearn `label_ranking_average_precision_score` requires at most 2 dims
|
||||
return label_ranking_average_precision_score(np.expand_dims(target, axis=0), np.expand_dims(preds, axis=0))
|
||||
else:
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
|
||||
class TestMRR(RetrievalMetricTester):
|
||||
|
|
|
@ -49,8 +49,7 @@ def _precision_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
|
|||
order_indexes = np.argsort(preds, axis=0)[::-1]
|
||||
relevant = np.sum(target[order_indexes][:k])
|
||||
return relevant * 1.0 / k
|
||||
else:
|
||||
return np.NaN
|
||||
return np.NaN
|
||||
|
||||
|
||||
class TestPrecision(RetrievalMetricTester):
|
||||
|
|
|
@ -48,8 +48,7 @@ def _recall_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
|
|||
order_indexes = np.argsort(preds, axis=0)[::-1]
|
||||
relevant = np.sum(target[order_indexes][:k])
|
||||
return relevant * 1.0 / target.sum()
|
||||
else:
|
||||
return np.NaN
|
||||
return np.NaN
|
||||
|
||||
|
||||
class TestRecall(RetrievalMetricTester):
|
||||
|
|
|
@ -268,9 +268,8 @@ class Accuracy(StatScores):
|
|||
"""
|
||||
if self.subset_accuracy:
|
||||
return _subset_accuracy_compute(self.correct, self.total)
|
||||
else:
|
||||
tp, fp, tn, fn = self._get_final_stats()
|
||||
return _accuracy_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.mode)
|
||||
tp, fp, tn, fn = self._get_final_stats()
|
||||
return _accuracy_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.mode)
|
||||
|
||||
@property
|
||||
def is_differentiable(self) -> bool:
|
||||
|
|
|
@ -163,8 +163,7 @@ class BinnedPrecisionRecallCurve(Metric):
|
|||
recalls = torch.cat([recalls, t_zeros], dim=1)
|
||||
if self.num_classes == 1:
|
||||
return (precisions[0, :], recalls[0, :], self.thresholds)
|
||||
else:
|
||||
return (list(precisions), list(recalls), [self.thresholds for _ in range(self.num_classes)])
|
||||
return (list(precisions), list(recalls), [self.thresholds for _ in range(self.num_classes)])
|
||||
|
||||
|
||||
class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
|
||||
|
|
|
@ -100,9 +100,9 @@ def _auroc_compute(
|
|||
# calculate average
|
||||
if average == AverageMethod.NONE:
|
||||
return auc_scores
|
||||
elif average == AverageMethod.MACRO:
|
||||
if average == AverageMethod.MACRO:
|
||||
return torch.mean(torch.stack(auc_scores))
|
||||
elif average == AverageMethod.WEIGHTED:
|
||||
if average == AverageMethod.WEIGHTED:
|
||||
if mode == DataType.MULTILABEL:
|
||||
support = torch.sum(target, dim=0)
|
||||
else:
|
||||
|
|
|
@ -76,11 +76,11 @@ def class_reduce(num: Tensor, denom: Tensor, weights: Tensor, class_reduction: s
|
|||
|
||||
if class_reduction == "micro":
|
||||
return fraction
|
||||
elif class_reduction == "macro":
|
||||
if class_reduction == "macro":
|
||||
return torch.mean(fraction)
|
||||
elif class_reduction == "weighted":
|
||||
if class_reduction == "weighted":
|
||||
return torch.sum(fraction * (weights.float() / torch.sum(weights)))
|
||||
elif class_reduction == "none" or class_reduction is None:
|
||||
if class_reduction == "none" or class_reduction is None:
|
||||
return fraction
|
||||
|
||||
raise ValueError(
|
||||
|
|
|
@ -40,7 +40,7 @@ def _bootstrap_sampler(
|
|||
p = torch.distributions.Poisson(1)
|
||||
n = p.sample((size, ))
|
||||
return torch.arange(size).repeat_interleave(n.long(), dim=0)
|
||||
elif sampling_strategy == 'multinomial':
|
||||
if sampling_strategy == 'multinomial':
|
||||
idx = torch.multinomial(torch.ones(size), num_samples=size, replacement=True)
|
||||
return idx
|
||||
raise ValueError('Unknown sampling strategy')
|
||||
|
|
Загрузка…
Ссылка в новой задаче