-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement nlinalg Ops in PyTorch (#920)
- Loading branch information
Showing
3 changed files
with
215 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import torch | ||
|
||
from pytensor.link.pytorch.dispatch import pytorch_funcify | ||
from pytensor.tensor.nlinalg import ( | ||
SVD, | ||
Det, | ||
Eig, | ||
Eigh, | ||
KroneckerProduct, | ||
MatrixInverse, | ||
MatrixPinv, | ||
QRFull, | ||
SLogDet, | ||
) | ||
|
||
|
||
@pytorch_funcify.register(SVD) | ||
def pytorch_funcify_SVD(op, **kwargs): | ||
full_matrices = op.full_matrices | ||
compute_uv = op.compute_uv | ||
|
||
def svd(x): | ||
U, S, V = torch.linalg.svd(x, full_matrices=full_matrices) | ||
if compute_uv: | ||
return U, S, V | ||
return S | ||
|
||
return svd | ||
|
||
|
||
@pytorch_funcify.register(Det) | ||
def pytorch_funcify_Det(op, **kwargs): | ||
def det(x): | ||
return torch.linalg.det(x) | ||
|
||
return det | ||
|
||
|
||
@pytorch_funcify.register(SLogDet) | ||
def pytorch_funcify_SLogDet(op, **kwargs): | ||
def slogdet(x): | ||
return torch.linalg.slogdet(x) | ||
|
||
return slogdet | ||
|
||
|
||
@pytorch_funcify.register(Eig) | ||
def pytorch_funcify_Eig(op, **kwargs): | ||
def eig(x): | ||
return torch.linalg.eig(x) | ||
|
||
return eig | ||
|
||
|
||
@pytorch_funcify.register(Eigh) | ||
def pytorch_funcify_Eigh(op, **kwargs): | ||
uplo = op.UPLO | ||
|
||
def eigh(x, uplo=uplo): | ||
return torch.linalg.eigh(x, UPLO=uplo) | ||
|
||
return eigh | ||
|
||
|
||
@pytorch_funcify.register(MatrixInverse) | ||
def pytorch_funcify_MatrixInverse(op, **kwargs): | ||
def matrix_inverse(x): | ||
return torch.linalg.inv(x) | ||
|
||
return matrix_inverse | ||
|
||
|
||
@pytorch_funcify.register(QRFull) | ||
def pytorch_funcify_QRFull(op, **kwargs): | ||
mode = op.mode | ||
if mode == "raw": | ||
raise NotImplementedError("raw mode not implemented in PyTorch") | ||
|
||
def qr_full(x): | ||
Q, R = torch.linalg.qr(x, mode=mode) | ||
if mode == "r": | ||
return R | ||
return Q, R | ||
|
||
return qr_full | ||
|
||
|
||
@pytorch_funcify.register(MatrixPinv) | ||
def pytorch_funcify_Pinv(op, **kwargs): | ||
hermitian = op.hermitian | ||
|
||
def pinv(x): | ||
return torch.linalg.pinv(x, hermitian=hermitian) | ||
|
||
return pinv | ||
|
||
|
||
@pytorch_funcify.register(KroneckerProduct) | ||
def pytorch_funcify_KroneckerProduct(op, **kwargs): | ||
def _kron(x, y): | ||
return torch.kron(x, y) | ||
|
||
return _kron |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pytensor.compile.function import function | ||
from pytensor.configdefaults import config | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.tensor import nlinalg as pt_nla | ||
from pytensor.tensor.type import matrix | ||
from tests.link.pytorch.test_basic import compare_pytorch_and_py | ||
|
||
|
||
@pytest.fixture | ||
def matrix_test(): | ||
rng = np.random.default_rng(213234) | ||
|
||
M = rng.normal(size=(3, 3)) | ||
test_value = M.dot(M.T).astype(config.floatX) | ||
|
||
x = matrix("x") | ||
return (x, test_value) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"func", | ||
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det), | ||
) | ||
def test_lin_alg_no_params(func, matrix_test): | ||
x, test_value = matrix_test | ||
|
||
out = func(x) | ||
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) | ||
|
||
def assert_fn(x, y): | ||
np.testing.assert_allclose(x, y, rtol=1e-3) | ||
|
||
compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"mode", | ||
( | ||
"complete", | ||
"reduced", | ||
"r", | ||
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)), | ||
), | ||
) | ||
def test_qr(mode, matrix_test): | ||
x, test_value = matrix_test | ||
outs = pt_nla.qr(x, mode=mode) | ||
out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs]) | ||
compare_pytorch_and_py(out_fg, [test_value]) | ||
|
||
|
||
@pytest.mark.parametrize("compute_uv", [True, False]) | ||
@pytest.mark.parametrize("full_matrices", [True, False]) | ||
def test_svd(compute_uv, full_matrices, matrix_test): | ||
x, test_value = matrix_test | ||
|
||
out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) | ||
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) | ||
|
||
compare_pytorch_and_py(out_fg, [test_value]) | ||
|
||
|
||
def test_pinv(): | ||
x = matrix("x") | ||
x_inv = pt_nla.pinv(x) | ||
|
||
fgraph = FunctionGraph([x], [x_inv]) | ||
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) | ||
compare_pytorch_and_py(fgraph, [x_np]) | ||
|
||
|
||
@pytest.mark.parametrize("hermitian", [False, True]) | ||
def test_pinv_hermitian(hermitian): | ||
A = matrix("A", dtype="complex128") | ||
A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]] | ||
A_not_h_test = A_h_test + 0 + 1j | ||
|
||
A_inv = pt_nla.pinv(A, hermitian=hermitian) | ||
torch_fn = function([A], A_inv, mode="PYTORCH") | ||
|
||
assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False)) | ||
assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True)) | ||
|
||
assert ( | ||
np.allclose( | ||
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False) | ||
) | ||
is not hermitian | ||
) | ||
|
||
assert ( | ||
np.allclose( | ||
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True) | ||
) | ||
is hermitian | ||
) | ||
|
||
|
||
def test_kron(): | ||
x = matrix("x") | ||
y = matrix("y") | ||
z = pt_nla.kron(x, y) | ||
|
||
fgraph = FunctionGraph([x, y], [z]) | ||
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) | ||
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) | ||
|
||
compare_pytorch_and_py(fgraph, [x_np, y_np]) |