Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup for Optimal Control Ops #1045

Merged
merged 18 commits into from
Oct 24, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Oct 20, 2024

Description

  • Add blockwise support for optimal control ops (SolveDiscreteLyapunov, SolveContinuousLyapunov, SolveDiscreteARE)
  • Add rewrite to remove BilinearSolveDiscreteLyapunov in JAX mode. JAX can't call out to the necessary LAPACK functions, but we can still fall back to the direct method
  • Simplify tests, adding batched cases
  • Add/update typehints

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1045.org.readthedocs.build/en/1045/

pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/slinalg.py Outdated Show resolved Hide resolved
pytensor/tensor/slinalg.py Show resolved Hide resolved
pytensor/tensor/slinalg.py Outdated Show resolved Hide resolved
@jessegrabowski
Copy link
Member Author

jessegrabowski commented Oct 20, 2024

JAX backend doesn't like my new batch-compatible _direct_solve_discrete_lyapunov, I guess because of this line:

    vec_q_shape = pt.concatenate([Q.shape[:-2], [-1]])
    vec_Q = Q.reshape(vec_q_shape)

I initially tried just wrapping the core case in pt.vectorize, but I hit shape errors going that route. Any suggestions?

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 20, 2024

JAX backend doesn't like my new batch-compatible _direct_solve_discrete_lyapunov, I guess because of this line:

    vec_q_shape = pt.concatenate([Q.shape[:-2], [-1]])
    vec_Q = Q.reshape(vec_q_shape)

I initially tried just wrapping the core case in pt.vectorize, but I hit shape errors going that route. Any suggestions?

May need something for Reshape like we do for the size argument of RVs (check size_tuple Op or whatever it is)

@jessegrabowski
Copy link
Member Author

I fixed the JAX problem by doing what I should have been doing in the first place -- implementing a core case then using Blockwise on it.

Unfortunately, this had the side effect of breaking gradients for the regular pytensor backend. Something is failing in rewrites; I think a useless blockwise isn't getting rewritten away.

@ricardoV94
Copy link
Member

Sounds like progress, can you show the gradient graph that's failing?

@jessegrabowski
Copy link
Member Author

Graph:

Sum{axes=None} [id A] <Scalar(float64, shape=())> 10
 └─ Mul [id B] <Tensor3(float64, shape=(5, ?, ?))> 9
    ├─ random_projection [id C] <Tensor3(float64, shape=(?, ?, ?))>
    └─ SpecifyShape [id D] <Tensor3(float64, shape=(5, ?, ?))> 8
       ├─ Reshape{3} [id E] <Tensor3(float64, shape=(?, ?, ?))> 7
       │  ├─ Blockwise{Solve{assume_a='gen', lower=False, check_finite=True, b_ndim=1, overwrite_a=False, overwrite_b=False}, (m,m),(m)->(m)} [id F] <Matrix(float64, shape=(5, ?))> 6
       │  │  ├─ Sub [id G] <Tensor3(float64, shape=(5, ?, ?))> 5
       │  │  │  ├─ ExpandDims{axis=0} [id H] <Tensor3(float64, shape=(1, ?, ?))> 4
       │  │  │  │  └─ Eye{dtype='float64'} [id I] <Matrix(float64, shape=(?, ?))> 3
       │  │  │  │     ├─ Shape_i{2} [id J] <Scalar(int64, shape=())> 2
       │  │  │  │     │  └─ Blockwise{KroneckerProduct{inline=False}, (i00,i01),(i10,i11)->(o00,o01)} [id K] <Tensor3(float64, shape=(5, ?, ?))> 1
       │  │  │  │     │     ├─ input 0 [id L] <Tensor3(float64, shape=(5, 5, 5))>
       │  │  │  │     │     └─ input 0 [id L] <Tensor3(float64, shape=(5, 5, 5))>
       │  │  │  │     ├─ Shape_i{2} [id J] <Scalar(int64, shape=())> 2
       │  │  │  │     │  └─ ···
       │  │  │  │     └─ 0 [id M] <Scalar(int8, shape=())>
       │  │  │  └─ Blockwise{KroneckerProduct{inline=False}, (i00,i01),(i10,i11)->(o00,o01)} [id K] <Tensor3(float64, shape=(5, ?, ?))> 1
       │  │  │     └─ ···
       │  │  └─ Reshape{2} [id N] <Matrix(float64, shape=(?, ?))> 0
       │  │     ├─ input 1 [id O] <Tensor3(float64, shape=(5, 5, 5))>
       │  │     └─ [ 5 -1] [id P] <Vector(int64, shape=(2,))>
       │  └─ [5 5 5] [id Q] <Vector(int64, shape=(3,))>
       ├─ 5 [id R] <Scalar(int8, shape=())>
       ├─ NoneConst{None} [id S] <NoneTypeT>
       └─ NoneConst{None} [id S] <NoneTypeT>

Error:

