hotfix: skip unsupported metric compostions (PL^5664)
* skip PT unsupported compositions * flake8 (cherry picked from commit 803bd25cce5998dc6cbf66befeb204aaf6819bd9)
This commit is contained in:
Родитель
ff01a0b9f2
Коммит
7904df201b
|
@ -1,3 +1,4 @@
|
|||
from distutils.version import LooseVersion
|
||||
from operator import neg, pos
|
||||
|
||||
import pytest
|
||||
|
@ -6,6 +7,11 @@ import torch
|
|||
from pytorch_lightning.metrics.compositional import CompositionalMetric
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
_MARK_TORCH_LOWER_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
|
||||
reason='required PT >= 1.5')
|
||||
_MARK_TORCH_LOWER_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
|
||||
reason='required PT >= 1.6')
|
||||
|
||||
|
||||
class DummyMetric(Metric):
|
||||
def __init__(self, val_to_return):
|
||||
|
@ -50,6 +56,7 @@ def test_metrics_add(second_operand, expected_result):
|
|||
["second_operand", "expected_result"],
|
||||
[(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
|
||||
def test_metrics_and(second_operand, expected_result):
|
||||
first_metric = DummyMetric(2)
|
||||
|
||||
|
@ -92,6 +99,7 @@ def test_metrics_eq(second_operand, expected_result):
|
|||
(torch.tensor(2), torch.tensor(2)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
|
||||
def test_metrics_floordiv(second_operand, expected_result):
|
||||
first_metric = DummyMetric(5)
|
||||
|
||||
|
@ -261,6 +269,7 @@ def test_metrics_ne(second_operand, expected_result):
|
|||
["second_operand", "expected_result"],
|
||||
[(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
|
||||
def test_metrics_or(second_operand, expected_result):
|
||||
first_metric = DummyMetric([-1, -2, 3])
|
||||
|
||||
|
@ -277,10 +286,10 @@ def test_metrics_or(second_operand, expected_result):
|
|||
@pytest.mark.parametrize(
|
||||
["second_operand", "expected_result"],
|
||||
[
|
||||
(DummyMetric(2), torch.tensor(4)),
|
||||
(2, torch.tensor(4)),
|
||||
(2.0, torch.tensor(4.0)),
|
||||
(torch.tensor(2), torch.tensor(4)),
|
||||
pytest.param(DummyMetric(2), torch.tensor(4)),
|
||||
pytest.param(2, torch.tensor(4)),
|
||||
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
|
||||
pytest.param(torch.tensor(2), torch.tensor(4)),
|
||||
],
|
||||
)
|
||||
def test_metrics_pow(second_operand, expected_result):
|
||||
|
@ -297,6 +306,7 @@ def test_metrics_pow(second_operand, expected_result):
|
|||
["first_operand", "expected_result"],
|
||||
[(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
|
||||
def test_metrics_rfloordiv(first_operand, expected_result):
|
||||
second_operand = DummyMetric(2)
|
||||
|
||||
|
@ -329,8 +339,12 @@ def test_metrics_rmod(first_operand, expected_result):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["first_operand", "expected_result"],
|
||||
[(DummyMetric(2), torch.tensor(4)), (2, torch.tensor(4)), (2.0, torch.tensor(4.0))],
|
||||
"first_operand,expected_result",
|
||||
[
|
||||
pytest.param(DummyMetric(2), torch.tensor(4)),
|
||||
pytest.param(2, torch.tensor(4)),
|
||||
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
|
||||
],
|
||||
)
|
||||
def test_metrics_rpow(first_operand, expected_result):
|
||||
second_operand = DummyMetric(2)
|
||||
|
@ -370,6 +384,7 @@ def test_metrics_rsub(first_operand, expected_result):
|
|||
(torch.tensor(6), torch.tensor(2.0)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
|
||||
def test_metrics_rtruediv(first_operand, expected_result):
|
||||
second_operand = DummyMetric(3)
|
||||
|
||||
|
@ -408,6 +423,7 @@ def test_metrics_sub(second_operand, expected_result):
|
|||
(torch.tensor(3), torch.tensor(2.0)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
|
||||
def test_metrics_truediv(second_operand, expected_result):
|
||||
first_metric = DummyMetric(6)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче