* versions

* simplify Tensor

* simplify tensor

* yapf
This commit is contained in:
Jirka Borovec 2021-03-18 20:34:25 +01:00 коммит произвёл GitHub
Родитель 1fcf9fc235
Коммит a1e50ca62b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
72 изменённых файлов: 594 добавлений и 556 удалений

Просмотреть файл

@ -15,6 +15,7 @@ from operator import neg, pos
import pytest
import torch
from torch import tensor
from tests.helpers import _MARK_TORCH_MIN_1_4, _MARK_TORCH_MIN_1_5, _MARK_TORCH_MIN_1_6
from torchmetrics.metric import CompositionalMetric, Metric
@ -31,7 +32,7 @@ class DummyMetric(Metric):
self._num_updates += 1
def compute(self):
return torch.tensor(self._val_to_return)
return tensor(self._val_to_return)
def reset(self):
self._num_updates = 0
@ -41,10 +42,10 @@ class DummyMetric(Metric):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(4)),
(2, torch.tensor(4)),
(2.0, torch.tensor(4.0)),
pytest.param(torch.tensor(2), torch.tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
(DummyMetric(2), tensor(4)),
(2, tensor(4)),
(2.0, tensor(4.0)),
pytest.param(tensor(2), tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
],
)
def test_metrics_add(second_operand, expected_result):
@ -62,7 +63,7 @@ def test_metrics_add(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))],
[(DummyMetric(3), tensor(2)), (3, tensor(2)), (3, tensor(2)), (tensor(3), tensor(2))],
)
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
def test_metrics_and(second_operand, expected_result):
@ -81,10 +82,10 @@ def test_metrics_and(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(True)),
(2, torch.tensor(True)),
(2.0, torch.tensor(True)),
(torch.tensor(2), torch.tensor(True)),
(DummyMetric(2), tensor(True)),
(2, tensor(True)),
(2.0, tensor(True)),
(tensor(2), tensor(True)),
],
)
def test_metrics_eq(second_operand, expected_result):
@ -101,10 +102,10 @@ def test_metrics_eq(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(2)),
(2, torch.tensor(2)),
(2.0, torch.tensor(2.0)),
(torch.tensor(2), torch.tensor(2)),
(DummyMetric(2), tensor(2)),
(2, tensor(2)),
(2.0, tensor(2.0)),
(tensor(2), tensor(2)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
@ -121,10 +122,10 @@ def test_metrics_floordiv(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(True)),
(2, torch.tensor(True)),
(2.0, torch.tensor(True)),
(torch.tensor(2), torch.tensor(True)),
(DummyMetric(2), tensor(True)),
(2, tensor(True)),
(2.0, tensor(True)),
(tensor(2), tensor(True)),
],
)
def test_metrics_ge(second_operand, expected_result):
@ -141,10 +142,10 @@ def test_metrics_ge(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(True)),
(2, torch.tensor(True)),
(2.0, torch.tensor(True)),
(torch.tensor(2), torch.tensor(True)),
(DummyMetric(2), tensor(True)),
(2, tensor(True)),
(2.0, tensor(True)),
(tensor(2), tensor(True)),
],
)
def test_metrics_gt(second_operand, expected_result):
@ -161,10 +162,10 @@ def test_metrics_gt(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(False)),
(2, torch.tensor(False)),
(2.0, torch.tensor(False)),
(torch.tensor(2), torch.tensor(False)),
(DummyMetric(2), tensor(False)),
(2, tensor(False)),
(2.0, tensor(False)),
(tensor(2), tensor(False)),
],
)
def test_metrics_le(second_operand, expected_result):
@ -181,10 +182,10 @@ def test_metrics_le(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(False)),
(2, torch.tensor(False)),
(2.0, torch.tensor(False)),
(torch.tensor(2), torch.tensor(False)),
(DummyMetric(2), tensor(False)),
(2, tensor(False)),
(2.0, tensor(False)),
(tensor(2), tensor(False)),
],
)
def test_metrics_lt(second_operand, expected_result):
@ -200,7 +201,7 @@ def test_metrics_lt(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[(DummyMetric([2, 2, 2]), torch.tensor(12)), (torch.tensor([2, 2, 2]), torch.tensor(12))],
[(DummyMetric([2, 2, 2]), tensor(12)), (tensor([2, 2, 2]), tensor(12))],
)
def test_metrics_matmul(second_operand, expected_result):
first_metric = DummyMetric([2, 2, 2])
@ -215,10 +216,10 @@ def test_metrics_matmul(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(1)),
(2, torch.tensor(1)),
(2.0, torch.tensor(1)),
(torch.tensor(2), torch.tensor(1)),
(DummyMetric(2), tensor(1)),
(2, tensor(1)),
(2.0, tensor(1)),
(tensor(2), tensor(1)),
],
)
def test_metrics_mod(second_operand, expected_result):
@ -234,10 +235,10 @@ def test_metrics_mod(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(4)),
(2, torch.tensor(4)),
(2.0, torch.tensor(4.0)),
pytest.param(torch.tensor(2), torch.tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
(DummyMetric(2), tensor(4)),
(2, tensor(4)),
(2.0, tensor(4.0)),
pytest.param(tensor(2), tensor(4), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
],
)
def test_metrics_mul(second_operand, expected_result):
@ -256,10 +257,10 @@ def test_metrics_mul(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(False)),
(2, torch.tensor(False)),
(2.0, torch.tensor(False)),
(torch.tensor(2), torch.tensor(False)),
(DummyMetric(2), tensor(False)),
(2, tensor(False)),
(2.0, tensor(False)),
(tensor(2), tensor(False)),
],
)
def test_metrics_ne(second_operand, expected_result):
@ -275,7 +276,7 @@ def test_metrics_ne(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))],
[(DummyMetric([1, 0, 3]), tensor([-1, -2, 3])), (tensor([1, 0, 3]), tensor([-1, -2, 3]))],
)
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
def test_metrics_or(second_operand, expected_result):
@ -294,10 +295,10 @@ def test_metrics_or(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
pytest.param(DummyMetric(2), torch.tensor(4)),
pytest.param(2, torch.tensor(4)),
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
pytest.param(torch.tensor(2), torch.tensor(4)),
pytest.param(DummyMetric(2), tensor(4)),
pytest.param(2, tensor(4)),
pytest.param(2.0, tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
pytest.param(tensor(2), tensor(4)),
],
)
def test_metrics_pow(second_operand, expected_result):
@ -312,7 +313,7 @@ def test_metrics_pow(second_operand, expected_result):
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))],
[(5, tensor(2)), (5.0, tensor(2.0)), (tensor(5), tensor(2))],
)
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
def test_metrics_rfloordiv(first_operand, expected_result):
@ -324,10 +325,8 @@ def test_metrics_rfloordiv(first_operand, expected_result):
assert torch.allclose(expected_result, final_rfloordiv.compute())
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[pytest.param(torch.tensor([2, 2, 2]), torch.tensor(12), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))]
)
@pytest.mark.parametrize(["first_operand", "expected_result"],
[pytest.param(tensor([2, 2, 2]), tensor(12), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))])
def test_metrics_rmatmul(first_operand, expected_result):
second_operand = DummyMetric([2, 2, 2])
@ -338,10 +337,8 @@ def test_metrics_rmatmul(first_operand, expected_result):
assert torch.allclose(expected_result, final_rmatmul.compute())
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[pytest.param(torch.tensor(2), torch.tensor(2), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))]
)
@pytest.mark.parametrize(["first_operand", "expected_result"],
[pytest.param(tensor(2), tensor(2), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4))])
def test_metrics_rmod(first_operand, expected_result):
second_operand = DummyMetric(5)
@ -355,9 +352,9 @@ def test_metrics_rmod(first_operand, expected_result):
@pytest.mark.parametrize(
"first_operand,expected_result",
[
pytest.param(DummyMetric(2), torch.tensor(4)),
pytest.param(2, torch.tensor(4)),
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
pytest.param(DummyMetric(2), tensor(4)),
pytest.param(2, tensor(4)),
pytest.param(2.0, tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_6)),
],
)
def test_metrics_rpow(first_operand, expected_result):
@ -373,10 +370,10 @@ def test_metrics_rpow(first_operand, expected_result):
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[
(DummyMetric(3), torch.tensor(1)),
(3, torch.tensor(1)),
(3.0, torch.tensor(1.0)),
pytest.param(torch.tensor(3), torch.tensor(1), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
(DummyMetric(3), tensor(1)),
(3, tensor(1)),
(3.0, tensor(1.0)),
pytest.param(tensor(3), tensor(1), marks=pytest.mark.skipif(**_MARK_TORCH_MIN_1_4)),
],
)
def test_metrics_rsub(first_operand, expected_result):
@ -392,10 +389,10 @@ def test_metrics_rsub(first_operand, expected_result):
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[
(DummyMetric(6), torch.tensor(2.0)),
(6, torch.tensor(2.0)),
(6.0, torch.tensor(2.0)),
(torch.tensor(6), torch.tensor(2.0)),
(DummyMetric(6), tensor(2.0)),
(6, tensor(2.0)),
(6.0, tensor(2.0)),
(tensor(6), tensor(2.0)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
@ -412,10 +409,10 @@ def test_metrics_rtruediv(first_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(1)),
(2, torch.tensor(1)),
(2.0, torch.tensor(1.0)),
(torch.tensor(2), torch.tensor(1)),
(DummyMetric(2), tensor(1)),
(2, tensor(1)),
(2.0, tensor(1.0)),
(tensor(2), tensor(1)),
],
)
def test_metrics_sub(second_operand, expected_result):
@ -431,10 +428,10 @@ def test_metrics_sub(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(3), torch.tensor(2.0)),
(3, torch.tensor(2.0)),
(3.0, torch.tensor(2.0)),
(torch.tensor(3), torch.tensor(2.0)),
(DummyMetric(3), tensor(2.0)),
(3, tensor(2.0)),
(3.0, tensor(2.0)),
(tensor(3), tensor(2.0)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
@ -450,7 +447,7 @@ def test_metrics_truediv(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[(DummyMetric([1, 0, 3]), torch.tensor([-2, -2, 0])), (torch.tensor([1, 0, 3]), torch.tensor([-2, -2, 0]))],
[(DummyMetric([1, 0, 3]), tensor([-2, -2, 0])), (tensor([1, 0, 3]), tensor([-2, -2, 0]))],
)
@pytest.mark.skipif(**_MARK_TORCH_MIN_1_5)
def test_metrics_xor(second_operand, expected_result):
@ -473,7 +470,7 @@ def test_metrics_abs():
assert isinstance(final_abs, CompositionalMetric)
assert torch.allclose(torch.tensor(1), final_abs.compute())
assert torch.allclose(tensor(1), final_abs.compute())
def test_metrics_invert():
@ -481,7 +478,7 @@ def test_metrics_invert():
final_inverse = ~first_metric
assert isinstance(final_inverse, CompositionalMetric)
assert torch.allclose(torch.tensor(-2), final_inverse.compute())
assert torch.allclose(tensor(-2), final_inverse.compute())
def test_metrics_neg():
@ -489,7 +486,7 @@ def test_metrics_neg():
final_neg = neg(first_metric)
assert isinstance(final_neg, CompositionalMetric)
assert torch.allclose(torch.tensor(-1), final_neg.compute())
assert torch.allclose(tensor(-1), final_neg.compute())
def test_metrics_pos():
@ -497,7 +494,7 @@ def test_metrics_pos():
final_pos = pos(first_metric)
assert isinstance(final_pos, CompositionalMetric)
assert torch.allclose(torch.tensor(1), final_pos.compute())
assert torch.allclose(tensor(1), final_pos.compute())
def test_compositional_metrics_update():

Просмотреть файл

@ -15,6 +15,7 @@ import sys
import pytest
import torch
from torch import tensor
from tests.helpers.testers import DummyMetric, setup_ddp
from torchmetrics import Metric
@ -26,7 +27,7 @@ def _test_ddp_sum(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = DummyMetric()
dummy._reductions = {"foo": torch.sum}
dummy.foo = torch.tensor(1)
dummy.foo = tensor(1)
dummy._sync_dist()
assert dummy.foo == worldsize
@ -36,21 +37,21 @@ def _test_ddp_cat(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = DummyMetric()
dummy._reductions = {"foo": torch.cat}
dummy.foo = [torch.tensor([1])]
dummy.foo = [tensor([1])]
dummy._sync_dist()
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
assert torch.all(torch.eq(dummy.foo, tensor([1, 1])))
def _test_ddp_sum_cat(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = DummyMetric()
dummy._reductions = {"foo": torch.cat, "bar": torch.sum}
dummy.foo = [torch.tensor([1])]
dummy.bar = torch.tensor(1)
dummy.foo = [tensor([1])]
dummy.bar = tensor(1)
dummy._sync_dist()
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
assert torch.all(torch.eq(dummy.foo, tensor([1, 1])))
assert dummy.bar == worldsize

Просмотреть файл

@ -13,15 +13,15 @@
# limitations under the License.
import pickle
from collections import OrderedDict
from distutils.version import LooseVersion
import cloudpickle
import numpy as np
import pytest
import torch
from torch import nn
from torch import nn, tensor
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
torch.manual_seed(42)
@ -33,23 +33,23 @@ def test_inherit():
def test_add_state():
a = DummyMetric()
a.add_state("a", torch.tensor(0), "sum")
assert a._reductions["a"](torch.tensor([1, 1])) == 2
a.add_state("a", tensor(0), "sum")
assert a._reductions["a"](tensor([1, 1])) == 2
a.add_state("b", torch.tensor(0), "mean")
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
a.add_state("b", tensor(0), "mean")
assert np.allclose(a._reductions["b"](tensor([1.0, 2.0])).numpy(), 1.5)
a.add_state("c", torch.tensor(0), "cat")
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2, )
a.add_state("c", tensor(0), "cat")
assert a._reductions["c"]([tensor([1]), tensor([1])]).shape == (2, )
with pytest.raises(ValueError):
a.add_state("d1", torch.tensor(0), 'xyz')
a.add_state("d1", tensor(0), 'xyz')
with pytest.raises(ValueError):
a.add_state("d2", torch.tensor(0), 42)
a.add_state("d2", tensor(0), 42)
with pytest.raises(ValueError):
a.add_state("d3", [torch.tensor(0)], 'sum')
a.add_state("d3", [tensor(0)], 'sum')
with pytest.raises(ValueError):
a.add_state("d4", 42, 'sum')
@ -57,19 +57,19 @@ def test_add_state():
def custom_fx(x):
return -1
a.add_state("e", torch.tensor(0), custom_fx)
assert a._reductions["e"](torch.tensor([1, 1])) == -1
a.add_state("e", tensor(0), custom_fx)
assert a._reductions["e"](tensor([1, 1])) == -1
def test_add_state_persistent():
a = DummyMetric()
a.add_state("a", torch.tensor(0), "sum", persistent=True)
a.add_state("a", tensor(0), "sum", persistent=True)
assert "a" in a.state_dict()
a.add_state("b", torch.tensor(0), "sum", persistent=False)
a.add_state("b", tensor(0), "sum", persistent=False)
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
if _TORCH_LOWER_1_6:
assert "b" not in a.state_dict()
@ -83,13 +83,13 @@ def test_reset():
a = A()
assert a.x == 0
a.x = torch.tensor(5)
a.x = tensor(5)
a.reset()
assert a.x == 0
b = B()
assert isinstance(b.x, list) and len(b.x) == 0
b.x = torch.tensor(5)
b.x = tensor(5)
b.reset()
assert isinstance(b.x, list) and len(b.x) == 0
@ -155,10 +155,10 @@ def test_hash():
b2 = B()
assert hash(b1) == hash(b2)
assert isinstance(b1.x, list) and len(b1.x) == 0
b1.x.append(torch.tensor(5))
b1.x.append(tensor(5))
assert isinstance(hash(b1), int) # <- check that nothing crashes
assert isinstance(b1.x, list) and len(b1.x) == 1
b2.x.append(torch.tensor(5))
b2.x.append(tensor(5))
# Sanity:
assert isinstance(b2.x, list) and len(b2.x) == 1
# Now that they have tensor contents, they should have different hashes:
@ -222,15 +222,15 @@ def test_child_metric_state_dict():
def __init__(self):
super().__init__()
self.metric = DummyMetric()
self.metric.add_state('a', torch.tensor(0), persistent=True)
self.metric.add_state('a', tensor(0), persistent=True)
self.metric.add_state('b', [], persistent=True)
self.metric.register_buffer('c', torch.tensor(0))
self.metric.register_buffer('c', tensor(0))
module = TestModule()
expected_state_dict = {
'metric.a': torch.tensor(0),
'metric.a': tensor(0),
'metric.b': [],
'metric.c': torch.tensor(0),
'metric.c': tensor(0),
}
assert module.state_dict() == expected_state_dict

Просмотреть файл

@ -17,6 +17,7 @@ import numpy as np
import pytest
import torch
from sklearn.metrics import accuracy_score as sk_accuracy
from torch import tensor
from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
@ -109,12 +110,12 @@ _l1to4t3_mcls = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T]
# The preds in these examples always put highest probability on class 3, second highest on class 2,
# third highest on class 1, and lowest on class 0
_topk_preds_mcls = torch.tensor([_l1to4t3, _l1to4t3]).float()
_topk_target_mcls = torch.tensor([[1, 2, 3], [2, 1, 0]])
_topk_preds_mcls = tensor([_l1to4t3, _l1to4t3]).float()
_topk_target_mcls = tensor([[1, 2, 3], [2, 1, 0]])
# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :)
_topk_preds_mdmc = torch.tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float()
_topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]])
_topk_preds_mdmc = tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float()
_topk_target_mdmc = tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]])
# Replace with a proper sk_metric test once sklearn 0.24 hits :)

