Skip to content

Commit

Permalink
Set dtype of Op outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Oct 24, 2024
1 parent cb809c1 commit 89d5fd0
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,9 @@ def make_node(self, A, B):
def perform(self, node, inputs, output_storage):
(A, B) = inputs
X = output_storage[0]
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)

out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand Down Expand Up @@ -866,7 +868,10 @@ def perform(self, node, inputs, output_storage):
(A, B) = inputs
X = output_storage[0]

X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype
)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand Down Expand Up @@ -964,11 +969,8 @@ class SolveDiscreteARE(Op):
__props__ = ("enforce_Q_symmetric",)
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"

def __init__(
self, enforce_Q_symmetric: bool = False, use_bilinear_lyapunov: bool = True
):
def __init__(self, enforce_Q_symmetric: bool = False):
self.enforce_Q_symmetric = enforce_Q_symmetric
self.use_bilinear_lyapunov = use_bilinear_lyapunov

def make_node(self, A, B, Q, R):
A = as_tensor_variable(A)
Expand All @@ -988,7 +990,8 @@ def perform(self, node, inputs, output_storage):
if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T)

X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R)
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand All @@ -1000,16 +1003,16 @@ def grad(self, inputs, output_grads):
(dX,) = output_grads
X = self(A, B, Q, R)

K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
K_inner = R + matrix_dot(B.T, X, B)

# K_inner is guaranteed to be symmetric, because X and R are symmetric
K_inner_inv_BT = pt.linalg.solve(K_inner, B.T, assume_a="sym")
K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym")
K = matrix_dot(K_inner_inv_BT, X, A)

A_tilde = A - B.dot(K)

dX_symm = 0.5 * (dX + dX.T)
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)
S = solve_discrete_lyapunov(A_tilde, dX_symm)

A_bar = 2 * matrix_dot(X, A_tilde, S)
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
Expand Down

0 comments on commit 89d5fd0

Please sign in to comment.