simplify Tensor import (#108)
* versions * simplify Tensor * simplify tensor * yapf
This commit is contained in:
Родитель
1fcf9fc235
Коммит
a1e50ca62b
|
@ -15,6 +15,7 @@ from operator import neg, pos
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import tensor
|
||||
|
||||
from tests.helpers import _MARK_TORCH_MIN_1_4, _MARK_TORCH_MIN_1_5, _MARK_TORCH_MIN_1_6
|
||||
from torchmetrics.metric import CompositionalMetric, Metric
|
||||
|
@ -31,7 +32,7 @@ class DummyMetric(Metric):
|
|||
self._num_updates += 1
|
||||
|
||||
def compute(self):
|
||||
return torch.tensor(self._val_to_return)
|
||||
return tensor(self._val_to_return)
|
||||
|
||||
def reset(self):
|
||||
self._num_updates = 0
|
||||
|
@ -41,10 +42,10 @@ class DummyMetric(Metric):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(4)),
|
||||
(2, torch.tensor(4)),
|
||||
(2.0, torch.tensor(4.0)),
|
||||
pytest.param(torch.tensor(2), torch.tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
|
||||
(DummyMetric(2), tensor(4)),
|
||||
(2, tensor(4)),
|
||||
(2.0, tensor(4.0)),
|
||||
pytest.param(tensor(2), tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
|
||||
],
|
||||
)
|
||||
def test_metrics_add(second_operand, expected_result):
|
||||
|
@ -62,7 +63,7 @@ def test_metrics_add(second_operand, expected_result):
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))],
|
||||
[(DummyMetric(3), tensor(2)), (3, tensor(2)), (3, tensor(2)), (tensor(3), tensor(2))],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
|
||||
def test_metrics_and(second_operand, expected_result):
|
||||
|
@ -81,10 +82,10 @@ def test_metrics_and(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(True)),
|
||||
(2, torch.tensor(True)),
|
||||
(2.0, torch.tensor(True)),
|
||||
(torch.tensor(2), torch.tensor(True)),
|
||||
(DummyMetric(2), tensor(True)),
|
||||
(2, tensor(True)),
|
||||
(2.0, tensor(True)),
|
||||
(tensor(2), tensor(True)),
|
||||
],
|
||||
)
|
||||
def test_metrics_eq(second_operand, expected_result):
|
||||
|
@ -101,10 +102,10 @@ def test_metrics_eq(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(2)),
|
||||
(2, torch.tensor(2)),
|
||||
(2.0, torch.tensor(2.0)),
|
||||
(torch.tensor(2), torch.tensor(2)),
|
||||
(DummyMetric(2), tensor(2)),
|
||||
(2, tensor(2)),
|
||||
(2.0, tensor(2.0)),
|
||||
(tensor(2), tensor(2)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
|
||||
|
@ -121,10 +122,10 @@ def test_metrics_floordiv(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(True)),
|
||||
(2, torch.tensor(True)),
|
||||
(2.0, torch.tensor(True)),
|
||||
(torch.tensor(2), torch.tensor(True)),
|
||||
(DummyMetric(2), tensor(True)),
|
||||
(2, tensor(True)),
|
||||
(2.0, tensor(True)),
|
||||
(tensor(2), tensor(True)),
|
||||
],
|
||||
)
|
||||
def test_metrics_ge(second_operand, expected_result):
|
||||
|
@ -141,10 +142,10 @@ def test_metrics_ge(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(True)),
|
||||
(2, torch.tensor(True)),
|
||||
(2.0, torch.tensor(True)),
|
||||
(torch.tensor(2), torch.tensor(True)),
|
||||
(DummyMetric(2), tensor(True)),
|
||||
(2, tensor(True)),
|
||||
(2.0, tensor(True)),
|
||||
(tensor(2), tensor(True)),
|
||||
],
|
||||
)
|
||||
def test_metrics_gt(second_operand, expected_result):
|
||||
|
@ -161,10 +162,10 @@ def test_metrics_gt(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(False)),
|
||||
(2, torch.tensor(False)),
|
||||
(2.0, torch.tensor(False)),
|
||||
(torch.tensor(2), torch.tensor(False)),
|
||||
(DummyMetric(2), tensor(False)),
|
||||
(2, tensor(False)),
|
||||
(2.0, tensor(False)),
|
||||
(tensor(2), tensor(False)),
|
||||
],
|
||||
)
|
||||
def test_metrics_le(second_operand, expected_result):
|
||||
|
@ -181,10 +182,10 @@ def test_metrics_le(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(False)),
|
||||
(2, torch.tensor(False)),
|
||||
(2.0, torch.tensor(False)),
|
||||
(torch.tensor(2), torch.tensor(False)),
|
||||
(DummyMetric(2), tensor(False)),
|
||||
(2, tensor(False)),
|
||||
(2.0, tensor(False)),
|
||||
(tensor(2), tensor(False)),
|
||||
],
|
||||
)
|
||||
def test_metrics_lt(second_operand, expected_result):
|
||||
|
@ -200,7 +201,7 @@ def test_metrics_lt(second_operand, expected_result):
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[(DummyMetric([2, 2, 2]), torch.tensor(12)), (torch.tensor([2, 2, 2]), torch.tensor(12))],
|
||||
[(DummyMetric([2, 2, 2]), tensor(12)), (tensor([2, 2, 2]), tensor(12))],
|
||||
)
|
||||
def test_metrics_matmul(second_operand, expected_result):
|
||||
first_metric = DummyMetric([2, 2, 2])
|
||||
|
@ -215,10 +216,10 @@ def test_metrics_matmul(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(1)),
|
||||
(2, torch.tensor(1)),
|
||||
(2.0, torch.tensor(1)),
|
||||
(torch.tensor(2), torch.tensor(1)),
|
||||
(DummyMetric(2), tensor(1)),
|
||||
(2, tensor(1)),
|
||||
(2.0, tensor(1)),
|
||||
(tensor(2), tensor(1)),
|
||||
],
|
||||
)
|
||||
def test_metrics_mod(second_operand, expected_result):
|
||||
|
@ -234,10 +235,10 @@ def test_metrics_mod(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(4)),
|
||||
(2, torch.tensor(4)),
|
||||
(2.0, torch.tensor(4.0)),
|
||||
pytest.param(torch.tensor(2), torch.tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
|
||||
(DummyMetric(2), tensor(4)),
|
||||
(2, tensor(4)),
|
||||
(2.0, tensor(4.0)),
|
||||
pytest.param(tensor(2), tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
|
||||
],
|
||||
)
|
||||
def test_metrics_mul(second_operand, expected_result):
|
||||
|
@ -256,10 +257,10 @@ def test_metrics_mul(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(False)),
|
||||
(2, torch.tensor(False)),
|
||||
(2.0, torch.tensor(False)),
|
||||
(torch.tensor(2), torch.tensor(False)),
|
||||
(DummyMetric(2), tensor(False)),
|
||||
(2, tensor(False)),
|
||||
(2.0, tensor(False)),
|
||||
(tensor(2), tensor(False)),
|
||||
],
|
||||
)
|
||||
def test_metrics_ne(second_operand, expected_result):
|
||||
|
@ -275,7 +276,7 @@ def test_metrics_ne(second_operand, expected_result):
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))],
|
||||
[(DummyMetric([1, 0, 3]), tensor([-1, -2, 3])), (tensor([1, 0, 3]), tensor([-1, -2, 3]))],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
|
||||
def test_metrics_or(second_operand, expected_result):
|
||||
|
@ -294,10 +295,10 @@ def test_metrics_or(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
pytest.param(DummyMetric(2), torch.tensor(4)),
|
||||
pytest.param(2, torch.tensor(4)),
|
||||
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
|
||||
pytest.param(torch.tensor(2), torch.tensor(4)),
|
||||
pytest.param(DummyMetric(2), tensor(4)),
|
||||
pytest.param(2, tensor(4)),
|
||||
pytest.param(2.0, tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
|
||||
pytest.param(tensor(2), tensor(4)),
|
||||
],
|
||||
)
|
||||
def test_metrics_pow(second_operand, expected_result):
|
||||
|
@ -312,7 +313,7 @@ def test_metrics_pow(second_operand, expected_result):
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
["first_operand", "expected_result"],
|
||||
[(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))],
|
||||
[(5, tensor(2)), (5.0, tensor(2.0)), (tensor(5), tensor(2))],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
|
||||
def test_metrics_rfloordiv(first_operand, expected_result):
|
||||
|
@ -324,10 +325,8 @@ def test_metrics_rfloordiv(first_operand, expected_result):
|
|||
assert torch.allclose(expected_result, final_rfloordiv.compute())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["first_operand", "expected_result"],
|
||||
[pytest.param(torch.tensor([2, 2, 2]), torch.tensor(12), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))]
|
||||
)
|
||||
@pytest.mark.parametrize(["first_operand", "expected_result"],
|
||||
[pytest.param(tensor([2, 2, 2]), tensor(12), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))])
|
||||
def test_metrics_rmatmul(first_operand, expected_result):
|
||||
second_operand = DummyMetric([2, 2, 2])
|
||||
|
||||
|
@ -338,10 +337,8 @@ def test_metrics_rmatmul(first_operand, expected_result):
|
|||
assert torch.allclose(expected_result, final_rmatmul.compute())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["first_operand", "expected_result"],
|
||||
[pytest.param(torch.tensor(2), torch.tensor(2), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))]
|
||||
)
|
||||
@pytest.mark.parametrize(["first_operand", "expected_result"],
|
||||
[pytest.param(tensor(2), tensor(2), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))])
|
||||
def test_metrics_rmod(first_operand, expected_result):
|
||||
second_operand = DummyMetric(5)
|
||||
|
||||
|
@ -355,9 +352,9 @@ def test_metrics_rmod(first_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
"first_operand,expected_result",
|
||||
[
|
||||
pytest.param(DummyMetric(2), torch.tensor(4)),
|
||||
pytest.param(2, torch.tensor(4)),
|
||||
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
|
||||
pytest.param(DummyMetric(2), tensor(4)),
|
||||
pytest.param(2, tensor(4)),
|
||||
pytest.param(2.0, tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
|
||||
],
|
||||
)
|
||||
def test_metrics_rpow(first_operand, expected_result):
|
||||
|
@ -373,10 +370,10 @@ def test_metrics_rpow(first_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["first_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(3), torch.tensor(1)),
|
||||
(3, torch.tensor(1)),
|
||||
(3.0, torch.tensor(1.0)),
|
||||
pytest.param(torch.tensor(3), torch.tensor(1), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
|
||||
(DummyMetric(3), tensor(1)),
|
||||
(3, tensor(1)),
|
||||
(3.0, tensor(1.0)),
|
||||
pytest.param(tensor(3), tensor(1), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
|
||||
],
|
||||
)
|
||||
def test_metrics_rsub(first_operand, expected_result):
|
||||
|
@ -392,10 +389,10 @@ def test_metrics_rsub(first_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["first_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(6), torch.tensor(2.0)),
|
||||
(6, torch.tensor(2.0)),
|
||||
(6.0, torch.tensor(2.0)),
|
||||
(torch.tensor(6), torch.tensor(2.0)),
|
||||
(DummyMetric(6), tensor(2.0)),
|
||||
(6, tensor(2.0)),
|
||||
(6.0, tensor(2.0)),
|
||||
(tensor(6), tensor(2.0)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
|
||||
|
@ -412,10 +409,10 @@ def test_metrics_rtruediv(first_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(1)),
|
||||
(2, torch.tensor(1)),
|
||||
(2.0, torch.tensor(1.0)),
|
||||
(torch.tensor(2), torch.tensor(1)),
|
||||
(DummyMetric(2), tensor(1)),
|
||||
(2, tensor(1)),
|
||||
(2.0, tensor(1.0)),
|
||||
(tensor(2), tensor(1)),
|
||||
],
|
||||
)
|
||||
def test_metrics_sub(second_operand, expected_result):
|
||||
|
@ -431,10 +428,10 @@ def test_metrics_sub(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(3), torch.tensor(2.0)),
|
||||
(3, torch.tensor(2.0)),
|
||||
(3.0, torch.tensor(2.0)),
|
||||
(torch.tensor(3), torch.tensor(2.0)),
|
||||
(DummyMetric(3), tensor(2.0)),
|
||||
(3, tensor(2.0)),
|
||||
(3.0, tensor(2.0)),
|
||||
(tensor(3), tensor(2.0)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
|
||||
|
@ -450,7 +447,7 @@ def test_metrics_truediv(second_operand, expected_result):
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[(DummyMetric([1, 0, 3]), torch.tensor([-2, -2, 0])), (torch.tensor([1, 0, 3]), torch.tensor([-2, -2, 0]))],
|
||||
[(DummyMetric([1, 0, 3]), tensor([-2, -2, 0])), (tensor([1, 0, 3]), tensor([-2, -2, 0]))],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
|
||||
def test_metrics_xor(second_operand, expected_result):
|
||||
|
@ -473,7 +470,7 @@ def test_metrics_abs():
|
|||
|
||||
assert isinstance(final_abs, CompositionalMetric)
|
||||
|
||||
assert torch.allclose(torch.tensor(1), final_abs.compute())
|
||||
assert torch.allclose(tensor(1), final_abs.compute())
|
||||
|
||||
|
||||
def test_metrics_invert():
|
||||
|
@ -481,7 +478,7 @@ def test_metrics_invert():
|
|||
|
||||
final_inverse = ~first_metric
|
||||
assert isinstance(final_inverse, CompositionalMetric)
|
||||
assert torch.allclose(torch.tensor(-2), final_inverse.compute())
|
||||
assert torch.allclose(tensor(-2), final_inverse.compute())
|
||||
|
||||
|
||||
def test_metrics_neg():
|
||||
|
@ -489,7 +486,7 @@ def test_metrics_neg():
|
|||
|
||||
final_neg = neg(first_metric)
|
||||
assert isinstance(final_neg, CompositionalMetric)
|
||||
assert torch.allclose(torch.tensor(-1), final_neg.compute())
|
||||
assert torch.allclose(tensor(-1), final_neg.compute())
|
||||
|
||||
|
||||
def test_metrics_pos():
|
||||
|
@ -497,7 +494,7 @@ def test_metrics_pos():
|
|||
|
||||
final_pos = pos(first_metric)
|
||||
assert isinstance(final_pos, CompositionalMetric)
|
||||
assert torch.allclose(torch.tensor(1), final_pos.compute())
|
||||
assert torch.allclose(tensor(1), final_pos.compute())
|
||||
|
||||
|
||||
def test_compositional_metrics_update():
|
||||
|
|
|
@ -15,6 +15,7 @@ import sys
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import tensor
|
||||
|
||||
from tests.helpers.testers import DummyMetric, setup_ddp
|
||||
from torchmetrics import Metric
|
||||
|
@ -26,7 +27,7 @@ def _test_ddp_sum(rank, worldsize):
|
|||
setup_ddp(rank, worldsize)
|
||||
dummy = DummyMetric()
|
||||
dummy._reductions = {"foo": torch.sum}
|
||||
dummy.foo = torch.tensor(1)
|
||||
dummy.foo = tensor(1)
|
||||
dummy._sync_dist()
|
||||
|
||||
assert dummy.foo == worldsize
|
||||
|
@ -36,21 +37,21 @@ def _test_ddp_cat(rank, worldsize):
|
|||
setup_ddp(rank, worldsize)
|
||||
dummy = DummyMetric()
|
||||
dummy._reductions = {"foo": torch.cat}
|
||||
dummy.foo = [torch.tensor([1])]
|
||||
dummy.foo = [tensor([1])]
|
||||
dummy._sync_dist()
|
||||
|
||||
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
|
||||
assert torch.all(torch.eq(dummy.foo, tensor([1, 1])))
|
||||
|
||||
|
||||
def _test_ddp_sum_cat(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
dummy = DummyMetric()
|
||||
dummy._reductions = {"foo": torch.cat, "bar": torch.sum}
|
||||
dummy.foo = [torch.tensor([1])]
|
||||
dummy.bar = torch.tensor(1)
|
||||
dummy.foo = [tensor([1])]
|
||||
dummy.bar = tensor(1)
|
||||
dummy._sync_dist()
|
||||
|
||||
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
|
||||
assert torch.all(torch.eq(dummy.foo, tensor([1, 1])))
|
||||
assert dummy.bar == worldsize
|
||||
|
||||
|
||||
|
|
|
@ -13,15 +13,15 @@
|
|||
# limitations under the License.
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import cloudpickle
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import nn, tensor
|
||||
|
||||
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
|
||||
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
@ -33,23 +33,23 @@ def test_inherit():
|
|||
def test_add_state():
|
||||
a = DummyMetric()
|
||||
|
||||
a.add_state("a", torch.tensor(0), "sum")
|
||||
assert a._reductions["a"](torch.tensor([1, 1])) == 2
|
||||
a.add_state("a", tensor(0), "sum")
|
||||
assert a._reductions["a"](tensor([1, 1])) == 2
|
||||
|
||||
a.add_state("b", torch.tensor(0), "mean")
|
||||
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
|
||||
a.add_state("b", tensor(0), "mean")
|
||||
assert np.allclose(a._reductions["b"](tensor([1.0, 2.0])).numpy(), 1.5)
|
||||
|
||||
a.add_state("c", torch.tensor(0), "cat")
|
||||
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2, )
|
||||
a.add_state("c", tensor(0), "cat")
|
||||
assert a._reductions["c"]([tensor([1]), tensor([1])]).shape == (2, )
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.add_state("d1", torch.tensor(0), 'xyz')
|
||||
a.add_state("d1", tensor(0), 'xyz')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.add_state("d2", torch.tensor(0), 42)
|
||||
a.add_state("d2", tensor(0), 42)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.add_state("d3", [torch.tensor(0)], 'sum')
|
||||
a.add_state("d3", [tensor(0)], 'sum')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.add_state("d4", 42, 'sum')
|
||||
|
@ -57,19 +57,19 @@ def test_add_state():
|
|||
def custom_fx(x):
|
||||
return -1
|
||||
|
||||
a.add_state("e", torch.tensor(0), custom_fx)
|
||||
assert a._reductions["e"](torch.tensor([1, 1])) == -1
|
||||
a.add_state("e", tensor(0), custom_fx)
|
||||
assert a._reductions["e"](tensor([1, 1])) == -1
|
||||
|
||||
|
||||
def test_add_state_persistent():
|
||||
a = DummyMetric()
|
||||
|
||||
a.add_state("a", torch.tensor(0), "sum", persistent=True)
|
||||
a.add_state("a", tensor(0), "sum", persistent=True)
|
||||
assert "a" in a.state_dict()
|
||||
|
||||
a.add_state("b", torch.tensor(0), "sum", persistent=False)
|
||||
a.add_state("b", tensor(0), "sum", persistent=False)
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
if _TORCH_LOWER_1_6:
|
||||
assert "b" not in a.state_dict()
|
||||
|
||||
|
||||
|
@ -83,13 +83,13 @@ def test_reset():
|
|||
|
||||
a = A()
|
||||
assert a.x == 0
|
||||
a.x = torch.tensor(5)
|
||||
a.x = tensor(5)
|
||||
a.reset()
|
||||
assert a.x == 0
|
||||
|
||||
b = B()
|
||||
assert isinstance(b.x, list) and len(b.x) == 0
|
||||
b.x = torch.tensor(5)
|
||||
b.x = tensor(5)
|
||||
b.reset()
|
||||
assert isinstance(b.x, list) and len(b.x) == 0
|
||||
|
||||
|
@ -155,10 +155,10 @@ def test_hash():
|
|||
b2 = B()
|
||||
assert hash(b1) == hash(b2)
|
||||
assert isinstance(b1.x, list) and len(b1.x) == 0
|
||||
b1.x.append(torch.tensor(5))
|
||||
b1.x.append(tensor(5))
|
||||
assert isinstance(hash(b1), int) # <- check that nothing crashes
|
||||
assert isinstance(b1.x, list) and len(b1.x) == 1
|
||||
b2.x.append(torch.tensor(5))
|
||||
b2.x.append(tensor(5))
|
||||
# Sanity:
|
||||
assert isinstance(b2.x, list) and len(b2.x) == 1
|
||||
# Now that they have tensor contents, they should have different hashes:
|
||||
|
@ -222,15 +222,15 @@ def test_child_metric_state_dict():
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = DummyMetric()
|
||||
self.metric.add_state('a', torch.tensor(0), persistent=True)
|
||||
self.metric.add_state('a', tensor(0), persistent=True)
|
||||
self.metric.add_state('b', [], persistent=True)
|
||||
self.metric.register_buffer('c', torch.tensor(0))
|
||||
self.metric.register_buffer('c', tensor(0))
|
||||
|
||||
module = TestModule()
|
||||
expected_state_dict = {
|
||||
'metric.a': torch.tensor(0),
|
||||
'metric.a': tensor(0),
|
||||
'metric.b': [],
|
||||
'metric.c': torch.tensor(0),
|
||||
'metric.c': tensor(0),
|
||||
}
|
||||
assert module.state_dict() == expected_state_dict
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import accuracy_score as sk_accuracy
|
||||
from torch import tensor
|
||||
|
||||
from tests.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.classification.inputs import _input_multiclass as _input_mcls
|
||||
|
@ -109,12 +110,12 @@ _l1to4t3_mcls = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T]
|
|||
|
||||
# The preds in these examples always put highest probability on class 3, second highest on class 2,
|
||||
# third highest on class 1, and lowest on class 0
|
||||
_topk_preds_mcls = torch.tensor([_l1to4t3, _l1to4t3]).float()
|
||||
_topk_target_mcls = torch.tensor([[1, 2, 3], [2, 1, 0]])
|
||||
_topk_preds_mcls = tensor([_l1to4t3, _l1to4t3]).float()
|
||||
_topk_target_mcls = tensor([[1, 2, 3], [2, 1, 0]])
|
||||
|
||||
# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :)
|
||||
_topk_preds_mdmc = torch.tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float()
|
||||
_topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]])
|
||||
_topk_preds_mdmc = tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float()
|
||||
_topk_target_mdmc = tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]])
|
||||
|
||||
|
||||
# Replace with a proper sk_metric test once sklearn 0.24 hits :)
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import auc as _sk_auc
|
||||
from torch import tensor
|
||||
|
||||
from tests.helpers.testers import NUM_BATCHES, MetricTester
|
||||
from torchmetrics.classification.auc import AUC
|
||||
|
@ -43,7 +44,7 @@ for i in range(4):
|
|||
y = y[idx] if i % 2 == 0 else x[idx[::-1]]
|
||||
x = x.reshape(NUM_BATCHES, 8)
|
||||
y = y.reshape(NUM_BATCHES, 8)
|
||||
_examples.append(Input(x=torch.tensor(x), y=torch.tensor(y)))
|
||||
_examples.append(Input(x=tensor(x), y=tensor(y)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x, y", _examples)
|
||||
|
@ -74,4 +75,4 @@ class TestAUC(MetricTester):
|
|||
])
|
||||
def test_auc(x, y, expected):
|
||||
# Test Area Under Curve (AUC) computation
|
||||
assert auc(torch.tensor(x), torch.tensor(y)) == expected
|
||||
assert auc(tensor(x), tensor(y)) == expected
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from distutils.version import LooseVersion
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
|
@ -26,6 +25,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_pro
|
|||
from tests.helpers.testers import NUM_CLASSES, MetricTester
|
||||
from torchmetrics.classification.auroc import AUROC
|
||||
from torchmetrics.functional import auroc
|
||||
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
@ -104,7 +104,7 @@ class TestAUROC(MetricTester):
|
|||
pytest.skip('max_fpr parameter not support for multi class or multi label')
|
||||
|
||||
# max_fpr only supported for torch v1.6 or higher
|
||||
if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
|
||||
if max_fpr is not None and _TORCH_LOWER_1_6:
|
||||
pytest.skip('requires torch v1.6 or higher to test max_fpr argument')
|
||||
|
||||
self.run_class_metric_test(
|
||||
|
@ -127,7 +127,7 @@ class TestAUROC(MetricTester):
|
|||
pytest.skip('max_fpr parameter not support for multi class or multi label')
|
||||
|
||||
# max_fpr only supported for torch v1.6 or higher
|
||||
if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
|
||||
if max_fpr is not None and _TORCH_LOWER_1_6:
|
||||
pytest.skip('requires torch v1.6 or higher to test max_fpr argument')
|
||||
|
||||
self.run_functional_metric_test(
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import average_precision_score as sk_average_precision_score
|
||||
from torch import tensor
|
||||
|
||||
from tests.classification.inputs import _input_binary_prob
|
||||
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
|
@ -101,9 +102,9 @@ class TestAveragePrecision(MetricTester):
|
|||
# And a constant score
|
||||
# The precision is then the fraction of positive whatever the recall
|
||||
# is, as there is only one threshold:
|
||||
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
|
||||
pytest.param(tensor([1, 1, 1, 1]), tensor([0, 0, 0, 1]), .25),
|
||||
# With threshold 0.8 : 1 TP and 2 TN and one FN
|
||||
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
|
||||
pytest.param(tensor([.6, .7, .8, 9]), tensor([1, 0, 0, 1]), .75),
|
||||
]
|
||||
)
|
||||
def test_average_precision(scores, target, expected_score):
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import fbeta_score
|
||||
from torch import tensor
|
||||
|
||||
from tests.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.classification.inputs import _input_multiclass as _input_mcls
|
||||
|
@ -152,8 +153,8 @@ class TestFBeta(MetricTester):
|
|||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]),
|
||||
])
|
||||
def test_fbeta_score(pred, target, beta, exp_score):
|
||||
score = fbeta(torch.tensor(pred), torch.tensor(target), num_classes=1, beta=beta, average='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
score = fbeta(tensor(pred), tensor(target), num_classes=1, beta=beta, average='none')
|
||||
assert torch.allclose(score, tensor(exp_score))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
|
||||
|
@ -162,5 +163,5 @@ def test_fbeta_score(pred, target, beta, exp_score):
|
|||
pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]),
|
||||
])
|
||||
def test_f1_score(pred, target, exp_score):
|
||||
score = f1(torch.tensor(pred), torch.tensor(target), num_classes=1, average='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
score = f1(tensor(pred), tensor(target), num_classes=1, average='none')
|
||||
assert torch.allclose(score, tensor(exp_score))
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import pytest
|
||||
import torch
|
||||
from torch import rand, randint
|
||||
from torch import Tensor, rand, randint, tensor
|
||||
|
||||
from tests.classification.inputs import Input
|
||||
from tests.classification.inputs import _input_binary as _bin
|
||||
|
@ -52,7 +52,7 @@ _mdmc_prob_2cls_preds /= _mdmc_prob_2cls_preds.sum(dim=2, keepdim=True)
|
|||
_mdmc_prob_2cls = Input(_mdmc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)))
|
||||
|
||||
# Some utils
|
||||
T = torch.Tensor
|
||||
T = Tensor
|
||||
|
||||
|
||||
def _idn(x):
|
||||
|
@ -209,7 +209,7 @@ def test_threshold():
|
|||
|
||||
preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5)
|
||||
|
||||
assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int())
|
||||
assert torch.equal(tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int())
|
||||
|
||||
|
||||
########################################################################
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import jaccard_score as sk_jaccard_score
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from tests.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.classification.inputs import _input_multiclass as _input_mcls
|
||||
|
@ -134,12 +135,12 @@ class TestIoU(MetricTester):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
|
||||
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
|
||||
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
|
||||
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
|
||||
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
|
||||
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
|
||||
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
|
||||
pytest.param(False, 'none', None, Tensor([1, 1, 1])),
|
||||
pytest.param(False, 'elementwise_mean', None, Tensor([1])),
|
||||
pytest.param(False, 'none', 0, Tensor([1, 1])),
|
||||
pytest.param(True, 'none', None, Tensor([0.5, 0.5, 0.5])),
|
||||
pytest.param(True, 'elementwise_mean', None, Tensor([0.5])),
|
||||
pytest.param(True, 'none', 0, Tensor([0.5, 0.5])),
|
||||
])
|
||||
def test_iou(half_ones, reduction, ignore_index, expected):
|
||||
pred = (torch.arange(120) % 3).view(-1, 1)
|
||||
|
@ -190,14 +191,14 @@ def test_iou(half_ones, reduction, ignore_index, expected):
|
|||
)
|
||||
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
pred=tensor(pred),
|
||||
target=tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
absent_score=absent_score,
|
||||
num_classes=num_classes,
|
||||
reduction='none',
|
||||
)
|
||||
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
|
||||
assert torch.allclose(iou_val, tensor(expected).to(iou_val))
|
||||
|
||||
|
||||
# example data taken from
|
||||
|
@ -220,10 +221,10 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
|
|||
)
|
||||
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
pred=tensor(pred),
|
||||
target=tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
num_classes=num_classes,
|
||||
reduction=reduction,
|
||||
)
|
||||
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
|
||||
assert torch.allclose(iou_val, tensor(expected).to(iou_val))
|
||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import precision_score, recall_score
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from tests.classification.inputs import _input_binary, _input_binary_prob
|
||||
from tests.classification.inputs import _input_multiclass as _input_mcls
|
||||
|
@ -128,8 +129,8 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
|
|||
def test_zero_division(metric_class, metric_fn):
|
||||
""" Test that zero_division works correctly (currently should just set to 0). """
|
||||
|
||||
preds = torch.tensor([1, 2, 1, 1])
|
||||
target = torch.tensor([2, 1, 2, 1])
|
||||
preds = tensor([1, 2, 1, 1])
|
||||
target = tensor([2, 1, 2, 1])
|
||||
|
||||
cl_metric = metric_class(average="none", num_classes=3)
|
||||
cl_metric(preds, target)
|
||||
|
@ -152,8 +153,8 @@ def test_no_support(metric_class, metric_fn):
|
|||
in this case (zero_division is for now not configurable and equals 0).
|
||||
"""
|
||||
|
||||
preds = torch.tensor([1, 1, 0, 0])
|
||||
target = torch.tensor([0, 0, 0, 0])
|
||||
preds = tensor([1, 1, 0, 0])
|
||||
target = tensor([0, 0, 0, 0])
|
||||
|
||||
cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0)
|
||||
cl_metric(preds, target)
|
||||
|
@ -198,8 +199,8 @@ class TestPrecisionRecall(MetricTester):
|
|||
self,
|
||||
ddp: bool,
|
||||
dist_sync_on_step: bool,
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
sk_wrapper: Callable,
|
||||
metric_class: Metric,
|
||||
metric_fn: Callable,
|
||||
|
@ -248,8 +249,8 @@ class TestPrecisionRecall(MetricTester):
|
|||
|
||||
def test_precision_recall_fn(
|
||||
self,
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
sk_wrapper: Callable,
|
||||
metric_class: Metric,
|
||||
metric_fn: Callable,
|
||||
|
@ -316,31 +317,31 @@ def test_precision_recall_joint(average):
|
|||
assert torch.equal(recall_result, prec_recall_result[1])
|
||||
|
||||
|
||||
_mc_k_target = torch.tensor([0, 1, 2])
|
||||
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
|
||||
_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
|
||||
_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
|
||||
_mc_k_target = tensor([0, 1, 2])
|
||||
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
|
||||
_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
|
||||
_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
|
||||
@pytest.mark.parametrize(
|
||||
"k, preds, target, average, expected_prec, expected_recall",
|
||||
[
|
||||
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
|
||||
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)),
|
||||
(1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)),
|
||||
(2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)),
|
||||
(1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)),
|
||||
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)),
|
||||
(1, _ml_k_preds, _ml_k_target, "micro", tensor(0.0), tensor(0.0)),
|
||||
(2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6), tensor(1 / 3)),
|
||||
],
|
||||
)
|
||||
def test_top_k(
|
||||
metric_class,
|
||||
metric_fn,
|
||||
k: int,
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
average: str,
|
||||
expected_prec: torch.Tensor,
|
||||
expected_recall: torch.Tensor,
|
||||
expected_prec: Tensor,
|
||||
expected_recall: Tensor,
|
||||
):
|
||||
"""A simple test to check that top_k works as expected.
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve
|
||||
from torch import tensor
|
||||
|
||||
from tests.classification.inputs import _input_binary_prob
|
||||
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
|
@ -101,10 +102,10 @@ class TestPrecisionRecallCurve(MetricTester):
|
|||
[pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])]
|
||||
)
|
||||
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
|
||||
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
|
||||
p, r, t = precision_recall_curve(tensor(pred), tensor(target))
|
||||
assert p.size() == r.size()
|
||||
assert p.size(0) == t.size(0) + 1
|
||||
|
||||
assert torch.allclose(p, torch.tensor(expected_p).to(p))
|
||||
assert torch.allclose(r, torch.tensor(expected_r).to(r))
|
||||
assert torch.allclose(t, torch.tensor(expected_t).to(t))
|
||||
assert torch.allclose(p, tensor(expected_p).to(p))
|
||||
assert torch.allclose(r, tensor(expected_r).to(r))
|
||||
assert torch.allclose(t, tensor(expected_t).to(t))
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import roc_curve as sk_roc_curve
|
||||
from torch import tensor
|
||||
|
||||
from tests.classification.inputs import _input_binary_prob
|
||||
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
|
@ -104,9 +105,9 @@ class TestROC(MetricTester):
|
|||
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
|
||||
])
|
||||
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
|
||||
fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target))
|
||||
fpr, tpr, thresh = roc(tensor(pred), tensor(target))
|
||||
|
||||
assert fpr.shape == tpr.shape
|
||||
assert fpr.size(0) == thresh.size(0)
|
||||
assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr))
|
||||
assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
|
||||
assert torch.allclose(fpr, tensor(expected_fpr).to(fpr))
|
||||
assert torch.allclose(tpr, tensor(expected_tpr).to(tpr))
|
||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import multilabel_confusion_matrix
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from tests.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
|
||||
from tests.classification.inputs import _input_multiclass_prob as _input_mccls_prob
|
||||
|
@ -159,8 +160,8 @@ class TestStatScores(MetricTester):
|
|||
ddp: bool,
|
||||
dist_sync_on_step: bool,
|
||||
sk_fn: Callable,
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
reduce: str,
|
||||
mdmc_reduce: Optional[str],
|
||||
num_classes: Optional[int],
|
||||
|
@ -202,8 +203,8 @@ class TestStatScores(MetricTester):
|
|||
def test_stat_scores_fn(
|
||||
self,
|
||||
sk_fn: Callable,
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
reduce: str,
|
||||
mdmc_reduce: Optional[str],
|
||||
num_classes: Optional[int],
|
||||
|
@ -239,26 +240,26 @@ class TestStatScores(MetricTester):
|
|||
)
|
||||
|
||||
|
||||
_mc_k_target = torch.tensor([0, 1, 2])
|
||||
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
|
||||
_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
|
||||
_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
|
||||
_mc_k_target = tensor([0, 1, 2])
|
||||
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
|
||||
_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
|
||||
_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k, preds, target, reduce, expected",
|
||||
[
|
||||
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])),
|
||||
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])),
|
||||
(1, _ml_k_preds, _ml_k_target, "micro", torch.tensor([0, 3, 3, 3, 3])),
|
||||
(2, _ml_k_preds, _ml_k_target, "micro", torch.tensor([1, 5, 1, 2, 3])),
|
||||
(1, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])),
|
||||
(2, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])),
|
||||
(1, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])),
|
||||
(2, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])),
|
||||
(1, _mc_k_preds, _mc_k_target, "micro", tensor([2, 1, 5, 1, 3])),
|
||||
(2, _mc_k_preds, _mc_k_target, "micro", tensor([3, 3, 3, 0, 3])),
|
||||
(1, _ml_k_preds, _ml_k_target, "micro", tensor([0, 3, 3, 3, 3])),
|
||||
(2, _ml_k_preds, _ml_k_target, "micro", tensor([1, 5, 1, 2, 3])),
|
||||
(1, _mc_k_preds, _mc_k_target, "macro", tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])),
|
||||
(2, _mc_k_preds, _mc_k_target, "macro", tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])),
|
||||
(1, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])),
|
||||
(2, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])),
|
||||
],
|
||||
)
|
||||
def test_top_k(k: int, preds: torch.Tensor, target: torch.Tensor, reduce: str, expected: torch.Tensor):
|
||||
def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Tensor):
|
||||
""" A simple test to check that top_k works as expected """
|
||||
|
||||
class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.functional import dice_score
|
||||
from torchmetrics.functional.classification.precision_recall_curve import _binary_clf_curve
|
||||
|
@ -21,7 +22,7 @@ from torchmetrics.utilities.data import get_num_classes, to_categorical, to_oneh
|
|||
|
||||
|
||||
def test_onehot():
|
||||
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
test_tensor = tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
expected = torch.stack([
|
||||
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
|
||||
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
|
||||
|
@ -48,7 +49,7 @@ def test_to_categorical():
|
|||
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
|
||||
]).to(torch.float)
|
||||
|
||||
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
expected = tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
assert expected.shape == (2, 5)
|
||||
assert test_tensor.shape == (2, 10, 5)
|
||||
|
||||
|
@ -77,15 +78,15 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
|
|||
# because when the array changes, you also have to fix the shape
|
||||
seed_everything(0)
|
||||
pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100
|
||||
target = torch.tensor([0, 1] * 50, dtype=torch.int)
|
||||
target = tensor([0, 1] * 50, dtype=torch.int)
|
||||
if sample_weight is not None:
|
||||
sample_weight = torch.ones_like(pred) * sample_weight
|
||||
|
||||
fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
|
||||
|
||||
assert isinstance(tps, torch.Tensor)
|
||||
assert isinstance(fps, torch.Tensor)
|
||||
assert isinstance(thresh, torch.Tensor)
|
||||
assert isinstance(tps, Tensor)
|
||||
assert isinstance(fps, Tensor)
|
||||
assert isinstance(thresh, Tensor)
|
||||
assert tps.shape == (exp_shape, )
|
||||
assert fps.shape == (exp_shape, )
|
||||
assert thresh.shape == (exp_shape, )
|
||||
|
@ -98,5 +99,5 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
|
|||
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.),
|
||||
])
|
||||
def test_dice_score(pred, target, expected):
|
||||
score = dice_score(torch.tensor(pred), torch.tensor(target))
|
||||
score = dice_score(tensor(pred), tensor(target))
|
||||
assert score == expected
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional import image_gradients
|
||||
|
||||
|
@ -73,8 +74,8 @@ def test_multi_batch_image_gradients():
|
|||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
]
|
||||
true_dy = torch.Tensor(true_dy)
|
||||
true_dx = torch.Tensor(true_dx)
|
||||
true_dy = Tensor(true_dy)
|
||||
true_dx = Tensor(true_dx)
|
||||
|
||||
dy, dx = image_gradients(image)
|
||||
|
||||
|
@ -113,8 +114,8 @@ def test_image_gradients():
|
|||
[1., 1., 1., 1., 0.],
|
||||
]
|
||||
|
||||
true_dy = torch.Tensor(true_dy)
|
||||
true_dx = torch.Tensor(true_dx)
|
||||
true_dy = Tensor(true_dy)
|
||||
true_dx = Tensor(true_dx)
|
||||
|
||||
dy, dx = image_gradients(image)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
|
||||
from torch import tensor
|
||||
|
||||
from torchmetrics.functional import bleu_score
|
||||
|
||||
|
@ -62,20 +63,20 @@ def test_bleu_score(weights, n_gram, smooth_func, smooth):
|
|||
smoothing_function=smooth_func,
|
||||
)
|
||||
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
|
||||
assert torch.allclose(pl_output, torch.tensor(nltk_output))
|
||||
assert torch.allclose(pl_output, tensor(nltk_output))
|
||||
|
||||
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
|
||||
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
|
||||
assert torch.allclose(pl_output, torch.tensor(nltk_output))
|
||||
assert torch.allclose(pl_output, tensor(nltk_output))
|
||||
|
||||
|
||||
def test_bleu_empty():
|
||||
hyp = [[]]
|
||||
ref = [[[]]]
|
||||
assert bleu_score(hyp, ref) == torch.tensor(0.0)
|
||||
assert bleu_score(hyp, ref) == tensor(0.0)
|
||||
|
||||
|
||||
def test_no_4_gram():
|
||||
hyps = [["My", "full", "pytorch-lightning"]]
|
||||
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
|
||||
assert bleu_score(hyps, refs) == torch.tensor(0.0)
|
||||
assert bleu_score(hyps, refs) == tensor(0.0)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import pairwise
|
||||
from torch import tensor
|
||||
|
||||
from torchmetrics.functional import embedding_similarity
|
||||
|
||||
|
@ -40,6 +41,6 @@ def test_against_sklearn(similarity, reduction):
|
|||
return dist
|
||||
|
||||
sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), similarity=similarity, reduction=reduction)
|
||||
sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device)
|
||||
sk_dist = tensor(sk_dist, dtype=torch.float, device=device)
|
||||
|
||||
assert torch.allclose(sk_dist, pl_dist)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from distutils.version import LooseVersion
|
||||
from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6
|
||||
|
||||
import torch
|
||||
|
||||
_MARK_TORCH_MIN_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.4"), reason='required PT >= 1.4')
|
||||
_MARK_TORCH_MIN_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5"), reason='required PT >= 1.5')
|
||||
_MARK_TORCH_MIN_1_6 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6"), reason='required PT >= 1.6')
|
||||
_MARK_TORCH_MIN_1_4 = dict(condition=_TORCH_LOWER_1_4, reason='required PT >= 1.4')
|
||||
_MARK_TORCH_MIN_1_5 = dict(condition=_TORCH_LOWER_1_5, reason='required PT >= 1.5')
|
||||
_MARK_TORCH_MIN_1_6 = dict(condition=_TORCH_LOWER_1_6, reason='required PT >= 1.6')
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import Callable
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
from torch.multiprocessing import Pool, set_start_method
|
||||
|
||||
from torchmetrics import Metric
|
||||
|
@ -51,7 +52,7 @@ def _assert_allclose(pl_result, sk_result, atol: float = 1e-8):
|
|||
a certain tolerance
|
||||
"""
|
||||
# single output compare
|
||||
if isinstance(pl_result, torch.Tensor):
|
||||
if isinstance(pl_result, Tensor):
|
||||
assert np.allclose(pl_result.numpy(), sk_result, atol=atol, equal_nan=True)
|
||||
# multi output compare
|
||||
elif isinstance(pl_result, (tuple, list)):
|
||||
|
@ -69,14 +70,14 @@ def _assert_tensor(pl_result):
|
|||
for plr in pl_result:
|
||||
_assert_tensor(plr)
|
||||
else:
|
||||
assert isinstance(pl_result, torch.Tensor)
|
||||
assert isinstance(pl_result, Tensor)
|
||||
|
||||
|
||||
def _class_test(
|
||||
rank: int,
|
||||
worldsize: int,
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
metric_class: Metric,
|
||||
sk_metric: Callable,
|
||||
dist_sync_on_step: bool,
|
||||
|
@ -140,8 +141,8 @@ def _class_test(
|
|||
|
||||
|
||||
def _functional_test(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
metric_functional: Callable,
|
||||
sk_metric: Callable,
|
||||
metric_args: dict = {},
|
||||
|
@ -195,8 +196,8 @@ class MetricTester:
|
|||
|
||||
def run_functional_metric_test(
|
||||
self,
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
metric_functional: Callable,
|
||||
sk_metric: Callable,
|
||||
metric_args: dict = {},
|
||||
|
@ -223,8 +224,8 @@ class MetricTester:
|
|||
def run_class_metric_test(
|
||||
self,
|
||||
ddp: bool,
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
metric_class: Metric,
|
||||
sk_metric: Callable,
|
||||
dist_sync_on_step: bool,
|
||||
|
@ -289,7 +290,7 @@ class DummyMetric(Metric):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None)
|
||||
self.add_state("x", tensor(0.0), dist_reduce_fx=None)
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from torch import tensor
|
||||
|
||||
from tests.integrations.lightning_models import BoringModel
|
||||
from torchmetrics import Metric
|
||||
|
@ -22,7 +23,7 @@ class SumMetric(Metric):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
@ -35,7 +36,7 @@ class DiffMetric(Metric):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, x):
|
||||
self.x -= x
|
||||
|
@ -116,8 +117,8 @@ def test_metric_lightning(tmpdir):
|
|||
# trainer.fit(model)
|
||||
#
|
||||
# logged = trainer.logged_metrics
|
||||
# assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum)
|
||||
# assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)
|
||||
# assert torch.allclose(tensor(logged["sum_step"]), model.sum)
|
||||
# assert torch.allclose(tensor(logged["sum_epoch"]), model.sum)
|
||||
|
||||
# todo: need to be fixed
|
||||
# def test_scriptable(tmpdir):
|
||||
|
@ -193,5 +194,5 @@ def test_metric_lightning(tmpdir):
|
|||
# trainer.fit(model)
|
||||
#
|
||||
# logged = trainer.logged_metrics
|
||||
# assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum)
|
||||
# assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)
|
||||
# assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum)
|
||||
# assert torch.allclose(tensor(logged["DiffMetric_epoch"]), model.diff)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.functional.classification.accuracy import _accuracy_compute, _accuracy_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -112,8 +113,8 @@ class Accuracy(Metric):
|
|||
dist_sync_fn=dist_sync_fn,
|
||||
)
|
||||
|
||||
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("correct", default=tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
if not 0 < threshold < 1:
|
||||
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")
|
||||
|
@ -125,7 +126,7 @@ class Accuracy(Metric):
|
|||
self.top_k = top_k
|
||||
self.subset_accuracy = subset_accuracy
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets. See :ref:`references/modules:input types` for more information
|
||||
on input types.
|
||||
|
@ -142,7 +143,7 @@ class Accuracy(Metric):
|
|||
self.correct += correct
|
||||
self.total += total
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes accuracy based on inputs passed in to ``update`` previously.
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.auc import _auc_compute, _auc_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -68,7 +69,7 @@ class AUC(Metric):
|
|||
' For large datasets this may lead to large memory footprint.'
|
||||
)
|
||||
|
||||
def update(self, x: torch.Tensor, y: torch.Tensor):
|
||||
def update(self, x: Tensor, y: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
@ -81,7 +82,7 @@ class AUC(Metric):
|
|||
self.x.append(x)
|
||||
self.y.append(y)
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes AUC based on inputs passed in to ``update`` previously.
|
||||
"""
|
||||
|
|
|
@ -11,14 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.auroc import _auroc_compute, _auroc_update
|
||||
from torchmetrics.metric import Metric
|
||||
from torchmetrics.utilities import rank_zero_warn
|
||||
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
|
||||
|
||||
|
||||
class AUROC(Metric):
|
||||
|
@ -120,7 +121,7 @@ class AUROC(Metric):
|
|||
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
|
||||
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
|
||||
|
||||
if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
|
||||
if _TORCH_LOWER_1_6:
|
||||
raise RuntimeError(
|
||||
'`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6'
|
||||
)
|
||||
|
@ -134,7 +135,7 @@ class AUROC(Metric):
|
|||
' For large datasets this may lead to large memory footprint.'
|
||||
)
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
@ -154,7 +155,7 @@ class AUROC(Metric):
|
|||
)
|
||||
self.mode = mode
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes AUROC based on inputs passed in to ``update`` previously.
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.average_precision import (
|
||||
_average_precision_compute,
|
||||
|
@ -97,7 +98,7 @@ class AveragePrecision(Metric):
|
|||
' For large datasets this may lead to large memory footprint.'
|
||||
)
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
@ -113,7 +114,7 @@ class AveragePrecision(Metric):
|
|||
self.num_classes = num_classes
|
||||
self.pos_label = pos_label
|
||||
|
||||
def compute(self) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
def compute(self) -> Union[Tensor, List[Tensor]]:
|
||||
"""
|
||||
Compute the average precision score
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.cohen_kappa import _cohen_kappa_compute, _cohen_kappa_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -74,6 +75,7 @@ class CohenKappa(Metric):
|
|||
>>> cohenkappa(preds, target)
|
||||
tensor(0.5000)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
|
@ -99,7 +101,7 @@ class CohenKappa(Metric):
|
|||
|
||||
self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
@ -110,7 +112,7 @@ class CohenKappa(Metric):
|
|||
confmat = _cohen_kappa_update(preds, target, self.num_classes, self.threshold)
|
||||
self.confmat += confmat
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes cohen kappa score
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -96,7 +97,7 @@ class ConfusionMatrix(Metric):
|
|||
|
||||
self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
@ -107,7 +108,7 @@ class ConfusionMatrix(Metric):
|
|||
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold)
|
||||
self.confmat += confmat
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes confusion matrix
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.f_beta import _fbeta_compute, _fbeta_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -109,7 +110,7 @@ class FBeta(Metric):
|
|||
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
||||
self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
@ -125,7 +126,7 @@ class FBeta(Metric):
|
|||
self.predicted_positives += predicted_positives
|
||||
self.actual_positives += actual_positives
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes fbeta over state.
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.functional.classification.hamming_distance import _hamming_distance_compute, _hamming_distance_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -79,14 +80,14 @@ class HammingDistance(Metric):
|
|||
dist_sync_fn=dist_sync_fn,
|
||||
)
|
||||
|
||||
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("correct", default=tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
if not 0 < threshold < 1:
|
||||
raise ValueError("The `threshold` should lie in the (0,1) interval.")
|
||||
self.threshold = threshold
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets. See :ref:`references/modules:input types` for more information
|
||||
on input types.
|
||||
|
@ -100,7 +101,7 @@ class HammingDistance(Metric):
|
|||
self.correct += correct
|
||||
self.total += total
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes hamming distance based on inputs passed in to ``update`` previously.
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.classification.confusion_matrix import ConfusionMatrix
|
||||
from torchmetrics.functional.classification.iou import _iou_from_confmat
|
||||
|
@ -100,7 +101,7 @@ class IoU(ConfusionMatrix):
|
|||
self.ignore_index = ignore_index
|
||||
self.absent_score = absent_score
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes intersection over union (IoU)
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.classification.stat_scores import StatScores
|
||||
from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute
|
||||
|
@ -151,7 +152,7 @@ class Precision(StatScores):
|
|||
|
||||
self.average = average
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes the precision score based on inputs passed in to ``update`` previously.
|
||||
|
||||
|
@ -299,7 +300,7 @@ class Recall(StatScores):
|
|||
|
||||
self.average = average
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes the recall score based on inputs passed in to ``update`` previously.
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.precision_recall_curve import (
|
||||
_precision_recall_curve_compute,
|
||||
|
@ -108,7 +109,7 @@ class PrecisionRecallCurve(Metric):
|
|||
' For large datasets this may lead to large memory footprint.'
|
||||
)
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
@ -124,10 +125,7 @@ class PrecisionRecallCurve(Metric):
|
|||
self.num_classes = num_classes
|
||||
self.pos_label = pos_label
|
||||
|
||||
def compute(
|
||||
self
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
|
||||
List[torch.Tensor]]]:
|
||||
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
|
||||
"""
|
||||
Compute the precision-recall curve
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.roc import _roc_compute, _roc_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -107,7 +108,7 @@ class ROC(Metric):
|
|||
' For large datasets this may lead to large memory footprint.'
|
||||
)
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
@ -121,10 +122,7 @@ class ROC(Metric):
|
|||
self.num_classes = num_classes
|
||||
self.pos_label = pos_label
|
||||
|
||||
def compute(
|
||||
self
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
|
||||
List[torch.Tensor]]]:
|
||||
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
|
||||
"""
|
||||
Compute the receiver operating characteristic
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from typing import Any, Callable, Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.functional.classification.stat_scores import _stat_scores_compute, _stat_scores_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -175,7 +176,7 @@ class StatScores(Metric):
|
|||
for s in ("tp", "fp", "tn", "fn"):
|
||||
self.add_state(s, default=default(), dist_reduce_fx=reduce_fn)
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets. See :ref:`references/modules:input types` for more information
|
||||
on input types.
|
||||
|
@ -209,7 +210,7 @@ class StatScores(Metric):
|
|||
self.tn.append(tn)
|
||||
self.fn.append(fn)
|
||||
|
||||
def _get_final_stats(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Performs concatenation on the stat scores if neccesary,
|
||||
before passing them to a compute function.
|
||||
"""
|
||||
|
@ -224,7 +225,7 @@ class StatScores(Metric):
|
|||
|
||||
return tp, fp, tn, fn
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes the stat scores based on inputs passed in to ``update`` previously.
|
||||
|
||||
|
@ -262,13 +263,13 @@ class StatScores(Metric):
|
|||
|
||||
|
||||
def _reduce_stat_scores(
|
||||
numerator: torch.Tensor,
|
||||
denominator: torch.Tensor,
|
||||
weights: Optional[torch.Tensor],
|
||||
numerator: Tensor,
|
||||
denominator: Tensor,
|
||||
weights: Optional[Tensor],
|
||||
average: str,
|
||||
mdmc_average: Optional[str],
|
||||
zero_division: int = 0,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Reduces scores of type ``numerator/denominator`` or
|
||||
``weights * (numerator/denominator)``, if ``average='weighted'``.
|
||||
|
@ -303,9 +304,9 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
|
|||
else:
|
||||
weights = weights.float()
|
||||
|
||||
numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator)
|
||||
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
|
||||
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)
|
||||
numerator = torch.where(zero_div_mask, tensor(float(zero_division), device=numerator.device), numerator)
|
||||
denominator = torch.where(zero_div_mask | ignore_mask, tensor(1.0, device=denominator.device), denominator)
|
||||
weights = torch.where(ignore_mask, tensor(0.0, device=weights.device), weights)
|
||||
|
||||
if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
|
||||
weights = weights / weights.sum(dim=-1, keepdim=True)
|
||||
|
@ -313,14 +314,14 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
|
|||
scores = weights * (numerator / denominator)
|
||||
|
||||
# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
|
||||
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)
|
||||
scores = torch.where(torch.isnan(scores), tensor(float(zero_division), device=scores.device), scores)
|
||||
|
||||
if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
|
||||
scores = scores.mean(dim=0)
|
||||
ignore_mask = ignore_mask.sum(dim=0).bool()
|
||||
|
||||
if average in (AverageMethod.NONE, None):
|
||||
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
|
||||
scores = torch.where(ignore_mask, tensor(np.nan, device=scores.device), scores)
|
||||
else:
|
||||
scores = scores.sum()
|
||||
|
||||
|
|
|
@ -60,10 +60,11 @@ class MetricCollection(nn.ModuleDict):
|
|||
>>> metrics.persistent()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]],
|
||||
prefix: Optional[str] = None
|
||||
self,
|
||||
metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]],
|
||||
prefix: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(metrics, dict):
|
||||
|
|
|
@ -14,14 +14,19 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.utilities.checks import _input_format_classification
|
||||
from torchmetrics.utilities.enums import DataType
|
||||
|
||||
|
||||
def _accuracy_update(
|
||||
preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
threshold: float,
|
||||
top_k: Optional[int],
|
||||
subset_accuracy: bool,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
|
||||
|
||||
|
@ -30,32 +35,32 @@ def _accuracy_update(
|
|||
|
||||
if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
|
||||
correct = (preds == target).all(dim=1).sum()
|
||||
total = torch.tensor(target.shape[0], device=target.device)
|
||||
total = tensor(target.shape[0], device=target.device)
|
||||
elif mode == DataType.MULTILABEL and not subset_accuracy:
|
||||
correct = (preds == target).sum()
|
||||
total = torch.tensor(target.numel(), device=target.device)
|
||||
total = tensor(target.numel(), device=target.device)
|
||||
elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy):
|
||||
correct = (preds * target).sum()
|
||||
total = target.sum()
|
||||
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
|
||||
sample_correct = (preds * target).sum(dim=(1, 2))
|
||||
correct = (sample_correct == target.shape[2]).sum()
|
||||
total = torch.tensor(target.shape[0], device=target.device)
|
||||
total = tensor(target.shape[0], device=target.device)
|
||||
|
||||
return correct, total
|
||||
|
||||
|
||||
def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor:
|
||||
def _accuracy_compute(correct: Tensor, total: Tensor) -> Tensor:
|
||||
return correct.float() / total
|
||||
|
||||
|
||||
def accuracy(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
threshold: float = 0.5,
|
||||
top_k: Optional[int] = None,
|
||||
subset_accuracy: bool = False,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
r"""Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:
|
||||
|
||||
.. math::
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.data import _stable_1d_sort
|
||||
|
||||
|
||||
def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
if x.ndim > 1 or y.ndim > 1:
|
||||
raise ValueError(
|
||||
f'Expected both `x` and `y` tensor to be 1d, but got'
|
||||
|
@ -32,7 +33,7 @@ def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.T
|
|||
return x, y
|
||||
|
||||
|
||||
def _auc_compute(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor:
|
||||
def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
|
||||
if reorder:
|
||||
x, x_idx = _stable_1d_sort(x)
|
||||
y = y[x_idx]
|
||||
|
@ -51,7 +52,7 @@ def _auc_compute(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> tor
|
|||
return direction * torch.trapz(y, x)
|
||||
|
||||
|
||||
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor:
|
||||
def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
|
||||
"""
|
||||
Computes Area Under the Curve (AUC) using the trapezoidal rule
|
||||
|
||||
|
|
|
@ -11,18 +11,19 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.functional.classification.auc import auc
|
||||
from torchmetrics.functional.classification.roc import roc
|
||||
from torchmetrics.utilities.checks import _input_format_classification
|
||||
from torchmetrics.utilities.enums import AverageMethod, DataType
|
||||
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
|
||||
|
||||
|
||||
def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]:
|
||||
def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, str]:
|
||||
# use _input_format_classification for validating the input and get the mode of data
|
||||
_, _, mode = _input_format_classification(preds, target)
|
||||
|
||||
|
@ -39,15 +40,15 @@ def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens
|
|||
|
||||
|
||||
def _auroc_compute(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
mode: str,
|
||||
num_classes: Optional[int] = None,
|
||||
pos_label: Optional[int] = None,
|
||||
average: Optional[str] = 'macro',
|
||||
max_fpr: Optional[float] = None,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
# binary mode override num_classes
|
||||
if mode == 'binary':
|
||||
num_classes = 1
|
||||
|
@ -57,7 +58,7 @@ def _auroc_compute(
|
|||
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
|
||||
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
|
||||
|
||||
if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
|
||||
if _TORCH_LOWER_1_6:
|
||||
raise RuntimeError(
|
||||
"`max_fpr` argument requires `torch.bucketize` which"
|
||||
" is not available below PyTorch version 1.6"
|
||||
|
@ -109,7 +110,7 @@ def _auroc_compute(
|
|||
|
||||
return auc(fpr, tpr)
|
||||
|
||||
max_fpr = torch.tensor(max_fpr, device=fpr.device)
|
||||
max_fpr = tensor(max_fpr, device=fpr.device)
|
||||
# Add a single point at max_fpr and interpolate its tpr value
|
||||
stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True)
|
||||
weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
|
||||
|
@ -128,14 +129,14 @@ def _auroc_compute(
|
|||
|
||||
|
||||
def auroc(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
pos_label: Optional[int] = None,
|
||||
average: Optional[str] = 'macro',
|
||||
max_fpr: Optional[float] = None,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
""" Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
||||
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations>`_
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.precision_recall_curve import (
|
||||
_precision_recall_curve_compute,
|
||||
|
@ -22,21 +23,21 @@ from torchmetrics.functional.classification.precision_recall_curve import (
|
|||
|
||||
|
||||
def _average_precision_update(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
pos_label: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
|
||||
) -> Tuple[Tensor, Tensor, int, int]:
|
||||
return _precision_recall_curve_update(preds, target, num_classes, pos_label)
|
||||
|
||||
|
||||
def _average_precision_compute(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: int,
|
||||
pos_label: int,
|
||||
sample_weights: Optional[Sequence] = None
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Union[List[Tensor], Tensor]:
|
||||
precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
|
||||
# Return the step function integral
|
||||
# The following works because the last entry of precision is
|
||||
|
@ -51,12 +52,12 @@ def _average_precision_compute(
|
|||
|
||||
|
||||
def average_precision(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
pos_label: Optional[int] = None,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
) -> Union[List[Tensor], Tensor]:
|
||||
"""
|
||||
Computes the average precision score.
|
||||
|
||||
|
|
|
@ -14,13 +14,14 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
|
||||
|
||||
_cohen_kappa_update = _confusion_matrix_update
|
||||
|
||||
|
||||
def _cohen_kappa_compute(confmat: torch.Tensor, weights: Optional[str] = None) -> torch.Tensor:
|
||||
def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tensor:
|
||||
confmat = _confusion_matrix_compute(confmat)
|
||||
n_classes = confmat.shape[0]
|
||||
sum0 = confmat.sum(dim=0, keepdim=True)
|
||||
|
@ -39,20 +40,18 @@ def _cohen_kappa_compute(confmat: torch.Tensor, weights: Optional[str] = None) -
|
|||
else:
|
||||
w_mat = torch.pow(w_mat - w_mat.T, 2.0)
|
||||
else:
|
||||
raise ValueError(f"Received {weights} for argument ``weights`` but should be either"
|
||||
" None, 'linear' or 'quadratic'")
|
||||
raise ValueError(
|
||||
f"Received {weights} for argument ``weights`` but should be either"
|
||||
" None, 'linear' or 'quadratic'"
|
||||
)
|
||||
|
||||
k = torch.sum(w_mat * confmat) / torch.sum(w_mat * expected)
|
||||
return 1 - k
|
||||
|
||||
|
||||
def cohen_kappa(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: int,
|
||||
weights: Optional[str] = None,
|
||||
threshold: float = 0.5
|
||||
) -> torch.Tensor:
|
||||
preds: Tensor, target: Tensor, num_classes: int, weights: Optional[str] = None, threshold: float = 0.5
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Calculates `Cohen's kappa score <https://en.wikipedia.org/wiki/Cohen%27s_kappa>`_ that measures
|
||||
inter-annotator agreement. It is defined as
|
||||
|
|
|
@ -14,15 +14,14 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities import rank_zero_warn
|
||||
from torchmetrics.utilities.checks import _input_format_classification
|
||||
from torchmetrics.utilities.enums import DataType
|
||||
|
||||
|
||||
def _confusion_matrix_update(
|
||||
preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5
|
||||
) -> torch.Tensor:
|
||||
def _confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5) -> Tensor:
|
||||
preds, target, mode = _input_format_classification(preds, target, threshold)
|
||||
if mode not in (DataType.BINARY, DataType.MULTILABEL):
|
||||
preds = preds.argmax(dim=1)
|
||||
|
@ -33,7 +32,7 @@ def _confusion_matrix_update(
|
|||
return confmat
|
||||
|
||||
|
||||
def _confusion_matrix_compute(confmat: torch.Tensor, normalize: Optional[str] = None) -> torch.Tensor:
|
||||
def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor:
|
||||
allowed_normalize = ('true', 'pred', 'all', 'none', None)
|
||||
assert normalize in allowed_normalize, \
|
||||
f"Argument average needs to one of the following: {allowed_normalize}"
|
||||
|
@ -54,12 +53,8 @@ def _confusion_matrix_compute(confmat: torch.Tensor, normalize: Optional[str] =
|
|||
|
||||
|
||||
def confusion_matrix(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: int,
|
||||
normalize: Optional[str] = None,
|
||||
threshold: float = 0.5
|
||||
) -> torch.Tensor:
|
||||
preds: Tensor, target: Tensor, num_classes: int, normalize: Optional[str] = None, threshold: float = 0.5
|
||||
) -> Tensor:
|
||||
"""
|
||||
Computes the confusion matrix. Works with binary, multiclass, and multilabel data.
|
||||
Accepts probabilities from a model output or integer class values in prediction.
|
||||
|
|
|
@ -14,17 +14,18 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.data import to_categorical
|
||||
from torchmetrics.utilities.distributed import reduce
|
||||
|
||||
|
||||
def _stat_scores(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
class_index: int,
|
||||
argmax_dim: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Calculates the number of true positive, false positive, true negative
|
||||
and false negative for a specific class
|
||||
|
@ -61,13 +62,13 @@ def _stat_scores(
|
|||
|
||||
|
||||
def dice_score(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
pred: Tensor,
|
||||
target: Tensor,
|
||||
bg: bool = False,
|
||||
nan_score: float = 0.0,
|
||||
no_fg_score: float = 0.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Compute dice score from prediction scores
|
||||
|
||||
|
|
|
@ -14,18 +14,19 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.checks import _input_format_classification_one_hot
|
||||
from torchmetrics.utilities.distributed import class_reduce
|
||||
|
||||
|
||||
def _fbeta_update(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: int,
|
||||
threshold: float = 0.5,
|
||||
multilabel: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
multilabel: bool = False,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
preds, target = _input_format_classification_one_hot(num_classes, preds, target, threshold, multilabel)
|
||||
true_positives = torch.sum(preds * target, dim=1)
|
||||
predicted_positives = torch.sum(preds, dim=1)
|
||||
|
@ -34,12 +35,12 @@ def _fbeta_update(
|
|||
|
||||
|
||||
def _fbeta_compute(
|
||||
true_positives: torch.Tensor,
|
||||
predicted_positives: torch.Tensor,
|
||||
actual_positives: torch.Tensor,
|
||||
true_positives: Tensor,
|
||||
predicted_positives: Tensor,
|
||||
actual_positives: Tensor,
|
||||
beta: float = 1.0,
|
||||
average: str = "micro"
|
||||
) -> torch.Tensor:
|
||||
average: str = "micro",
|
||||
) -> Tensor:
|
||||
if average == "micro":
|
||||
precision = true_positives.sum().float() / predicted_positives.sum()
|
||||
recall = true_positives.sum().float() / actual_positives.sum()
|
||||
|
@ -53,14 +54,14 @@ def _fbeta_compute(
|
|||
|
||||
|
||||
def fbeta(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: int,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 0.5,
|
||||
average: str = "micro",
|
||||
multilabel: bool = False
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Computes f_beta metric.
|
||||
|
||||
|
@ -106,13 +107,13 @@ def fbeta(
|
|||
|
||||
|
||||
def f1(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: int,
|
||||
threshold: float = 0.5,
|
||||
average: str = "micro",
|
||||
multilabel: bool = False
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Computes F1 metric. F1 metrics correspond to a equally weighted average of the
|
||||
precision and recall scores.
|
||||
|
|
|
@ -14,15 +14,16 @@
|
|||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.checks import _input_format_classification
|
||||
|
||||
|
||||
def _hamming_distance_update(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
threshold: float = 0.5,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
) -> Tuple[Tensor, int]:
|
||||
preds, target, _ = _input_format_classification(preds, target, threshold=threshold)
|
||||
|
||||
correct = (preds == target).sum()
|
||||
|
@ -31,11 +32,11 @@ def _hamming_distance_update(
|
|||
return correct, total
|
||||
|
||||
|
||||
def _hamming_distance_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor:
|
||||
def _hamming_distance_compute(correct: Tensor, total: Union[int, Tensor]) -> Tensor:
|
||||
return 1 - correct.float() / total
|
||||
|
||||
|
||||
def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
|
||||
def hamming_distance(preds: Tensor, target: Tensor, threshold: float = 0.5) -> Tensor:
|
||||
r"""
|
||||
Computes the average `Hamming distance <https://en.wikipedia.org/wiki/Hamming_distance>`_ (also
|
||||
known as Hamming loss) between targets and predictions:
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update
|
||||
from torchmetrics.utilities.data import get_num_classes
|
||||
|
@ -21,7 +22,7 @@ from torchmetrics.utilities.distributed import reduce
|
|||
|
||||
|
||||
def _iou_from_confmat(
|
||||
confmat: torch.Tensor,
|
||||
confmat: Tensor,
|
||||
num_classes: int,
|
||||
ignore_index: Optional[int] = None,
|
||||
absent_score: float = 0.0,
|
||||
|
@ -44,14 +45,14 @@ def _iou_from_confmat(
|
|||
|
||||
|
||||
def iou(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
pred: Tensor,
|
||||
target: Tensor,
|
||||
ignore_index: Optional[int] = None,
|
||||
absent_score: float = 0.0,
|
||||
threshold: float = 0.5,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Computes `Intersection over union, or Jaccard index calculation <https://en.wikipedia.org/wiki/Jaccard_index>`_:
|
||||
|
||||
|
|
|
@ -14,19 +14,20 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.classification.stat_scores import _reduce_stat_scores
|
||||
from torchmetrics.functional.classification.stat_scores import _stat_scores_update
|
||||
|
||||
|
||||
def _precision_compute(
|
||||
tp: torch.Tensor,
|
||||
fp: torch.Tensor,
|
||||
tn: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
tp: Tensor,
|
||||
fp: Tensor,
|
||||
tn: Tensor,
|
||||
fn: Tensor,
|
||||
average: str,
|
||||
mdmc_average: Optional[str],
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
return _reduce_stat_scores(
|
||||
numerator=tp,
|
||||
denominator=tp + fp,
|
||||
|
@ -37,8 +38,8 @@ def _precision_compute(
|
|||
|
||||
|
||||
def precision(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
average: str = "micro",
|
||||
mdmc_average: Optional[str] = None,
|
||||
ignore_index: Optional[int] = None,
|
||||
|
@ -46,7 +47,7 @@ def precision(
|
|||
threshold: float = 0.5,
|
||||
top_k: Optional[int] = None,
|
||||
is_multiclass: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Computes `Precision <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
|
||||
|
||||
|
@ -170,13 +171,13 @@ def precision(
|
|||
|
||||
|
||||
def _recall_compute(
|
||||
tp: torch.Tensor,
|
||||
fp: torch.Tensor,
|
||||
tn: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
tp: Tensor,
|
||||
fp: Tensor,
|
||||
tn: Tensor,
|
||||
fn: Tensor,
|
||||
average: str,
|
||||
mdmc_average: Optional[str],
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
return _reduce_stat_scores(
|
||||
numerator=tp,
|
||||
denominator=tp + fn,
|
||||
|
@ -187,8 +188,8 @@ def _recall_compute(
|
|||
|
||||
|
||||
def recall(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
average: str = "micro",
|
||||
mdmc_average: Optional[str] = None,
|
||||
ignore_index: Optional[int] = None,
|
||||
|
@ -196,7 +197,7 @@ def recall(
|
|||
threshold: float = 0.5,
|
||||
top_k: Optional[int] = None,
|
||||
is_multiclass: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Computes `Recall <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
|
||||
|
||||
|
@ -320,8 +321,8 @@ def recall(
|
|||
|
||||
|
||||
def precision_recall(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
average: str = "micro",
|
||||
mdmc_average: Optional[str] = None,
|
||||
ignore_index: Optional[int] = None,
|
||||
|
@ -329,7 +330,7 @@ def precision_recall(
|
|||
threshold: float = 0.5,
|
||||
top_k: Optional[int] = None,
|
||||
is_multiclass: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Computes `Precision and Recall <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
|
||||
|
||||
|
|
|
@ -15,21 +15,22 @@ from typing import List, Optional, Sequence, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.utilities import rank_zero_warn
|
||||
|
||||
|
||||
def _binary_clf_curve(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
|
||||
"""
|
||||
if sample_weights is not None and not isinstance(sample_weights, torch.Tensor):
|
||||
sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float)
|
||||
if sample_weights is not None and not isinstance(sample_weights, Tensor):
|
||||
sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float)
|
||||
|
||||
# remove class dimension if necessary
|
||||
if preds.ndim > target.ndim:
|
||||
|
@ -63,11 +64,11 @@ def _binary_clf_curve(
|
|||
|
||||
|
||||
def _precision_recall_curve_update(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
pos_label: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
|
||||
) -> Tuple[Tensor, Tensor, int, int]:
|
||||
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
|
||||
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
|
||||
# single class evaluation
|
||||
|
@ -99,13 +100,12 @@ def _precision_recall_curve_update(
|
|||
|
||||
|
||||
def _precision_recall_curve_compute(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: int,
|
||||
pos_label: int,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
|
||||
List[torch.Tensor]]]:
|
||||
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
|
||||
|
||||
if num_classes == 1:
|
||||
fps, tps, thresholds = _binary_clf_curve(
|
||||
|
@ -149,13 +149,12 @@ def _precision_recall_curve_compute(
|
|||
|
||||
|
||||
def precision_recall_curve(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
pos_label: Optional[int] = None,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
|
||||
List[torch.Tensor]]]:
|
||||
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
|
||||
"""
|
||||
Computes precision-recall pairs for different thresholds.
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.classification.precision_recall_curve import (
|
||||
_binary_clf_curve,
|
||||
|
@ -22,22 +23,21 @@ from torchmetrics.functional.classification.precision_recall_curve import (
|
|||
|
||||
|
||||
def _roc_update(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
pos_label: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
|
||||
) -> Tuple[Tensor, Tensor, int, int]:
|
||||
return _precision_recall_curve_update(preds, target, num_classes, pos_label)
|
||||
|
||||
|
||||
def _roc_compute(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: int,
|
||||
pos_label: int,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
|
||||
List[torch.Tensor]]]:
|
||||
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
|
||||
|
||||
if num_classes == 1:
|
||||
fps, tps, thresholds = _binary_clf_curve(
|
||||
|
@ -78,13 +78,12 @@ def _roc_compute(
|
|||
|
||||
|
||||
def roc(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
pos_label: Optional[int] = None,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
|
||||
List[torch.Tensor]]]:
|
||||
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
|
||||
"""
|
||||
Computes the Receiver Operating Characteristic (ROC).
|
||||
|
||||
|
|
|
@ -14,21 +14,22 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.utilities.checks import _input_format_classification
|
||||
|
||||
|
||||
def _del_column(tensor: torch.Tensor, index: int):
|
||||
def _del_column(tensor: Tensor, index: int):
|
||||
""" Delete the column at index."""
|
||||
|
||||
return torch.cat([tensor[:, :index], tensor[:, (index + 1):]], 1)
|
||||
|
||||
|
||||
def _stat_scores(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
reduce: str = "micro",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Calculate the number of tp, fp, tn, fn.
|
||||
|
||||
Args:
|
||||
|
@ -74,8 +75,8 @@ def _stat_scores(
|
|||
|
||||
|
||||
def _stat_scores_update(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
reduce: str = "micro",
|
||||
mdmc_reduce: Optional[str] = None,
|
||||
num_classes: Optional[int] = None,
|
||||
|
@ -83,7 +84,7 @@ def _stat_scores_update(
|
|||
threshold: float = 0.5,
|
||||
is_multiclass: Optional[bool] = None,
|
||||
ignore_index: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
|
||||
preds, target, _ = _input_format_classification(
|
||||
preds, target, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k
|
||||
|
@ -121,7 +122,7 @@ def _stat_scores_update(
|
|||
return tp, fp, tn, fn
|
||||
|
||||
|
||||
def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor:
|
||||
def _stat_scores_compute(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> Tensor:
|
||||
|
||||
outputs = [
|
||||
tp.unsqueeze(-1),
|
||||
|
@ -131,14 +132,14 @@ def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, f
|
|||
tp.unsqueeze(-1) + fn.unsqueeze(-1), # support
|
||||
]
|
||||
outputs = torch.cat(outputs, -1)
|
||||
outputs = torch.where(outputs < 0, torch.tensor(-1, device=outputs.device), outputs)
|
||||
outputs = torch.where(outputs < 0, tensor(-1, device=outputs.device), outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def stat_scores(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
reduce: str = "micro",
|
||||
mdmc_reduce: Optional[str] = None,
|
||||
num_classes: Optional[int] = None,
|
||||
|
@ -146,7 +147,7 @@ def stat_scores(
|
|||
threshold: float = 0.5,
|
||||
is_multiclass: Optional[bool] = None,
|
||||
ignore_index: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
"""Computes the number of true positives, false positives, true negatives, false negatives.
|
||||
Related to `Type I and Type II errors <https://en.wikipedia.org/wiki/Type_I_and_type_II_errors>`__
|
||||
and the `confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion>`__.
|
||||
|
|
|
@ -14,18 +14,19 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def _image_gradients_validate(img: torch.Tensor) -> torch.Tensor:
|
||||
def _image_gradients_validate(img: Tensor) -> Tensor:
|
||||
""" Validates whether img is a 4D torch Tensor """
|
||||
|
||||
if not isinstance(img, torch.Tensor):
|
||||
raise TypeError(f"The `img` expects a value of <torch.Tensor> type but got {type(img)}")
|
||||
if not isinstance(img, Tensor):
|
||||
raise TypeError(f"The `img` expects a value of <Tensor> type but got {type(img)}")
|
||||
if img.ndim != 4:
|
||||
raise RuntimeError(f"The `img` expects a 4D tensor but got {img.ndim}D tensor")
|
||||
|
||||
|
||||
def _compute_image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _compute_image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
""" Computes image gradients (dy/dx) for a given image """
|
||||
|
||||
batch_size, channels, height, width = img.shape
|
||||
|
@ -44,7 +45,7 @@ def _compute_image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten
|
|||
return dy, dx
|
||||
|
||||
|
||||
def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Computes the `gradients <https://en.wikipedia.org/wiki/Image_gradient>`_ of a given image using finite difference
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from collections import Counter
|
|||
from typing import List, Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
|
||||
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
|
||||
|
@ -45,11 +46,8 @@ def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
|
|||
|
||||
|
||||
def bleu_score(
|
||||
translate_corpus: Sequence[str],
|
||||
reference_corpus: Sequence[str],
|
||||
n_gram: int = 4,
|
||||
smooth: bool = False
|
||||
) -> torch.Tensor:
|
||||
translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False
|
||||
) -> Tensor:
|
||||
"""
|
||||
Calculate BLEU score of machine translated text with one or more references
|
||||
|
||||
|
@ -96,20 +94,20 @@ def bleu_score(
|
|||
for counter in translation_counter:
|
||||
denominator[len(counter) - 1] += translation_counter[counter]
|
||||
|
||||
trans_len = torch.tensor(c)
|
||||
ref_len = torch.tensor(r)
|
||||
trans_len = tensor(c)
|
||||
ref_len = tensor(r)
|
||||
|
||||
if min(numerator) == 0.0:
|
||||
return torch.tensor(0.0)
|
||||
return tensor(0.0)
|
||||
|
||||
if smooth:
|
||||
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
|
||||
else:
|
||||
precision_scores = numerator / denominator
|
||||
|
||||
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
|
||||
log_precision_scores = tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
|
||||
geometric_mean = torch.exp(torch.sum(log_precision_scores))
|
||||
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
|
||||
brevity_penalty = tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
|
||||
bleu = brevity_penalty * geometric_mean
|
||||
|
||||
return bleu
|
||||
|
|
|
@ -14,38 +14,37 @@
|
|||
from typing import Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.checks import _check_same_shape
|
||||
|
||||
|
||||
def _explained_variance_update(
|
||||
preds: torch.Tensor, target: torch.Tensor
|
||||
) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def _explained_variance_update(preds: Tensor, target: Tensor) -> Tuple[int, Tensor, Tensor, Tensor, Tensor]:
|
||||
_check_same_shape(preds, target)
|
||||
|
||||
n_obs = preds.size(0)
|
||||
sum_error = torch.sum(target - preds, dim=0)
|
||||
sum_squared_error = torch.sum((target - preds) ** 2, dim=0)
|
||||
sum_squared_error = torch.sum((target - preds)**2, dim=0)
|
||||
|
||||
sum_target = torch.sum(target, dim=0)
|
||||
sum_squared_target = torch.sum(target ** 2, dim=0)
|
||||
sum_squared_target = torch.sum(target**2, dim=0)
|
||||
|
||||
return n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target
|
||||
|
||||
|
||||
def _explained_variance_compute(
|
||||
n_obs: torch.Tensor,
|
||||
sum_error: torch.Tensor,
|
||||
sum_squared_error: torch.Tensor,
|
||||
sum_target: torch.Tensor,
|
||||
sum_squared_target: torch.Tensor,
|
||||
n_obs: Tensor,
|
||||
sum_error: Tensor,
|
||||
sum_squared_error: Tensor,
|
||||
sum_target: Tensor,
|
||||
sum_squared_target: Tensor,
|
||||
multioutput: str = "uniform_average",
|
||||
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
||||
) -> Union[Tensor, Sequence[Tensor]]:
|
||||
diff_avg = sum_error / n_obs
|
||||
numerator = sum_squared_error / n_obs - diff_avg ** 2
|
||||
numerator = sum_squared_error / n_obs - diff_avg**2
|
||||
|
||||
target_avg = sum_target / n_obs
|
||||
denominator = sum_squared_target / n_obs - target_avg ** 2
|
||||
denominator = sum_squared_target / n_obs - target_avg**2
|
||||
|
||||
# Take care of division by zero
|
||||
nonzero_numerator = numerator != 0
|
||||
|
@ -67,10 +66,10 @@ def _explained_variance_compute(
|
|||
|
||||
|
||||
def explained_variance(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
multioutput: str = "uniform_average",
|
||||
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
||||
) -> Union[Tensor, Sequence[Tensor]]:
|
||||
"""
|
||||
Computes explained variance.
|
||||
|
||||
|
|
|
@ -14,22 +14,23 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.checks import _check_same_shape
|
||||
|
||||
|
||||
def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
||||
def _mean_absolute_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
|
||||
_check_same_shape(preds, target)
|
||||
sum_abs_error = torch.sum(torch.abs(preds - target))
|
||||
n_obs = target.numel()
|
||||
return sum_abs_error, n_obs
|
||||
|
||||
|
||||
def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor:
|
||||
def _mean_absolute_error_compute(sum_abs_error: Tensor, n_obs: int) -> Tensor:
|
||||
return sum_abs_error / n_obs
|
||||
|
||||
|
||||
def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
def mean_absolute_error(preds: Tensor, target: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes mean absolute error
|
||||
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.checks import _check_same_shape
|
||||
|
||||
|
||||
def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
||||
def _mean_relative_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
|
||||
_check_same_shape(preds, target)
|
||||
target_nz = target.clone()
|
||||
target_nz[target == 0] = 1
|
||||
|
@ -27,11 +28,11 @@ def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tu
|
|||
return sum_rltv_error, n_obs
|
||||
|
||||
|
||||
def _mean_relative_error_compute(sum_rltv_error: torch.Tensor, n_obs: int) -> torch.Tensor:
|
||||
def _mean_relative_error_compute(sum_rltv_error: Tensor, n_obs: int) -> Tensor:
|
||||
return sum_rltv_error / n_obs
|
||||
|
||||
|
||||
def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
def mean_relative_error(preds: Tensor, target: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes mean relative error
|
||||
|
||||
|
|
|
@ -14,22 +14,23 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.checks import _check_same_shape
|
||||
|
||||
|
||||
def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
||||
def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
|
||||
_check_same_shape(preds, target)
|
||||
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
|
||||
n_obs = target.numel()
|
||||
return sum_squared_error, n_obs
|
||||
|
||||
|
||||
def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor:
|
||||
def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int) -> Tensor:
|
||||
return sum_squared_error / n_obs
|
||||
|
||||
|
||||
def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes mean squared error
|
||||
|
||||
|
|
|
@ -14,22 +14,23 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.checks import _check_same_shape
|
||||
|
||||
|
||||
def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
||||
def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
|
||||
_check_same_shape(preds, target)
|
||||
sum_squared_log_error = torch.sum(torch.pow(torch.log1p(preds) - torch.log1p(target), 2))
|
||||
n_obs = target.numel()
|
||||
return sum_squared_log_error, n_obs
|
||||
|
||||
|
||||
def _mean_squared_log_error_compute(sum_squared_log_error: torch.Tensor, n_obs: int) -> torch.Tensor:
|
||||
def _mean_squared_log_error_compute(sum_squared_log_error: Tensor, n_obs: int) -> Tensor:
|
||||
return sum_squared_log_error / n_obs
|
||||
|
||||
|
||||
def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
def mean_squared_log_error(preds: Tensor, target: Tensor) -> Tensor:
|
||||
"""
|
||||
Computes mean squared log error
|
||||
|
||||
|
|
|
@ -14,28 +14,31 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.utilities import rank_zero_warn, reduce
|
||||
|
||||
|
||||
def _psnr_compute(
|
||||
sum_squared_error: torch.Tensor,
|
||||
n_obs: torch.Tensor,
|
||||
data_range: torch.Tensor,
|
||||
sum_squared_error: Tensor,
|
||||
n_obs: Tensor,
|
||||
data_range: Tensor,
|
||||
base: float = 10.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
|
||||
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
|
||||
psnr = psnr_base_e * (10 / torch.log(tensor(base)))
|
||||
return reduce(psnr, reduction=reduction)
|
||||
|
||||
|
||||
def _psnr_update(preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _psnr_update(
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
dim: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
if dim is None:
|
||||
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
|
||||
n_obs = torch.tensor(target.numel(), device=target.device)
|
||||
n_obs = tensor(target.numel(), device=target.device)
|
||||
return sum_squared_error, n_obs
|
||||
|
||||
sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim)
|
||||
|
@ -45,22 +48,22 @@ def _psnr_update(preds: torch.Tensor,
|
|||
else:
|
||||
dim_list = list(dim)
|
||||
if not dim_list:
|
||||
n_obs = torch.tensor(target.numel(), device=target.device)
|
||||
n_obs = tensor(target.numel(), device=target.device)
|
||||
else:
|
||||
n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod()
|
||||
n_obs = tensor(target.size(), device=target.device)[dim_list].prod()
|
||||
n_obs = n_obs.expand_as(sum_squared_error)
|
||||
|
||||
return sum_squared_error, n_obs
|
||||
|
||||
|
||||
def psnr(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
data_range: Optional[float] = None,
|
||||
base: float = 10.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
dim: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Computes the peak signal-to-noise ratio
|
||||
|
||||
|
@ -102,6 +105,6 @@ def psnr(
|
|||
|
||||
data_range = target.max() - target.min()
|
||||
else:
|
||||
data_range = torch.tensor(float(data_range))
|
||||
data_range = tensor(float(data_range))
|
||||
sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
|
||||
return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)
|
||||
|
|
|
@ -14,15 +14,13 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities import rank_zero_warn
|
||||
from torchmetrics.utilities.checks import _check_same_shape
|
||||
|
||||
|
||||
def _r2score_update(
|
||||
preds: torch.tensor,
|
||||
target: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def _r2score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
_check_same_shape(preds, target)
|
||||
if preds.ndim > 2:
|
||||
raise ValueError(
|
||||
|
@ -41,13 +39,13 @@ def _r2score_update(
|
|||
|
||||
|
||||
def _r2score_compute(
|
||||
sum_squared_error: torch.Tensor,
|
||||
sum_error: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
total: torch.Tensor,
|
||||
sum_squared_error: Tensor,
|
||||
sum_error: Tensor,
|
||||
residual: Tensor,
|
||||
total: Tensor,
|
||||
adjusted: int = 0,
|
||||
multioutput: str = "uniform_average"
|
||||
) -> torch.Tensor:
|
||||
multioutput: str = "uniform_average",
|
||||
) -> Tensor:
|
||||
mean_error = sum_error / total
|
||||
diff = sum_squared_error - sum_error * mean_error
|
||||
raw_scores = 1 - (residual / diff)
|
||||
|
@ -82,11 +80,11 @@ def _r2score_compute(
|
|||
|
||||
|
||||
def r2score(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
adjusted: int = 0,
|
||||
multioutput: str = "uniform_average",
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Computes r2 score also known as `coefficient of determination
|
||||
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torchmetrics.utilities.checks import _check_same_shape
|
||||
|
@ -36,10 +37,7 @@ def _gaussian_kernel(
|
|||
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
|
||||
|
||||
|
||||
def _ssim_update(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
if preds.dtype != target.dtype:
|
||||
raise TypeError(
|
||||
"Expected `preds` and `target` to have the same data type."
|
||||
|
@ -55,8 +53,8 @@ def _ssim_update(
|
|||
|
||||
|
||||
def _ssim_compute(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
kernel_size: Sequence[int] = (11, 11),
|
||||
sigma: Sequence[float] = (1.5, 1.5),
|
||||
reduction: str = "elementwise_mean",
|
||||
|
@ -114,15 +112,15 @@ def _ssim_compute(
|
|||
|
||||
|
||||
def ssim(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
kernel_size: Sequence[int] = (11, 11),
|
||||
sigma: Sequence[float] = (1.5, 1.5),
|
||||
reduction: str = "elementwise_mean",
|
||||
data_range: Optional[float] = None,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Computes Structual Similarity Index Measure
|
||||
|
||||
|
|
|
@ -12,14 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def embedding_similarity(
|
||||
batch: torch.Tensor,
|
||||
similarity: str = 'cosine',
|
||||
reduction: str = 'none',
|
||||
zero_diagonal: bool = True
|
||||
) -> torch.Tensor:
|
||||
batch: Tensor, similarity: str = 'cosine', reduction: str = 'none', zero_diagonal: bool = True
|
||||
) -> Tensor:
|
||||
"""
|
||||
Computes representation similarity
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from copy import deepcopy
|
|||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from torchmetrics.utilities import apply_to_collection
|
||||
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
|
||||
|
@ -122,7 +122,7 @@ class Metric(nn.Module, ABC):
|
|||
|
||||
"""
|
||||
if (
|
||||
not isinstance(default, torch.Tensor) and not isinstance(default, list) # noqa: W503
|
||||
not isinstance(default, Tensor) and not isinstance(default, list) # noqa: W503
|
||||
or (isinstance(default, list) and len(default) != 0) # noqa: W503
|
||||
):
|
||||
raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")
|
||||
|
@ -175,14 +175,14 @@ class Metric(nn.Module, ABC):
|
|||
input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()}
|
||||
output_dict = apply_to_collection(
|
||||
input_dict,
|
||||
torch.Tensor,
|
||||
Tensor,
|
||||
dist_sync_fn,
|
||||
group=self.process_group,
|
||||
)
|
||||
|
||||
for attr, reduction_fn in self._reductions.items():
|
||||
# pre-processing ops (stack or flatten for inputs)
|
||||
if isinstance(output_dict[attr][0], torch.Tensor):
|
||||
if isinstance(output_dict[attr][0], Tensor):
|
||||
output_dict[attr] = torch.stack(output_dict[attr])
|
||||
elif isinstance(output_dict[attr][0], list):
|
||||
output_dict[attr] = _flatten(output_dict[attr])
|
||||
|
@ -253,7 +253,7 @@ class Metric(nn.Module, ABC):
|
|||
"""
|
||||
for attr, default in self._defaults.items():
|
||||
current_val = getattr(self, attr)
|
||||
if isinstance(default, torch.Tensor):
|
||||
if isinstance(default, Tensor):
|
||||
setattr(self, attr, deepcopy(default).to(current_val.device))
|
||||
else:
|
||||
setattr(self, attr, deepcopy(default))
|
||||
|
@ -280,14 +280,14 @@ class Metric(nn.Module, ABC):
|
|||
# Also apply fn to metric states
|
||||
for key in self._defaults.keys():
|
||||
current_val = getattr(self, key)
|
||||
if isinstance(current_val, torch.Tensor):
|
||||
if isinstance(current_val, Tensor):
|
||||
setattr(self, key, fn(current_val))
|
||||
elif isinstance(current_val, Sequence):
|
||||
setattr(self, key, [fn(cur_v) for cur_v in current_val])
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected metric state to be either a torch.Tensor"
|
||||
f"or a list of torch.Tensor, but encountered {current_val}"
|
||||
"Expected metric state to be either a Tensor"
|
||||
f"or a list of Tensor, but encountered {current_val}"
|
||||
)
|
||||
return self
|
||||
|
||||
|
@ -336,7 +336,7 @@ class Metric(nn.Module, ABC):
|
|||
val = getattr(self, key)
|
||||
# Special case: allow list values, so long
|
||||
# as their elements are hashable
|
||||
if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor):
|
||||
if hasattr(val, '__iter__') and not isinstance(val, Tensor):
|
||||
hash_vals.extend(val)
|
||||
else:
|
||||
hash_vals.append(val)
|
||||
|
@ -444,7 +444,7 @@ class Metric(nn.Module, ABC):
|
|||
return CompositionalMetric(torch.abs, self, None)
|
||||
|
||||
|
||||
def _neg(tensor: torch.Tensor):
|
||||
def _neg(tensor: Tensor):
|
||||
return -torch.abs(tensor)
|
||||
|
||||
|
||||
|
@ -454,8 +454,8 @@ class CompositionalMetric(Metric):
|
|||
def __init__(
|
||||
self,
|
||||
operator: Callable,
|
||||
metric_a: Union[Metric, int, float, torch.Tensor],
|
||||
metric_b: Union[Metric, int, float, torch.Tensor, None],
|
||||
metric_a: Union[Metric, int, float, Tensor],
|
||||
metric_b: Union[Metric, int, float, Tensor, None],
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -470,12 +470,12 @@ class CompositionalMetric(Metric):
|
|||
|
||||
self.op = operator
|
||||
|
||||
if isinstance(metric_a, torch.Tensor):
|
||||
if isinstance(metric_a, Tensor):
|
||||
self.register_buffer("metric_a", metric_a)
|
||||
else:
|
||||
self.metric_a = metric_a
|
||||
|
||||
if isinstance(metric_b, torch.Tensor):
|
||||
if isinstance(metric_b, Tensor):
|
||||
self.register_buffer("metric_b", metric_b)
|
||||
else:
|
||||
self.metric_b = metric_b
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.functional.regression.explained_variance import (
|
||||
_explained_variance_compute,
|
||||
|
@ -94,13 +95,13 @@ class ExplainedVariance(Metric):
|
|||
f"Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}"
|
||||
)
|
||||
self.multioutput = multioutput
|
||||
self.add_state("sum_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("sum_target", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("sum_squared_target", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("n_obs", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("sum_error", default=tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("sum_target", default=tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("sum_squared_target", default=tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("n_obs", default=tensor(0.0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.functional.regression.mean_absolute_error import (
|
||||
_mean_absolute_error_compute,
|
||||
|
@ -63,10 +64,10 @@ class MeanAbsoluteError(Metric):
|
|||
dist_sync_fn=dist_sync_fn,
|
||||
)
|
||||
|
||||
self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.regression.mean_squared_error import (
|
||||
_mean_squared_error_compute,
|
||||
|
@ -67,7 +68,7 @@ class MeanSquaredError(Metric):
|
|||
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.functional.regression.mean_squared_log_error import (
|
||||
_mean_squared_log_error_compute,
|
||||
|
@ -66,10 +67,10 @@ class MeanSquaredLogError(Metric):
|
|||
dist_sync_fn=dist_sync_fn,
|
||||
)
|
||||
|
||||
self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("sum_squared_log_error", default=tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.regression.psnr import _psnr_compute, _psnr_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -103,7 +104,7 @@ class PSNR(Metric):
|
|||
self.reduction = reduction
|
||||
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.functional.regression.r2score import _r2score_compute, _r2score_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -115,9 +116,9 @@ class R2Score(Metric):
|
|||
self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
|
||||
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
|
||||
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
@ -132,7 +133,7 @@ class R2Score(Metric):
|
|||
self.residual += residual
|
||||
self.total += total
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
def compute(self) -> Tensor:
|
||||
"""
|
||||
Computes r2 score over the metric states.
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.functional.regression.ssim import _ssim_compute, _ssim_update
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -82,7 +83,7 @@ class SSIM(Metric):
|
|||
self.k2 = k2
|
||||
self.reduction = reduction
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
def update(self, preds: Tensor, target: Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
|
|
|
@ -14,18 +14,19 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.data import select_topk, to_onehot
|
||||
from torchmetrics.utilities.enums import DataType
|
||||
|
||||
|
||||
def _check_same_shape(pred: torch.Tensor, target: torch.Tensor):
|
||||
def _check_same_shape(pred: Tensor, target: Tensor):
|
||||
""" Check that predictions and target have the same shape, else raise error """
|
||||
if pred.shape != target.shape:
|
||||
raise RuntimeError("Predictions and targets are expected to have the same shape")
|
||||
|
||||
|
||||
def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
|
||||
def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, is_multiclass: bool):
|
||||
"""
|
||||
Perform basic validation of inputs that does not require deducing any information
|
||||
of the type of inputs.
|
||||
|
@ -56,7 +57,7 @@ def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold
|
|||
raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.")
|
||||
|
||||
|
||||
def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]:
|
||||
def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[str, int]:
|
||||
"""
|
||||
This checks that the shape and type of inputs are consistent with
|
||||
each other and fall into one of the allowed input types (see the
|
||||
|
@ -139,9 +140,7 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool):
|
|||
)
|
||||
|
||||
|
||||
def _check_num_classes_mc(
|
||||
preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int
|
||||
):
|
||||
def _check_num_classes_mc(preds: Tensor, target: Tensor, num_classes: int, is_multiclass: bool, implied_classes: int):
|
||||
"""
|
||||
This checks that the consistency of `num_classes` with the data
|
||||
and `is_multiclass` param for (multi-dimensional) multi-class data.
|
||||
|
@ -206,8 +205,8 @@ def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Opt
|
|||
|
||||
|
||||
def _check_classification_inputs(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
threshold: float,
|
||||
num_classes: Optional[int],
|
||||
is_multiclass: bool,
|
||||
|
@ -305,13 +304,13 @@ def _check_classification_inputs(
|
|||
|
||||
|
||||
def _input_format_classification(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
threshold: float = 0.5,
|
||||
top_k: Optional[int] = None,
|
||||
num_classes: Optional[int] = None,
|
||||
is_multiclass: Optional[bool] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, str]:
|
||||
) -> Tuple[Tensor, Tensor, str]:
|
||||
"""Convert preds and target tensors into common format.
|
||||
|
||||
Preds and targets are supposed to fall into one of these categories (and are
|
||||
|
@ -448,11 +447,11 @@ def _input_format_classification(
|
|||
|
||||
def _input_format_classification_one_hot(
|
||||
num_classes: int,
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
threshold: float = 0.5,
|
||||
multilabel: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Convert preds and target tensors into one hot spare label tensors
|
||||
|
||||
Args:
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
from typing import Any, Callable, Mapping, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchmetrics.utilities.prints import rank_zero_warn
|
||||
|
||||
|
@ -38,9 +39,9 @@ def _flatten(x):
|
|||
|
||||
|
||||
def to_onehot(
|
||||
label_tensor: torch.Tensor,
|
||||
label_tensor: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Converts a dense label tensor to one-hot format
|
||||
|
||||
|
@ -74,7 +75,7 @@ def to_onehot(
|
|||
return tensor_onehot.scatter_(1, index, 1.0)
|
||||
|
||||
|
||||
def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor:
|
||||
def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor:
|
||||
"""
|
||||
Convert a probability tensor to binary by selecting top-k highest entries.
|
||||
|
||||
|
@ -99,7 +100,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch
|
|||
return topk_tensor.int()
|
||||
|
||||
|
||||
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
||||
def to_categorical(tensor: Tensor, argmax_dim: int = 1) -> Tensor:
|
||||
"""
|
||||
Converts a tensor of probabilities to a dense label tensor
|
||||
|
||||
|
@ -121,8 +122,8 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
|||
|
||||
|
||||
def get_num_classes(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
pred: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
|
@ -202,7 +203,7 @@ def apply_to_collection(
|
|||
the resulting collection
|
||||
|
||||
Example:
|
||||
>>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=torch.Tensor, function=lambda x: x ** 2)
|
||||
>>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=Tensor, function=lambda x: x ** 2)
|
||||
tensor([64, 0, 4, 36, 49])
|
||||
>>> apply_to_collection([8, 0, 2, 6, 7], dtype=int, function=lambda x: x ** 2)
|
||||
[64, 0, 4, 36, 49]
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
|
||||
def reduce(to_reduce: Tensor, reduction: str) -> Tensor:
|
||||
"""
|
||||
Reduces a given tensor by a given reduction method
|
||||
|
||||
|
@ -39,9 +40,7 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
|
|||
raise ValueError("Reduction parameter unknown.")
|
||||
|
||||
|
||||
def class_reduce(
|
||||
num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none"
|
||||
) -> torch.Tensor:
|
||||
def class_reduce(num: Tensor, denom: Tensor, weights: Tensor, class_reduction: str = "none") -> Tensor:
|
||||
"""
|
||||
Function used to reduce classification metrics of the form `num / denom * weights`.
|
||||
For example for calculating standard accuracy the num would be number of
|
||||
|
@ -85,7 +84,7 @@ def class_reduce(
|
|||
)
|
||||
|
||||
|
||||
def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None):
|
||||
def gather_all_tensors(result: Union[Tensor], group: Optional[Any] = None):
|
||||
"""
|
||||
Function to gather all tensors from several ddp processes onto a list that
|
||||
is broadcasted to all processes
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
|
||||
_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
|
||||
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")
|
||||
_TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0")
|
Загрузка…
Ссылка в новой задаче