ERROR    pytensor.graph.rewriting.basic:basic.py:1746 Rewrite failure due to: local_blockwise_alloc
ERROR    pytensor.graph.rewriting.basic:basic.py:1747 node: Blockwise{Reshape{1}, (i00,i01),(i10)->(o00)}(SpecifyShape.0, Alloc.0)
ERROR    pytensor.graph.rewriting.basic:basic.py:1748 TRACEBACK:
ERROR    pytensor.graph.rewriting.basic:basic.py:1749 Traceback (most recent call last):
  File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/tensor/rewriting/blockwise.py", line 176, in local_blockwise_alloc
    new_outs = node.op.make_node(*new_inputs).outputs
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/tensor/blockwise.py", line 130, in make_node
    core_node = self._create_dummy_core_node(inputs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jessegrabowski/Documents/Python/pytensor/pytensor/tensor/blockwise.py", line 103, in _create_dummy_core_node
    raise ValueError(
ValueError: Input 1 DropDims{axis=0}.0 has insufficient core dimensions for signature (i00,i01),(i10)->(o00)

@ricardoV94
Copy link
Member

Looking at the rewrite bug

Copy link

codecov bot commented Oct 21, 2024

Codecov Report

Attention: Patch coverage is 95.91837% with 2 lines in your changes missing coverage. Please review.

Project coverage is 81.93%. Comparing base (dae731d) to head (89d5fd0).
Report is 88 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/slinalg.py 95.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1045      +/-   ##
==========================================
+ Coverage   81.90%   81.93%   +0.02%     
==========================================
  Files         182      182              
  Lines       47872    47890      +18     
  Branches     8617     8617              
==========================================
+ Hits        39210    39239      +29     
+ Misses       6489     6481       -8     
+ Partials     2173     2170       -3     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/blockwise.py 96.40% <100.00%> (ø)
pytensor/tensor/rewriting/linalg.py 91.37% <100.00%> (+0.44%) ⬆️
pytensor/tensor/slinalg.py 93.47% <95.00%> (+1.49%) ⬆️

... and 1 file with indirect coverage changes

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small point about the casting in the perform method. Otherwise looks great

pytensor/tensor/slinalg.py Show resolved Hide resolved
@jessegrabowski
Copy link
Member Author

I had to go back to tracking Blockwise on the rewrite. I found that when I instantiated an Op, as in:

_solve_lyapunov = Blockwise(SolveLyapunov)`

I found though that the node.outputs[0].type.dtype was being "frozen" at the first call, though. If I did float64 then complex128, it was downcasting the complex output to float64. If I remake the node each time a function wrapper is called, I didn't have this problem.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 22, 2024

I had to go back to tracking Blockwise on the rewrite. I found that when I instantiated an Op, as in:

_solve_lyapunov = Blockwise(SolveLyapunov)`

I found though that the node.outputs[0].type.dtype was being "frozen" at the first call, though. If I did float64 then complex128, it was downcasting the complex output to float64. If I remake the node each time a function wrapper is called, I didn't have this problem.

I'm confused. The make_node of an Op has flexibility (and reponsability) to create variables of any type it wants, it's not frozen unless you wrote it like that?

So you shouldn't even have to cast the output if you could predict it correctly at the time make_node is called (based on the input types). Several Ops go the lazy way and just call the scipy/numpy function with the smallest input possible and get the output type from there.

pytensor/tensor/slinalg.py Show resolved Hide resolved
pytensor/tensor/slinalg.py Outdated Show resolved Hide resolved
@jessegrabowski jessegrabowski merged commit fffb84c into pymc-devs:main Oct 24, 2024
60 of 61 checks passed
@jessegrabowski jessegrabowski deleted the lyapunov-jax branch October 24, 2024 11:45
Ch0ronomato pushed a commit to Ch0ronomato/pytensor that referenced this pull request Nov 2, 2024
* Blockwise optimal linear control ops

* Add jax rewrite to eliminate `BilinearSolveDiscreteLyapunov`

* set `solve_discrete_lyapunov` method default to bilinear

* Appease mypy

* restore method dispatching

* Use `pt.vectorize` on base `solve_discrete_lyapunov` case

* Apply JAX rewrite before canonicalization

* Improve tests

* Remove useless warning filters

* Fix local_blockwise_alloc rewrite

The rewrite was squeezing too many dimensions of the alloced value, when this didn't have dummy expand dims to the left.

* Fix float32 tests

* Test against complex inputs

* Appease ViPy (Vieira-py type checking)

* Remove condition from `TensorLike` import

* Infer dtype from `node.outputs.type.dtype`

* Remove unused mypy ignore

* Don't manually set dtype of output

Revert change to `_solve_discrete_lyapunov`

* Set dtype of Op outputs

---------

Co-authored-by: ricardoV94 <[email protected]>
@ricardoV94 ricardoV94 removed the enhancement New feature or request label Nov 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants