Skip to content

Commit

Permalink
Don't manually set dtype of output
Browse files Browse the repository at this point in the history
Revert change to `_solve_discrete_lyapunov`
  • Loading branch information
jessegrabowski committed Oct 24, 2024
1 parent fb35d92 commit cb809c1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 37 deletions.
7 changes: 2 additions & 5 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
register_stabilize,
)
from pytensor.tensor.slinalg import (
BilinearSolveDiscreteLyapunov,
BlockDiagonal,
Cholesky,
Solve,
SolveBase,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
solve,
Expand Down Expand Up @@ -972,14 +972,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return [eye_input * (non_eye_input**0.5)]


@node_rewriter([Blockwise])
@node_rewriter([_bilinear_solve_discrete_lyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
"""
if not isinstance(node.op.core_op, BilinearSolveDiscreteLyapunov):
return None

A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")

Expand Down
63 changes: 31 additions & 32 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,7 @@ def make_node(self, A, B):
def perform(self, node, inputs, output_storage):
(A, B) = inputs
X = output_storage[0]
out_dtype = node.outputs[0].type.dtype

X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand All @@ -820,6 +818,30 @@ def grad(self, inputs, output_grads):
return [A_bar, Q_bar]


_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())


def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
"""

return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))


class BilinearSolveDiscreteLyapunov(Op):
"""
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
Expand All @@ -844,10 +866,7 @@ def perform(self, node, inputs, output_storage):
(A, B) = inputs
X = output_storage[0]

out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype
)
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand All @@ -869,6 +888,9 @@ def grad(self, inputs, output_grads):
return [A_bar, Q_bar]


_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov())


def _direct_solve_discrete_lyapunov(
A: TensorVariable, Q: TensorVariable
) -> TensorVariable:
Expand Down Expand Up @@ -932,33 +954,12 @@ def solve_discrete_lyapunov(
return cast(TensorVariable, X)

elif method == "bilinear":
return cast(TensorVariable, Blockwise(BilinearSolveDiscreteLyapunov())(A, Q))
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q))

else:
raise ValueError(f"Unknown method {method}")


def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.
Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``
"""

return cast(TensorVariable, Blockwise(SolveContinuousLyapunov())(A, Q))


class SolveDiscreteARE(Op):
__props__ = ("enforce_Q_symmetric",)
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
Expand Down Expand Up @@ -987,9 +988,7 @@ def perform(self, node, inputs, output_storage):
if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T)

X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(
node.outputs[0].type.dtype
)
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand Down

0 comments on commit cb809c1

Please sign in to comment.