Просмотреть файл

@ -17,6 +17,7 @@ import numpy as np
import pytest
import torch
from sklearn.metrics import auc as _sk_auc
from torch import tensor
from tests.helpers.testers import NUM_BATCHES, MetricTester
from torchmetrics.classification.auc import AUC
@ -43,7 +44,7 @@ for i in range(4):
y = y[idx] if i % 2 == 0 else x[idx[::-1]]
x = x.reshape(NUM_BATCHES, 8)
y = y.reshape(NUM_BATCHES, 8)
_examples.append(Input(x=torch.tensor(x), y=torch.tensor(y)))
_examples.append(Input(x=tensor(x), y=tensor(y)))
@pytest.mark.parametrize("x, y", _examples)
@ -74,4 +75,4 @@ class TestAUC(MetricTester):
])
def test_auc(x, y, expected):
# Test Area Under Curve (AUC) computation
assert auc(torch.tensor(x), torch.tensor(y)) == expected
assert auc(tensor(x), tensor(y)) == expected

Просмотреть файл

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from functools import partial
import pytest
@ -26,6 +25,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_pro
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.auroc import AUROC
from torchmetrics.functional import auroc
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
torch.manual_seed(42)
@ -104,7 +104,7 @@ class TestAUROC(MetricTester):
pytest.skip('max_fpr parameter not support for multi class or multi label')
# max_fpr only supported for torch v1.6 or higher
if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
if max_fpr is not None and _TORCH_LOWER_1_6:
pytest.skip('requires torch v1.6 or higher to test max_fpr argument')
self.run_class_metric_test(
@ -127,7 +127,7 @@ class TestAUROC(MetricTester):
pytest.skip('max_fpr parameter not support for multi class or multi label')
# max_fpr only supported for torch v1.6 or higher
if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
if max_fpr is not None and _TORCH_LOWER_1_6:
pytest.skip('requires torch v1.6 or higher to test max_fpr argument')
self.run_functional_metric_test(

Просмотреть файл

@ -17,6 +17,7 @@ import numpy as np
import pytest
import torch
from sklearn.metrics import average_precision_score as sk_average_precision_score
from torch import tensor
from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
@ -101,9 +102,9 @@ class TestAveragePrecision(MetricTester):
# And a constant score
# The precision is then the fraction of positive whatever the recall
# is, as there is only one threshold:
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
pytest.param(tensor([1, 1, 1, 1]), tensor([0, 0, 0, 1]), .25),
# With threshold 0.8 : 1 TP and 2 TN and one FN
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
pytest.param(tensor([.6, .7, .8, 9]), tensor([1, 0, 0, 1]), .75),
]
)
def test_average_precision(scores, target, expected_score):

Просмотреть файл

@ -17,6 +17,7 @@ import numpy as np
import pytest
import torch
from sklearn.metrics import fbeta_score
from torch import tensor
from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
@ -152,8 +153,8 @@ class TestFBeta(MetricTester):
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]),
])
def test_fbeta_score(pred, target, beta, exp_score):
score = fbeta(torch.tensor(pred), torch.tensor(target), num_classes=1, beta=beta, average='none')
assert torch.allclose(score, torch.tensor(exp_score))
score = fbeta(tensor(pred), tensor(target), num_classes=1, beta=beta, average='none')
assert torch.allclose(score, tensor(exp_score))
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
@ -162,5 +163,5 @@ def test_fbeta_score(pred, target, beta, exp_score):
pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]),
])
def test_f1_score(pred, target, exp_score):
score = f1(torch.tensor(pred), torch.tensor(target), num_classes=1, average='none')
assert torch.allclose(score, torch.tensor(exp_score))
score = f1(tensor(pred), tensor(target), num_classes=1, average='none')
assert torch.allclose(score, tensor(exp_score))

Просмотреть файл

@ -13,7 +13,7 @@
# limitations under the License.
import pytest
import torch
from torch import rand, randint
from torch import Tensor, rand, randint, tensor
from tests.classification.inputs import Input
from tests.classification.inputs import _input_binary as _bin
@ -52,7 +52,7 @@ _mdmc_prob_2cls_preds /= _mdmc_prob_2cls_preds.sum(dim=2, keepdim=True)
_mdmc_prob_2cls = Input(_mdmc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)))
# Some utils
T = torch.Tensor
T = Tensor
def _idn(x):
@ -209,7 +209,7 @@ def test_threshold():
preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5)
assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int())
assert torch.equal(tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int())
########################################################################

Просмотреть файл

@ -17,6 +17,7 @@ import numpy as np
import pytest
import torch
from sklearn.metrics import jaccard_score as sk_jaccard_score
from torch import Tensor, tensor
from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
@ -134,12 +135,12 @@ class TestIoU(MetricTester):
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
pytest.param(False, 'none', None, Tensor([1, 1, 1])),
pytest.param(False, 'elementwise_mean', None, Tensor([1])),
pytest.param(False, 'none', 0, Tensor([1, 1])),
pytest.param(True, 'none', None, Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', None, Tensor([0.5])),
pytest.param(True, 'none', 0, Tensor([0.5, 0.5])),
])
def test_iou(half_ones, reduction, ignore_index, expected):
pred = (torch.arange(120) % 3).view(-1, 1)
@ -190,14 +191,14 @@ def test_iou(half_ones, reduction, ignore_index, expected):
)
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
pred=tensor(pred),
target=tensor(target),
ignore_index=ignore_index,
absent_score=absent_score,
num_classes=num_classes,
reduction='none',
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
assert torch.allclose(iou_val, tensor(expected).to(iou_val))
# example data taken from
@ -220,10 +221,10 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
)
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
pred=tensor(pred),
target=tensor(target),
ignore_index=ignore_index,
num_classes=num_classes,
reduction=reduction,
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
assert torch.allclose(iou_val, tensor(expected).to(iou_val))

Просмотреть файл

@ -18,6 +18,7 @@ import numpy as np
import pytest
import torch
from sklearn.metrics import precision_score, recall_score
from torch import Tensor, tensor
from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
@ -128,8 +129,8 @@ def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ign
def test_zero_division(metric_class, metric_fn):
""" Test that zero_division works correctly (currently should just set to 0). """
preds = torch.tensor([1, 2, 1, 1])
target = torch.tensor([2, 1, 2, 1])
preds = tensor([1, 2, 1, 1])
target = tensor([2, 1, 2, 1])
cl_metric = metric_class(average="none", num_classes=3)
cl_metric(preds, target)
@ -152,8 +153,8 @@ def test_no_support(metric_class, metric_fn):
in this case (zero_division is for now not configurable and equals 0).
"""
preds = torch.tensor([1, 1, 0, 0])
target = torch.tensor([0, 0, 0, 0])
preds = tensor([1, 1, 0, 0])
target = tensor([0, 0, 0, 0])
cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0)
cl_metric(preds, target)
@ -198,8 +199,8 @@ class TestPrecisionRecall(MetricTester):
self,
ddp: bool,
dist_sync_on_step: bool,
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
@ -248,8 +249,8 @@ class TestPrecisionRecall(MetricTester):
def test_precision_recall_fn(
self,
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
@ -316,31 +317,31 @@ def test_precision_recall_joint(average):
assert torch.equal(recall_result, prec_recall_result[1])
_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
@pytest.mark.parametrize(
"k, preds, target, average, expected_prec, expected_recall",
[
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)),
(1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)),
(2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)),
(1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)),
(1, _ml_k_preds, _ml_k_target, "micro", tensor(0.0), tensor(0.0)),
(2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6), tensor(1 / 3)),
],
)
def test_top_k(
metric_class,
metric_fn,
k: int,
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
average: str,
expected_prec: torch.Tensor,
expected_recall: torch.Tensor,
expected_prec: Tensor,
expected_recall: Tensor,
):
"""A simple test to check that top_k works as expected.

Просмотреть файл

@ -17,6 +17,7 @@ import numpy as np
import pytest
import torch
from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve
from torch import tensor
from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
@ -101,10 +102,10 @@ class TestPrecisionRecallCurve(MetricTester):
[pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])]
)
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
p, r, t = precision_recall_curve(tensor(pred), tensor(target))
assert p.size() == r.size()
assert p.size(0) == t.size(0) + 1
assert torch.allclose(p, torch.tensor(expected_p).to(p))
assert torch.allclose(r, torch.tensor(expected_r).to(r))
assert torch.allclose(t, torch.tensor(expected_t).to(t))
assert torch.allclose(p, tensor(expected_p).to(p))
assert torch.allclose(r, tensor(expected_r).to(r))
assert torch.allclose(t, tensor(expected_t).to(t))

Просмотреть файл

@ -17,6 +17,7 @@ import numpy as np
import pytest
import torch
from sklearn.metrics import roc_curve as sk_roc_curve
from torch import tensor
from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
@ -104,9 +105,9 @@ class TestROC(MetricTester):
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
])
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target))
fpr, tpr, thresh = roc(tensor(pred), tensor(target))
assert fpr.shape == tpr.shape
assert fpr.size(0) == thresh.size(0)
assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr))
assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
assert torch.allclose(fpr, tensor(expected_fpr).to(fpr))
assert torch.allclose(tpr, tensor(expected_tpr).to(tpr))

Просмотреть файл

@ -18,6 +18,7 @@ import numpy as np
import pytest
import torch
from sklearn.metrics import multilabel_confusion_matrix
from torch import Tensor, tensor
from tests.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
from tests.classification.inputs import _input_multiclass_prob as _input_mccls_prob
@ -159,8 +160,8 @@ class TestStatScores(MetricTester):
ddp: bool,
dist_sync_on_step: bool,
sk_fn: Callable,
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
reduce: str,
mdmc_reduce: Optional[str],
num_classes: Optional[int],
@ -202,8 +203,8 @@ class TestStatScores(MetricTester):
def test_stat_scores_fn(
self,
sk_fn: Callable,
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
reduce: str,
mdmc_reduce: Optional[str],
num_classes: Optional[int],
@ -239,26 +240,26 @@ class TestStatScores(MetricTester):
)
_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
@pytest.mark.parametrize(
"k, preds, target, reduce, expected",
[
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])),
(1, _ml_k_preds, _ml_k_target, "micro", torch.tensor([0, 3, 3, 3, 3])),
(2, _ml_k_preds, _ml_k_target, "micro", torch.tensor([1, 5, 1, 2, 3])),
(1, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])),
(2, _mc_k_preds, _mc_k_target, "macro", torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])),
(1, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])),
(2, _ml_k_preds, _ml_k_target, "macro", torch.tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])),
(1, _mc_k_preds, _mc_k_target, "micro", tensor([2, 1, 5, 1, 3])),
(2, _mc_k_preds, _mc_k_target, "micro", tensor([3, 3, 3, 0, 3])),
(1, _ml_k_preds, _ml_k_target, "micro", tensor([0, 3, 3, 3, 3])),
(2, _ml_k_preds, _ml_k_target, "micro", tensor([1, 5, 1, 2, 3])),
(1, _mc_k_preds, _mc_k_target, "macro", tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])),
(2, _mc_k_preds, _mc_k_target, "macro", tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])),
(1, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])),
(2, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])),
],
)
def test_top_k(k: int, preds: torch.Tensor, target: torch.Tensor, reduce: str, expected: torch.Tensor):
def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Tensor):
""" A simple test to check that top_k works as expected """
class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3)

Просмотреть файл

@ -14,6 +14,7 @@
import pytest
import torch
from pytorch_lightning import seed_everything
from torch import Tensor, tensor
from torchmetrics.functional import dice_score
from torchmetrics.functional.classification.precision_recall_curve import _binary_clf_curve
@ -21,7 +22,7 @@ from torchmetrics.utilities.data import get_num_classes, to_categorical, to_oneh
def test_onehot():
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
test_tensor = tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
expected = torch.stack([
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
@ -48,7 +49,7 @@ def test_to_categorical():
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
]).to(torch.float)
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
expected = tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
assert expected.shape == (2, 5)
assert test_tensor.shape == (2, 10, 5)
@ -77,15 +78,15 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
# because when the array changes, you also have to fix the shape
seed_everything(0)
pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100
target = torch.tensor([0, 1] * 50, dtype=torch.int)
target = tensor([0, 1] * 50, dtype=torch.int)
if sample_weight is not None:
sample_weight = torch.ones_like(pred) * sample_weight
fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
assert isinstance(tps, torch.Tensor)
assert isinstance(fps, torch.Tensor)
assert isinstance(thresh, torch.Tensor)
assert isinstance(tps, Tensor)
assert isinstance(fps, Tensor)
assert isinstance(thresh, Tensor)
assert tps.shape == (exp_shape, )
assert fps.shape == (exp_shape, )
assert thresh.shape == (exp_shape, )
@ -98,5 +99,5 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.),
])
def test_dice_score(pred, target, expected):
score = dice_score(torch.tensor(pred), torch.tensor(target))
score = dice_score(tensor(pred), tensor(target))
assert score == expected

Просмотреть файл

