Skip to content

Commit

Permalink
Merge branch 'release/1.2-dev' into refactor/legacy-accel-plug
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Jan 26, 2021
2 parents 9edf083 + 86d905c commit 34cceb7
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions tests/metrics/test_composition.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from distutils.version import LooseVersion
from operator import neg, pos

import pytest
Expand All @@ -6,6 +7,11 @@
from pytorch_lightning.metrics.compositional import CompositionalMetric
from pytorch_lightning.metrics.metric import Metric

_MARK_TORCH_LOWER_1_4 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
reason='required PT >= 1.5')
_MARK_TORCH_LOWER_1_5 = dict(condition=LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason='required PT >= 1.6')


class DummyMetric(Metric):
def __init__(self, val_to_return):
Expand Down Expand Up @@ -50,6 +56,7 @@ def test_metrics_add(second_operand, expected_result):
["second_operand", "expected_result"],
[(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_and(second_operand, expected_result):
first_metric = DummyMetric(2)

Expand Down Expand Up @@ -92,6 +99,7 @@ def test_metrics_eq(second_operand, expected_result):
(torch.tensor(2), torch.tensor(2)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_floordiv(second_operand, expected_result):
first_metric = DummyMetric(5)

Expand Down Expand Up @@ -261,6 +269,7 @@ def test_metrics_ne(second_operand, expected_result):
["second_operand", "expected_result"],
[(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_or(second_operand, expected_result):
first_metric = DummyMetric([-1, -2, 3])

Expand All @@ -277,10 +286,10 @@ def test_metrics_or(second_operand, expected_result):
@pytest.mark.parametrize(
["second_operand", "expected_result"],
[
(DummyMetric(2), torch.tensor(4)),
(2, torch.tensor(4)),
(2.0, torch.tensor(4.0)),
(torch.tensor(2), torch.tensor(4)),
pytest.param(DummyMetric(2), torch.tensor(4)),
pytest.param(2, torch.tensor(4)),
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
pytest.param(torch.tensor(2), torch.tensor(4)),
],
)
def test_metrics_pow(second_operand, expected_result):
Expand All @@ -297,6 +306,7 @@ def test_metrics_pow(second_operand, expected_result):
["first_operand", "expected_result"],
[(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_rfloordiv(first_operand, expected_result):
second_operand = DummyMetric(2)

Expand Down Expand Up @@ -329,8 +339,12 @@ def test_metrics_rmod(first_operand, expected_result):


@pytest.mark.parametrize(
["first_operand", "expected_result"],
[(DummyMetric(2), torch.tensor(4)), (2, torch.tensor(4)), (2.0, torch.tensor(4.0))],
"first_operand,expected_result",
[
pytest.param(DummyMetric(2), torch.tensor(4)),
pytest.param(2, torch.tensor(4)),
pytest.param(2.0, torch.tensor(4.0), marks=pytest.mark.skipif(**_MARK_TORCH_LOWER_1_5)),
],
)
def test_metrics_rpow(first_operand, expected_result):
second_operand = DummyMetric(2)
Expand Down Expand Up @@ -370,6 +384,7 @@ def test_metrics_rsub(first_operand, expected_result):
(torch.tensor(6), torch.tensor(2.0)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_rtruediv(first_operand, expected_result):
second_operand = DummyMetric(3)

Expand Down Expand Up @@ -408,6 +423,7 @@ def test_metrics_sub(second_operand, expected_result):
(torch.tensor(3), torch.tensor(2.0)),
],
)
@pytest.mark.skipif(**_MARK_TORCH_LOWER_1_4)
def test_metrics_truediv(second_operand, expected_result):
first_metric = DummyMetric(6)

Expand Down

0 comments on commit 34cceb7

Please sign in to comment.