From 15c06ff39c059bc57976333a4f42f55b979281dd Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 10 Jan 2025 19:36:47 +0100 Subject: [PATCH] Use actual Solve Op to infer output dtype CholSolve outputs a different dtype than basic Solve in Scipy==1.15 --- pytensor/tensor/slinalg.py | 7 ++++--- tests/tensor/test_slinalg.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 4904259d25..325567918a 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -259,9 +259,10 @@ def make_node(self, A, b): raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.") # Infer dtype by solving the most simple case with 1x1 matrices - o_dtype = scipy.linalg.solve( - np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype) - ).dtype + inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)] + out_arr = [[None]] + self.perform(None, inp_arr, out_arr) + o_dtype = out_arr[0][0].dtype x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 3d4b6697b8..380b88fffe 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -450,7 +450,7 @@ def test_solve_dtype(self): fn = function([A, b], x) x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype)) - assert x.dtype == x_result.dtype + assert x.dtype == x_result.dtype, (A_dtype, b_dtype) def test_cho_solve():