@ -13,6 +13,7 @@
# limitations under the License.
import pytest
import torch
from torch import Tensor
from torchmetrics.functional import image_gradients
@ -73,8 +74,8 @@ def test_multi_batch_image_gradients():
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
]
true_dy = torch.Tensor(true_dy)
true_dx = torch.Tensor(true_dx)
true_dy = Tensor(true_dy)
true_dx = Tensor(true_dx)
dy, dx = image_gradients(image)
@ -113,8 +114,8 @@ def test_image_gradients():
[1., 1., 1., 1., 0.],
]
true_dy = torch.Tensor(true_dy)
true_dx = torch.Tensor(true_dx)
true_dy = Tensor(true_dy)
true_dx = Tensor(true_dx)
dy, dx = image_gradients(image)

Просмотреть файл

@ -14,6 +14,7 @@
import pytest
import torch
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
from torch import tensor
from torchmetrics.functional import bleu_score
@ -62,20 +63,20 @@ def test_bleu_score(weights, n_gram, smooth_func, smooth):
smoothing_function=smooth_func,
)
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, torch.tensor(nltk_output))
assert torch.allclose(pl_output, tensor(nltk_output))
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, torch.tensor(nltk_output))
assert torch.allclose(pl_output, tensor(nltk_output))
def test_bleu_empty():
hyp = [[]]
ref = [[[]]]
assert bleu_score(hyp, ref) == torch.tensor(0.0)
assert bleu_score(hyp, ref) == tensor(0.0)
def test_no_4_gram():
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
assert bleu_score(hyps, refs) == torch.tensor(0.0)
assert bleu_score(hyps, refs) == tensor(0.0)

Просмотреть файл

@ -14,6 +14,7 @@
import pytest
import torch
from sklearn.metrics import pairwise
from torch import tensor
from torchmetrics.functional import embedding_similarity
@ -40,6 +41,6 @@ def test_against_sklearn(similarity, reduction):
return dist
sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), similarity=similarity, reduction=reduction)
sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device)
sk_dist = tensor(sk_dist, dtype=torch.float, device=device)
assert torch.allclose(sk_dist, pl_dist)

Просмотреть файл

@ -1,7 +1,5 @@
from distutils.version import LooseVersion
from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6
import torch
_MARK_TORCH_MIN_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.4"), reason='required PT >= 1.4')
_MARK_TORCH_MIN_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5"), reason='required PT >= 1.5')
_MARK_TORCH_MIN_1_6 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6"), reason='required PT >= 1.6')
_MARK_TORCH_MIN_1_4 = dict(condition=_TORCH_LOWER_1_4, reason='required PT >= 1.4')
_MARK_TORCH_MIN_1_5 = dict(condition=_TORCH_LOWER_1_5, reason='required PT >= 1.5')
_MARK_TORCH_MIN_1_6 = dict(condition=_TORCH_LOWER_1_6, reason='required PT >= 1.6')

Просмотреть файл

@ -20,6 +20,7 @@ from typing import Callable
import numpy as np
import pytest
import torch
from torch import Tensor, tensor
from torch.multiprocessing import Pool, set_start_method
from torchmetrics import Metric
@ -51,7 +52,7 @@ def _assert_allclose(pl_result, sk_result, atol: float = 1e-8):
a certain tolerance
"""
# single output compare
if isinstance(pl_result, torch.Tensor):
if isinstance(pl_result, Tensor):
assert np.allclose(pl_result.numpy(), sk_result, atol=atol, equal_nan=True)
# multi output compare
elif isinstance(pl_result, (tuple, list)):
@ -69,14 +70,14 @@ def _assert_tensor(pl_result):
for plr in pl_result:
_assert_tensor(plr)
else:
assert isinstance(pl_result, torch.Tensor)
assert isinstance(pl_result, Tensor)
def _class_test(
rank: int,
worldsize: int,
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
metric_class: Metric,
sk_metric: Callable,
dist_sync_on_step: bool,
@ -140,8 +141,8 @@ def _class_test(
def _functional_test(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
metric_functional: Callable,
sk_metric: Callable,
metric_args: dict = {},
@ -195,8 +196,8 @@ class MetricTester:
def run_functional_metric_test(
self,
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
metric_functional: Callable,
sk_metric: Callable,
metric_args: dict = {},
@ -223,8 +224,8 @@ class MetricTester:
def run_class_metric_test(
self,
ddp: bool,
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
metric_class: Metric,
sk_metric: Callable,
dist_sync_on_step: bool,
@ -289,7 +290,7 @@ class DummyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None)
self.add_state("x", tensor(0.0), dist_reduce_fx=None)
def update(self):
pass

Просмотреть файл

@ -13,6 +13,7 @@
# limitations under the License.
import torch
from pytorch_lightning import Trainer
from torch import tensor
from tests.integrations.lightning_models import BoringModel
from torchmetrics import Metric
@ -22,7 +23,7 @@ class SumMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")
def update(self, x):
self.x += x
@ -35,7 +36,7 @@ class DiffMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")
def update(self, x):
self.x -= x
@ -116,8 +117,8 @@ def test_metric_lightning(tmpdir):
# trainer.fit(model)
#
# logged = trainer.logged_metrics
# assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum)
# assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)
# assert torch.allclose(tensor(logged["sum_step"]), model.sum)
# assert torch.allclose(tensor(logged["sum_epoch"]), model.sum)
# todo: need to be fixed
# def test_scriptable(tmpdir):
@ -193,5 +194,5 @@ def test_metric_lightning(tmpdir):
# trainer.fit(model)
#
# logged = trainer.logged_metrics
# assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum)
# assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)
# assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum)
# assert torch.allclose(tensor(logged["DiffMetric_epoch"]), model.diff)

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Optional
import torch
from torch import Tensor, tensor
from torchmetrics.functional.classification.accuracy import _accuracy_compute, _accuracy_update
from torchmetrics.metric import Metric
@ -112,8 +113,8 @@ class Accuracy(Metric):
dist_sync_fn=dist_sync_fn,
)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("correct", default=tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
if not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")
@ -125,7 +126,7 @@ class Accuracy(Metric):
self.top_k = top_k
self.subset_accuracy = subset_accuracy
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets. See :ref:`references/modules:input types` for more information
on input types.
@ -142,7 +143,7 @@ class Accuracy(Metric):
self.correct += correct
self.total += total
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes accuracy based on inputs passed in to ``update`` previously.
"""

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Optional
import torch
from torch import Tensor
from torchmetrics.functional.classification.auc import _auc_compute, _auc_update
from torchmetrics.metric import Metric
@ -68,7 +69,7 @@ class AUC(Metric):
' For large datasets this may lead to large memory footprint.'
)
def update(self, x: torch.Tensor, y: torch.Tensor):
def update(self, x: Tensor, y: Tensor):
"""
Update state with predictions and targets.
@ -81,7 +82,7 @@ class AUC(Metric):
self.x.append(x)
self.y.append(y)
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes AUC based on inputs passed in to ``update`` previously.
"""

Просмотреть файл

@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from typing import Any, Callable, Optional
import torch
from torch import Tensor
from torchmetrics.functional.classification.auroc import _auroc_compute, _auroc_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
class AUROC(Metric):
@ -120,7 +121,7 @@ class AUROC(Metric):
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
if _TORCH_LOWER_1_6:
raise RuntimeError(
'`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6'
)
@ -134,7 +135,7 @@ class AUROC(Metric):
' For large datasets this may lead to large memory footprint.'
)
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.
@ -154,7 +155,7 @@ class AUROC(Metric):
)
self.mode = mode
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes AUROC based on inputs passed in to ``update`` previously.
"""

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, List, Optional, Union
import torch
from torch import Tensor
from torchmetrics.functional.classification.average_precision import (
_average_precision_compute,
@ -97,7 +98,7 @@ class AveragePrecision(Metric):
' For large datasets this may lead to large memory footprint.'
)
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.
@ -113,7 +114,7 @@ class AveragePrecision(Metric):
self.num_classes = num_classes
self.pos_label = pos_label
def compute(self) -> Union[torch.Tensor, List[torch.Tensor]]:
def compute(self) -> Union[Tensor, List[Tensor]]:
"""
Compute the average precision score

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Optional
import torch
from torch import Tensor
from torchmetrics.functional.classification.cohen_kappa import _cohen_kappa_compute, _cohen_kappa_update
from torchmetrics.metric import Metric
@ -74,6 +75,7 @@ class CohenKappa(Metric):
>>> cohenkappa(preds, target)
tensor(0.5000)
"""
def __init__(
self,
num_classes: int,
@ -99,7 +101,7 @@ class CohenKappa(Metric):
self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.
@ -110,7 +112,7 @@ class CohenKappa(Metric):
confmat = _cohen_kappa_update(preds, target, self.num_classes, self.threshold)
self.confmat += confmat
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes cohen kappa score
"""

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Optional
import torch
from torch import Tensor
from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
from torchmetrics.metric import Metric
@ -96,7 +97,7 @@ class ConfusionMatrix(Metric):
self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.
@ -107,7 +108,7 @@ class ConfusionMatrix(Metric):
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold)
self.confmat += confmat
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes confusion matrix
"""

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Optional
import torch
from torch import Tensor
from torchmetrics.functional.classification.f_beta import _fbeta_compute, _fbeta_update
from torchmetrics.metric import Metric
@ -109,7 +110,7 @@ class FBeta(Metric):
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.
@ -125,7 +126,7 @@ class FBeta(Metric):
self.predicted_positives += predicted_positives
self.actual_positives += actual_positives
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes fbeta over state.
"""

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Optional
import torch
from torch import Tensor, tensor
from torchmetrics.functional.classification.hamming_distance import _hamming_distance_compute, _hamming_distance_update
from torchmetrics.metric import Metric
@ -79,14 +80,14 @@ class HammingDistance(Metric):
dist_sync_fn=dist_sync_fn,
)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("correct", default=tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
if not 0 < threshold < 1:
raise ValueError("The `threshold` should lie in the (0,1) interval.")
self.threshold = threshold
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets. See :ref:`references/modules:input types` for more information
on input types.
@ -100,7 +101,7 @@ class HammingDistance(Metric):
self.correct += correct
self.total += total
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes hamming distance based on inputs passed in to ``update`` previously.
"""

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Optional
import torch
from torch import Tensor
from torchmetrics.classification.confusion_matrix import ConfusionMatrix
from torchmetrics.functional.classification.iou import _iou_from_confmat
@ -100,7 +101,7 @@ class IoU(ConfusionMatrix):
self.ignore_index = ignore_index
self.absent_score = absent_score
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes intersection over union (IoU)
"""

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Optional
import torch
from torch import Tensor
from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute
@ -151,7 +152,7 @@ class Precision(StatScores):
self.average = average
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes the precision score based on inputs passed in to ``update`` previously.
@ -299,7 +300,7 @@ class Recall(StatScores):
self.average = average
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes the recall score based on inputs passed in to ``update`` previously.

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torchmetrics.functional.classification.precision_recall_curve import (
_precision_recall_curve_compute,
@ -108,7 +109,7 @@ class PrecisionRecallCurve(Metric):
' For large datasets this may lead to large memory footprint.'
)
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.
@ -124,10 +125,7 @@ class PrecisionRecallCurve(Metric):
self.num_classes = num_classes
self.pos_label = pos_label
def compute(
self
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor]]]:
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""
Compute the precision-recall curve

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torchmetrics.functional.classification.roc import _roc_compute, _roc_update
from torchmetrics.metric import Metric
@ -107,7 +108,7 @@ class ROC(Metric):
' For large datasets this may lead to large memory footprint.'
)
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.
@ -121,10 +122,7 @@ class ROC(Metric):
self.num_classes = num_classes
self.pos_label = pos_label
def compute(
self
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor]]]:
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""
Compute the receiver operating characteristic

Просмотреть файл

@ -15,6 +15,7 @@ from typing import Any, Callable, Optional, Tuple
import numpy as np
import torch
from torch import Tensor, tensor
from torchmetrics.functional.classification.stat_scores import _stat_scores_compute, _stat_scores_update
from torchmetrics.metric import Metric
@ -175,7 +176,7 @@ class StatScores(Metric):
for s in ("tp", "fp", "tn", "fn"):
self.add_state(s, default=default(), dist_reduce_fx=reduce_fn)
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets. See :ref:`references/modules:input types` for more information
on input types.
@ -209,7 +210,7 @@ class StatScores(Metric):
self.tn.append(tn)
self.fn.append(fn)
def _get_final_stats(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Performs concatenation on the stat scores if neccesary,
before passing them to a compute function.
"""
@ -224,7 +225,7 @@ class StatScores(Metric):
return tp, fp, tn, fn
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes the stat scores based on inputs passed in to ``update`` previously.
@ -262,13 +263,13 @@ class StatScores(Metric):
def _reduce_stat_scores(
numerator: torch.Tensor,
denominator: torch.Tensor,
weights: Optional[torch.Tensor],
numerator: Tensor,
denominator: Tensor,
weights: Optional[Tensor],
average: str,
mdmc_average: Optional[str],
zero_division: int = 0,
) -> torch.Tensor:
) -> Tensor:
"""
Reduces scores of type ``numerator/denominator`` or
``weights * (numerator/denominator)``, if ``average='weighted'``.
@ -303,9 +304,9 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
else:
weights = weights.float()
numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator)
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)
numerator = torch.where(zero_div_mask, tensor(float(zero_division), device=numerator.device), numerator)
denominator = torch.where(zero_div_mask | ignore_mask, tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, tensor(0.0, device=weights.device), weights)
if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
weights = weights / weights.sum(dim=-1, keepdim=True)
@ -313,14 +314,14 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
scores = weights * (numerator / denominator)
# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)
scores = torch.where(torch.isnan(scores), tensor(float(zero_division), device=scores.device), scores)
if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
scores = scores.mean(dim=0)
ignore_mask = ignore_mask.sum(dim=0).bool()
if average in (AverageMethod.NONE, None):
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
scores = torch.where(ignore_mask, tensor(np.nan, device=scores.device), scores)
else:
scores = scores.sum()

Просмотреть файл

@ -60,10 +60,11 @@ class MetricCollection(nn.ModuleDict):
>>> metrics.persistent()
"""
def __init__(
self,
metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]],
prefix: Optional[str] = None
prefix: Optional[str] = None,
):
super().__init__()
if isinstance(metrics, dict):

Просмотреть файл

@ -14,14 +14,19 @@
from typing import Optional, Tuple
import torch
from torch import Tensor, tensor
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType
def _accuracy_update(
preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
preds: Tensor,
target: Tensor,
threshold: float,
top_k: Optional[int],
subset_accuracy: bool,
) -> Tuple[Tensor, Tensor]:
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
@ -30,32 +35,32 @@ def _accuracy_update(
if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
correct = (preds == target).all(dim=1).sum()
total = torch.tensor(target.shape[0], device=target.device)
total = tensor(target.shape[0], device=target.device)
elif mode == DataType.MULTILABEL and not subset_accuracy:
correct = (preds == target).sum()
total = torch.tensor(target.numel(), device=target.device)
total = tensor(target.numel(), device=target.device)
elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy):
correct = (preds * target).sum()
total = target.sum()
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
sample_correct = (preds * target).sum(dim=(1, 2))
correct = (sample_correct == target.shape[2]).sum()
total = torch.tensor(target.shape[0], device=target.device)
total = tensor(target.shape[0], device=target.device)
return correct, total
def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor:
def _accuracy_compute(correct: Tensor, total: Tensor) -> Tensor:
return correct.float() / total
def accuracy(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
threshold: float = 0.5,
top_k: Optional[int] = None,
subset_accuracy: bool = False,
) -> torch.Tensor:
) -> Tensor:
r"""Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:
.. math::

Просмотреть файл

