-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Capture] Allow higher order primitives to accept dynamically shaped arrays #6786
base: master
Are you sure you want to change the base?
Conversation
Hello. You may have forgotten to update the changelog!
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6786 +/- ##
=======================================
Coverage 99.60% 99.60%
=======================================
Files 476 477 +1
Lines 45237 45282 +45
=======================================
+ Hits 45060 45105 +45
Misses 177 177 ☔ View full report in Codecov by Sentry. |
* The higher order primitives in program capture can now accept inputs with abstract shapes. | ||
[(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to add your name to the changelog 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added. thanks for the reminder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few minor comments, but I have questions about determine_abstract_axes
😅
I need some help understanding what is going on, it would probably be easiest just to chat.
import jax | ||
``` | ||
|
||
Dynamic shapes are experimental feature of jax with limited support and feature coverage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamic shapes are experimental feature of jax with limited support and feature coverage. | |
Dynamic shapes are an experimental feature of jax with limited support and feature coverage. |
|
||
|
||
```python | ||
jax.config.update("jax_dynamic_shapes", False) | ||
``` | ||
|
||
Without this setup, we can't create arrays whose size depends on an abstract value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This really confused me for a moment because I read it as "without the above setup", and I thought it implied we need to run jax.config.update("jax_dynamic_shapes", False)
in order for the dynamic arrays to work.
```python | |
jax.config.update("jax_dynamic_shapes", False) | |
``` | |
Without this setup, we can't create arrays whose size depends on an abstract value. | |
Without the `"jax_dynamic_shapes"` feature, we can't create arrays whose size depends on an abstract value. | |
```python | |
jax.config.update("jax_dynamic_shapes", False) |
|
||
Executing with `eval_jaxpr`: | ||
|
||
No idea how to fix this right now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😢
abstract_shapes.append(s) | ||
abstracted_axes.append(tuple(l_shape) if len(l_shape) != 1 else l_shape[0]) # maybe ? | ||
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes) | ||
return abstracted_axes, abstract_shapes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little lost on the alphabet
part
|
||
|
||
|
||
We can now take these learnings a make custom higher order primitive that supports dynamically shaped inputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can now take these learnings a make custom higher order primitive that supports dynamically shaped inputs: | |
We can now take these learnings to make a custom higher order primitive that supports dynamically shaped inputs: |
1) A bit more difficult to read and follow | ||
2) Relies on unstable componets of jax internals | ||
|
||
But why let those concerns stop us now! Let's do it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤣
*flat_args, | ||
*abstract_shapes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below we have *abstract_shapes, *flat_args
in some of the primitive bind
functions, and here it's reversed; it would be more consistent to have a set order.
has_jax = False # pragma: no cover | ||
|
||
|
||
def determine_abstracted_axes(args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm having a hard time reviewing this function because I'm not confident that I understand what the expected behaviour is for various inputs. Maybe we could have a quick call/meeting and you could walk me through it?
Context:
By turning on the experimental
jax_dynamic_shapes
mode, you can capture and compile jaxpr for a series of different shapes at the same time. While this expermental feature has issues and isn't fully supported by jax yet, it is used by catalyst. To continue to support all of catalyst's features, we need to be able to capture and work with dynamic shapes as well.Description of the Change:
qml.capture.determine_abstracted_axes
function to determine the requiredabstracted_axes
and the corresponding abstract shapes.determine_abstracted_axes
function in all of our higher order primitives other thangrad
andjacobian
, asgrad
andjacobian
may prove more complicated.Benefits:
Our higher order primitives can accept inputs with abstract shapes.
Possible Drawbacks:
This jax mode is still experimental.
Related GitHub Issues:
[sc-81471]