-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup for Optimal Control Ops #1045
Conversation
JAX backend doesn't like my new batch-compatible
I initially tried just wrapping the core case in |
May need something for Reshape like we do for the size argument of RVs (check size_tuple Op or whatever it is) |
I fixed the JAX problem by doing what I should have been doing in the first place -- implementing a core case then using Blockwise on it. Unfortunately, this had the side effect of breaking gradients for the regular pytensor backend. Something is failing in rewrites; I think a useless blockwise isn't getting rewritten away. |
Sounds like progress, can you show the gradient graph that's failing? |
Graph: Sum{axes=None} [id A] <Scalar(float64, shape=())> 10
└─ Mul [id B] <Tensor3(float64, shape=(5, ?, ?))> 9
├─ random_projection [id C] <Tensor3(float64, shape=(?, ?, ?))>
└─ SpecifyShape [id D] <Tensor3(float64, shape=(5, ?, ?))> 8
├─ Reshape{3} [id E] <Tensor3(float64, shape=(?, ?, ?))> 7
│ ├─ Blockwise{Solve{assume_a='gen', lower=False, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)} [id F] <Matrix(float64, shape=(5, ?))> 6
│ │ ├─ Sub [id G] <Tensor3(float64, shape=(5, ?, ?))> 5
│ │ │ ├─ ExpandDims{axis=0} [id H] <Tensor3(float64, shape=(1, ?, ?))> 4
│ │ │ │ └─ Eye{dtype='float64'} [id I] <Matrix(float64, shape=(?, ?))> 3
│ │ │ │ ├─ Shape_i{2} [id J] <Scalar(int64, shape=())> 2
│ │ │ │ │ └─ Blockwise{KroneckerProduct{inline=False}, (i00,i01),(i10,i11)->(o00,o01)} [id K] <Tensor3(float64, shape=(5, ?, ?))> 1
│ │ │ │ │ ├─ input 0 [id L] <Tensor3(float64, shape=(5, 5, 5))>
│ │ │ │ │ └─ input 0 [id L] <Tensor3(float64, shape=(5, 5, 5))>
│ │ │ │ ├─ Shape_i{2} [id J] <Scalar(int64, shape=())> 2
│ │ │ │ │ └─ ···
│ │ │ │ └─ 0 [id M] <Scalar(int8, shape=())>
│ │ │ └─ Blockwise{KroneckerProduct{inline=False}, (i00,i01),(i10,i11)->(o00,o01)} [id K] <Tensor3(float64, shape=(5, ?, ?))> 1
│ │ │ └─ ···
│ │ └─ Reshape{2} [id N] <Matrix(float64, shape=(?, ?))> 0
│ │ ├─ input 1 [id O] <Tensor3(float64, shape=(5, 5, 5))>
│ │ └─ [ 5 -1] [id P] <Vector(int64, shape=(2,))>
│ └─ [5 5 5] [id Q] <Vector(int64, shape=(3,))>
├─ 5 [id R] <Scalar(int8, shape=())>
├─ NoneConst{None} [id S] <NoneTypeT>
└─ NoneConst{None} [id S] <NoneTypeT> Error: ERROR pytensor.graph.rewriting.basic:basic.py:1746 Rewrite failure due to: local_blockwise_alloc
ERROR pytensor.graph.rewriting.basic:basic.py:1747 node: Blockwise{Reshape{1}, (i00,i01),(i10)->(o00)}(SpecifyShape.0, Alloc.0)
ERROR pytensor.graph.rewriting.basic:basic.py:1748 TRACEBACK:
ERROR pytensor.graph.rewriting.basic:basic.py:1749 Traceback (most recent call last):
File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/graph/rewriting/basic.py", line 1909, in process_node
replacements = node_rewriter.transform(fgraph, node)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/graph/rewriting/basic.py", line 1081, in transform
return self.fn(fgraph, node)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/tensor/rewriting/blockwise.py", line 176, in local_blockwise_alloc
new_outs = node.op.make_node(*new_inputs).outputs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/tensor/blockwise.py", line 130, in make_node
core_node = self._create_dummy_core_node(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/tensor/blockwise.py", line 103, in _create_dummy_core_node
raise ValueError(
ValueError: Input 1 DropDims{axis=0}.0 has insufficient core dimensions for signature (i00,i01),(i10)->(o00) |
Looking at the rewrite bug |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1045 +/- ##
==========================================
+ Coverage 81.90% 81.93% +0.02%
==========================================
Files 182 182
Lines 47872 47890 +18
Branches 8617 8617
==========================================
+ Hits 39210 39239 +29
+ Misses 6489 6481 -8
+ Partials 2173 2170 -3
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small point about the casting in the perform
method. Otherwise looks great
I had to go back to tracking
I found though that the |
I'm confused. The make_node of an Op has flexibility (and reponsability) to create variables of any type it wants, it's not frozen unless you wrote it like that? So you shouldn't even have to cast the output if you could predict it correctly at the time make_node is called (based on the input types). Several Ops go the lazy way and just call the scipy/numpy function with the smallest input possible and get the output type from there. |
The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left.
Revert change to `_solve_discrete_lyapunov`
d5e8e45
to
89d5fd0
Compare
* Blockwise optimal linear control ops * Add jax rewrite to eliminate `BilinearSolveDiscreteLyapunov` * set `solve_discrete_lyapunov` method default to bilinear * Appease mypy * restore method dispatching * Use `pt.vectorize` on base `solve_discrete_lyapunov` case * Apply JAX rewrite before canonicalization * Improve tests * Remove useless warning filters * Fix local_blockwise_alloc rewrite The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left. * Fix float32 tests * Test against complex inputs * Appease ViPy (Vieira-py type checking) * Remove condition from `TensorLike` import * Infer dtype from `node.outputs.type.dtype` * Remove unused mypy ignore * Don't manually set dtype of output Revert change to `_solve_discrete_lyapunov` * Set dtype of Op outputs --------- Co-authored-by: ricardoV94 <[email protected]>
Description
SolveDiscreteLyapunov
,SolveContinuousLyapunov
,SolveDiscreteARE
)BilinearSolveDiscreteLyapunov
in JAX mode. JAX can't call out to the necessary LAPACK functions, but we can still fall back to the direct methodRelated Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1045.org.readthedocs.build/en/1045/