Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Allow higher order primitives to accept dynamically shaped arrays #6786

Open
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Jan 8, 2025

Context:

By turning on the experimental jax_dynamic_shapes mode, you can capture and compile jaxpr for a series of different shapes at the same time. While this expermental feature has issues and isn't fully supported by jax yet, it is used by catalyst. To continue to support all of catalyst's features, we need to be able to capture and work with dynamic shapes as well.

Description of the Change:

  • Adds a qml.capture.determine_abstracted_axes function to determine the required abstracted_axes and the corresponding abstract shapes.
  • Use the determine_abstracted_axes function in all of our higher order primitives other than grad and jacobian, as grad and jacobian may prove more complicated.
  • Add a document explaining abstract shapes and how we can work with them.

Benefits:

Our higher order primitives can accept inputs with abstract shapes.

Possible Drawbacks:

This jax mode is still experimental.

Related GitHub Issues:

[sc-81471]

Copy link
Contributor

github-actions bot commented Jan 8, 2025

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@albi3ro albi3ro changed the title [Draft][Capture] Allow higher order primitives to accept dynamically shaped arrays [Capture] Allow higher order primitives to accept dynamically shaped arrays Jan 9, 2025
@albi3ro albi3ro marked this pull request as ready for review January 9, 2025 22:32
Copy link

codecov bot commented Jan 9, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.60%. Comparing base (807bc4c) to head (47bdf11).
Report is 2 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6786   +/-   ##
=======================================
  Coverage   99.60%   99.60%           
=======================================
  Files         476      477    +1     
  Lines       45237    45282   +45     
=======================================
+ Hits        45060    45105   +45     
  Misses        177      177           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +9 to +11
* The higher order primitives in program capture can now accept inputs with abstract shapes.
[(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to add your name to the changelog 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. thanks for the reminder.

Copy link
Contributor

@lillian542 lillian542 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few minor comments, but I have questions about determine_abstract_axes 😅

I need some help understanding what is going on, it would probably be easiest just to chat.

import jax
```

Dynamic shapes are experimental feature of jax with limited support and feature coverage.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Dynamic shapes are experimental feature of jax with limited support and feature coverage.
Dynamic shapes are an experimental feature of jax with limited support and feature coverage.

Comment on lines +9 to +15


```python
jax.config.update("jax_dynamic_shapes", False)
```

Without this setup, we can't create arrays whose size depends on an abstract value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This really confused me for a moment because I read it as "without the above setup", and I thought it implied we need to run jax.config.update("jax_dynamic_shapes", False) in order for the dynamic arrays to work.

Suggested change
```python
jax.config.update("jax_dynamic_shapes", False)
```
Without this setup, we can't create arrays whose size depends on an abstract value.
Without the `"jax_dynamic_shapes"` feature, we can't create arrays whose size depends on an abstract value.
```python
jax.config.update("jax_dynamic_shapes", False)


Executing with `eval_jaxpr`:

No idea how to fix this right now.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😢

abstract_shapes.append(s)
abstracted_axes.append(tuple(l_shape) if len(l_shape) != 1 else l_shape[0]) # maybe ?
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes)
return abstracted_axes, abstract_shapes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little lost on the alphabet part




We can now take these learnings a make custom higher order primitive that supports dynamically shaped inputs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We can now take these learnings a make custom higher order primitive that supports dynamically shaped inputs:
We can now take these learnings to make a custom higher order primitive that supports dynamically shaped inputs:

1) A bit more difficult to read and follow
2) Relies on unstable componets of jax internals

But why let those concerns stop us now! Let's do it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤣

Comment on lines 732 to +733
*flat_args,
*abstract_shapes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below we have *abstract_shapes, *flat_args in some of the primitive bind functions, and here it's reversed; it would be more consistent to have a set order.

has_jax = False # pragma: no cover


def determine_abstracted_axes(args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a hard time reviewing this function because I'm not confident that I understand what the expected behaviour is for various inputs. Maybe we could have a quick call/meeting and you could walk me through it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants