docs: line 120
This commit is contained in:
Родитель
219af0e42c
Коммит
21fe0ca7e1
|
@ -47,7 +47,7 @@ repos:
|
|||
rev: v1.4
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: [--in-place]
|
||||
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
|
||||
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.9.2
|
||||
|
|
|
@ -84,8 +84,7 @@ def test_metric_lightning(tmpdir):
|
|||
|
||||
@pytest.mark.skipif(not _LIGHTNING_GREATER_EQUAL_1_3, reason="test requires lightning v1.3 or higher")
|
||||
def test_metrics_reset(tmpdir):
|
||||
"""Tests that metrics are reset correctly after the end of the
|
||||
train/val/test epoch.
|
||||
"""Tests that metrics are reset correctly after the end of the train/val/test epoch.
|
||||
|
||||
Taken from:
|
||||
https://github.com/PyTorchLightning/pytorch-lightning/pull/7055
|
||||
|
|
|
@ -51,8 +51,7 @@ def naive_implementation_pit_scipy(
|
|||
metric_func: Callable,
|
||||
eval_func: str,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""A naive implementation of `Permutation Invariant Training` based on
|
||||
Scipy.
|
||||
"""A naive implementation of `Permutation Invariant Training` based on Scipy.
|
||||
|
||||
Args:
|
||||
preds: predictions, shape[batch, spk, time]
|
||||
|
|
|
@ -105,8 +105,8 @@ def test_metric_collection_wrong_input(tmpdir):
|
|||
|
||||
|
||||
def test_metric_collection_args_kwargs(tmpdir):
|
||||
"""Check that args and kwargs gets passed correctly in metric collection,
|
||||
Checks both update and forward method."""
|
||||
"""Check that args and kwargs gets passed correctly in metric collection, Checks both update and forward
|
||||
method."""
|
||||
m1 = DummyMetricSum()
|
||||
m2 = DummyMetricDiff()
|
||||
|
||||
|
|
|
@ -225,6 +225,6 @@ def _test_state_dict_is_synced(rank, worldsize, tmpdir):
|
|||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
def test_state_dict_is_synced(tmpdir):
|
||||
"""This test asserts that metrics are synced while creating the state dict
|
||||
but restored after to continue accumulation."""
|
||||
"""This test asserts that metrics are synced while creating the state dict but restored after to continue
|
||||
accumulation."""
|
||||
torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2)
|
||||
|
|
|
@ -270,8 +270,7 @@ def test_device_and_dtype_transfer(tmpdir):
|
|||
|
||||
|
||||
def test_warning_on_compute_before_update():
|
||||
"""test that an warning is raised if user tries to call compute before
|
||||
update."""
|
||||
"""test that an warning is raised if user tries to call compute before update."""
|
||||
metric = DummyMetricSum()
|
||||
|
||||
# make sure everything is fine with forward
|
||||
|
|
|
@ -327,9 +327,8 @@ def test_average_accuracy_bin(preds, target, num_classes, exp_result, average, m
|
|||
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
|
||||
)
|
||||
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
|
||||
"""This tests that when metric is computed per class and a given class is
|
||||
not present in both the `preds` and `target`, the resulting score is
|
||||
`nan`."""
|
||||
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
|
||||
`target`, the resulting score is `nan`."""
|
||||
preds = torch.tensor([0, 0, 0])
|
||||
target = torch.tensor([0, 0, 0])
|
||||
num_classes = 2
|
||||
|
|
|
@ -168,8 +168,8 @@ class TestAUROC(MetricTester):
|
|||
|
||||
|
||||
def test_error_on_different_mode():
|
||||
"""test that an error is raised if the user pass in data of different modes
|
||||
(binary, multi-label, multi-class)"""
|
||||
"""test that an error is raised if the user pass in data of different modes (binary, multi-label, multi-
|
||||
class)"""
|
||||
metric = AUROC()
|
||||
# pass in multi-class data
|
||||
metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10,)))
|
||||
|
@ -186,8 +186,7 @@ def test_error_multiclass_no_num_classes():
|
|||
|
||||
|
||||
def test_weighted_with_empty_classes():
|
||||
"""Tests that weighted multiclass AUROC calculation yields the same results
|
||||
if a new but empty class exists.
|
||||
"""Tests that weighted multiclass AUROC calculation yields the same results if a new but empty class exists.
|
||||
|
||||
Tests that the proper warnings and errors are raised
|
||||
"""
|
||||
|
|
|
@ -133,8 +133,7 @@ def test_wrong_params(metric_class, metric_fn, average, mdmc_average, num_classe
|
|||
],
|
||||
)
|
||||
def test_zero_division(metric_class, metric_fn):
|
||||
"""Test that zero_division works correctly (currently should just set to
|
||||
0)."""
|
||||
"""Test that zero_division works correctly (currently should just set to 0)."""
|
||||
|
||||
preds = torch.tensor([1, 2, 1, 1])
|
||||
target = torch.tensor([2, 0, 2, 1])
|
||||
|
@ -184,9 +183,8 @@ def test_no_support(metric_class, metric_fn):
|
|||
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
|
||||
)
|
||||
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
|
||||
"""This tests that when metric is computed per class and a given class is
|
||||
not present in both the `preds` and `target`, the resulting score is
|
||||
`nan`."""
|
||||
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
|
||||
`target`, the resulting score is `nan`."""
|
||||
preds = torch.tensor([0, 0, 0])
|
||||
target = torch.tensor([0, 0, 0])
|
||||
num_classes = 2
|
||||
|
@ -414,8 +412,7 @@ def test_top_k(
|
|||
):
|
||||
"""A simple test to check that top_k works as expected.
|
||||
|
||||
Just a sanity check, the tests in StatScores should already
|
||||
guarantee the corectness of results.
|
||||
Just a sanity check, the tests in StatScores should already guarantee the corectness of results.
|
||||
"""
|
||||
class_metric = metric_class(top_k=k, average=average, num_classes=3)
|
||||
class_metric.update(preds, target)
|
||||
|
|
|
@ -133,8 +133,7 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
|
|||
|
||||
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
|
||||
def test_zero_division(metric_class, metric_fn):
|
||||
"""Test that zero_division works correctly (currently should just set to
|
||||
0)."""
|
||||
"""Test that zero_division works correctly (currently should just set to 0)."""
|
||||
|
||||
preds = tensor([0, 2, 1, 1])
|
||||
target = tensor([2, 1, 2, 1])
|
||||
|
@ -357,8 +356,8 @@ class TestPrecisionRecall(MetricTester):
|
|||
def test_precision_recall_joint(average):
|
||||
"""A simple test of the joint precision_recall metric.
|
||||
|
||||
No need to test this thorougly, as it is just a combination of
|
||||
precision and recall, which are already tested thoroughly.
|
||||
No need to test this thorougly, as it is just a combination of precision and recall, which are already tested
|
||||
thoroughly.
|
||||
"""
|
||||
|
||||
precision_result = precision(
|
||||
|
@ -404,8 +403,7 @@ def test_top_k(
|
|||
):
|
||||
"""A simple test to check that top_k works as expected.
|
||||
|
||||
Just a sanity check, the tests in StatScores should already
|
||||
guarantee the correctness of results.
|
||||
Just a sanity check, the tests in StatScores should already guarantee the correctness of results.
|
||||
"""
|
||||
|
||||
class_metric = metric_class(top_k=k, average=average, num_classes=3)
|
||||
|
@ -425,9 +423,8 @@ def test_top_k(
|
|||
"ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
|
||||
)
|
||||
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
|
||||
"""This tests that when metric is computed per class and a given class is
|
||||
not present in both the `preds` and `target`, the resulting score is
|
||||
`nan`."""
|
||||
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
|
||||
`target`, the resulting score is `nan`."""
|
||||
preds = torch.tensor([0, 0, 0])
|
||||
target = torch.tensor([0, 0, 0])
|
||||
num_classes = 2
|
||||
|
|
|
@ -160,8 +160,7 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
|
|||
|
||||
@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)])
|
||||
def test_zero_division(metric_class, metric_fn):
|
||||
"""Test that zero_division works correctly (currently should just set to
|
||||
0)."""
|
||||
"""Test that zero_division works correctly (currently should just set to 0)."""
|
||||
|
||||
preds = tensor([1, 2, 1, 1])
|
||||
target = tensor([0, 0, 0, 0])
|
||||
|
@ -383,8 +382,7 @@ def test_top_k(
|
|||
):
|
||||
"""A simple test to check that top_k works as expected.
|
||||
|
||||
Just a sanity check, the tests in StatScores should already
|
||||
guarantee the correctness of results.
|
||||
Just a sanity check, the tests in StatScores should already guarantee the correctness of results.
|
||||
"""
|
||||
|
||||
class_metric = metric_class(top_k=k, average=average, num_classes=3)
|
||||
|
@ -399,9 +397,8 @@ def test_top_k(
|
|||
"ignore_index, expected", [(None, torch.tensor([0.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))]
|
||||
)
|
||||
def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
|
||||
"""This tests that when metric is computed per class and a given class is
|
||||
not present in both the `preds` and `target`, the resulting score is
|
||||
`nan`."""
|
||||
"""This tests that when metric is computed per class and a given class is not present in both the `preds` and
|
||||
`target`, the resulting score is `nan`."""
|
||||
preds = torch.tensor([0, 0, 0])
|
||||
target = torch.tensor([0, 0, 0])
|
||||
num_classes = 2
|
||||
|
|
|
@ -113,14 +113,11 @@ def _sk_stat_scores_mdim_mcls(
|
|||
],
|
||||
)
|
||||
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
|
||||
"""Test a combination of parameters that are invalid and should raise an
|
||||
error.
|
||||
"""Test a combination of parameters that are invalid and should raise an error.
|
||||
|
||||
This includes invalid ``reduce`` and ``mdmc_reduce`` parameter
|
||||
values, not setting ``num_classes`` when ``reduce='macro'`, not
|
||||
setting ``mdmc_reduce`` when inputs are multi-dim multi-class``,
|
||||
setting ``ignore_index`` when inputs are binary, as well as setting
|
||||
``ignore_index`` to a value higher than the number of classes.
|
||||
This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting ``num_classes`` when
|
||||
``reduce='macro'`, not setting ``mdmc_reduce`` when inputs are multi-dim multi-class``, setting ``ignore_index``
|
||||
when inputs are binary, as well as setting ``ignore_index`` to a value higher than the number of classes.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
stat_scores(
|
||||
|
|
|
@ -27,8 +27,7 @@ def test_invalid_input_img_type():
|
|||
|
||||
|
||||
def test_invalid_input_ndims():
|
||||
"""Test whether the module successfully handles invalid number of
|
||||
dimensions of input tensor."""
|
||||
"""Test whether the module successfully handles invalid number of dimensions of input tensor."""
|
||||
|
||||
BATCH_SIZE = 1
|
||||
HEIGHT = 5
|
||||
|
@ -43,9 +42,8 @@ def test_invalid_input_ndims():
|
|||
|
||||
|
||||
def test_multi_batch_image_gradients():
|
||||
"""Test whether the module correctly calculates gradients for known input
|
||||
with non-unity batch size.Example input-output pair taken from TF's
|
||||
implementation of i mage-gradients."""
|
||||
"""Test whether the module correctly calculates gradients for known input with non-unity batch size.Example
|
||||
input-output pair taken from TF's implementation of i mage-gradients."""
|
||||
|
||||
BATCH_SIZE = 5
|
||||
HEIGHT = 5
|
||||
|
@ -76,8 +74,7 @@ def test_multi_batch_image_gradients():
|
|||
def test_image_gradients():
|
||||
"""Test whether the module correctly calculates gradients for known input.
|
||||
|
||||
Example input-output pair taken from TF's implementation of image-
|
||||
gradients
|
||||
Example input-output pair taken from TF's implementation of image- gradients
|
||||
"""
|
||||
|
||||
BATCH_SIZE = 1
|
||||
|
|
|
@ -72,13 +72,11 @@ def calibration_error(
|
|||
pos_label: Optional[Union[int, str]] = None,
|
||||
reduce_bias: bool = True,
|
||||
) -> float:
|
||||
"""Compute calibration error of a binary classifier. Across all items in a
|
||||
set of N predictions, the calibration error measures the aggregated
|
||||
difference between (1) the average predicted probabilities assigned to the
|
||||
positive class, and (2) the frequencies of the positive class in the actual
|
||||
outcome. The calibration error is only appropriate for binary categorical
|
||||
outcomes. Which label is considered to be the positive label is controlled
|
||||
via the parameter pos_label, which defaults to 1.
|
||||
"""Compute calibration error of a binary classifier. Across all items in a set of N predictions, the
|
||||
calibration error measures the aggregated difference between (1) the average predicted probabilities assigned
|
||||
to the positive class, and (2) the frequencies of the positive class in the actual outcome. The calibration
|
||||
error is only appropriate for binary categorical outcomes. Which label is considered to be the positive label
|
||||
is controlled via the parameter pos_label, which defaults to 1.
|
||||
|
||||
Args:
|
||||
y_true: array-like of shape (n_samples,)
|
||||
|
|
|
@ -58,8 +58,7 @@ def setup_ddp(rank, world_size):
|
|||
|
||||
|
||||
def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8):
|
||||
"""Utility function for recursively asserting that two results are within a
|
||||
certain tolerance."""
|
||||
"""Utility function for recursively asserting that two results are within a certain tolerance."""
|
||||
# single output compare
|
||||
if isinstance(pl_result, Tensor):
|
||||
assert np.allclose(pl_result.cpu().numpy(), sk_result, atol=atol, equal_nan=True)
|
||||
|
@ -72,8 +71,7 @@ def _assert_allclose(pl_result: Any, sk_result: Any, atol: float = 1e-8):
|
|||
|
||||
|
||||
def _assert_tensor(pl_result: Any):
|
||||
"""Utility function for recursively checking that some input only consists
|
||||
of torch tensors."""
|
||||
"""Utility function for recursively checking that some input only consists of torch tensors."""
|
||||
if isinstance(pl_result, Sequence):
|
||||
for plr in pl_result:
|
||||
_assert_tensor(plr)
|
||||
|
@ -82,8 +80,8 @@ def _assert_tensor(pl_result: Any):
|
|||
|
||||
|
||||
def _assert_requires_grad(metric: Metric, pl_result: Any):
|
||||
"""Utility function for recursively asserting that metric output is
|
||||
consistent with the `is_differentiable` attribute."""
|
||||
"""Utility function for recursively asserting that metric output is consistent with the `is_differentiable`
|
||||
attribute."""
|
||||
if isinstance(pl_result, Sequence):
|
||||
for plr in pl_result:
|
||||
_assert_requires_grad(metric, plr)
|
||||
|
@ -108,8 +106,7 @@ def _class_test(
|
|||
check_scriptable: bool = True,
|
||||
**kwargs_update: Any,
|
||||
):
|
||||
"""Utility function doing the actual comparison between lightning class
|
||||
metric and reference metric.
|
||||
"""Utility function doing the actual comparison between lightning class metric and reference metric.
|
||||
|
||||
Args:
|
||||
rank: rank of current process
|
||||
|
@ -203,8 +200,7 @@ def _functional_test(
|
|||
fragment_kwargs: bool = False,
|
||||
**kwargs_update,
|
||||
):
|
||||
"""Utility function doing the actual comparison between lightning
|
||||
functional metric and reference metric.
|
||||
"""Utility function doing the actual comparison between lightning functional metric and reference metric.
|
||||
|
||||
Args:
|
||||
preds: torch tensor with predictions
|
||||
|
@ -271,12 +267,10 @@ def _assert_half_support(
|
|||
|
||||
|
||||
class MetricTester:
|
||||
"""Class used for efficiently run alot of parametrized tests in ddp mode.
|
||||
Makes sure that ddp is only setup once and that pool of processes are used
|
||||
for all tests.
|
||||
"""Class used for efficiently run alot of parametrized tests in ddp mode. Makes sure that ddp is only setup
|
||||
once and that pool of processes are used for all tests.
|
||||
|
||||
All tests should subclass from this and implement a new method
|
||||
called `test_metric_name` where the method
|
||||
All tests should subclass from this and implement a new method called `test_metric_name` where the method
|
||||
`self.run_metric_test` is called inside.
|
||||
"""
|
||||
|
||||
|
@ -285,8 +279,7 @@ class MetricTester:
|
|||
def setup_class(self):
|
||||
"""Setup the metric class.
|
||||
|
||||
This will spawn the pool of workers that are used for metric
|
||||
testing and setup_ddp
|
||||
This will spawn the pool of workers that are used for metric testing and setup_ddp
|
||||
"""
|
||||
|
||||
self.poolSize = NUM_PROCESSES
|
||||
|
@ -308,8 +301,7 @@ class MetricTester:
|
|||
fragment_kwargs: bool = False,
|
||||
**kwargs_update,
|
||||
):
|
||||
"""Main method that should be used for testing functions. Call this
|
||||
inside testing method.
|
||||
"""Main method that should be used for testing functions. Call this inside testing method.
|
||||
|
||||
Args:
|
||||
preds: torch tensor with predictions
|
||||
|
@ -350,8 +342,7 @@ class MetricTester:
|
|||
check_scriptable: bool = True,
|
||||
**kwargs_update,
|
||||
):
|
||||
"""Main method that should be used for testing class. Call this inside
|
||||
testing methods.
|
||||
"""Main method that should be used for testing class. Call this inside testing methods.
|
||||
|
||||
Args:
|
||||
ddp: bool, if running in ddp mode or not
|
||||
|
|
|
@ -38,9 +38,8 @@ seed_all(42)
|
|||
|
||||
|
||||
def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]:
|
||||
"""Given an integer `torch.Tensor` or `np.ndarray` `indexes`, return a
|
||||
`torch.Tensor` or `np.ndarray` of indexes for each different value in
|
||||
`indexes`.
|
||||
"""Given an integer `torch.Tensor` or `np.ndarray` `indexes`, return a `torch.Tensor` or `np.ndarray` of
|
||||
indexes for each different value in `indexes`.
|
||||
|
||||
Args:
|
||||
indexes: a `torch.Tensor` or `np.ndarray` of integers
|
||||
|
@ -75,8 +74,7 @@ def _compute_sklearn_metric(
|
|||
reverse: bool = False,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Compute metric with multiple iterations over every query predictions
|
||||
set."""
|
||||
"""Compute metric with multiple iterations over every query predictions set."""
|
||||
|
||||
if indexes is None:
|
||||
indexes = np.full_like(preds, fill_value=0, dtype=np.int64)
|
||||
|
|
|
@ -34,11 +34,9 @@ seed_all(42)
|
|||
|
||||
|
||||
def _fallout_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
|
||||
"""Didn't find a reliable implementation of Fall-out in Information
|
||||
Retrieval, so, reimplementing here.
|
||||
"""Didn't find a reliable implementation of Fall-out in Information Retrieval, so, reimplementing here.
|
||||
|
||||
See Wikipedia for `Fall-out`_ for more information about the metric
|
||||
definition.
|
||||
See Wikipedia for `Fall-out`_ for more information about the metric definition.
|
||||
"""
|
||||
assert target.shape == preds.shape
|
||||
assert len(target.shape) == 1 # works only with single dimension inputs
|
||||
|
|
|
@ -35,10 +35,9 @@ seed_all(42)
|
|||
def _reciprocal_rank(target: np.ndarray, preds: np.ndarray):
|
||||
"""Adaptation of `sklearn.metrics.label_ranking_average_precision_score`.
|
||||
|
||||
Since the original sklearn metric works as RR only when the number
|
||||
of positive targets is exactly 1, here we remove every positive
|
||||
target that is not the most important. Remember that in RR only the
|
||||
positive target with the highest score is considered.
|
||||
Since the original sklearn metric works as RR only when the number of positive targets is exactly 1, here we remove
|
||||
every positive target that is not the most important. Remember that in RR only the positive target with the highest
|
||||
score is considered.
|
||||
"""
|
||||
assert target.shape == preds.shape
|
||||
assert len(target.shape) == 1 # works only with single dimension inputs
|
||||
|
|
|
@ -34,8 +34,7 @@ seed_all(42)
|
|||
|
||||
|
||||
def _precision_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
|
||||
"""Didn't find a reliable implementation of Precision in Information
|
||||
Retrieval, so, reimplementing here.
|
||||
"""Didn't find a reliable implementation of Precision in Information Retrieval, so, reimplementing here.
|
||||
|
||||
A good explanation can be found
|
||||
`here <https://web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-1-per.pdf>_`.
|
||||
|
|
|
@ -34,8 +34,7 @@ seed_all(42)
|
|||
|
||||
|
||||
def _recall_at_k(target: np.ndarray, preds: np.ndarray, k: int = None):
|
||||
"""Didn't find a reliable implementation of Recall in Information
|
||||
Retrieval, so, reimplementing here.
|
||||
"""Didn't find a reliable implementation of Recall in Information Retrieval, so, reimplementing here.
|
||||
|
||||
See wikipedia for more information about definition.
|
||||
"""
|
||||
|
|
|
@ -30,8 +30,7 @@ def test_wer_same(hyp, ref, score):
|
|||
)
|
||||
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
|
||||
def test_wer_functional(ref, hyp, expected_score, expected_incorrect, expected_total):
|
||||
"""Test to ensure that the torchmetric functional WER matches the jiwer
|
||||
reference."""
|
||||
"""Test to ensure that the torchmetric functional WER matches the jiwer reference."""
|
||||
assert wer(ref, hyp) == expected_score
|
||||
|
||||
|
||||
|
@ -44,15 +43,13 @@ def test_wer_functional(ref, hyp, expected_score, expected_incorrect, expected_t
|
|||
)
|
||||
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
|
||||
def test_wer_reference_functional(hyp, ref):
|
||||
"""Test to ensure that the torchmetric functional WER matches the jiwer
|
||||
reference."""
|
||||
"""Test to ensure that the torchmetric functional WER matches the jiwer reference."""
|
||||
assert wer(ref, hyp) == compute_measures(ref, hyp)["wer"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
|
||||
def test_wer_reference_functional_concatenate():
|
||||
"""Test to ensure that the torchmetric functional WER matches the jiwer
|
||||
reference when concatenating."""
|
||||
"""Test to ensure that the torchmetric functional WER matches the jiwer reference when concatenating."""
|
||||
ref = ["hello world", "hello world"]
|
||||
hyp = ["hello world", "Firwww"]
|
||||
assert wer(ref, hyp) == compute_measures(ref, hyp)["wer"]
|
||||
|
@ -76,8 +73,7 @@ def test_wer_reference(hyp, ref):
|
|||
|
||||
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
|
||||
def test_wer_reference_batch():
|
||||
"""Test to ensure that the torchmetric WER matches the jiwer reference with
|
||||
accumulation."""
|
||||
"""Test to ensure that the torchmetric WER matches the jiwer reference with accumulation."""
|
||||
batches = [("hello world", "Firwww"), ("hello world", "hello world")]
|
||||
metric = WER()
|
||||
|
||||
|
|
|
@ -29,8 +29,8 @@ _target = torch.randint(10, (10, 32))
|
|||
|
||||
|
||||
class TestBootStrapper(BootStrapper):
|
||||
"""For testing purpose, we subclass the bootstrapper class so we can get
|
||||
the exact permutation the class is creating."""
|
||||
"""For testing purpose, we subclass the bootstrapper class so we can get the exact permutation the class is
|
||||
creating."""
|
||||
|
||||
def update(self, *args) -> None:
|
||||
self.out = []
|
||||
|
@ -75,8 +75,7 @@ def test_bootstrap_sampler(sampling_strategy):
|
|||
"metric, sk_metric", [[Precision(average="micro"), precision_score], [Recall(average="micro"), recall_score]]
|
||||
)
|
||||
def test_bootstrap(sampling_strategy, metric, sk_metric):
|
||||
"""Test that the different bootstraps gets updated as we expected and that
|
||||
the compute method works."""
|
||||
"""Test that the different bootstraps gets updated as we expected and that the compute method works."""
|
||||
_kwargs = {"base_metric": metric, "mean": True, "std": True, "raw": True, "sampling_strategy": sampling_strategy}
|
||||
if _TORCH_GREATER_EQUAL_1_7:
|
||||
_kwargs.update(dict(quantile=torch.tensor([0.05, 0.95])))
|
||||
|
|
|
@ -20,9 +20,9 @@ from torchmetrics.metric import Metric
|
|||
|
||||
|
||||
class PIT(Metric):
|
||||
"""Permutation invariant training (PIT). The PIT implements the famous
|
||||
Permutation Invariant Training method [1] in speech separation field in
|
||||
order to calculate audio metrics in a permutation invariant way.
|
||||
"""Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method.
|
||||
|
||||
[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.
|
||||
|
||||
Forward accepts
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ from torchmetrics.metric import Metric
|
|||
|
||||
|
||||
class SI_SDR(Metric):
|
||||
"""Scale-invariant signal-to-distortion ratio (SI-SDR). The SI-SDR value is
|
||||
in general considered an overall measure of how good a source sound.
|
||||
"""Scale-invariant signal-to-distortion ratio (SI-SDR). The SI-SDR value is in general considered an overall
|
||||
measure of how good a source sound.
|
||||
|
||||
Forward accepts
|
||||
|
||||
|
|
|
@ -266,8 +266,7 @@ class Accuracy(StatScores):
|
|||
self.fn.append(fn)
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""Computes accuracy based on inputs passed in to ``update``
|
||||
previously."""
|
||||
"""Computes accuracy based on inputs passed in to ``update`` previously."""
|
||||
if not self.mode:
|
||||
raise RuntimeError("You have to have determined mode.")
|
||||
if self.subset_accuracy:
|
||||
|
|
|
@ -91,6 +91,6 @@ class AUC(Metric):
|
|||
|
||||
@property
|
||||
def is_differentiable(self) -> bool:
|
||||
"""AUC metrics is considered as non differentiable so it should have
|
||||
`false` value for `is_differentiable` property."""
|
||||
"""AUC metrics is considered as non differentiable so it should have `false` value for `is_differentiable`
|
||||
property."""
|
||||
return False
|
||||
|
|
|
@ -169,8 +169,7 @@ class AUROC(Metric):
|
|||
self.mode = mode
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""Computes AUROC based on inputs passed in to ``update``
|
||||
previously."""
|
||||
"""Computes AUROC based on inputs passed in to ``update`` previously."""
|
||||
if not self.mode:
|
||||
raise RuntimeError("You have to have determined mode.")
|
||||
preds = dim_zero_cat(self.preds)
|
||||
|
@ -187,6 +186,6 @@ class AUROC(Metric):
|
|||
|
||||
@property
|
||||
def is_differentiable(self) -> bool:
|
||||
"""AUROC metrics is considered as non differentiable so it should have
|
||||
`false` value for `is_differentiable` property."""
|
||||
"""AUROC metrics is considered as non differentiable so it should have `false` value for
|
||||
`is_differentiable` property."""
|
||||
return False
|
||||
|
|
|
@ -26,10 +26,9 @@ from torchmetrics.utilities.data import dim_zero_cat
|
|||
|
||||
|
||||
class AveragePrecision(Metric):
|
||||
"""Computes the average precision score, which summarises the precision
|
||||
recall curve into one number. Works for both binary and multiclass
|
||||
problems. In the case of multiclass, the values will be calculated based on
|
||||
a one-vs-the-rest approach.
|
||||
"""Computes the average precision score, which summarises the precision recall curve into one number. Works for
|
||||
both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-
|
||||
vs-the-rest approach.
|
||||
|
||||
Forward accepts
|
||||
|
||||
|
|
|
@ -43,9 +43,8 @@ def _recall_at_precision(
|
|||
|
||||
|
||||
class BinnedPrecisionRecallCurve(Metric):
|
||||
"""Computes precision-recall pairs for different thresholds. Works for both
|
||||
binary and multiclass problems. In the case of multiclass, the values will
|
||||
be calculated based on a one-vs-the-rest approach.
|
||||
"""Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In
|
||||
the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
|
||||
|
||||
Computation is performed in constant-memory by computing precision and recall
|
||||
for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1).
|
||||
|
@ -190,10 +189,9 @@ class BinnedPrecisionRecallCurve(Metric):
|
|||
|
||||
|
||||
class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
|
||||
"""Computes the average precision score, which summarises the precision
|
||||
recall curve into one number. Works for both binary and multiclass
|
||||
problems. In the case of multiclass, the values will be calculated based on
|
||||
a one-vs-the-rest approach.
|
||||
"""Computes the average precision score, which summarises the precision recall curve into one number. Works for
|
||||
both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-
|
||||
vs-the-rest approach.
|
||||
|
||||
Computation is performed in constant-memory by computing precision and recall
|
||||
for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1).
|
||||
|
@ -245,8 +243,7 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
|
|||
|
||||
|
||||
class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
|
||||
"""Computes the higest possible recall value given the minimum precision
|
||||
thresholds provided.
|
||||
"""Computes the higest possible recall value given the minimum precision thresholds provided.
|
||||
|
||||
Computation is performed in constant-memory by computing precision and recall
|
||||
for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1).
|
||||
|
|
|
@ -93,8 +93,8 @@ class CalibrationError(Metric):
|
|||
self.add_state("accuracies", [], dist_reduce_fx="cat")
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
|
||||
"""Computes top-level confidences and accuracies for the input
|
||||
probabilites and appends them to internal state.
|
||||
"""Computes top-level confidences and accuracies for the input probabilites and appends them to internal
|
||||
state.
|
||||
|
||||
Args:
|
||||
preds (Tensor): Model output probabilities.
|
||||
|
|
|
@ -119,7 +119,6 @@ class CohenKappa(Metric):
|
|||
|
||||
@property
|
||||
def is_differentiable(self) -> bool:
|
||||
"""cohen kappa is not differentiable since the implementation is based
|
||||
on calculating the confusion matrix which in general is not
|
||||
differentiable."""
|
||||
"""cohen kappa is not differentiable since the implementation is based on calculating the confusion matrix
|
||||
which in general is not differentiable."""
|
||||
return False
|
||||
|
|
|
@ -175,8 +175,7 @@ class FBeta(StatScores):
|
|||
|
||||
|
||||
class F1(FBeta):
|
||||
"""Computes F1 metric. F1 metrics correspond to a harmonic mean of the
|
||||
precision and recall scores.
|
||||
"""Computes F1 metric. F1 metrics correspond to a harmonic mean of the precision and recall scores.
|
||||
|
||||
Works with binary, multiclass, and multilabel data. Accepts logits or probabilities from a model
|
||||
output or integer class values in prediction. Works with multi-dimensional preds and target.
|
||||
|
|
|
@ -105,8 +105,7 @@ class HammingDistance(Metric):
|
|||
self.total += total
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""Computes hamming distance based on inputs passed in to ``update``
|
||||
previously."""
|
||||
"""Computes hamming distance based on inputs passed in to ``update`` previously."""
|
||||
return _hamming_distance_compute(self.correct, self.total)
|
||||
|
||||
@property
|
||||
|
|
|
@ -156,8 +156,7 @@ class Precision(StatScores):
|
|||
self.average = average
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""Computes the precision score based on inputs passed in to ``update``
|
||||
previously.
|
||||
"""Computes the precision score based on inputs passed in to ``update`` previously.
|
||||
|
||||
Return:
|
||||
The shape of the returned tensor depends on the ``average`` parameter
|
||||
|
@ -310,8 +309,7 @@ class Recall(StatScores):
|
|||
self.average = average
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""Computes the recall score based on inputs passed in to ``update``
|
||||
previously.
|
||||
"""Computes the recall score based on inputs passed in to ``update`` previously.
|
||||
|
||||
Return:
|
||||
The shape of the returned tensor depends on the ``average`` parameter
|
||||
|
|
|
@ -26,9 +26,8 @@ from torchmetrics.utilities.data import dim_zero_cat
|
|||
|
||||
|
||||
class PrecisionRecallCurve(Metric):
|
||||
"""Computes precision-recall pairs for different thresholds. Works for both
|
||||
binary and multiclass problems. In the case of multiclass, the values will
|
||||
be calculated based on a one-vs-the-rest approach.
|
||||
"""Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In
|
||||
the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
|
||||
|
||||
Forward accepts
|
||||
|
||||
|
|
|
@ -22,9 +22,8 @@ from torchmetrics.utilities import rank_zero_warn
|
|||
|
||||
|
||||
class ROC(Metric):
|
||||
"""Computes the Receiver Operating Characteristic (ROC). Works for both
|
||||
binary, multiclass and multilabel problems. In the case of multiclass, the
|
||||
values will be calculated based on a one-vs-the-rest approach.
|
||||
"""Computes the Receiver Operating Characteristic (ROC). Works for both binary, multiclass and multilabel
|
||||
problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
|
||||
|
||||
Forward accepts
|
||||
|
||||
|
|
|
@ -157,8 +157,7 @@ class Specificity(StatScores):
|
|||
self.average = average
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""Computes the specificity score based on inputs passed in to
|
||||
``update`` previously.
|
||||
"""Computes the specificity score based on inputs passed in to ``update`` previously.
|
||||
|
||||
Return:
|
||||
The shape of the returned tensor depends on the ``average`` parameter
|
||||
|
|
|
@ -225,8 +225,7 @@ class StatScores(Metric):
|
|||
self.fn.append(fn)
|
||||
|
||||
def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Performs concatenation on the stat scores if neccesary, before
|
||||
passing them to a compute function."""
|
||||
"""Performs concatenation on the stat scores if neccesary, before passing them to a compute function."""
|
||||
tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp
|
||||
fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp
|
||||
tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn
|
||||
|
@ -234,8 +233,7 @@ class StatScores(Metric):
|
|||
return tp, fp, tn, fn
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""Computes the stat scores based on inputs passed in to ``update``
|
||||
previously.
|
||||
"""Computes the stat scores based on inputs passed in to ``update`` previously.
|
||||
|
||||
Return:
|
||||
The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds
|
||||
|
|
|
@ -24,8 +24,7 @@ from torchmetrics.utilities import rank_zero_warn
|
|||
|
||||
|
||||
class MetricCollection(nn.ModuleDict):
|
||||
"""MetricCollection class can be used to chain metrics that have the same
|
||||
call pattern into one single class.
|
||||
"""MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
|
||||
|
||||
Args:
|
||||
metrics: One of the following
|
||||
|
@ -105,18 +104,16 @@ class MetricCollection(nn.ModuleDict):
|
|||
def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Iteratively call forward for each metric.
|
||||
|
||||
Positional arguments (args) will be passed to every metric in
|
||||
the collection, while keyword arguments (kwargs) will be
|
||||
filtered based on the signature of the individual metric.
|
||||
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
|
||||
will be filtered based on the signature of the individual metric.
|
||||
"""
|
||||
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
|
||||
|
||||
def update(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Iteratively call update for each metric.
|
||||
|
||||
Positional arguments (args) will be passed to every metric in
|
||||
the collection, while keyword arguments (kwargs) will be
|
||||
filtered based on the signature of the individual metric.
|
||||
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
|
||||
will be filtered based on the signature of the individual metric.
|
||||
"""
|
||||
for _, m in self.items(keep_base=True):
|
||||
m_kwargs = m._filter_kwargs(**kwargs)
|
||||
|
@ -145,8 +142,7 @@ class MetricCollection(nn.ModuleDict):
|
|||
return mc
|
||||
|
||||
def persistent(self, mode: bool = True) -> None:
|
||||
"""Method for post-init to change if metric states should be saved to
|
||||
its state_dict."""
|
||||
"""Method for post-init to change if metric states should be saved to its state_dict."""
|
||||
for _, m in self.items(keep_base=True):
|
||||
m.persistent(mode)
|
||||
|
||||
|
|
|
@ -30,8 +30,8 @@ def _find_best_perm_by_linear_sum_assignment(
|
|||
metric_mtx: torch.Tensor,
|
||||
eval_func: Union[torch.min, torch.max],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Solves the linear sum assignment problem using scipy, and returns the
|
||||
best metric values and the corresponding permutations.
|
||||
"""Solves the linear sum assignment problem using scipy, and returns the best metric values and the
|
||||
corresponding permutations.
|
||||
|
||||
Args:
|
||||
metric_mtx:
|
||||
|
@ -58,9 +58,8 @@ def _find_best_perm_by_exhuastive_method(
|
|||
metric_mtx: torch.Tensor,
|
||||
eval_func: Union[torch.min, torch.max],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Solves the linear sum assignment problem using exhuastive method, i.e.
|
||||
exhuastively calculates the metric values of all possible permutations, and
|
||||
returns the best metric values and the corresponding permutations.
|
||||
"""Solves the linear sum assignment problem using exhuastive method, i.e. exhuastively calculates the metric
|
||||
values of all possible permutations, and returns the best metric values and the corresponding permutations.
|
||||
|
||||
Args:
|
||||
metric_mtx:
|
||||
|
@ -105,9 +104,9 @@ def _find_best_perm_by_exhuastive_method(
|
|||
def pit(
|
||||
preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = "max", **kwargs: Dict[str, Any]
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Permutation invariant training (PIT). The PIT implements the famous
|
||||
Permutation Invariant Training method [1] in speech separation field in
|
||||
order to calculate audio metrics in a permutation invariant way.
|
||||
"""Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method.
|
||||
|
||||
[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.
|
||||
|
||||
Args:
|
||||
target:
|
||||
|
|
|
@ -18,9 +18,8 @@ from torchmetrics.utilities.checks import _check_same_shape
|
|||
|
||||
|
||||
def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
|
||||
"""Calculates Scale-invariant signal-to-distortion ratio (SI-SDR) metric.
|
||||
The SI-SDR value is in general considered an overall measure of how good a
|
||||
source sound.
|
||||
"""Calculates Scale-invariant signal-to-distortion ratio (SI-SDR) metric. The SI-SDR value is in general
|
||||
considered an overall measure of how good a source sound.
|
||||
|
||||
Args:
|
||||
preds:
|
||||
|
|
|
@ -159,8 +159,7 @@ def auroc(
|
|||
max_fpr: Optional[float] = None,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Tensor:
|
||||
"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC
|
||||
AUC`_)
|
||||
"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_)
|
||||
|
||||
Args:
|
||||
preds: predictions from model (logits or probabilities)
|
||||
|
|
|
@ -27,8 +27,7 @@ def _ce_compute(
|
|||
norm: str = "l1",
|
||||
debias: bool = False,
|
||||
) -> Tensor:
|
||||
"""Computes the calibration error given the provided bin boundaries and
|
||||
norm.
|
||||
"""Computes the calibration error given the provided bin boundaries and norm.
|
||||
|
||||
Args:
|
||||
confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction.
|
||||
|
@ -76,8 +75,8 @@ def _ce_compute(
|
|||
|
||||
|
||||
def _ce_update(preds: Tensor, target: Tensor) -> Tuple[FloatTensor, FloatTensor]:
|
||||
"""Given a predictions and targets tensor, computes the confidences of the
|
||||
top-1 prediction and records their correctness.
|
||||
"""Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their
|
||||
correctness.
|
||||
|
||||
Args:
|
||||
preds (Tensor): Input softmaxed predictions.
|
||||
|
|
|
@ -26,8 +26,8 @@ def _stat_scores(
|
|||
class_index: int,
|
||||
argmax_dim: int = 1,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Calculates the number of true positive, false positive, true negative
|
||||
and false negative for a specific class.
|
||||
"""Calculates the number of true positive, false positive, true negative and false negative for a specific
|
||||
class.
|
||||
|
||||
Args:
|
||||
preds: prediction tensor
|
||||
|
|
|
@ -229,8 +229,7 @@ def f1(
|
|||
top_k: Optional[int] = None,
|
||||
multiclass: Optional[bool] = None,
|
||||
) -> Tensor:
|
||||
"""Computes F1 metric. F1 metrics correspond to a equally weighted average
|
||||
of the precision and recall scores.
|
||||
"""Computes F1 metric. F1 metrics correspond to a equally weighted average of the precision and recall scores.
|
||||
|
||||
Works with binary, multiclass, and multilabel data.
|
||||
Accepts probabilities or logits from a model output or integer class values in prediction.
|
||||
|
|
|
@ -26,8 +26,7 @@ def _binary_clf_curve(
|
|||
sample_weights: Optional[Sequence] = None,
|
||||
pos_label: int = 1,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""adapted from https://github.com/scikit-learn/scikit-
|
||||
learn/blob/master/sklearn/metrics/_ranking.py."""
|
||||
"""adapted from https://github.com/scikit-learn/scikit- learn/blob/master/sklearn/metrics/_ranking.py."""
|
||||
if sample_weights is not None and not isinstance(sample_weights, Tensor):
|
||||
sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float)
|
||||
|
||||
|
|
|
@ -106,8 +106,8 @@ def roc(
|
|||
pos_label: Optional[int] = None,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
|
||||
"""Computes the Receiver Operating Characteristic (ROC). Works with both
|
||||
binary, multiclass and multilabel input.
|
||||
"""Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel
|
||||
input.
|
||||
|
||||
Args:
|
||||
preds: predictions from model (logits or probabilities)
|
||||
|
|
|
@ -46,8 +46,8 @@ def _compute_image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
|
|||
|
||||
|
||||
def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Computes the `gradients <https://en.wikipedia.org/wiki/Image_gradient>`_
|
||||
of a given image using finite difference.
|
||||
"""Computes the `gradients <https://en.wikipedia.org/wiki/Image_gradient>`_ of a given image using finite
|
||||
difference.
|
||||
|
||||
Args:
|
||||
img: An ``(N, C, H, W)`` input tensor where C is the number of image channels
|
||||
|
|
|
@ -60,8 +60,7 @@ def _psnr_update(
|
|||
target: Tensor,
|
||||
dim: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Updates and returns variables required to compute peak signal-to-noise
|
||||
ratio.
|
||||
"""Updates and returns variables required to compute peak signal-to-noise ratio.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -68,8 +68,8 @@ def _gaussian_kernel(
|
|||
|
||||
|
||||
def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Updates and returns variables required to compute Structural Similarity
|
||||
Index Measure. Checks for same shape and type of the input tensors.
|
||||
"""Updates and returns variables required to compute Structural Similarity Index Measure. Checks for same shape
|
||||
and type of the input tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -25,8 +25,8 @@ def bleu_score(
|
|||
n_gram: int = 4,
|
||||
smooth: bool = False,
|
||||
) -> Tensor:
|
||||
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine
|
||||
translated text with one or more references.
|
||||
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine translated text with one or more
|
||||
references.
|
||||
|
||||
Example:
|
||||
>>> from torchmetrics.functional import bleu_score
|
||||
|
|
|
@ -23,8 +23,7 @@ def _cosine_similarity_update(
|
|||
preds: Tensor,
|
||||
target: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Updates and returns variables required to compute Cosine Similarity.
|
||||
Checks for same shape of input tensors.
|
||||
"""Updates and returns variables required to compute Cosine Similarity. Checks for same shape of input tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_same_shape
|
|||
|
||||
|
||||
def _explained_variance_update(preds: Tensor, target: Tensor) -> Tuple[int, Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Updates and returns variables required to compute Explained Variance.
|
||||
Checks for same shape of input tensors.
|
||||
"""Updates and returns variables required to compute Explained Variance. Checks for same shape of input
|
||||
tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_same_shape
|
|||
|
||||
|
||||
def _mean_absolute_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
|
||||
"""Updates and returns variables required to compute Mean Absolute Error.
|
||||
Checks for same shape of input tensors.
|
||||
"""Updates and returns variables required to compute Mean Absolute Error. Checks for same shape of input
|
||||
tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -24,8 +24,8 @@ def _mean_absolute_percentage_error_update(
|
|||
target: Tensor,
|
||||
epsilon: float = 1.17e-06,
|
||||
) -> Tuple[Tensor, int]:
|
||||
"""Updates and returns variables required to compute Mean Percentage Error.
|
||||
Checks for same shape of input tensors.
|
||||
"""Updates and returns variables required to compute Mean Percentage Error. Checks for same shape of input
|
||||
tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_same_shape
|
|||
|
||||
|
||||
def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
|
||||
"""Updates and returns variables required to compute Mean Squared Error.
|
||||
Checks for same shape of input tensors.
|
||||
"""Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
|
||||
tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -20,8 +20,7 @@ from torchmetrics.utilities.checks import _check_same_shape
|
|||
|
||||
|
||||
def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
|
||||
"""Returns variables required to compute Mean Squared Log Error. Checks for
|
||||
same shape of tensors.
|
||||
"""Returns variables required to compute Mean Squared Log Error. Checks for same shape of tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -29,8 +29,8 @@ def _pearson_corrcoef_update(
|
|||
corr_xy: Tensor,
|
||||
n_prior: Tensor,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Updates and returns variables required to compute Pearson Correlation
|
||||
Coefficient. Checks for same shape of input tensors.
|
||||
"""Updates and returns variables required to compute Pearson Correlation Coefficient. Checks for same shape of
|
||||
input tensors.
|
||||
|
||||
Args:
|
||||
mean_x: current mean estimate of x tensor
|
||||
|
|
|
@ -21,8 +21,7 @@ from torchmetrics.utilities.checks import _check_same_shape
|
|||
|
||||
|
||||
def _r2_score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Updates and returns variables required to compute R2 score. Checks for
|
||||
same shape and 1D/2D input tensors.
|
||||
"""Updates and returns variables required to compute R2 score. Checks for same shape and 1D/2D input tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -20,8 +20,7 @@ from torchmetrics.utilities.checks import _check_same_shape
|
|||
|
||||
|
||||
def _find_repeats(data: Tensor) -> Tensor:
|
||||
"""find and return values which have repeats i.e. the same value are more
|
||||
than once in the tensor."""
|
||||
"""find and return values which have repeats i.e. the same value are more than once in the tensor."""
|
||||
temp = data.detach().clone()
|
||||
temp = temp.sort()[0]
|
||||
|
||||
|
@ -34,9 +33,9 @@ def _find_repeats(data: Tensor) -> Tensor:
|
|||
|
||||
|
||||
def _rank_data(data: Tensor) -> Tensor:
|
||||
"""Calculate the rank for each element of a tensor. The rank refers to the
|
||||
indices of an element in the corresponding sorted tensor (starting from 1).
|
||||
Duplicates of the same value will be assigned the mean of their rank.
|
||||
"""Calculate the rank for each element of a tensor. The rank refers to the indices of an element in the
|
||||
corresponding sorted tensor (starting from 1). Duplicates of the same value will be assigned the mean of their
|
||||
rank.
|
||||
|
||||
Adopted from:
|
||||
https://github.com/scipy/scipy/blob/v1.6.2/scipy/stats/stats.py#L4140-L4303
|
||||
|
@ -54,8 +53,8 @@ def _rank_data(data: Tensor) -> Tensor:
|
|||
|
||||
|
||||
def _spearman_corrcoef_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Updates and returns variables required to compute Spearman Correlation
|
||||
Coefficient. Checks for same shape and type of input tensors.
|
||||
"""Updates and returns variables required to compute Spearman Correlation Coefficient. Checks for same shape
|
||||
and type of input tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -24,8 +24,8 @@ def _symmetric_mean_absolute_percentage_error_update(
|
|||
target: Tensor,
|
||||
epsilon: float = 1.17e-06,
|
||||
) -> Tuple[Tensor, int]:
|
||||
"""Updates and returns variables required to compute Symmetric Mean
|
||||
Absolute Percentage Error. Checks for same shape of input tensors.
|
||||
"""Updates and returns variables required to compute Symmetric Mean Absolute Percentage Error. Checks for same
|
||||
shape of input tensors.
|
||||
|
||||
Args:
|
||||
preds: Predicted tensor
|
||||
|
|
|
@ -18,8 +18,7 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
|
|||
|
||||
|
||||
def retrieval_average_precision(preds: Tensor, target: Tensor) -> Tensor:
|
||||
"""Computes average precision (for information retrieval), as explained in
|
||||
`IR Average precision`_.
|
||||
"""Computes average precision (for information retrieval), as explained in `IR Average precision`_.
|
||||
|
||||
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
|
||||
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
|
||||
|
|
|
@ -18,9 +18,8 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
|
|||
|
||||
|
||||
def retrieval_fall_out(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
|
||||
"""Computes the Fall-out (for information retrieval), as explained in `IR
|
||||
Fall-out`_ Fall-out is the fraction of non-relevant documents retrieved
|
||||
among all the non-relevant documents.
|
||||
"""Computes the Fall-out (for information retrieval), as explained in `IR Fall-out`_ Fall-out is the fraction
|
||||
of non-relevant documents retrieved among all the non-relevant documents.
|
||||
|
||||
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
|
||||
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
|
||||
|
|
|
@ -26,8 +26,7 @@ def _dcg(target: Tensor) -> Tensor:
|
|||
|
||||
|
||||
def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
|
||||
"""Computes `Normalized Discounted Cumulative Gain`_ (for information
|
||||
retrieval).
|
||||
"""Computes `Normalized Discounted Cumulative Gain`_ (for information retrieval).
|
||||
|
||||
``preds`` and ``target`` should be of the same shape and live on the same device.
|
||||
``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
|
||||
|
|
|
@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
|
|||
|
||||
|
||||
def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
|
||||
"""Computes the precision metric (for information retrieval). Precision is
|
||||
the fraction of relevant documents among all the retrieved documents.
|
||||
"""Computes the precision metric (for information retrieval). Precision is the fraction of relevant documents
|
||||
among all the retrieved documents.
|
||||
|
||||
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
|
||||
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
|
||||
|
|
|
@ -20,8 +20,8 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
|
|||
|
||||
|
||||
def retrieval_recall(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
|
||||
"""Computes the recall metric (for information retrieval). Recall is the
|
||||
fraction of relevant documents retrieved among all the relevant documents.
|
||||
"""Computes the recall metric (for information retrieval). Recall is the fraction of relevant documents
|
||||
retrieved among all the relevant documents.
|
||||
|
||||
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
|
||||
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
|
||||
|
|
|
@ -123,8 +123,8 @@ def bleu_score(
|
|||
n_gram: int = 4,
|
||||
smooth: bool = False,
|
||||
) -> Tensor:
|
||||
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine
|
||||
translated text with one or more references.
|
||||
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine translated text with one or more
|
||||
references.
|
||||
|
||||
Args:
|
||||
reference_corpus:
|
||||
|
|
|
@ -41,8 +41,7 @@ ALLOWED_ROUGE_KEYS = (
|
|||
|
||||
|
||||
def add_newline_to_end_of_each_sentence(x: str) -> str:
|
||||
"""This was added to get rougeLsum scores matching published rougeL scores
|
||||
for BART and PEGASUS."""
|
||||
"""This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
|
||||
if _NLTK_AVAILABLE:
|
||||
import nltk
|
||||
|
||||
|
@ -54,8 +53,7 @@ def add_newline_to_end_of_each_sentence(x: str) -> str:
|
|||
|
||||
|
||||
def format_rouge_results(result: Dict[str, AggregateScore], decimal_places: int = 4) -> Dict[str, Tensor]:
|
||||
"""Formats the computed (aggregated) rouge score to a dictionary of tensors
|
||||
format."""
|
||||
"""Formats the computed (aggregated) rouge score to a dictionary of tensors format."""
|
||||
flattened_result = {}
|
||||
for rouge_key, rouge_aggregate_score in result.items():
|
||||
for stat in ["precision", "recall", "fmeasure"]:
|
||||
|
@ -72,8 +70,7 @@ def _rouge_score_update(
|
|||
aggregator: BootstrapAggregator,
|
||||
newline_sep: bool = False,
|
||||
) -> None:
|
||||
"""Update the rouge score with the current set of predicted and target
|
||||
sentences.
|
||||
"""Update the rouge score with the current set of predicted and target sentences.
|
||||
|
||||
Args:
|
||||
preds:
|
||||
|
@ -104,8 +101,7 @@ def _rouge_score_update(
|
|||
|
||||
|
||||
def _rouge_score_compute(aggregator: BootstrapAggregator, decimal_places: int = 4) -> Dict[str, Tensor]:
|
||||
"""Compute the combined ROUGE metric for all the input set of predicted and
|
||||
target sentences.
|
||||
"""Compute the combined ROUGE metric for all the input set of predicted and target sentences.
|
||||
|
||||
Args:
|
||||
aggregator:
|
||||
|
@ -125,8 +121,7 @@ def rouge_score(
|
|||
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), # type: ignore
|
||||
decimal_places: int = 4,
|
||||
) -> Dict[str, Tensor]:
|
||||
"""Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_,
|
||||
used for automatic summarization.
|
||||
"""Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_, used for automatic summarization.
|
||||
|
||||
Args:
|
||||
preds:
|
||||
|
|
|
@ -25,9 +25,8 @@ def wer(
|
|||
predictions: Union[str, List[str]],
|
||||
concatenate_texts: bool = False,
|
||||
) -> float:
|
||||
"""Word error rate (WER_) is a common metric of the performance of an
|
||||
automatic speech recognition system. This value indicates the percentage of
|
||||
words that were incorrectly predicted. The lower the value, the better the
|
||||
"""Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. This
|
||||
value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the
|
||||
performance of the ASR system with a WER of 0 being a perfect score.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -47,8 +47,7 @@ class NoTrainInceptionV3(FeatureExtractorInceptionV3):
|
|||
self.eval()
|
||||
|
||||
def train(self, mode: bool) -> "NoTrainInceptionV3":
|
||||
"""the inception network should not be able to be switched away from
|
||||
evaluation mode."""
|
||||
"""the inception network should not be able to be switched away from evaluation mode."""
|
||||
return super().train(False)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
@ -264,8 +263,7 @@ class FID(Metric):
|
|||
self.fake_features.append(features)
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""Calculate FID score based on accumulated extracted features from the
|
||||
two distributions."""
|
||||
"""Calculate FID score based on accumulated extracted features from the two distributions."""
|
||||
real_features = dim_zero_cat(self.real_features)
|
||||
fake_features = dim_zero_cat(self.fake_features)
|
||||
# computation is extremely sensitive so it needs to happen in double precision
|
||||
|
|
|
@ -25,8 +25,7 @@ from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE
|
|||
|
||||
|
||||
def maximum_mean_discrepancy(k_xx: Tensor, k_xy: Tensor, k_yy: Tensor) -> Tensor:
|
||||
"""Adapted from https://github.com/toshas/torch-
|
||||
fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
|
||||
"""Adapted from https://github.com/toshas/torch- fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
|
||||
m = k_xx.shape[0]
|
||||
|
||||
diag_x = torch.diag(k_xx)
|
||||
|
@ -46,8 +45,7 @@ def maximum_mean_discrepancy(k_xx: Tensor, k_xy: Tensor, k_yy: Tensor) -> Tensor
|
|||
|
||||
|
||||
def poly_kernel(f1: Tensor, f2: Tensor, degree: int = 3, gamma: Optional[float] = None, coef: float = 1.0) -> Tensor:
|
||||
"""Adapted from https://github.com/toshas/torch-
|
||||
fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
|
||||
"""Adapted from https://github.com/toshas/torch- fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
|
||||
if gamma is None:
|
||||
gamma = 1.0 / f1.shape[1]
|
||||
kernel = (f1 @ f2.T * gamma + coef) ** degree
|
||||
|
@ -57,8 +55,7 @@ def poly_kernel(f1: Tensor, f2: Tensor, degree: int = 3, gamma: Optional[float]
|
|||
def poly_mmd(
|
||||
f_real: Tensor, f_fake: Tensor, degree: int = 3, gamma: Optional[float] = None, coef: float = 1.0
|
||||
) -> Tensor:
|
||||
"""Adapted from https://github.com/toshas/torch-
|
||||
fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
|
||||
"""Adapted from https://github.com/toshas/torch- fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py."""
|
||||
k_11 = poly_kernel(f_real, f_real, degree, gamma, coef)
|
||||
k_22 = poly_kernel(f_fake, f_fake, degree, gamma, coef)
|
||||
k_12 = poly_kernel(f_real, f_fake, degree, gamma, coef)
|
||||
|
@ -252,9 +249,8 @@ class KID(Metric):
|
|||
self.fake_features.append(features)
|
||||
|
||||
def compute(self) -> Tuple[Tensor, Tensor]:
|
||||
"""Calculate KID score based on accumulated extracted features from the
|
||||
two distributions. Returns a tuple of mean and standard deviation of
|
||||
KID scores calculated on subsets of extracted features.
|
||||
"""Calculate KID score based on accumulated extracted features from the two distributions. Returns a tuple
|
||||
of mean and standard deviation of KID scores calculated on subsets of extracted features.
|
||||
|
||||
Implementation inspired by https://github.com/toshas/torch-fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py
|
||||
"""
|
||||
|
|
|
@ -176,8 +176,7 @@ class Metric(nn.Module, ABC):
|
|||
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Automatically calls ``update()``.
|
||||
|
||||
Returns the metric value over inputs if ``compute_on_step`` is
|
||||
True.
|
||||
Returns the metric value over inputs if ``compute_on_step`` is True.
|
||||
"""
|
||||
# add current step
|
||||
if self._is_synced:
|
||||
|
@ -256,8 +255,7 @@ class Metric(nn.Module, ABC):
|
|||
should_sync: bool = True,
|
||||
distributed_available: Optional[Callable] = jit_distributed_available,
|
||||
) -> None:
|
||||
"""Sync function for manually controlling when metrics states should be
|
||||
synced across processes.
|
||||
"""Sync function for manually controlling when metrics states should be synced across processes.
|
||||
|
||||
Args:
|
||||
dist_sync_fn: Function to be used to perform states synchronization
|
||||
|
@ -287,8 +285,8 @@ class Metric(nn.Module, ABC):
|
|||
self._is_synced = True
|
||||
|
||||
def unsync(self, should_unsync: bool = True) -> None:
|
||||
"""Unsync function for manually controlling when metrics states should
|
||||
be reverted back to their local states.
|
||||
"""Unsync function for manually controlling when metrics states should be reverted back to their local
|
||||
states.
|
||||
|
||||
Args:
|
||||
should_unsync: Whether to perform unsync
|
||||
|
@ -317,9 +315,8 @@ class Metric(nn.Module, ABC):
|
|||
should_unsync: bool = True,
|
||||
distributed_available: Optional[Callable] = jit_distributed_available,
|
||||
) -> Generator:
|
||||
"""Context manager to synchronize the states between processes when
|
||||
running in a distributed setting and restore the local cache states
|
||||
after yielding.
|
||||
"""Context manager to synchronize the states between processes when running in a distributed setting and
|
||||
restore the local cache states after yielding.
|
||||
|
||||
Args:
|
||||
dist_sync_fn: Function to be used to perform states synchronization
|
||||
|
@ -372,17 +369,15 @@ class Metric(nn.Module, ABC):
|
|||
|
||||
@abstractmethod
|
||||
def update(self, *_: Any, **__: Any) -> None:
|
||||
"""Override this method to update the state variables of your metric
|
||||
class."""
|
||||
"""Override this method to update the state variables of your metric class."""
|
||||
|
||||
@abstractmethod
|
||||
def compute(self) -> Any:
|
||||
"""Override this method to compute the final metric value from state
|
||||
variables synchronized across the distributed backend."""
|
||||
"""Override this method to compute the final metric value from state variables synchronized across the
|
||||
distributed backend."""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""This method automatically resets the metric state variables to their
|
||||
default value."""
|
||||
"""This method automatically resets the metric state variables to their default value."""
|
||||
self._update_called = False
|
||||
self._forward_cache = None
|
||||
# lower lightning versions requires this implicitly to log metric objects correctly in self.log
|
||||
|
@ -416,8 +411,8 @@ class Metric(nn.Module, ABC):
|
|||
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore
|
||||
|
||||
def _apply(self, fn: Callable) -> Module:
|
||||
"""Overwrite _apply function such that we can also move metric states
|
||||
to the correct device when `.to`, `.cuda`, etc methods are called."""
|
||||
"""Overwrite _apply function such that we can also move metric states to the correct device when `.to`,
|
||||
`.cuda`, etc methods are called."""
|
||||
this = super()._apply(fn)
|
||||
# Also apply fn to metric states and defaults
|
||||
for key, value in this._defaults.items():
|
||||
|
@ -445,8 +440,7 @@ class Metric(nn.Module, ABC):
|
|||
return this
|
||||
|
||||
def persistent(self, mode: bool = False) -> None:
|
||||
"""Method for post-init to change if metric states should be saved to
|
||||
its state_dict."""
|
||||
"""Method for post-init to change if metric states should be saved to its state_dict."""
|
||||
for key in self._persistent:
|
||||
self._persistent[key] = mode
|
||||
|
||||
|
@ -491,8 +485,7 @@ class Metric(nn.Module, ABC):
|
|||
)
|
||||
|
||||
def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""filter kwargs such that they match the update signature of the
|
||||
metric."""
|
||||
"""filter kwargs such that they match the update signature of the metric."""
|
||||
|
||||
# filter all parameters based on update signature except those of
|
||||
# type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs)
|
||||
|
@ -638,8 +631,7 @@ def _neg(x: Tensor) -> Tensor:
|
|||
|
||||
|
||||
class CompositionalMetric(Metric):
|
||||
"""Composition of two metrics with a specific operator which will be
|
||||
executed upon metrics compute."""
|
||||
"""Composition of two metrics with a specific operator which will be executed upon metrics compute."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -91,14 +91,11 @@ class RetrievalFallOut(RetrievalMetric):
|
|||
self.k = k
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""First concat state `indexes`, `preds` and `target` since they were
|
||||
stored as lists.
|
||||
"""First concat state `indexes`, `preds` and `target` since they were stored as lists.
|
||||
|
||||
After that, compute list of groups that will help in keeping
|
||||
together predictions about the same query. Finally, for each
|
||||
group compute the `_metric` if the number of negative targets is
|
||||
at least 1, otherwise behave as specified by
|
||||
`self.empty_target_action`.
|
||||
After that, compute list of groups that will help in keeping together predictions about the same query. Finally,
|
||||
for each group compute the `_metric` if the number of negative targets is at least 1, otherwise behave as
|
||||
specified by `self.empty_target_action`.
|
||||
"""
|
||||
indexes = torch.cat(self.indexes, dim=0)
|
||||
preds = torch.cat(self.preds, dim=0)
|
||||
|
|
|
@ -25,8 +25,7 @@ from torchmetrics.utilities.data import get_group_indexes
|
|||
|
||||
|
||||
class RetrievalMetric(Metric, ABC):
|
||||
"""Works with binary target data. Accepts float predictions from a model
|
||||
output.
|
||||
"""Works with binary target data. Accepts float predictions from a model output.
|
||||
|
||||
Forward accepts
|
||||
|
||||
|
@ -96,8 +95,7 @@ class RetrievalMetric(Metric, ABC):
|
|||
self.add_state("target", default=[], dist_reduce_fx=None)
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # type: ignore
|
||||
"""Check shape, check and convert dtypes, flatten and add to
|
||||
accumulators."""
|
||||
"""Check shape, check and convert dtypes, flatten and add to accumulators."""
|
||||
if indexes is None:
|
||||
raise ValueError("Argument `indexes` cannot be None")
|
||||
|
||||
|
@ -110,14 +108,11 @@ class RetrievalMetric(Metric, ABC):
|
|||
self.target.append(target)
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
"""First concat state ``indexes``, ``preds`` and ``target`` since they
|
||||
were stored as lists.
|
||||
"""First concat state ``indexes``, ``preds`` and ``target`` since they were stored as lists.
|
||||
|
||||
After that, compute list of groups that will help in keeping
|
||||
together predictions about the same query. Finally, for each
|
||||
group compute the ``_metric`` if the number of positive targets
|
||||
is at least 1, otherwise behave as specified by
|
||||
``self.empty_target_action``.
|
||||
After that, compute list of groups that will help in keeping together predictions about the same query. Finally,
|
||||
for each group compute the ``_metric`` if the number of positive targets is at least 1, otherwise behave as
|
||||
specified by ``self.empty_target_action``.
|
||||
"""
|
||||
indexes = torch.cat(self.indexes, dim=0)
|
||||
preds = torch.cat(self.preds, dim=0)
|
||||
|
|
|
@ -26,8 +26,8 @@ from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_u
|
|||
|
||||
|
||||
class BLEUScore(Metric):
|
||||
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine
|
||||
translated text with one or more references.
|
||||
"""Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine translated text with one or more
|
||||
references.
|
||||
|
||||
Args:
|
||||
n_gram:
|
||||
|
|
|
@ -27,8 +27,7 @@ else:
|
|||
|
||||
|
||||
class ROUGEScore(Metric):
|
||||
"""Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_,
|
||||
used for automatic summarization.
|
||||
"""Calculate `ROUGE score <https://en.wikipedia.org/wiki/ROUGE_(metric)>`_, used for automatic summarization.
|
||||
|
||||
Args:
|
||||
newline_sep:
|
||||
|
|
|
@ -21,15 +21,13 @@ from torchmetrics.utilities.enums import DataType
|
|||
|
||||
|
||||
def _check_same_shape(preds: Tensor, target: Tensor) -> None:
|
||||
"""Check that predictions and target have the same shape, else raise
|
||||
error."""
|
||||
"""Check that predictions and target have the same shape, else raise error."""
|
||||
if preds.shape != target.shape:
|
||||
raise RuntimeError("Predictions and targets are expected to have the same shape")
|
||||
|
||||
|
||||
def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, multiclass: Optional[bool]) -> None:
|
||||
"""Perform basic validation of inputs that does not require deducing any
|
||||
information of the type of inputs."""
|
||||
"""Perform basic validation of inputs that does not require deducing any information of the type of inputs."""
|
||||
|
||||
if target.is_floating_point():
|
||||
raise ValueError("The `target` has to be an integer tensor.")
|
||||
|
@ -51,14 +49,12 @@ def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, mul
|
|||
|
||||
|
||||
def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[DataType, int]:
|
||||
"""This checks that the shape and type of inputs are consistent with each
|
||||
other and fall into one of the allowed input types (see the documentation
|
||||
of docstring of ``_input_format_classification``). It does not check for
|
||||
consistency of number of classes, other functions take care of that.
|
||||
"""This checks that the shape and type of inputs are consistent with each other and fall into one of the
|
||||
allowed input types (see the documentation of docstring of ``_input_format_classification``). It does not check
|
||||
for consistency of number of classes, other functions take care of that.
|
||||
|
||||
It returns the name of the case in which the inputs fall, and the
|
||||
implied number of classes (from the ``C`` dim for multi-class data,
|
||||
or extra dim(s) for multi-label data).
|
||||
It returns the name of the case in which the inputs fall, and the implied number of classes (from the ``C`` dim for
|
||||
multi-class data, or extra dim(s) for multi-label data).
|
||||
"""
|
||||
|
||||
preds_float = preds.is_floating_point()
|
||||
|
@ -111,8 +107,7 @@ def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[Da
|
|||
|
||||
|
||||
def _check_num_classes_binary(num_classes: int, multiclass: Optional[bool]) -> None:
|
||||
"""This checks that the consistency of `num_classes` with the data and
|
||||
`multiclass` param for binary data."""
|
||||
"""This checks that the consistency of `num_classes` with the data and `multiclass` param for binary data."""
|
||||
|
||||
if num_classes > 2:
|
||||
raise ValueError("Your data is binary, but `num_classes` is larger than 2.")
|
||||
|
@ -136,8 +131,8 @@ def _check_num_classes_mc(
|
|||
multiclass: Optional[bool],
|
||||
implied_classes: int,
|
||||
) -> None:
|
||||
"""This checks that the consistency of `num_classes` with the data and
|
||||
`multiclass` param for (multi-dimensional) multi-class data."""
|
||||
"""This checks that the consistency of `num_classes` with the data and `multiclass` param for (multi-
|
||||
dimensional) multi-class data."""
|
||||
|
||||
if num_classes == 1 and multiclass is not False:
|
||||
raise ValueError(
|
||||
|
@ -161,8 +156,8 @@ def _check_num_classes_mc(
|
|||
|
||||
|
||||
def _check_num_classes_ml(num_classes: int, multiclass: Optional[bool], implied_classes: int) -> None:
|
||||
"""This checks that the consistency of `num_classes` with the data and
|
||||
`multiclass` param for multi-label data."""
|
||||
"""This checks that the consistency of `num_classes` with the data and `multiclass` param for multi-label
|
||||
data."""
|
||||
|
||||
if multiclass and num_classes != 2:
|
||||
raise ValueError(
|
||||
|
@ -494,8 +489,7 @@ def _check_retrieval_functional_inputs(
|
|||
target: Tensor,
|
||||
allow_non_binary_target: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Check ``preds`` and ``target`` tensors are of the same shape and of the
|
||||
correct dtype.
|
||||
"""Check ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.
|
||||
|
||||
Args:
|
||||
preds: either tensor with scores/logits
|
||||
|
@ -535,8 +529,7 @@ def _check_retrieval_inputs(
|
|||
target: Tensor,
|
||||
allow_non_binary_target: bool = False,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Check ``indexes``, ``preds`` and ``target`` tensors are of the same
|
||||
shape and of the correct dtype.
|
||||
"""Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.
|
||||
|
||||
Args:
|
||||
indexes: tensor with queries indexes
|
||||
|
|
|
@ -76,8 +76,7 @@ def to_onehot(
|
|||
|
||||
|
||||
def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor:
|
||||
"""Convert a probability tensor to binary by selecting top-k highest
|
||||
entries.
|
||||
"""Convert a probability tensor to binary by selecting top-k highest entries.
|
||||
|
||||
Args:
|
||||
prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the
|
||||
|
@ -122,8 +121,7 @@ def get_num_classes(
|
|||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Calculates the number of classes for a given prediction and target
|
||||
tensor.
|
||||
"""Calculates the number of classes for a given prediction and target tensor.
|
||||
|
||||
Args:
|
||||
preds: predicted values
|
||||
|
@ -200,8 +198,8 @@ def apply_to_collection(
|
|||
|
||||
|
||||
def get_group_indexes(indexes: Tensor) -> List[Tensor]:
|
||||
"""Given an integer `torch.Tensor` `indexes`, return a `torch.Tensor` of
|
||||
indexes for each different value in `indexes`.
|
||||
"""Given an integer `torch.Tensor` `indexes`, return a `torch.Tensor` of indexes for each different value in
|
||||
`indexes`.
|
||||
|
||||
Args:
|
||||
indexes: a `torch.Tensor`
|
||||
|
|
|
@ -94,11 +94,9 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L
|
|||
|
||||
|
||||
def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
|
||||
"""Function to gather all tensors from several ddp processes onto a list
|
||||
that is broadcasted to all processes. Works on tensors that have the same
|
||||
number of dimensions, but where each dimension may differ. In this case
|
||||
tensors are padded, gathered and then trimmed to secure equal workload for
|
||||
all processes.
|
||||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
|
||||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
|
||||
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
|
||||
|
||||
Args:
|
||||
result: the value to sync
|
||||
|
|
|
@ -16,8 +16,7 @@ from typing import Optional, Union
|
|||
|
||||
|
||||
class EnumStr(str, Enum):
|
||||
"""Type of any enumerator with allowed comparison to string invariant to
|
||||
cases.
|
||||
"""Type of any enumerator with allowed comparison to string invariant to cases.
|
||||
|
||||
Example:
|
||||
>>> class MyEnum(EnumStr):
|
||||
|
|
|
@ -157,9 +157,8 @@ class BootStrapper(Metric):
|
|||
def compute(self) -> Dict[str, Tensor]:
|
||||
"""Computes the bootstrapped metric values.
|
||||
|
||||
Allways returns a dict of tensors, which can contain the
|
||||
following keys: ``mean``, ``std``, ``quantile`` and ``raw``
|
||||
depending on how the class was initialized
|
||||
Allways returns a dict of tensors, which can contain the following keys: ``mean``, ``std``, ``quantile`` and
|
||||
``raw`` depending on how the class was initialized
|
||||
"""
|
||||
computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0)
|
||||
output_dict = {}
|
||||
|
|
Загрузка…
Ссылка в новой задаче