hotfix: skip unsupported metric compostions (PL^5664)

* skip PT unsupported compositions

* flake8

(cherry picked from commit 803bd25cce5998dc6cbf66befeb204aaf6819bd9)
This commit is contained in:
Jirka Borovec 2021-01-26 20:24:15 +01:00 коммит произвёл Jirka Borovec
Родитель ff01a0b9f2
Коммит 7904df201b
1 изменённых файлов: 22 добавлений и 6 удалений

Просмотреть файл

@ -1,3 +1,4 @@
from distutils.version import LooseVersion
from operator import neg, pos
import pytest
@ -6,6 +7,11 @@ import torch
from pytorch_lightning.metrics.compositional import CompositionalMetric
from pytorch_lightning.metrics.metric import Metric
_MARK_TORCH_LOWER_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
reason='required PT >= 1.5')
_MARK_TORCH_LOWER_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason='required PT >= 1.6')
class DummyMetric(Metric):
def __init__(self, val_to_return):
@ -50,6 +56,7 @@ def test_metrics_add(second_operand, expected_result):
["second_operand", "expected_result"],
[(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_and(second_operand, expected_result):
first_metric = DummyMetric(2)
@ -92,6 +99,7 @@ def test_metrics_eq(second_operand, expected_result):
(torch.tensor(2), torch.tensor(2)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_floordiv(second_operand, expected_result):
first_metric = DummyMetric(5)
@ -261,6 +269,7 @@ def test_metrics_ne(second_operand, expected_result):
["second_operand", "expected_result"],
[(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_or(second_operand, expected_result):
first_metric = DummyMetric([-1, -2, 3])
@ -277,10 +286,10 @@ def test_metrics_or(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(4)),
(2, torch.tensor(4)),
(2.0, torch.tensor(4.0)),
(torch.tensor(2), torch.tensor(4)),
pytest.param(DummyMetric(2), torch.tensor(4)),
pytest.param(2, torch.tensor(4)),
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
pytest.param(torch.tensor(2), torch.tensor(4)),
],
)
def test_metrics_pow(second_operand, expected_result):
@ -297,6 +306,7 @@ def test_metrics_pow(second_operand, expected_result):
["first_operand", "expected_result"],
[(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_rfloordiv(first_operand, expected_result):
second_operand = DummyMetric(2)
@ -329,8 +339,12 @@ def test_metrics_rmod(first_operand, expected_result):
@pytest.mark.parametrize(
["first_operand", "expected_result"],
[(DummyMetric(2), torch.tensor(4)), (2, torch.tensor(4)), (2.0, torch.tensor(4.0))],
"first_operand,expected_result",
[
pytest.param(DummyMetric(2), torch.tensor(4)),
pytest.param(2, torch.tensor(4)),
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
],
)
def test_metrics_rpow(first_operand, expected_result):
second_operand = DummyMetric(2)
@ -370,6 +384,7 @@ def test_metrics_rsub(first_operand, expected_result):
(torch.tensor(6), torch.tensor(2.0)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_rtruediv(first_operand, expected_result):
second_operand = DummyMetric(3)
@ -408,6 +423,7 @@ def test_metrics_sub(second_operand, expected_result):
(torch.tensor(3), torch.tensor(2.0)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_truediv(second_operand, expected_result):
first_metric = DummyMetric(6)