This commit is contained in:
Jirka 2021-08-03 12:50:18 +02:00
Родитель 219af0e42c
Коммит 21fe0ca7e1
81 изменённых файлов: 233 добавлений и 352 удалений

Просмотреть файл

@ -47,7 +47,7 @@ repos:
rev: v1.4
hooks:
- id: docformatter
args: [--in-place]
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
- repo: https://github.com/PyCQA/isort
rev: 5.9.2

Просмотреть файл

@ -84,8 +84,7 @@ def test_metric_lightning(tmpdir):
@pytest.mark.skipif(not _LIGHTNING_GREATER_EQUAL_1_3, reason="test requires lightning v1.3 or higher")
def test_metrics_reset(tmpdir):
"""Tests that metrics are reset correctly after the end of the
train/val/test epoch.
"""Tests that metrics are reset correctly after the end of the train/val/test epoch.
Taken from:
https://github.com/PyTorchLightning/pytorch-lightning/pull/7055

Просмотреть файл

@ -51,8 +51,7 @@ def naive_implementation_pit_scipy(
metric_func: Callable,
eval_func: str,
) -> Tuple[Tensor, Tensor]:
"""A naive implementation of `Permutation Invariant Training` based on
Scipy.
"""A naive implementation of `Permutation Invariant Training` based on Scipy.
Args:
preds: predictions, shape[batch, spk, time]

Просмотреть файл

@ -105,8 +105,8 @@ def test_metric_collection_wrong_input(tmpdir):
def test_metric_collection_args_kwargs(tmpdir):
"""Check that args and kwargs gets passed correctly in metric collection,
Checks both update and forward method."""
"""Check that args and kwargs gets passed correctly in metric collection, Checks both update and forward
method."""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()

Просмотреть файл

@ -225,6 +225,6 @@ def _test_state_dict_is_synced(rank, worldsize, tmpdir):
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_state_dict_is_synced(tmpdir):
"""This test asserts that metrics are synced while creating the state dict
but restored after to continue accumulation."""
"""This test asserts that metrics are synced while creating the state dict but restored after to continue
accumulation."""
torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2)

Просмотреть файл

@ -270,8 +270,7 @@ def test_device_and_dtype_transfer(tmpdir):
def test_warning_on_compute_before_update():
"""test that an warning is raised if user tries to call compute before
update."""
"""test that an warning is raised if user tries to call compute before update."""
metric = DummyMetricSum()
# make sure everything is fine with forward

Просмотреть файл

@ -327,9 +327,8 @@ def test_average_accuracy_bin(preds, target, num_classes, exp_result, average, m
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
"""This tests that when metric is computed per class and a given class is
not present in both the `preds` and `target`, the resulting score is
`nan`."""
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
`target`, the resulting score is `nan`."""
preds = torch.tensor([0, 0, 0])
target = torch.tensor([0, 0, 0])
num_classes = 2

Просмотреть файл

@ -168,8 +168,8 @@ class TestAUROC(MetricTester):
def test_error_on_different_mode():
"""test that an error is raised if the user pass in data of different modes
(binary, multi-label, multi-class)"""
"""test that an error is raised if the user pass in data of different modes (binary, multi-label, multi-
class)"""
metric = AUROC()
# pass in multi-class data
metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10,)))
@ -186,8 +186,7 @@ def test_error_multiclass_no_num_classes():
def test_weighted_with_empty_classes():
"""Tests that weighted multiclass AUROC calculation yields the same results
if a new but empty class exists.
"""Tests that weighted multiclass AUROC calculation yields the same results if a new but empty class exists.
Tests that the proper warnings and errors are raised
"""

Просмотреть файл

@ -133,8 +133,7 @@ def test_wrong_params(metric_class, metric_fn, average, mdmc_average, num_classe
],
)
def test_zero_division(metric_class, metric_fn):
"""Test that zero_division works correctly (currently should just set to
0)."""
"""Test that zero_division works correctly (currently should just set to 0)."""
preds = torch.tensor([1, 2, 1, 1])
target = torch.tensor([2, 0, 2, 1])
@ -184,9 +183,8 @@ def test_no_support(metric_class, metric_fn):
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
"""This tests that when metric is computed per class and a given class is
not present in both the `preds` and `target`, the resulting score is
`nan`."""
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
`target`, the resulting score is `nan`."""
preds = torch.tensor([0, 0, 0])
target = torch.tensor([0, 0, 0])
num_classes = 2
@ -414,8 +412,7 @@ def test_top_k(
):
"""A simple test to check that top_k works as expected.
Just a sanity check, the tests in StatScores should already
guarantee the corectness of results.
Just a sanity check, the tests in StatScores should already guarantee the corectness of results.
"""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)

Просмотреть файл

@ -133,8 +133,7 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
def test_zero_division(metric_class, metric_fn):
"""Test that zero_division works correctly (currently should just set to
0)."""
"""Test that zero_division works correctly (currently should just set to 0)."""
preds = tensor([0, 2, 1, 1])
target = tensor([2, 1, 2, 1])
@ -357,8 +356,8 @@ class TestPrecisionRecall(MetricTester):
def test_precision_recall_joint(average):
"""A simple test of the joint precision_recall metric.
No need to test this thorougly, as it is just a combination of
precision and recall, which are already tested thoroughly.
No need to test this thorougly, as it is just a combination of precision and recall, which are already tested
thoroughly.
"""
precision_result = precision(
@ -404,8 +403,7 @@ def test_top_k(
):
"""A simple test to check that top_k works as expected.
Just a sanity check, the tests in StatScores should already
guarantee the correctness of results.
Just a sanity check, the tests in StatScores should already guarantee the correctness of results.
"""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
@ -425,9 +423,8 @@ def test_top_k(
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
"""This tests that when metric is computed per class and a given class is
not present in both the `preds` and `target`, the resulting score is
`nan`."""
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
`target`, the resulting score is `nan`."""
preds = torch.tensor([0, 0, 0])
target = torch.tensor([0, 0, 0])
num_classes = 2

Просмотреть файл

