Skip to content

Commit

Permalink
Modified tests of nlinalg in pytorch implementation
Browse files Browse the repository at this point in the history
Replaced instances using Blockwise by the Op constructor.
  • Loading branch information
twaclaw committed Jul 12, 2024
1 parent 4e1391c commit 2854c37
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 20 deletions.
4 changes: 3 additions & 1 deletion pytensor/link/pytorch/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def pytorch_funcify_SVD(op, **kwargs):

def svd(x):
U, S, V = torch.linalg.svd(x, full_matrices=full_matrices)

Check warning on line 23 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L23

Added line #L23 was not covered by tests
return U, S, V if compute_uv else S
if compute_uv:
return U, S, V
return S

Check warning on line 26 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L25-L26

Added lines #L25 - L26 were not covered by tests

return svd

Expand Down
27 changes: 8 additions & 19 deletions tests/link/pytorch/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,13 @@ def matrix_test():

@pytest.mark.parametrize(
"func",
(
pt_nla.eig,
pt_nla.eigh,
pt_nla.slogdet,
pytest.param(
pt_nla.inv, marks=pytest.mark.xfail(reason="Blockwise not implemented")
),
pytest.param(
pt_nla.det, marks=pytest.mark.xfail(reason="Blockwise not implemented")
),
),
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.MatrixInverse(), pt_nla.Det()),
)
def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test

outs = func(x)
out_fg = FunctionGraph([x], outs)
out = func(x)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])

def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
Expand All @@ -58,18 +48,17 @@ def assert_fn(x, y):
def test_qr(mode, matrix_test):
x, test_value = matrix_test
outs = pt_nla.qr(x, mode=mode)
out_fg = FunctionGraph([x], [outs] if mode == "r" else outs)
out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs])
compare_pytorch_and_py(out_fg, [test_value])


@pytest.mark.xfail(reason="Blockwise not implemented")
@pytest.mark.parametrize("compute_uv", [False, True])
@pytest.mark.parametrize("full_matrices", [False, True])
@pytest.mark.parametrize("compute_uv", [True, False])
@pytest.mark.parametrize("full_matrices", [True, False])
def test_svd(compute_uv, full_matrices, matrix_test):
x, test_value = matrix_test

outs = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
out_fg = FunctionGraph([x], outs)
out = pt_nla.SVD(full_matrices=full_matrices, compute_uv=compute_uv)(x)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])

def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
Expand Down

0 comments on commit 2854c37

Please sign in to comment.