diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index baa13cde87..802ca6e543 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -797,7 +797,9 @@ def make_node(self, A, B): def perform(self, node, inputs, output_storage): (A, B) = inputs X = output_storage[0] - X[0] = scipy.linalg.solve_continuous_lyapunov(A, B) + + out_dtype = node.outputs[0].type.dtype + X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -866,7 +868,10 @@ def perform(self, node, inputs, output_storage): (A, B) = inputs X = output_storage[0] - X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear") + out_dtype = node.outputs[0].type.dtype + X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype( + out_dtype + ) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -964,11 +969,8 @@ class SolveDiscreteARE(Op): __props__ = ("enforce_Q_symmetric",) gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)" - def __init__( - self, enforce_Q_symmetric: bool = False, use_bilinear_lyapunov: bool = True - ): + def __init__(self, enforce_Q_symmetric: bool = False): self.enforce_Q_symmetric = enforce_Q_symmetric - self.use_bilinear_lyapunov = use_bilinear_lyapunov def make_node(self, A, B, Q, R): A = as_tensor_variable(A) @@ -988,7 +990,8 @@ def perform(self, node, inputs, output_storage): if self.enforce_Q_symmetric: Q = 0.5 * (Q + Q.T) - X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R) + out_dtype = node.outputs[0].type.dtype + X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype) def infer_shape(self, fgraph, node, shapes): return [shapes[0]] @@ -1000,16 +1003,16 @@ def grad(self, inputs, output_grads): (dX,) = output_grads X = self(A, B, Q, R) - K_inner = R + pt.linalg.matrix_dot(B.T, X, B) + K_inner = R + matrix_dot(B.T, X, B) # K_inner is guaranteed to be symmetric, because X and R are symmetric - K_inner_inv_BT = pt.linalg.solve(K_inner, B.T, assume_a="sym") + K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym") K = matrix_dot(K_inner_inv_BT, X, A) A_tilde = A - B.dot(K) dX_symm = 0.5 * (dX + dX.T) - S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype) + S = solve_discrete_lyapunov(A_tilde, dX_symm) A_bar = 2 * matrix_dot(X, A_tilde, S) B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)