@ -160,8 +160,7 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)])
def test_zero_division(metric_class, metric_fn):
"""Test that zero_division works correctly (currently should just set to
0)."""
"""Test that zero_division works correctly (currently should just set to 0)."""
preds = tensor([1, 2, 1, 1])
target = tensor([0, 0, 0, 0])
@ -383,8 +382,7 @@ def test_top_k(
):
"""A simple test to check that top_k works as expected.
Just a sanity check, the tests in StatScores should already
guarantee the correctness of results.
Just a sanity check, the tests in StatScores should already guarantee the correctness of results.
"""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
@ -399,9 +397,8 @@ def test_top_k(
"ignore_index, expected", [(None, torch.tensor([0.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
)
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
"""This tests that when metric is computed per class and a given class is
not present in both the `preds` and `target`, the resulting score is
`nan`."""
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
`target`, the resulting score is `nan`."""
preds = torch.tensor([0, 0, 0])
target = torch.tensor([0, 0, 0])
num_classes = 2

Просмотреть файл

@ -113,14 +113,11 @@ def _sk_stat_scores_mdim_mcls(
],
)
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
"""Test a combination of parameters that are invalid and should raise an
error.
"""Test a combination of parameters that are invalid and should raise an error.
This includes invalid ``reduce`` and ``mdmc_reduce`` parameter
values, not setting ``num_classes`` when ``reduce='macro'`, not
setting ``mdmc_reduce`` when inputs are multi-dim multi-class``,
setting ``ignore_index`` when inputs are binary, as well as setting
``ignore_index`` to a value higher than the number of classes.
This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting ``num_classes`` when
``reduce='macro'`, not setting ``mdmc_reduce`` when inputs are multi-dim multi-class``, setting ``ignore_index``
when inputs are binary, as well as setting ``ignore_index`` to a value higher than the number of classes.
"""
with pytest.raises(ValueError):
stat_scores(

Просмотреть файл

@ -27,8 +27,7 @@ def test_invalid_input_img_type():
def test_invalid_input_ndims():
"""Test whether the module successfully handles invalid number of
dimensions of input tensor."""
"""Test whether the module successfully handles invalid number of dimensions of input tensor."""
BATCH_SIZE = 1
HEIGHT = 5
@ -43,9 +42,8 @@ def test_invalid_input_ndims():
def test_multi_batch_image_gradients():
"""Test whether the module correctly calculates gradients for known input
with non-unity batch size.Example input-output pair taken from TF's
implementation of i mage-gradients."""
"""Test whether the module correctly calculates gradients for known input with non-unity batch size.Example
input-output pair taken from TF's implementation of i mage-gradients."""
BATCH_SIZE = 5
HEIGHT = 5
@ -76,8 +74,7 @@ def test_multi_batch_image_gradients():
def test_image_gradients():
"""Test whether the module correctly calculates gradients for known input.
Example input-output pair taken from TF's implementation of image-
gradients
Example input-output pair taken from TF's implementation of image- gradients
"""
BATCH_SIZE = 1

Просмотреть файл

@ -72,13 +72,11 @@ def calibration_error(
pos_label: Optional[Union[int, str]] = None,
reduce_bias: bool = True,
) -> float:
"""Compute calibration error of a binary classifier. Across all items in a
set of N predictions, the calibration error measures the aggregated
difference between (1) the average predicted probabilities assigned to the
positive class, and (2) the frequencies of the positive class in the actual
outcome. The calibration error is only appropriate for binary categorical
outcomes. Which label is considered to be the positive label is controlled
via the parameter pos_label, which defaults to 1.
"""Compute calibration error of a binary classifier. Across all items in a set of N predictions, the
calibration error measures the aggregated difference between (1) the average predicted probabilities assigned
to the positive class, and (2) the frequencies of the positive class in the actual outcome. The calibration
error is only appropriate for binary categorical outcomes. Which label is considered to be the positive label
is controlled via the parameter pos_label, which defaults to 1.
Args:
y_true: array-like of shape (n_samples,)

Просмотреть файл

@ -58,8 +58,7 @@ def setup_ddp(rank, world_size):
def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8):
"""Utility function for recursively asserting that two results are within a
certain tolerance."""
"""Utility function for recursively asserting that two results are within a certain tolerance."""
# single output compare
if isinstance(pl_result, Tensor):
assert np.allclose(pl_result.cpu().numpy(), sk_result, atol=atol, equal_nan=True)
@ -72,8 +71,7 @@ def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8):
def _assert_tensor(pl_result: Any):
"""Utility function for recursively checking that some input only consists
of torch tensors."""
"""Utility function for recursively checking that some input only consists of torch tensors."""
if isinstance(pl_result, Sequence):
for plr in pl_result:
_assert_tensor(plr)
@ -82,8 +80,8 @@ def _assert_tensor(pl_result: Any):
def _assert_requires_grad(metric: Metric, pl_result: Any):
"""Utility function for recursively asserting that metric output is
consistent with the `is_differentiable` attribute."""
"""Utility function for recursively asserting that metric output is consistent with the `is_differentiable`
attribute."""
if isinstance(pl_result, Sequence):
for plr in pl_result:
_assert_requires_grad(metric, plr)
@ -108,8 +106,7 @@ def _class_test(
check_scriptable: bool = True,
**kwargs_update: Any,
):
"""Utility function doing the actual comparison between lightning class
metric and reference metric.
"""Utility function doing the actual comparison between lightning class metric and reference metric.
Args:
rank: rank of current process
@ -203,8 +200,7 @@ def _functional_test(
fragment_kwargs: bool = False,
**kwargs_update,
):
"""Utility function doing the actual comparison between lightning
functional metric and reference metric.
"""Utility function doing the actual comparison between lightning functional metric and reference metric.
Args:
preds: torch tensor with predictions
@ -271,12 +267,10 @@ def _assert_half_support(
class MetricTester:
"""Class used for efficiently run alot of parametrized tests in ddp mode.
Makes sure that ddp is only setup once and that pool of processes are used
for all tests.
"""Class used for efficiently run alot of parametrized tests in ddp mode. Makes sure that ddp is only setup
once and that pool of processes are used for all tests.
All tests should subclass from this and implement a new method
called `test_metric_name` where the method
All tests should subclass from this and implement a new method called `test_metric_name` where the method
`self.run_metric_test` is called inside.
"""
@ -285,8 +279,7 @@ class MetricTester:
def setup_class(self):
"""Setup the metric class.
This will spawn the pool of workers that are used for metric
testing and setup_ddp
This will spawn the pool of workers that are used for metric testing and setup_ddp
"""
self.poolSize = NUM_PROCESSES
@ -308,8 +301,7 @@ class MetricTester:
fragment_kwargs: bool = False,
**kwargs_update,
):
"""Main method that should be used for testing functions. Call this
inside testing method.
"""Main method that should be used for testing functions. Call this inside testing method.
Args:
preds: torch tensor with predictions
@ -350,8 +342,7 @@ class MetricTester:
check_scriptable: bool = True,
**kwargs_update,
):
"""Main method that should be used for testing class. Call this inside
testing methods.
"""Main method that should be used for testing class. Call this inside testing methods.
Args:
ddp: bool, if running in ddp mode or not

Просмотреть файл

@ -38,9 +38,8 @@ seed_all(42)
def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]:
"""Given an integer `torch.Tensor` or `np.ndarray` `indexes`, return a
`torch.Tensor` or `np.ndarray` of indexes for each different value in
`indexes`.
"""Given an integer `torch.Tensor` or `np.ndarray` `indexes`, return a `torch.Tensor` or `np.ndarray` of
indexes for each different value in `indexes`.
Args:
indexes: a `torch.Tensor` or `np.ndarray` of integers
@ -75,8 +74,7 @@ def _compute_sklearn_metric(
reverse: bool = False,
**kwargs,
) -> Tensor:
"""Compute metric with multiple iterations over every query predictions
set."""
"""Compute metric with multiple iterations over every query predictions set."""
if indexes is None:
indexes = np.full_like(preds, fill_value=0, dtype=np.int64)

Просмотреть файл

@ -34,11 +34,9 @@ seed_all(42)
def _fallout_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
"""Didn't find a reliable implementation of Fall-out in Information
Retrieval, so, reimplementing here.
"""Didn't find a reliable implementation of Fall-out in Information Retrieval, so, reimplementing here.
See Wikipedia for `Fall-out`_ for more information about the metric
definition.
See Wikipedia for `Fall-out`_ for more information about the metric definition.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

Просмотреть файл

@ -35,10 +35,9 @@ seed_all(42)
def _reciprocal_rank(target: np.ndarray, preds: np.ndarray):
"""Adaptation of `sklearn.metrics.label_ranking_average_precision_score`.
Since the original sklearn metric works as RR only when the number
of positive targets is exactly 1, here we remove every positive
target that is not the most important. Remember that in RR only the
positive target with the highest score is considered.
Since the original sklearn metric works as RR only when the number of positive targets is exactly 1, here we remove
every positive target that is not the most important. Remember that in RR only the positive target with the highest
score is considered.
"""
assert target.shape == preds.shape
assert len(target.shape) == 1 # works only with single dimension inputs

Просмотреть файл

@ -34,8 +34,7 @@ seed_all(42)
def _precision_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
"""Didn't find a reliable implementation of Precision in Information
Retrieval, so, reimplementing here.
"""Didn't find a reliable implementation of Precision in Information Retrieval, so, reimplementing here.
A good explanation can be found
`here <https://web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-1-per.pdf>_`.

Просмотреть файл

@ -34,8 +34,7 @@ seed_all(42)
def _recall_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
"""Didn't find a reliable implementation of Recall in Information
Retrieval, so, reimplementing here.
"""Didn't find a reliable implementation of Recall in Information Retrieval, so, reimplementing here.
See wikipedia for more information about definition.
"""

Просмотреть файл

@ -30,8 +30,7 @@ def test_wer_same(hyp, ref, score):
)
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_functional(ref, hyp, expected_score, expected_incorrect, expected_total):
"""Test to ensure that the torchmetric functional WER matches the jiwer
reference."""
"""Test to ensure that the torchmetric functional WER matches the jiwer reference."""
assert wer(ref, hyp) == expected_score
@ -44,15 +43,13 @@ def test_wer_functional(ref, hyp, expected_score, expected_incorrect, expected_t
)
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_reference_functional(hyp, ref):
"""Test to ensure that the torchmetric functional WER matches the jiwer
reference."""
"""Test to ensure that the torchmetric functional WER matches the jiwer reference."""
assert wer(ref, hyp) == compute_measures(ref, hyp)["wer"]
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_reference_functional_concatenate():
"""Test to ensure that the torchmetric functional WER matches the jiwer
reference when concatenating."""
"""Test to ensure that the torchmetric functional WER matches the jiwer reference when concatenating."""
ref = ["hello world", "hello world"]
hyp = ["hello world", "Firwww"]
assert wer(ref, hyp) == compute_measures(ref, hyp)["wer"]
@ -76,8 +73,7 @@ def test_wer_reference(hyp, ref):
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_reference_batch():
"""Test to ensure that the torchmetric WER matches the jiwer reference with
accumulation."""
"""Test to ensure that the torchmetric WER matches the jiwer reference with accumulation."""
batches = [("hello world", "Firwww"), ("hello world", "hello world")]
metric = WER()

Просмотреть файл

@ -29,8 +29,8 @@ _target = torch.randint(10, (10, 32))
class TestBootStrapper(BootStrapper):
"""For testing purpose, we subclass the bootstrapper class so we can get
the exact permutation the class is creating."""
"""For testing purpose, we subclass the bootstrapper class so we can get the exact permutation the class is
creating."""
def update(self, *args) -> None:
self.out = []
@ -75,8 +75,7 @@ def test_bootstrap_sampler(sampling_strategy):
"metric, sk_metric", [[Precision(average="micro"), precision_score], [Recall(average="micro"), recall_score]]
)
def test_bootstrap(sampling_strategy, metric, sk_metric):
"""Test that the different bootstraps gets updated as we expected and that
the compute method works."""
"""Test that the different bootstraps gets updated as we expected and that the compute method works."""
_kwargs = {"base_metric": metric, "mean": True, "std": True, "raw": True, "sampling_strategy": sampling_strategy}
if _TORCH_GREATER_EQUAL_1_7:
_kwargs.update(dict(quantile=torch.tensor([0.05, 0.95])))

Просмотреть файл

@ -20,9 +20,9 @@ from torchmetrics.metric import Metric
class PIT(Metric):
"""Permutation invariant training (PIT). The PIT implements the famous
Permutation Invariant Training method [1] in speech separation field in
order to calculate audio metrics in a permutation invariant way.
"""Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method.
[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.
Forward accepts

Просмотреть файл

@ -20,8 +20,8 @@ from torchmetrics.metric import Metric
class SI_SDR(Metric):
"""Scale-invariant signal-to-distortion ratio (SI-SDR). The SI-SDR value is
in general considered an overall measure of how good a source sound.
"""Scale-invariant signal-to-distortion ratio (SI-SDR). The SI-SDR value is in general considered an overall
measure of how good a source sound.
Forward accepts

Просмотреть файл

@ -266,8 +266,7 @@ class Accuracy(StatScores):
self.fn.append(fn)
def compute(self) -> Tensor:
"""Computes accuracy based on inputs passed in to ``update``
previously."""
"""Computes accuracy based on inputs passed in to ``update`` previously."""
if not self.mode:
raise RuntimeError("You have to have determined mode.")
if self.subset_accuracy:

Просмотреть файл

@ -91,6 +91,6 @@ class AUC(Metric):
@property
def is_differentiable(self) -> bool:
"""AUC metrics is considered as non differentiable so it should have
`false` value for `is_differentiable` property."""
"""AUC metrics is considered as non differentiable so it should have `false` value for `is_differentiable`
property."""
return False

Просмотреть файл

@ -169,8 +169,7 @@ class AUROC(Metric):
self.mode = mode
def compute(self) -> Tensor:
"""Computes AUROC based on inputs passed in to ``update``
previously."""
"""Computes AUROC based on inputs passed in to ``update`` previously."""
if not self.mode:
raise RuntimeError("You have to have determined mode.")
preds = dim_zero_cat(self.preds)
@ -187,6 +186,6 @@ class AUROC(Metric):
@property
def is_differentiable(self) -> bool:
"""AUROC metrics is considered as non differentiable so it should have
`false` value for `is_differentiable` property."""
"""AUROC metrics is considered as non differentiable so it should have `false` value for
`is_differentiable` property."""
return False

Просмотреть файл

@ -26,10 +26,9 @@ from torchmetrics.utilities.data import dim_zero_cat
class AveragePrecision(Metric):
"""Computes the average precision score, which summarises the precision
recall curve into one number. Works for both binary and multiclass
problems. In the case of multiclass, the values will be calculated based on
a one-vs-the-rest approach.
"""Computes the average precision score, which summarises the precision recall curve into one number. Works for
both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-
vs-the-rest approach.
Forward accepts

Просмотреть файл

@ -43,9 +43,8 @@ def _recall_at_precision(
class BinnedPrecisionRecallCurve(Metric):
"""Computes precision-recall pairs for different thresholds. Works for both
binary and multiclass problems. In the case of multiclass, the values will
be calculated based on a one-vs-the-rest approach.
"""Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In
the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Computation is performed in constant-memory by computing precision and recall
for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1).
@ -190,10 +189,9 @@ class BinnedPrecisionRecallCurve(Metric):
class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
"""Computes the average precision score, which summarises the precision
recall curve into one number. Works for both binary and multiclass
problems. In the case of multiclass, the values will be calculated based on
a one-vs-the-rest approach.
"""Computes the average precision score, which summarises the precision recall curve into one number. Works for
both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-
vs-the-rest approach.
Computation is performed in constant-memory by computing precision and recall
for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1).
@ -245,8 +243,7 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
"""Computes the higest possible recall value given the minimum precision
thresholds provided.
"""Computes the higest possible recall value given the minimum precision thresholds provided.
Computation is performed in constant-memory by computing precision and recall
for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1).

Просмотреть файл

@ -93,8 +93,8 @@ class CalibrationError(Metric):
self.add_state("accuracies", [], dist_reduce_fx="cat")
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Computes top-level confidences and accuracies for the input
probabilites and appends them to internal state.
"""Computes top-level confidences and accuracies for the input probabilites and appends them to internal
state.
Args:
preds (Tensor): Model output probabilities.

Просмотреть файл

@ -119,7 +119,6 @@ class CohenKappa(Metric):
@property
def is_differentiable(self) -> bool:
"""cohen kappa is not differentiable since the implementation is based
on calculating the confusion matrix which in general is not
differentiable."""
"""cohen kappa is not differentiable since the implementation is based on calculating the confusion matrix
which in general is not differentiable."""
return False

Просмотреть файл

@ -175,8 +175,7 @@ class FBeta(StatScores):
class F1(FBeta):
"""Computes F1 metric. F1 metrics correspond to a harmonic mean of the
precision and recall scores.
"""Computes F1 metric. F1 metrics correspond to a harmonic mean of the precision and recall scores.
Works with binary, multiclass, and multilabel data. Accepts logits or probabilities from a model
output or integer class values in prediction. Works with multi-dimensional preds and target.

Просмотреть файл

@ -105,8 +105,7 @@ class HammingDistance(Metric):
self.total += total
def compute(self) -> Tensor:
"""Computes hamming distance based on inputs passed in to ``update``
previously."""
"""Computes hamming distance based on inputs passed in to ``update`` previously."""
return _hamming_distance_compute(self.correct, self.total)
@property

Просмотреть файл

@ -156,8 +156,7 @@ class Precision(StatScores):
self.average = average
def compute(self) -> Tensor:
"""Computes the precision score based on inputs passed in to ``update``
previously.
"""Computes the precision score based on inputs passed in to ``update`` previously.
Return:
The shape of the returned tensor depends on the ``average`` parameter
@ -310,8 +309,7 @@ class Recall(StatScores):
self.average = average
def compute(self) -> Tensor:
"""Computes the recall score based on inputs passed in to ``update``
previously.
"""Computes the recall score based on inputs passed in to ``update`` previously.
Return:
The shape of the returned tensor depends on the ``average`` parameter

Просмотреть файл

@ -26,9 +26,8 @@ from torchmetrics.utilities.data import dim_zero_cat
class PrecisionRecallCurve(Metric):
"""Computes precision-recall pairs for different thresholds. Works for both
binary and multiclass problems. In the case of multiclass, the values will
be calculated based on a one-vs-the-rest approach.
"""Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In
the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Forward accepts

Просмотреть файл

@ -22,9 +22,8 @@ from torchmetrics.utilities import rank_zero_warn
class ROC(Metric):
"""Computes the Receiver Operating Characteristic (ROC). Works for both
binary, multiclass and multilabel problems. In the case of multiclass, the
values will be calculated based on a one-vs-the-rest approach.
"""Computes the Receiver Operating Characteristic (ROC). Works for both binary, multiclass and multilabel
problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Forward accepts

Просмотреть файл

@ -157,8 +157,7 @@ class Specificity(StatScores):
self.average = average
def compute(self) -> Tensor:
"""Computes the specificity score based on inputs passed in to
``update`` previously.
"""Computes the specificity score based on inputs passed in to ``update`` previously.
Return:
The shape of the returned tensor depends on the ``average`` parameter

Просмотреть файл

@ -225,8 +225,7 @@ class StatScores(Metric):
self.fn.append(fn)
def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Performs concatenation on the stat scores if neccesary, before
passing them to a compute function."""
"""Performs concatenation on the stat scores if neccesary, before passing them to a compute function."""
tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp
fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp
tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn
@ -234,8 +233,7 @@ class StatScores(Metric):
return tp, fp, tn, fn
def compute(self) -> Tensor:
"""Computes the stat scores based on inputs passed in to ``update``
previously.
"""Computes the stat scores based on inputs passed in to ``update`` previously.
Return:
The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds

Просмотреть файл

@ -24,8 +24,7 @@ from torchmetrics.utilities import rank_zero_warn
class MetricCollection(nn.ModuleDict):
"""MetricCollection class can be used to chain metrics that have the same
call pattern into one single class.
"""MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
Args:
metrics: One of the following
@ -105,18 +104,16 @@ class MetricCollection(nn.ModuleDict):
def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Iteratively call forward for each metric.
Positional arguments (args) will be passed to every metric in
the collection, while keyword arguments (kwargs) will be
filtered based on the signature of the individual metric.
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
def update(self, *args: Any, **kwargs: Any) -> None:
"""Iteratively call update for each metric.
Positional arguments (args) will be passed to every metric in
the collection, while keyword arguments (kwargs) will be
filtered based on the signature of the individual metric.
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
for _, m in self.items(keep_base=True):
m_kwargs = m._filter_kwargs(**kwargs)
@ -145,8 +142,7 @@ class MetricCollection(nn.ModuleDict):
return mc
def persistent(self, mode: bool = True) -> None:
"""Method for post-init to change if metric states should be saved to
its state_dict."""
"""Method for post-init to change if metric states should be saved to its state_dict."""
for _, m in self.items(keep_base=True):
m.persistent(mode)

Просмотреть файл

@ -30,8 +30,8 @@ def _find_best_perm_by_linear_sum_assignment(
metric_mtx: torch.Tensor,
eval_func: Union[torch.min, torch.max],
) -> Tuple[Tensor, Tensor]:
"""Solves the linear sum assignment problem using scipy, and returns the
best metric values and the corresponding permutations.
"""Solves the linear sum assignment problem using scipy, and returns the best metric values and the
corresponding permutations.
Args:
metric_mtx:
@ -58,9 +58,8 @@ def _find_best_perm_by_exhuastive_method(
metric_mtx: torch.Tensor,
eval_func: Union[torch.min, torch.max],
) -> Tuple[Tensor, Tensor]:
"""Solves the linear sum assignment problem using exhuastive method, i.e.
exhuastively calculates the metric values of all possible permutations, and
returns the best metric values and the corresponding permutations.
"""Solves the linear sum assignment problem using exhuastive method, i.e. exhuastively calculates the metric
values of all possible permutations, and returns the best metric values and the corresponding permutations.
Args:
metric_mtx:
@ -105,9 +104,9 @@ def _find_best_perm_by_exhuastive_method(
def pit(
preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = "max", **kwargs: Dict[str, Any]
) -> Tuple[Tensor, Tensor]:
"""Permutation invariant training (PIT). The PIT implements the famous
Permutation Invariant Training method [1] in speech separation field in
order to calculate audio metrics in a permutation invariant way.
"""Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method.
[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.
Args:
target:

Просмотреть файл

@ -18,9 +18,8 @@ from torchmetrics.utilities.checks import _check_same_shape
def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
"""Calculates Scale-invariant signal-to-distortion ratio (SI-SDR) metric.
The SI-SDR value is in general considered an overall measure of how good a
source sound.
"""Calculates Scale-invariant signal-to-distortion ratio (SI-SDR) metric. The SI-SDR value is in general
considered an overall measure of how good a source sound.
Args:
preds:

Просмотреть файл

@ -159,8 +159,7 @@ def auroc(
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> Tensor:
"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC
AUC`_)
"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_)
Args:
preds: predictions from model (logits or probabilities)

Просмотреть файл

@ -27,8 +27,7 @@ def _ce_compute(
norm: str = "l1",
debias: bool = False,
) -> Tensor:
"""Computes the calibration error given the provided bin boundaries and
norm.
"""Computes the calibration error given the provided bin boundaries and norm.
Args:
confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction.
@ -76,8 +75,8 @@ def _ce_compute(
def _ce_update(preds: Tensor, target: Tensor) -> Tuple[FloatTensor, FloatTensor]:
"""Given a predictions and targets tensor, computes the confidences of the
top-1 prediction and records their correctness.
"""Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their
correctness.
Args:
preds (Tensor): Input softmaxed predictions.

Просмотреть файл

@ -26,8 +26,8 @@ def _stat_scores(
class_index: int,
argmax_dim: int = 1,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Calculates the number of true positive, false positive, true negative
and false negative for a specific class.
"""Calculates the number of true positive, false positive, true negative and false negative for a specific
class.
Args:
preds: prediction tensor

Просмотреть файл

@ -229,8 +229,7 @@ def f1(
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
) -> Tensor:
"""Computes F1 metric. F1 metrics correspond to a equally weighted average
of the precision and recall scores.
"""Computes F1 metric. F1 metrics correspond to a equally weighted average of the precision and recall scores.
Works with binary, multiclass, and multilabel data.
Accepts probabilities or logits from a model output or integer class values in prediction.

Просмотреть файл

@ -26,8 +26,7 @@ def _binary_clf_curve(
sample_weights: Optional[Sequence] = None,
pos_label: int = 1,
) -> Tuple[Tensor, Tensor, Tensor]:
"""adapted from https://github.com/scikit-learn/scikit-
learn/blob/master/sklearn/metrics/_ranking.py."""
"""adapted from https://github.com/scikit-learn/scikit- learn/blob/master/sklearn/metrics/_ranking.py."""
if sample_weights is not None and not isinstance(sample_weights, Tensor):
sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float)

