diff --git a/tests/metrics/test_composition.py b/tests/metrics/test_composition.py index 087dee521d6941..a9bba7d7fac7da 100644 --- a/tests/metrics/test_composition.py +++ b/tests/metrics/test_composition.py @@ -1,3 +1,4 @@ +from distutils.version import LooseVersion from operator import neg, pos import pytest @@ -6,6 +7,11 @@ from pytorch_lightning.metrics.compositional import CompositionalMetric from pytorch_lightning.metrics.metric import Metric +_MARK_TORCH_LOWER_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"), + reason='required PT >= 1.5') +_MARK_TORCH_LOWER_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"), + reason='required PT >= 1.6') + class DummyMetric(Metric): def __init__(self, val_to_return): @@ -50,6 +56,7 @@ def test_metrics_add(second_operand, expected_result): ["second_operand", "expected_result"], [(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))], ) +@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4) def test_metrics_and(second_operand, expected_result): first_metric = DummyMetric(2) @@ -92,6 +99,7 @@ def test_metrics_eq(second_operand, expected_result): (torch.tensor(2), torch.tensor(2)), ], ) +@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4) def test_metrics_floordiv(second_operand, expected_result): first_metric = DummyMetric(5) @@ -261,6 +269,7 @@ def test_metrics_ne(second_operand, expected_result): ["second_operand", "expected_result"], [(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))], ) +@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4) def test_metrics_or(second_operand, expected_result): first_metric = DummyMetric([-1, -2, 3]) @@ -277,10 +286,10 @@ def test_metrics_or(second_operand, expected_result): @pytest.mark.parametrize( ["second_operand", "expected_result"], [ - (DummyMetric(2), torch.tensor(4)), - (2, torch.tensor(4)), - (2.0, torch.tensor(4.0)), - (torch.tensor(2), torch.tensor(4)), + pytest.param(DummyMetric(2), torch.tensor(4)), + pytest.param(2, torch.tensor(4)), + pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)), + pytest.param(torch.tensor(2), torch.tensor(4)), ], ) def test_metrics_pow(second_operand, expected_result): @@ -297,6 +306,7 @@ def test_metrics_pow(second_operand, expected_result): ["first_operand", "expected_result"], [(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))], ) +@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4) def test_metrics_rfloordiv(first_operand, expected_result): second_operand = DummyMetric(2) @@ -329,8 +339,12 @@ def test_metrics_rmod(first_operand, expected_result): @pytest.mark.parametrize( - ["first_operand", "expected_result"], - [(DummyMetric(2), torch.tensor(4)), (2, torch.tensor(4)), (2.0, torch.tensor(4.0))], + "first_operand,expected_result", + [ + pytest.param(DummyMetric(2), torch.tensor(4)), + pytest.param(2, torch.tensor(4)), + pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)), + ], ) def test_metrics_rpow(first_operand, expected_result): second_operand = DummyMetric(2) @@ -370,6 +384,7 @@ def test_metrics_rsub(first_operand, expected_result): (torch.tensor(6), torch.tensor(2.0)), ], ) +@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4) def test_metrics_rtruediv(first_operand, expected_result): second_operand = DummyMetric(3) @@ -408,6 +423,7 @@ def test_metrics_sub(second_operand, expected_result): (torch.tensor(3), torch.tensor(2.0)), ], ) +@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4) def test_metrics_truediv(second_operand, expected_result): first_metric = DummyMetric(6)