Skip to content

Commit

Permalink
Skip functorch tests for PyTorch 1.x
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jan 5, 2025
1 parent e8142ab commit 05a45cb
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/test_model/test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest
import torch
import torch.nn.functional as F
from tad_mctc._version import __tversion__
from tad_mctc.autograd import jacrev
from tad_mctc.batch import pack
from tad_mctc.ncoord import cn_d4
Expand Down Expand Up @@ -152,6 +153,7 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None:
assert pytest.approx(gwvec.cpu(), abs=tol) == ref.cpu()


@pytest.mark.skipif(__tversion__ < (2, 0, 0), reason="Requires torch>=2.0.0")
@pytest.mark.parametrize("name", ["LiH", "SiH4", "MB16_43_03"])
def test_grad_q(name: str) -> None:
dd: DD = {"device": DEVICE, "dtype": torch.float64}
Expand All @@ -177,6 +179,7 @@ def test_grad_q(name: str) -> None:
assert pytest.approx(dgwdq_auto.cpu(), abs=1e-6) == dgwdq_ana.cpu()


@pytest.mark.skipif(__tversion__ < (2, 0, 0), reason="Requires torch>=2.0.0")
@pytest.mark.parametrize("name", ["LiH", "SiH4", "MB16_43_03"])
def test_grad_cn(name: str) -> None:
dd: DD = {"device": DEVICE, "dtype": torch.float64}
Expand All @@ -202,6 +205,7 @@ def test_grad_cn(name: str) -> None:
assert pytest.approx(dgwdcn_auto.cpu(), abs=1e-6) == -dgwdq_ana.cpu()


@pytest.mark.skipif(__tversion__ < (2, 0, 0), reason="Requires torch>=2.0.0")
@pytest.mark.parametrize("name", ["LiH", "SiH4", "MB16_43_03"])
def test_grad_both(name: str) -> None:
dd: DD = {"device": DEVICE, "dtype": torch.float64}
Expand Down

0 comments on commit 05a45cb

Please sign in to comment.