Просмотреть файл

@ -106,8 +106,8 @@ def roc(
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Computes the Receiver Operating Characteristic (ROC). Works with both
binary, multiclass and multilabel input.
"""Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel
input.
Args:
preds: predictions from model (logits or probabilities)

Просмотреть файл

@ -46,8 +46,8 @@ def _compute_image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
"""Computes the `gradients <https://en.wikipedia.org/wiki/Image_gradient>`_
of a given image using finite difference.
"""Computes the `gradients <https://en.wikipedia.org/wiki/Image_gradient>`_ of a given image using finite
difference.
Args:
img: An ``(N, C, H, W)`` input tensor where C is the number of image channels

Просмотреть файл

@ -60,8 +60,7 @@ def _psnr_update(
target: Tensor,
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute peak signal-to-noise
ratio.
"""Updates and returns variables required to compute peak signal-to-noise ratio.
Args:
preds: Predicted tensor

Просмотреть файл

@ -68,8 +68,8 @@ def _gaussian_kernel(
def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute Structural Similarity
Index Measure. Checks for same shape and type of the input tensors.
"""Updates and returns variables required to compute Structural Similarity Index Measure. Checks for same shape
and type of the input tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -25,8 +25,8 @@ def bleu_score(
n_gram: int = 4,
smooth: bool = False,
) -> Tensor:
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine
translated text with one or more references.
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine translated text with one or more
references.
Example:
>>> from torchmetrics.functional import bleu_score

Просмотреть файл

@ -23,8 +23,7 @@ def _cosine_similarity_update(
preds: Tensor,
target: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute Cosine Similarity.
Checks for same shape of input tensors.
"""Updates and returns variables required to compute Cosine Similarity. Checks for same shape of input tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_same_shape
def _explained_variance_update(preds: Tensor, target: Tensor) -> Tuple[int, Tensor, Tensor, Tensor, Tensor]:
"""Updates and returns variables required to compute Explained Variance.
Checks for same shape of input tensors.
"""Updates and returns variables required to compute Explained Variance. Checks for same shape of input
tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_same_shape
def _mean_absolute_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
"""Updates and returns variables required to compute Mean Absolute Error.
Checks for same shape of input tensors.
"""Updates and returns variables required to compute Mean Absolute Error. Checks for same shape of input
tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -24,8 +24,8 @@ def _mean_absolute_percentage_error_update(
target: Tensor,
epsilon: float = 1.17e-06,
) -> Tuple[Tensor, int]:
"""Updates and returns variables required to compute Mean Percentage Error.
Checks for same shape of input tensors.
"""Updates and returns variables required to compute Mean Percentage Error. Checks for same shape of input
tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_same_shape
def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
"""Updates and returns variables required to compute Mean Squared Error.
Checks for same shape of input tensors.
"""Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -20,8 +20,7 @@ from torchmetrics.utilities.checks import _check_same_shape
def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
"""Returns variables required to compute Mean Squared Log Error. Checks for
same shape of tensors.
"""Returns variables required to compute Mean Squared Log Error. Checks for same shape of tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -29,8 +29,8 @@ def _pearson_corrcoef_update(
corr_xy: Tensor,
n_prior: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Updates and returns variables required to compute Pearson Correlation
Coefficient. Checks for same shape of input tensors.
"""Updates and returns variables required to compute Pearson Correlation Coefficient. Checks for same shape of
input tensors.
Args:
mean_x: current mean estimate of x tensor

Просмотреть файл

@ -21,8 +21,7 @@ from torchmetrics.utilities.checks import _check_same_shape
def _r2_score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Updates and returns variables required to compute R2 score. Checks for
same shape and 1D/2D input tensors.
"""Updates and returns variables required to compute R2 score. Checks for same shape and 1D/2D input tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -20,8 +20,7 @@ from torchmetrics.utilities.checks import _check_same_shape
def _find_repeats(data: Tensor) -> Tensor:
"""find and return values which have repeats i.e. the same value are more
than once in the tensor."""
"""find and return values which have repeats i.e. the same value are more than once in the tensor."""
temp = data.detach().clone()
temp = temp.sort()[0]
@ -34,9 +33,9 @@ def _find_repeats(data: Tensor) -> Tensor:
def _rank_data(data: Tensor) -> Tensor:
"""Calculate the rank for each element of a tensor. The rank refers to the
indices of an element in the corresponding sorted tensor (starting from 1).
Duplicates of the same value will be assigned the mean of their rank.
"""Calculate the rank for each element of a tensor. The rank refers to the indices of an element in the
corresponding sorted tensor (starting from 1). Duplicates of the same value will be assigned the mean of their
rank.
Adopted from:
https://github.com/scipy/scipy/blob/v1.6.2/scipy/stats/stats.py#L4140-L4303
@ -54,8 +53,8 @@ def _rank_data(data: Tensor) -> Tensor:
def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute Spearman Correlation
Coefficient. Checks for same shape and type of input tensors.
"""Updates and returns variables required to compute Spearman Correlation Coefficient. Checks for same shape
and type of input tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -24,8 +24,8 @@ def _symmetric_mean_absolute_percentage_error_update(
target: Tensor,
epsilon: float = 1.17e-06,
) -> Tuple[Tensor, int]:
"""Updates and returns variables required to compute Symmetric Mean
Absolute Percentage Error. Checks for same shape of input tensors.
"""Updates and returns variables required to compute Symmetric Mean Absolute Percentage Error. Checks for same
shape of input tensors.
Args:
preds: Predicted tensor

Просмотреть файл

@ -18,8 +18,7 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor:
"""Computes average precision (for information retrieval), as explained in
`IR Average precision`_.
"""Computes average precision (for information retrieval), as explained in `IR Average precision`_.
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,

Просмотреть файл

@ -18,9 +18,8 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
"""Computes the Fall-out (for information retrieval), as explained in `IR
Fall-out`_ Fall-out is the fraction of non-relevant documents retrieved
among all the non-relevant documents.
"""Computes the Fall-out (for information retrieval), as explained in `IR Fall-out`_ Fall-out is the fraction
of non-relevant documents retrieved among all the non-relevant documents.
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,

Просмотреть файл

@ -26,8 +26,7 @@ def _dcg(target: Tensor) -> Tensor:
def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""Computes `Normalized Discounted Cumulative Gain`_ (for information
retrieval).
"""Computes `Normalized Discounted Cumulative Gain`_ (for information retrieval).
``preds`` and ``target`` should be of the same shape and live on the same device.
``target`` must be either `bool` or `integers` and ``preds`` must be `float`,

Просмотреть файл

@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""Computes the precision metric (for information retrieval). Precision is
the fraction of relevant documents among all the retrieved documents.
"""Computes the precision metric (for information retrieval). Precision is the fraction of relevant documents
among all the retrieved documents.
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,

Просмотреть файл

@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
def retrieval_recall(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""Computes the recall metric (for information retrieval). Recall is the
fraction of relevant documents retrieved among all the relevant documents.
"""Computes the recall metric (for information retrieval). Recall is the fraction of relevant documents
retrieved among all the relevant documents.
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,

Просмотреть файл

@ -123,8 +123,8 @@ def bleu_score(
n_gram: int = 4,
smooth: bool = False,
) -> Tensor:
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine
translated text with one or more references.
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine translated text with one or more
references.
Args:
reference_corpus:

Просмотреть файл

@ -41,8 +41,7 @@ ALLOWED_ROUGE_KEYS = (
def add_newline_to_end_of_each_sentence(x: str) -> str:
"""This was added to get rougeLsum scores matching published rougeL scores
for BART and PEGASUS."""
"""This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
if _NLTK_AVAILABLE:
import nltk
@ -54,8 +53,7 @@ def add_newline_to_end_of_each_sentence(x: str) -> str:
def format_rouge_results(result: Dict[str, AggregateScore], decimal_places: int = 4) -> Dict[str, Tensor]:
"""Formats the computed (aggregated) rouge score to a dictionary of tensors
format."""
"""Formats the computed (aggregated) rouge score to a dictionary of tensors format."""
flattened_result = {}
for rouge_key, rouge_aggregate_score in result.items():
for stat in ["precision", "recall", "fmeasure"]:
@ -72,8 +70,7 @@ def _rouge_score_update(
aggregator: BootstrapAggregator,
newline_sep: bool = False,
) -> None:
"""Update the rouge score with the current set of predicted and target
sentences.
"""Update the rouge score with the current set of predicted and target sentences.
Args:
preds:
@ -104,8 +101,7 @@ def _rouge_score_update(
def _rouge_score_compute(aggregator: BootstrapAggregator, decimal_places: int = 4) -> Dict[str, Tensor]:
"""Compute the combined ROUGE metric for all the input set of predicted and
target sentences.
"""Compute the combined ROUGE metric for all the input set of predicted and target sentences.
Args:
aggregator:
@ -125,8 +121,7 @@ def rouge_score(
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore
decimal_places: int = 4,
) -> Dict[str, Tensor]:
"""Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_,
used for automatic summarization.
"""Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_, used for automatic summarization.
Args:
preds:

Просмотреть файл

@ -25,9 +25,8 @@ def wer(
predictions: Union[str, List[str]],
concatenate_texts: bool = False,
) -> float:
"""Word error rate (WER_) is a common metric of the performance of an
automatic speech recognition system. This value indicates the percentage of
words that were incorrectly predicted. The lower the value, the better the
"""Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. This
value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the
performance of the ASR system with a WER of 0 being a perfect score.
Args:

Просмотреть файл

@ -47,8 +47,7 @@ class NoTrainInceptionV3(FeatureExtractorInceptionV3):
self.eval()
def train(self, mode: bool) -> "NoTrainInceptionV3":
"""the inception network should not be able to be switched away from
evaluation mode."""
"""the inception network should not be able to be switched away from evaluation mode."""
return super().train(False)
def forward(self, x: Tensor) -> Tensor:
@ -264,8 +263,7 @@ class FID(Metric):
self.fake_features.append(features)
def compute(self) -> Tensor:
"""Calculate FID score based on accumulated extracted features from the
two distributions."""
"""Calculate FID score based on accumulated extracted features from the two distributions."""
real_features = dim_zero_cat(self.real_features)
fake_features = dim_zero_cat(self.fake_features)
# computation is extremely sensitive so it needs to happen in double precision

Просмотреть файл

@ -25,8 +25,7 @@ from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE
def maximum_mean_discrepancy(k_xx: Tensor, k_xy: Tensor, k_yy: Tensor) -> Tensor:
"""Adapted from https://github.com/toshas/torch-
fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
"""Adapted from https://github.com/toshas/torch- fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
m = k_xx.shape[0]
diag_x = torch.diag(k_xx)
@ -46,8 +45,7 @@ def maximum_mean_discrepancy(k_xx: Tensor, k_xy: Tensor, k_yy: Tensor) -> Tensor
def poly_kernel(f1: Tensor, f2: Tensor, degree: int = 3, gamma: Optional[float] = None, coef: float = 1.0) -> Tensor:
"""Adapted from https://github.com/toshas/torch-
fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
"""Adapted from https://github.com/toshas/torch- fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
if gamma is None:
gamma = 1.0 / f1.shape[1]
kernel = (f1 @ f2.T * gamma + coef) ** degree
@ -57,8 +55,7 @@ def poly_kernel(f1: Tensor, f2: Tensor, degree: int = 3, gamma: Optional[float]
def poly_mmd(
f_real: Tensor, f_fake: Tensor, degree: int = 3, gamma: Optional[float] = None, coef: float = 1.0
) -> Tensor:
"""Adapted from https://github.com/toshas/torch-
fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
"""Adapted from https://github.com/toshas/torch- fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
k_11 = poly_kernel(f_real, f_real, degree, gamma, coef)
k_22 = poly_kernel(f_fake, f_fake, degree, gamma, coef)
k_12 = poly_kernel(f_real, f_fake, degree, gamma, coef)
@ -252,9 +249,8 @@ class KID(Metric):
self.fake_features.append(features)
def compute(self) -> Tuple[Tensor, Tensor]:
"""Calculate KID score based on accumulated extracted features from the
two distributions. Returns a tuple of mean and standard deviation of
KID scores calculated on subsets of extracted features.
"""Calculate KID score based on accumulated extracted features from the two distributions. Returns a tuple
of mean and standard deviation of KID scores calculated on subsets of extracted features.
Implementation inspired by https://github.com/toshas/torch-fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py
"""

Просмотреть файл

@ -176,8 +176,7 @@ class Metric(nn.Module, ABC):
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Automatically calls ``update()``.
Returns the metric value over inputs if ``compute_on_step`` is
True.
Returns the metric value over inputs if ``compute_on_step`` is True.
"""
# add current step
if self._is_synced:
@ -256,8 +255,7 @@ class Metric(nn.Module, ABC):
should_sync: bool = True,
distributed_available: Optional[Callable] = jit_distributed_available,
) -> None:
"""Sync function for manually controlling when metrics states should be
synced across processes.
"""Sync function for manually controlling when metrics states should be synced across processes.
Args:
dist_sync_fn: Function to be used to perform states synchronization
@ -287,8 +285,8 @@ class Metric(nn.Module, ABC):
self._is_synced = True
def unsync(self, should_unsync: bool = True) -> None:
"""Unsync function for manually controlling when metrics states should
be reverted back to their local states.
"""Unsync function for manually controlling when metrics states should be reverted back to their local
states.
Args:
should_unsync: Whether to perform unsync
@ -317,9 +315,8 @@ class Metric(nn.Module, ABC):
should_unsync: bool = True,
distributed_available: Optional[Callable] = jit_distributed_available,
) -> Generator:
"""Context manager to synchronize the states between processes when
running in a distributed setting and restore the local cache states
after yielding.
"""Context manager to synchronize the states between processes when running in a distributed setting and
restore the local cache states after yielding.
Args:
dist_sync_fn: Function to be used to perform states synchronization
@ -372,17 +369,15 @@ class Metric(nn.Module, ABC):
@abstractmethod
def update(self, *_: Any, **__: Any) -> None:
"""Override this method to update the state variables of your metric
class."""
"""Override this method to update the state variables of your metric class."""
@abstractmethod
def compute(self) -> Any:
"""Override this method to compute the final metric value from state
variables synchronized across the distributed backend."""
"""Override this method to compute the final metric value from state variables synchronized across the
distributed backend."""
def reset(self) -> None:
"""This method automatically resets the metric state variables to their
default value."""
"""This method automatically resets the metric state variables to their default value."""
self._update_called = False
self._forward_cache = None
# lower lightning versions requires this implicitly to log metric objects correctly in self.log
@ -416,8 +411,8 @@ class Metric(nn.Module, ABC):
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore
def _apply(self, fn: Callable) -> Module:
"""Overwrite _apply function such that we can also move metric states
to the correct device when `.to`, `.cuda`, etc methods are called."""
"""Overwrite _apply function such that we can also move metric states to the correct device when `.to`,
`.cuda`, etc methods are called."""
this = super()._apply(fn)
# Also apply fn to metric states and defaults
for key, value in this._defaults.items():
@ -445,8 +440,7 @@ class Metric(nn.Module, ABC):
return this
def persistent(self, mode: bool = False) -> None:
"""Method for post-init to change if metric states should be saved to
its state_dict."""
"""Method for post-init to change if metric states should be saved to its state_dict."""
for key in self._persistent:
self._persistent[key] = mode
@ -491,8 +485,7 @@ class Metric(nn.Module, ABC):
)
def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
"""filter kwargs such that they match the update signature of the
metric."""
"""filter kwargs such that they match the update signature of the metric."""
# filter all parameters based on update signature except those of
# type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs)
@ -638,8 +631,7 @@ def _neg(x: Tensor) -> Tensor:
class CompositionalMetric(Metric):
"""Composition of two metrics with a specific operator which will be
executed upon metrics compute."""
"""Composition of two metrics with a specific operator which will be executed upon metrics compute."""
def __init__(
self,

Просмотреть файл

@ -91,14 +91,11 @@ class RetrievalFallOut(RetrievalMetric):
self.k = k
def compute(self) -> Tensor:
"""First concat state `indexes`, `preds` and `target` since they were
stored as lists.
"""First concat state `indexes`, `preds` and `target` since they were stored as lists.
After that, compute list of groups that will help in keeping
together predictions about the same query. Finally, for each
group compute the `_metric` if the number of negative targets is
at least 1, otherwise behave as specified by
`self.empty_target_action`.
After that, compute list of groups that will help in keeping together predictions about the same query. Finally,
for each group compute the `_metric` if the number of negative targets is at least 1, otherwise behave as
specified by `self.empty_target_action`.
"""
indexes = torch.cat(self.indexes, dim=0)
preds = torch.cat(self.preds, dim=0)

Просмотреть файл

@ -25,8 +25,7 @@ from torchmetrics.utilities.data import get_group_indexes
class RetrievalMetric(Metric, ABC):
"""Works with binary target data. Accepts float predictions from a model
output.
"""Works with binary target data. Accepts float predictions from a model output.
Forward accepts
@ -96,8 +95,7 @@ class RetrievalMetric(Metric, ABC):
self.add_state("target", default=[], dist_reduce_fx=None)
def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # type: ignore
"""Check shape, check and convert dtypes, flatten and add to
accumulators."""
"""Check shape, check and convert dtypes, flatten and add to accumulators."""
if indexes is None:
raise ValueError("Argument `indexes` cannot be None")
@ -110,14 +108,11 @@ class RetrievalMetric(Metric, ABC):
self.target.append(target)
def compute(self) -> Tensor:
"""First concat state ``indexes``, ``preds`` and ``target`` since they
were stored as lists.
"""First concat state ``indexes``, ``preds`` and ``target`` since they were stored as lists.
After that, compute list of groups that will help in keeping
together predictions about the same query. Finally, for each
group compute the ``_metric`` if the number of positive targets
is at least 1, otherwise behave as specified by
``self.empty_target_action``.
After that, compute list of groups that will help in keeping together predictions about the same query. Finally,
for each group compute the ``_metric`` if the number of positive targets is at least 1, otherwise behave as
specified by ``self.empty_target_action``.
"""
indexes = torch.cat(self.indexes, dim=0)
preds = torch.cat(self.preds, dim=0)

Просмотреть файл

@ -26,8 +26,8 @@ from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_u
class BLEUScore(Metric):
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine
translated text with one or more references.
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine translated text with one or more
references.
Args:
n_gram:

Просмотреть файл

@ -27,8 +27,7 @@ else:
class ROUGEScore(Metric):
"""Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_,
used for automatic summarization.
"""Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_, used for automatic summarization.
Args:
newline_sep:

Просмотреть файл

@ -21,15 +21,13 @@ from torchmetrics.utilities.enums import DataType
def _check_same_shape(preds: Tensor, target: Tensor) -> None:
"""Check that predictions and target have the same shape, else raise
error."""
"""Check that predictions and target have the same shape, else raise error."""
if preds.shape != target.shape:
raise RuntimeError("Predictions and targets are expected to have the same shape")
def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, multiclass: Optional[bool]) -> None:
"""Perform basic validation of inputs that does not require deducing any
information of the type of inputs."""
"""Perform basic validation of inputs that does not require deducing any information of the type of inputs."""
if target.is_floating_point():
raise ValueError("The `target` has to be an integer tensor.")
@ -51,14 +49,12 @@ def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, mul
def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[DataType, int]:
"""This checks that the shape and type of inputs are consistent with each
other and fall into one of the allowed input types (see the documentation
of docstring of ``_input_format_classification``). It does not check for
consistency of number of classes, other functions take care of that.
"""This checks that the shape and type of inputs are consistent with each other and fall into one of the
allowed input types (see the documentation of docstring of ``_input_format_classification``). It does not check
for consistency of number of classes, other functions take care of that.
It returns the name of the case in which the inputs fall, and the
implied number of classes (from the ``C`` dim for multi-class data,
or extra dim(s) for multi-label data).
It returns the name of the case in which the inputs fall, and the implied number of classes (from the ``C`` dim for
multi-class data, or extra dim(s) for multi-label data).
"""
preds_float = preds.is_floating_point()
@ -111,8 +107,7 @@ def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[Da
def _check_num_classes_binary(num_classes: int, multiclass: Optional[bool]) -> None:
"""This checks that the consistency of `num_classes` with the data and
`multiclass` param for binary data."""
"""This checks that the consistency of `num_classes` with the data and `multiclass` param for binary data."""
if num_classes > 2:
raise ValueError("Your data is binary, but `num_classes` is larger than 2.")
@ -136,8 +131,8 @@ def _check_num_classes_mc(
multiclass: Optional[bool],
implied_classes: int,
) -> None:
"""This checks that the consistency of `num_classes` with the data and
`multiclass` param for (multi-dimensional) multi-class data."""
"""This checks that the consistency of `num_classes` with the data and `multiclass` param for (multi-
dimensional) multi-class data."""
if num_classes == 1 and multiclass is not False:
raise ValueError(
@ -161,8 +156,8 @@ def _check_num_classes_mc(
def _check_num_classes_ml(num_classes: int, multiclass: Optional[bool], implied_classes: int) -> None:
"""This checks that the consistency of `num_classes` with the data and
`multiclass` param for multi-label data."""
"""This checks that the consistency of `num_classes` with the data and `multiclass` param for multi-label
data."""
if multiclass and num_classes != 2:
raise ValueError(
@ -494,8 +489,7 @@ def _check_retrieval_functional_inputs(
target: Tensor,
allow_non_binary_target: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Check ``preds`` and ``target`` tensors are of the same shape and of the
correct dtype.
"""Check ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.
Args:
preds: either tensor with scores/logits
@ -535,8 +529,7 @@ def _check_retrieval_inputs(
target: Tensor,
allow_non_binary_target: bool = False,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Check ``indexes``, ``preds`` and ``target`` tensors are of the same
shape and of the correct dtype.
"""Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.
Args:
indexes: tensor with queries indexes

Просмотреть файл

@ -76,8 +76,7 @@ def to_onehot(
def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor:
"""Convert a probability tensor to binary by selecting top-k highest
entries.
"""Convert a probability tensor to binary by selecting top-k highest entries.
Args:
prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the
@ -122,8 +121,7 @@ def get_num_classes(
target: Tensor,
num_classes: Optional[int] = None,
) -> int:
"""Calculates the number of classes for a given prediction and target
tensor.
"""Calculates the number of classes for a given prediction and target tensor.
Args:
preds: predicted values
@ -200,8 +198,8 @@ def apply_to_collection(
def get_group_indexes(indexes: Tensor) -> List[Tensor]:
"""Given an integer `torch.Tensor` `indexes`, return a `torch.Tensor` of
indexes for each different value in `indexes`.
"""Given an integer `torch.Tensor` `indexes`, return a `torch.Tensor` of indexes for each different value in
`indexes`.
Args:
indexes: a `torch.Tensor`

Просмотреть файл

@ -94,11 +94,9 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L
def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
"""Function to gather all tensors from several ddp processes onto a list
that is broadcasted to all processes. Works on tensors that have the same
number of dimensions, but where each dimension may differ. In this case
tensors are padded, gathered and then trimmed to secure equal workload for
all processes.
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
Args:
result: the value to sync

Просмотреть файл

@ -16,8 +16,7 @@ from typing import Optional, Union
class EnumStr(str, Enum):
"""Type of any enumerator with allowed comparison to string invariant to
cases.
"""Type of any enumerator with allowed comparison to string invariant to cases.
Example:
>>> class MyEnum(EnumStr):

Просмотреть файл

@ -157,9 +157,8 @@ class BootStrapper(Metric):
def compute(self) -> Dict[str, Tensor]:
"""Computes the bootstrapped metric values.
Allways returns a dict of tensors, which can contain the
following keys: ``mean``, ``std``, ``quantile`` and ``raw``
depending on how the class was initialized
Allways returns a dict of tensors, which can contain the following keys: ``mean``, ``std``, ``quantile`` and
``raw`` depending on how the class was initialized
"""
computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0)
output_dict = {}