@ -14,11 +14,12 @@
from typing import Tuple
import torch
from torch import Tensor
from torchmetrics.utilities.data import _stable_1d_sort
def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
if x.ndim > 1 or y.ndim > 1:
raise ValueError(
f'Expected both `x` and `y` tensor to be 1d, but got'
@ -32,7 +33,7 @@ def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.T
return x, y
def _auc_compute(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor:
def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
if reorder:
x, x_idx = _stable_1d_sort(x)
y = y[x_idx]
@ -51,7 +52,7 @@ def _auc_compute(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> tor
return direction * torch.trapz(y, x)
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor:
def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
"""
Computes Area Under the Curve (AUC) using the trapezoidal rule

Просмотреть файл

@ -11,18 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from typing import Optional, Sequence, Tuple
import torch
from torch import Tensor, tensor
from torchmetrics.functional.classification.auc import auc
from torchmetrics.functional.classification.roc import roc
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import AverageMethod, DataType
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6
def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]:
def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, str]:
# use _input_format_classification for validating the input and get the mode of data
_, _, mode = _input_format_classification(preds, target)
@ -39,15 +40,15 @@ def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens
def _auroc_compute(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
mode: str,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> torch.Tensor:
) -> Tensor:
# binary mode override num_classes
if mode == 'binary':
num_classes = 1
@ -57,7 +58,7 @@ def _auroc_compute(
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
if _TORCH_LOWER_1_6:
raise RuntimeError(
"`max_fpr` argument requires `torch.bucketize` which"
" is not available below PyTorch version 1.6"
@ -109,7 +110,7 @@ def _auroc_compute(
return auc(fpr, tpr)
max_fpr = torch.tensor(max_fpr, device=fpr.device)
max_fpr = tensor(max_fpr, device=fpr.device)
# Add a single point at max_fpr and interpolate its tpr value
stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True)
weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
@ -128,14 +129,14 @@ def _auroc_compute(
def auroc(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = 'macro',
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> torch.Tensor:
) -> Tensor:
""" Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC)
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations>`_

Просмотреть файл

@ -14,6 +14,7 @@
from typing import List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
from torchmetrics.functional.classification.precision_recall_curve import (
_precision_recall_curve_compute,
@ -22,21 +23,21 @@ from torchmetrics.functional.classification.precision_recall_curve import (
def _average_precision_update(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
) -> Tuple[Tensor, Tensor, int, int]:
return _precision_recall_curve_update(preds, target, num_classes, pos_label)
def _average_precision_compute(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: int,
pos_label: int,
sample_weights: Optional[Sequence] = None
) -> Union[List[torch.Tensor], torch.Tensor]:
sample_weights: Optional[Sequence] = None,
) -> Union[List[Tensor], Tensor]:
precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
# Return the step function integral
# The following works because the last entry of precision is
@ -51,12 +52,12 @@ def _average_precision_compute(
def average_precision(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[List[torch.Tensor], torch.Tensor]:
) -> Union[List[Tensor], Tensor]:
"""
Computes the average precision score.

Просмотреть файл

@ -14,13 +14,14 @@
from typing import Optional
import torch
from torch import Tensor
from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
_cohen_kappa_update = _confusion_matrix_update
def _cohen_kappa_compute(confmat: torch.Tensor, weights: Optional[str] = None) -> torch.Tensor:
def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tensor:
confmat = _confusion_matrix_compute(confmat)
n_classes = confmat.shape[0]
sum0 = confmat.sum(dim=0, keepdim=True)
@ -39,20 +40,18 @@ def _cohen_kappa_compute(confmat: torch.Tensor, weights: Optional[str] = None) -
else:
w_mat = torch.pow(w_mat - w_mat.T, 2.0)
else:
raise ValueError(f"Received {weights} for argument ``weights`` but should be either"
" None, 'linear' or 'quadratic'")
raise ValueError(
f"Received {weights} for argument ``weights`` but should be either"
" None, 'linear' or 'quadratic'"
)
k = torch.sum(w_mat * confmat) / torch.sum(w_mat * expected)
return 1 - k
def cohen_kappa(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
weights: Optional[str] = None,
threshold: float = 0.5
) -> torch.Tensor:
preds: Tensor, target: Tensor, num_classes: int, weights: Optional[str] = None, threshold: float = 0.5
) -> Tensor:
r"""
Calculates `Cohen's kappa score <https://en.wikipedia.org/wiki/Cohen%27s_kappa>`_ that measures
inter-annotator agreement. It is defined as

Просмотреть файл

@ -14,15 +14,14 @@
from typing import Optional
import torch
from torch import Tensor
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType
def _confusion_matrix_update(
preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5
) -> torch.Tensor:
def _confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5) -> Tensor:
preds, target, mode = _input_format_classification(preds, target, threshold)
if mode not in (DataType.BINARY, DataType.MULTILABEL):
preds = preds.argmax(dim=1)
@ -33,7 +32,7 @@ def _confusion_matrix_update(
return confmat
def _confusion_matrix_compute(confmat: torch.Tensor, normalize: Optional[str] = None) -> torch.Tensor:
def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor:
allowed_normalize = ('true', 'pred', 'all', 'none', None)
assert normalize in allowed_normalize, \
f"Argument average needs to one of the following: {allowed_normalize}"
@ -54,12 +53,8 @@ def _confusion_matrix_compute(confmat: torch.Tensor, normalize: Optional[str] =
def confusion_matrix(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5
) -> torch.Tensor:
preds: Tensor, target: Tensor, num_classes: int, normalize: Optional[str] = None, threshold: float = 0.5
) -> Tensor:
"""
Computes the confusion matrix. Works with binary, multiclass, and multilabel data.
Accepts probabilities from a model output or integer class values in prediction.

Просмотреть файл

@ -14,17 +14,18 @@
from typing import Tuple
import torch
from torch import Tensor
from torchmetrics.utilities.data import to_categorical
from torchmetrics.utilities.distributed import reduce
def _stat_scores(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
class_index: int,
argmax_dim: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Calculates the number of true positive, false positive, true negative
and false negative for a specific class
@ -61,13 +62,13 @@ def _stat_scores(
def dice_score(
pred: torch.Tensor,
target: torch.Tensor,
pred: Tensor,
target: Tensor,
bg: bool = False,
nan_score: float = 0.0,
no_fg_score: float = 0.0,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
) -> Tensor:
"""
Compute dice score from prediction scores

Просмотреть файл

@ -14,18 +14,19 @@
from typing import Tuple
import torch
from torch import Tensor
from torchmetrics.utilities.checks import _input_format_classification_one_hot
from torchmetrics.utilities.distributed import class_reduce
def _fbeta_update(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: int,
threshold: float = 0.5,
multilabel: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
multilabel: bool = False,
) -> Tuple[Tensor, Tensor, Tensor]:
preds, target = _input_format_classification_one_hot(num_classes, preds, target, threshold, multilabel)
true_positives = torch.sum(preds * target, dim=1)
predicted_positives = torch.sum(preds, dim=1)
@ -34,12 +35,12 @@ def _fbeta_update(
def _fbeta_compute(
true_positives: torch.Tensor,
predicted_positives: torch.Tensor,
actual_positives: torch.Tensor,
true_positives: Tensor,
predicted_positives: Tensor,
actual_positives: Tensor,
beta: float = 1.0,
average: str = "micro"
) -> torch.Tensor:
average: str = "micro",
) -> Tensor:
if average == "micro":
precision = true_positives.sum().float() / predicted_positives.sum()
recall = true_positives.sum().float() / actual_positives.sum()
@ -53,14 +54,14 @@ def _fbeta_compute(
def fbeta(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: int,
beta: float = 1.0,
threshold: float = 0.5,
average: str = "micro",
multilabel: bool = False
) -> torch.Tensor:
) -> Tensor:
"""
Computes f_beta metric.
@ -106,13 +107,13 @@ def fbeta(
def f1(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: int,
threshold: float = 0.5,
average: str = "micro",
multilabel: bool = False
) -> torch.Tensor:
) -> Tensor:
"""
Computes F1 metric. F1 metrics correspond to a equally weighted average of the
precision and recall scores.

Просмотреть файл

@ -14,15 +14,16 @@
from typing import Tuple, Union
import torch
from torch import Tensor
from torchmetrics.utilities.checks import _input_format_classification
def _hamming_distance_update(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
threshold: float = 0.5,
) -> Tuple[torch.Tensor, int]:
) -> Tuple[Tensor, int]:
preds, target, _ = _input_format_classification(preds, target, threshold=threshold)
correct = (preds == target).sum()
@ -31,11 +32,11 @@ def _hamming_distance_update(
return correct, total
def _hamming_distance_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor:
def _hamming_distance_compute(correct: Tensor, total: Union[int, Tensor]) -> Tensor:
return 1 - correct.float() / total
def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
def hamming_distance(preds: Tensor, target: Tensor, threshold: float = 0.5) -> Tensor:
r"""
Computes the average `Hamming distance <https://en.wikipedia.org/wiki/Hamming_distance>`_ (also
known as Hamming loss) between targets and predictions:

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Optional
import torch
from torch import Tensor
from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update
from torchmetrics.utilities.data import get_num_classes
@ -21,7 +22,7 @@ from torchmetrics.utilities.distributed import reduce
def _iou_from_confmat(
confmat: torch.Tensor,
confmat: Tensor,
num_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
@ -44,14 +45,14 @@ def _iou_from_confmat(
def iou(
pred: torch.Tensor,
target: torch.Tensor,
pred: Tensor,
target: Tensor,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
) -> Tensor:
r"""
Computes `Intersection over union, or Jaccard index calculation <https://en.wikipedia.org/wiki/Jaccard_index>`_:

Просмотреть файл

@ -14,19 +14,20 @@
from typing import Optional
import torch
from torch import Tensor
from torchmetrics.classification.stat_scores import _reduce_stat_scores
from torchmetrics.functional.classification.stat_scores import _stat_scores_update
def _precision_compute(
tp: torch.Tensor,
fp: torch.Tensor,
tn: torch.Tensor,
fn: torch.Tensor,
tp: Tensor,
fp: Tensor,
tn: Tensor,
fn: Tensor,
average: str,
mdmc_average: Optional[str],
) -> torch.Tensor:
) -> Tensor:
return _reduce_stat_scores(
numerator=tp,
denominator=tp + fp,
@ -37,8 +38,8 @@ def _precision_compute(
def precision(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
average: str = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
@ -46,7 +47,7 @@ def precision(
threshold: float = 0.5,
top_k: Optional[int] = None,
is_multiclass: Optional[bool] = None,
) -> torch.Tensor:
) -> Tensor:
r"""
Computes `Precision <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
@ -170,13 +171,13 @@ def precision(
def _recall_compute(
tp: torch.Tensor,
fp: torch.Tensor,
tn: torch.Tensor,
fn: torch.Tensor,
tp: Tensor,
fp: Tensor,
tn: Tensor,
fn: Tensor,
average: str,
mdmc_average: Optional[str],
) -> torch.Tensor:
) -> Tensor:
return _reduce_stat_scores(
numerator=tp,
denominator=tp + fn,
@ -187,8 +188,8 @@ def _recall_compute(
def recall(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
average: str = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
@ -196,7 +197,7 @@ def recall(
threshold: float = 0.5,
top_k: Optional[int] = None,
is_multiclass: Optional[bool] = None,
) -> torch.Tensor:
) -> Tensor:
r"""
Computes `Recall <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
@ -320,8 +321,8 @@ def recall(
def precision_recall(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
average: str = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
@ -329,7 +330,7 @@ def precision_recall(
threshold: float = 0.5,
top_k: Optional[int] = None,
is_multiclass: Optional[bool] = None,
) -> torch.Tensor:
) -> Tensor:
r"""
Computes `Precision and Recall <https://en.wikipedia.org/wiki/Precision_and_recall>`_:

Просмотреть файл

@ -15,21 +15,22 @@ from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, tensor
from torchmetrics.utilities import rank_zero_warn
def _binary_clf_curve(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
sample_weights: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[Tensor, Tensor, Tensor]:
"""
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
"""
if sample_weights is not None and not isinstance(sample_weights, torch.Tensor):
sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float)
if sample_weights is not None and not isinstance(sample_weights, Tensor):
sample_weights = tensor(sample_weights, device=preds.device, dtype=torch.float)
# remove class dimension if necessary
if preds.ndim > target.ndim:
@ -63,11 +64,11 @@ def _binary_clf_curve(
def _precision_recall_curve_update(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
) -> Tuple[Tensor, Tensor, int, int]:
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
# single class evaluation
@ -99,13 +100,12 @@ def _precision_recall_curve_update(
def _precision_recall_curve_compute(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: int,
pos_label: int,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor]]]:
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
if num_classes == 1:
fps, tps, thresholds = _binary_clf_curve(
@ -149,13 +149,12 @@ def _precision_recall_curve_compute(
def precision_recall_curve(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor]]]:
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""
Computes precision-recall pairs for different thresholds.

Просмотреть файл

@ -14,6 +14,7 @@
from typing import List, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
from torchmetrics.functional.classification.precision_recall_curve import (
_binary_clf_curve,
@ -22,22 +23,21 @@ from torchmetrics.functional.classification.precision_recall_curve import (
def _roc_update(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
) -> Tuple[Tensor, Tensor, int, int]:
return _precision_recall_curve_update(preds, target, num_classes, pos_label)
def _roc_compute(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: int,
pos_label: int,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor]]]:
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
if num_classes == 1:
fps, tps, thresholds = _binary_clf_curve(
@ -78,13 +78,12 @@ def _roc_compute(
def roc(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor]]]:
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""
Computes the Receiver Operating Characteristic (ROC).

Просмотреть файл

@ -14,21 +14,22 @@
from typing import Optional, Tuple
import torch
from torch import Tensor, tensor
from torchmetrics.utilities.checks import _input_format_classification
def _del_column(tensor: torch.Tensor, index: int):
def _del_column(tensor: Tensor, index: int):
""" Delete the column at index."""
return torch.cat([tensor[:, :index], tensor[:, (index + 1):]], 1)
def _stat_scores(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
reduce: str = "micro",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Calculate the number of tp, fp, tn, fn.
Args:
@ -74,8 +75,8 @@ def _stat_scores(
def _stat_scores_update(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
reduce: str = "micro",
mdmc_reduce: Optional[str] = None,
num_classes: Optional[int] = None,
@ -83,7 +84,7 @@ def _stat_scores_update(
threshold: float = 0.5,
is_multiclass: Optional[bool] = None,
ignore_index: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
preds, target, _ = _input_format_classification(
preds, target, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k
@ -121,7 +122,7 @@ def _stat_scores_update(
return tp, fp, tn, fn
def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor:
def _stat_scores_compute(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> Tensor:
outputs = [
tp.unsqueeze(-1),
@ -131,14 +132,14 @@ def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, f
tp.unsqueeze(-1) + fn.unsqueeze(-1), # support
]
outputs = torch.cat(outputs, -1)
outputs = torch.where(outputs < 0, torch.tensor(-1, device=outputs.device), outputs)
outputs = torch.where(outputs < 0, tensor(-1, device=outputs.device), outputs)
return outputs
def stat_scores(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
reduce: str = "micro",
mdmc_reduce: Optional[str] = None,
num_classes: Optional[int] = None,
@ -146,7 +147,7 @@ def stat_scores(
threshold: float = 0.5,
is_multiclass: Optional[bool] = None,
ignore_index: Optional[int] = None,
) -> torch.Tensor:
) -> Tensor:
"""Computes the number of true positives, false positives, true negatives, false negatives.
Related to `Type I and Type II errors <https://en.wikipedia.org/wiki/Type_I_and_type_II_errors>`__
and the `confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion>`__.

Просмотреть файл

@ -14,18 +14,19 @@
from typing import Tuple
import torch
from torch import Tensor
def _image_gradients_validate(img: torch.Tensor) -> torch.Tensor:
def _image_gradients_validate(img: Tensor) -> Tensor:
""" Validates whether img is a 4D torch Tensor """
if not isinstance(img, torch.Tensor):
raise TypeError(f"The `img` expects a value of <torch.Tensor> type but got {type(img)}")
if not isinstance(img, Tensor):
raise TypeError(f"The `img` expects a value of <Tensor> type but got {type(img)}")
if img.ndim != 4:
raise RuntimeError(f"The `img` expects a 4D tensor but got {img.ndim}D tensor")
def _compute_image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _compute_image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
""" Computes image gradients (dy/dx) for a given image """
batch_size, channels, height, width = img.shape
@ -44,7 +45,7 @@ def _compute_image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten
return dy, dx
def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
"""
Computes the `gradients <https://en.wikipedia.org/wiki/Image_gradient>`_ of a given image using finite difference

Просмотреть файл

@ -20,6 +20,7 @@ from collections import Counter
from typing import List, Sequence
import torch
from torch import Tensor, tensor
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
@ -45,11 +46,8 @@ def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
def bleu_score(
translate_corpus: Sequence[str],
reference_corpus: Sequence[str],
n_gram: int = 4,
smooth: bool = False
) -> torch.Tensor:
translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False
) -> Tensor:
"""
Calculate BLEU score of machine translated text with one or more references
@ -96,20 +94,20 @@ def bleu_score(
for counter in translation_counter:
denominator[len(counter) - 1] += translation_counter[counter]
trans_len = torch.tensor(c)
ref_len = torch.tensor(r)
trans_len = tensor(c)
ref_len = tensor(r)
if min(numerator) == 0.0:
return torch.tensor(0.0)
return tensor(0.0)
if smooth:
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
else:
precision_scores = numerator / denominator
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
log_precision_scores = tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
brevity_penalty = tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
bleu = brevity_penalty * geometric_mean
return bleu

Просмотреть файл

@ -14,13 +14,12 @@
from typing import Sequence, Tuple, Union
import torch
from torch import Tensor
from torchmetrics.utilities.checks import _check_same_shape
def _explained_variance_update(
preds: torch.Tensor, target: torch.Tensor
) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def _explained_variance_update(preds: Tensor, target: Tensor) -> Tuple[int, Tensor, Tensor, Tensor, Tensor]:
_check_same_shape(preds, target)
n_obs = preds.size(0)
@ -34,13 +33,13 @@ def _explained_variance_update(
def _explained_variance_compute(
n_obs: torch.Tensor,
sum_error: torch.Tensor,
sum_squared_error: torch.Tensor,
sum_target: torch.Tensor,
sum_squared_target: torch.Tensor,
n_obs: Tensor,
sum_error: Tensor,
sum_squared_error: Tensor,
sum_target: Tensor,
sum_squared_target: Tensor,
multioutput: str = "uniform_average",
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
) -> Union[Tensor, Sequence[Tensor]]:
diff_avg = sum_error / n_obs
numerator = sum_squared_error / n_obs - diff_avg**2
@ -67,10 +66,10 @@ def _explained_variance_compute(
def explained_variance(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
multioutput: str = "uniform_average",
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
) -> Union[Tensor, Sequence[Tensor]]:
"""
Computes explained variance.

Просмотреть файл

@ -14,22 +14,23 @@
from typing import Tuple
import torch
from torch import Tensor
from torchmetrics.utilities.checks import _check_same_shape
def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
def _mean_absolute_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
_check_same_shape(preds, target)
sum_abs_error = torch.sum(torch.abs(preds - target))
n_obs = target.numel()
return sum_abs_error, n_obs
def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor:
def _mean_absolute_error_compute(sum_abs_error: Tensor, n_obs: int) -> Tensor:
return sum_abs_error / n_obs
def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def mean_absolute_error(preds: Tensor, target: Tensor) -> Tensor:
"""
Computes mean absolute error

Просмотреть файл

@ -14,11 +14,12 @@
from typing import Tuple
import torch
from torch import Tensor
from torchmetrics.utilities.checks import _check_same_shape
def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
def _mean_relative_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
_check_same_shape(preds, target)
target_nz = target.clone()
target_nz[target == 0] = 1
@ -27,11 +28,11 @@ def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tu
return sum_rltv_error, n_obs
def _mean_relative_error_compute(sum_rltv_error: torch.Tensor, n_obs: int) -> torch.Tensor:
def _mean_relative_error_compute(sum_rltv_error: Tensor, n_obs: int) -> Tensor:
return sum_rltv_error / n_obs
def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def mean_relative_error(preds: Tensor, target: Tensor) -> Tensor:
"""
Computes mean relative error

Просмотреть файл

@ -14,22 +14,23 @@
from typing import Tuple
import torch
from torch import Tensor
from torchmetrics.utilities.checks import _check_same_shape
def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
_check_same_shape(preds, target)
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = target.numel()
return sum_squared_error, n_obs
def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor:
def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int) -> Tensor:
return sum_squared_error / n_obs
def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor:
"""
Computes mean squared error

Просмотреть файл

@ -14,22 +14,23 @@
from typing import Tuple
import torch
from torch import Tensor
from torchmetrics.utilities.checks import _check_same_shape
def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
def _mean_squared_log_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
_check_same_shape(preds, target)
sum_squared_log_error = torch.sum(torch.pow(torch.log1p(preds) - torch.log1p(target), 2))
n_obs = target.numel()
return sum_squared_log_error, n_obs
def _mean_squared_log_error_compute(sum_squared_log_error: torch.Tensor, n_obs: int) -> torch.Tensor:
def _mean_squared_log_error_compute(sum_squared_log_error: Tensor, n_obs: int) -> Tensor:
return sum_squared_log_error / n_obs
def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def mean_squared_log_error(preds: Tensor, target: Tensor) -> Tensor:
"""
Computes mean squared log error

Просмотреть файл

@ -14,28 +14,31 @@
from typing import Optional, Tuple, Union
import torch
from torch import Tensor, tensor
from torchmetrics.utilities import rank_zero_warn, reduce
def _psnr_compute(
sum_squared_error: torch.Tensor,
n_obs: torch.Tensor,
data_range: torch.Tensor,
sum_squared_error: Tensor,
n_obs: Tensor,
data_range: Tensor,
base: float = 10.0,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
) -> Tensor:
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
psnr = psnr_base_e * (10 / torch.log(tensor(base)))
return reduce(psnr, reduction=reduction)
def _psnr_update(preds: torch.Tensor,
target: torch.Tensor,
dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
def _psnr_update(
preds: Tensor,
target: Tensor,
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Tuple[Tensor, Tensor]:
if dim is None:
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = torch.tensor(target.numel(), device=target.device)
n_obs = tensor(target.numel(), device=target.device)
return sum_squared_error, n_obs
sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim)
@ -45,22 +48,22 @@ def _psnr_update(preds: torch.Tensor,
else:
dim_list = list(dim)
if not dim_list:
n_obs = torch.tensor(target.numel(), device=target.device)
n_obs = tensor(target.numel(), device=target.device)
else:
n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod()
n_obs = tensor(target.size(), device=target.device)[dim_list].prod()
n_obs = n_obs.expand_as(sum_squared_error)
return sum_squared_error, n_obs
def psnr(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
data_range: Optional[float] = None,
base: float = 10.0,
reduction: str = 'elementwise_mean',
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> torch.Tensor:
) -> Tensor:
"""
Computes the peak signal-to-noise ratio
@ -102,6 +105,6 @@ def psnr(
data_range = target.max() - target.min()
else:
data_range = torch.tensor(float(data_range))
data_range = tensor(float(data_range))
sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)

Просмотреть файл

@ -14,15 +14,13 @@
from typing import Tuple
import torch
from torch import Tensor
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _check_same_shape
def _r2score_update(
preds: torch.tensor,
target: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def _r2score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
_check_same_shape(preds, target)
if preds.ndim > 2:
raise ValueError(
@ -41,13 +39,13 @@ def _r2score_update(
def _r2score_compute(
sum_squared_error: torch.Tensor,
sum_error: torch.Tensor,
residual: torch.Tensor,
total: torch.Tensor,
sum_squared_error: Tensor,
sum_error: Tensor,
residual: Tensor,
total: Tensor,
adjusted: int = 0,
multioutput: str = "uniform_average"
) -> torch.Tensor:
multioutput: str = "uniform_average",
) -> Tensor:
mean_error = sum_error / total
diff = sum_squared_error - sum_error * mean_error
raw_scores = 1 - (residual / diff)
@ -82,11 +80,11 @@ def _r2score_compute(
def r2score(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
adjusted: int = 0,
multioutput: str = "uniform_average",
) -> torch.Tensor:
) -> Tensor:
r"""
Computes r2 score also known as `coefficient of determination
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Optional, Sequence, Tuple
import torch
from torch import Tensor
from torch.nn import functional as F
from torchmetrics.utilities.checks import _check_same_shape
@ -36,10 +37,7 @@ def _gaussian_kernel(
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
def _ssim_update(
preds: torch.Tensor,
target: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
if preds.dtype != target.dtype:
raise TypeError(
"Expected `preds` and `target` to have the same data type."
@ -55,8 +53,8 @@ def _ssim_update(
def _ssim_compute(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
@ -114,15 +112,15 @@ def _ssim_compute(
def ssim(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
) -> torch.Tensor:
) -> Tensor:
"""
Computes Structual Similarity Index Measure

Просмотреть файл

@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
def embedding_similarity(
batch: torch.Tensor,
similarity: str = 'cosine',
reduction: str = 'none',
zero_diagonal: bool = True
) -> torch.Tensor:
batch: Tensor, similarity: str = 'cosine', reduction: str = 'none', zero_diagonal: bool = True
) -> Tensor:
"""
Computes representation similarity

Просмотреть файл

@ -19,7 +19,7 @@ from copy import deepcopy
from typing import Any, Callable, Optional, Union
import torch
from torch import nn
from torch import Tensor, nn
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
@ -122,7 +122,7 @@ class Metric(nn.Module, ABC):
"""
if (
not isinstance(default, torch.Tensor) and not isinstance(default, list) # noqa: W503
not isinstance(default, Tensor) and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503
):
raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")
@ -175,14 +175,14 @@ class Metric(nn.Module, ABC):
input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()}
output_dict = apply_to_collection(
input_dict,
torch.Tensor,
Tensor,
dist_sync_fn,
group=self.process_group,
)
for attr, reduction_fn in self._reductions.items():
# pre-processing ops (stack or flatten for inputs)
if isinstance(output_dict[attr][0], torch.Tensor):
if isinstance(output_dict[attr][0], Tensor):
output_dict[attr] = torch.stack(output_dict[attr])
elif isinstance(output_dict[attr][0], list):
output_dict[attr] = _flatten(output_dict[attr])
@ -253,7 +253,7 @@ class Metric(nn.Module, ABC):
"""
for attr, default in self._defaults.items():
current_val = getattr(self, attr)
if isinstance(default, torch.Tensor):
if isinstance(default, Tensor):
setattr(self, attr, deepcopy(default).to(current_val.device))
else:
setattr(self, attr, deepcopy(default))
@ -280,14 +280,14 @@ class Metric(nn.Module, ABC):
# Also apply fn to metric states
for key in self._defaults.keys():
current_val = getattr(self, key)
if isinstance(current_val, torch.Tensor):
if isinstance(current_val, Tensor):
setattr(self, key, fn(current_val))
elif isinstance(current_val, Sequence):
setattr(self, key, [fn(cur_v) for cur_v in current_val])
else:
raise TypeError(
"Expected metric state to be either a torch.Tensor"
f"or a list of torch.Tensor, but encountered {current_val}"
"Expected metric state to be either a Tensor"
f"or a list of Tensor, but encountered {current_val}"
)
return self
@ -336,7 +336,7 @@ class Metric(nn.Module, ABC):
val = getattr(self, key)
# Special case: allow list values, so long
# as their elements are hashable
if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor):
if hasattr(val, '__iter__') and not isinstance(val, Tensor):
hash_vals.extend(val)
else:
hash_vals.append(val)
@ -444,7 +444,7 @@ class Metric(nn.Module, ABC):
return CompositionalMetric(torch.abs, self, None)
def _neg(tensor: torch.Tensor):
def _neg(tensor: Tensor):
return -torch.abs(tensor)
@ -454,8 +454,8 @@ class CompositionalMetric(Metric):
def __init__(
self,
operator: Callable,
metric_a: Union[Metric, int, float, torch.Tensor],
metric_b: Union[Metric, int, float, torch.Tensor, None],
metric_a: Union[Metric, int, float, Tensor],
metric_b: Union[Metric, int, float, Tensor, None],
):
"""
Args:
@ -470,12 +470,12 @@ class CompositionalMetric(Metric):
self.op = operator
if isinstance(metric_a, torch.Tensor):
if isinstance(metric_a, Tensor):
self.register_buffer("metric_a", metric_a)
else:
self.metric_a = metric_a
if isinstance(metric_b, torch.Tensor):
if isinstance(metric_b, Tensor):
self.register_buffer("metric_b", metric_b)
else:
self.metric_b = metric_b

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Optional
import torch
from torch import Tensor, tensor
from torchmetrics.functional.regression.explained_variance import (
_explained_variance_compute,
@ -94,13 +95,13 @@ class ExplainedVariance(Metric):
f"Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}"
)
self.multioutput = multioutput
self.add_state("sum_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_target", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_target", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_obs", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_target", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_squared_target", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_obs", default=tensor(0.0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Optional
import torch
from torch import Tensor, tensor
from torchmetrics.functional.regression.mean_absolute_error import (
_mean_absolute_error_compute,
@ -63,10 +64,10 @@ class MeanAbsoluteError(Metric):
dist_sync_fn=dist_sync_fn,
)
self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Optional
import torch
from torch import Tensor
from torchmetrics.functional.regression.mean_squared_error import (
_mean_squared_error_compute,
@ -67,7 +68,7 @@ class MeanSquaredError(Metric):
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Optional
import torch
from torch import Tensor, tensor
from torchmetrics.functional.regression.mean_squared_log_error import (
_mean_squared_log_error_compute,
@ -66,10 +67,10 @@ class MeanSquaredLogError(Metric):
dist_sync_fn=dist_sync_fn,
)
self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("sum_squared_log_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
from torchmetrics.functional.regression.psnr import _psnr_compute, _psnr_update
from torchmetrics.metric import Metric
@ -103,7 +104,7 @@ class PSNR(Metric):
self.reduction = reduction
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Optional
import torch
from torch import Tensor, tensor
from torchmetrics.functional.regression.r2score import _r2score_compute, _r2score_update
from torchmetrics.metric import Metric
@ -115,9 +116,9 @@ class R2Score(Metric):
self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.
@ -132,7 +133,7 @@ class R2Score(Metric):
self.residual += residual
self.total += total
def compute(self) -> torch.Tensor:
def compute(self) -> Tensor:
"""
Computes r2 score over the metric states.
"""

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Optional, Sequence
import torch
from torch import Tensor
from torchmetrics.functional.regression.ssim import _ssim_compute, _ssim_update
from torchmetrics.metric import Metric
@ -82,7 +83,7 @@ class SSIM(Metric):
self.k2 = k2
self.reduction = reduction
def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.

Просмотреть файл

@ -14,18 +14,19 @@
from typing import Optional, Tuple
import torch
from torch import Tensor
from torchmetrics.utilities.data import select_topk, to_onehot
from torchmetrics.utilities.enums import DataType
def _check_same_shape(pred: torch.Tensor, target: torch.Tensor):
def _check_same_shape(pred: Tensor, target: Tensor):
""" Check that predictions and target have the same shape, else raise error """
if pred.shape != target.shape:
raise RuntimeError("Predictions and targets are expected to have the same shape")
def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
def _basic_input_validation(preds: Tensor, target: Tensor, threshold: float, is_multiclass: bool):
"""
Perform basic validation of inputs that does not require deducing any information
of the type of inputs.
@ -56,7 +57,7 @@ def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold
raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.")
def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]:
def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[str, int]:
"""
This checks that the shape and type of inputs are consistent with
each other and fall into one of the allowed input types (see the
@ -139,9 +140,7 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool):
)
def _check_num_classes_mc(
preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int
):
def _check_num_classes_mc(preds: Tensor, target: Tensor, num_classes: int, is_multiclass: bool, implied_classes: int):
"""
This checks that the consistency of `num_classes` with the data
and `is_multiclass` param for (multi-dimensional) multi-class data.
@ -206,8 +205,8 @@ def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Opt
def _check_classification_inputs(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
threshold: float,
num_classes: Optional[int],
is_multiclass: bool,
@ -305,13 +304,13 @@ def _check_classification_inputs(
def _input_format_classification(
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
threshold: float = 0.5,
top_k: Optional[int] = None,
num_classes: Optional[int] = None,
is_multiclass: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor, str]:
) -> Tuple[Tensor, Tensor, str]:
"""Convert preds and target tensors into common format.
Preds and targets are supposed to fall into one of these categories (and are
@ -448,11 +447,11 @@ def _input_format_classification(
def _input_format_classification_one_hot(
num_classes: int,
preds: torch.Tensor,
target: torch.Tensor,
preds: Tensor,
target: Tensor,
threshold: float = 0.5,
multilabel: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[Tensor, Tensor]:
"""Convert preds and target tensors into one hot spare label tensors
Args:

Просмотреть файл

@ -14,6 +14,7 @@
from typing import Any, Callable, Mapping, Optional, Sequence, Union
import torch
from torch import Tensor
from torchmetrics.utilities.prints import rank_zero_warn
@ -38,9 +39,9 @@ def _flatten(x):
def to_onehot(
label_tensor: torch.Tensor,
label_tensor: Tensor,
num_classes: Optional[int] = None,
) -> torch.Tensor:
) -> Tensor:
"""
Converts a dense label tensor to one-hot format
@ -74,7 +75,7 @@ def to_onehot(
return tensor_onehot.scatter_(1, index, 1.0)
def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor:
def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor:
"""
Convert a probability tensor to binary by selecting top-k highest entries.
@ -99,7 +100,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch
return topk_tensor.int()
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
def to_categorical(tensor: Tensor, argmax_dim: int = 1) -> Tensor:
"""
Converts a tensor of probabilities to a dense label tensor
@ -121,8 +122,8 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
def get_num_classes(
pred: torch.Tensor,
target: torch.Tensor,
pred: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
) -> int:
"""
@ -202,7 +203,7 @@ def apply_to_collection(
the resulting collection
Example:
>>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=torch.Tensor, function=lambda x: x ** 2)
>>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=Tensor, function=lambda x: x ** 2)
tensor([64, 0, 4, 36, 49])
>>> apply_to_collection([8, 0, 2, 6, 7], dtype=int, function=lambda x: x ** 2)
[64, 0, 4, 36, 49]

Просмотреть файл

@ -14,9 +14,10 @@
from typing import Any, Optional, Union
import torch
from torch import Tensor
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
def reduce(to_reduce: Tensor, reduction: str) -> Tensor:
"""
Reduces a given tensor by a given reduction method
@ -39,9 +40,7 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
raise ValueError("Reduction parameter unknown.")
def class_reduce(
num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none"
) -> torch.Tensor:
def class_reduce(num: Tensor, denom: Tensor, weights: Tensor, class_reduction: str = "none") -> Tensor:
"""
Function used to reduce classification metrics of the form `num / denom * weights`.
For example for calculating standard accuracy the num would be number of
@ -85,7 +84,7 @@ def class_reduce(
)
def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None):
def gather_all_tensors(result: Union[Tensor], group: Optional[Any] = None):
"""
Function to gather all tensors from several ddp processes onto a list that
is broadcasted to all processes

Просмотреть файл

@ -0,0 +1,7 @@
from distutils.version import LooseVersion
import torch
_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")
_TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0")