-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Capture] Allow higher order primitives to accept dynamically shaped arrays #6786
base: master
Are you sure you want to change the base?
Changes from all commits
54f7c67
b1303c1
1469ad2
a0c9176
2c59a64
8e6a167
5bc1a90
6fff950
b6a7eb2
4b04dae
0916fba
a5fb9ee
8604e5b
47bdf11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright 2025 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Contains a utility for handling inputs with dynamically shaped arrays. | ||
""" | ||
from string import ascii_lowercase | ||
|
||
has_jax = True | ||
try: | ||
import jax | ||
except ImportError: # pragma: no cover | ||
has_jax = False # pragma: no cover | ||
|
||
|
||
def determine_abstracted_axes(args): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm having a hard time reviewing this function because I'm not confident that I understand what the expected behaviour is for various inputs. Maybe we could have a quick call/meeting and you could walk me through it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll put in some more examples for how |
||
"""Computed the abstracted axes and extracing the abstract shapes from the arguments. | ||
|
||
Args: | ||
args (tuple): the arguments for a higher order primitive | ||
|
||
Returns: | ||
tuple, tuple: the corresponding abstracted axes and dynamic shapes | ||
|
||
See the ``intro_to_dynamic_shapes.md`` document for more information on how dynamic shapes work. | ||
|
||
To make jaxpr from arguments with dynamic shapes, the ``abstracted_axes`` keyword argument must be set. | ||
Then, when calling the jaxpr, variables for the dynamic shapes must be passed. | ||
|
||
``` | ||
def f(n): | ||
x = jax.numpy.ones((n,)) | ||
abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,)) | ||
jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x) | ||
return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x) | ||
``` | ||
|
||
""" | ||
if not has_jax: # pragma: no cover | ||
raise ImportError("jax must be installed to use determine_abstracted_axes") | ||
if not jax.config.jax_dynamic_shapes: | ||
return None, tuple() | ||
|
||
args, structure = jax.tree_util.tree_flatten(args) | ||
abstracted_axes = [] | ||
abstract_shapes = [] | ||
for l in args: | ||
l_shape = {} | ||
for i, s in enumerate(getattr(l, "shape", ())): | ||
if not isinstance(s, int): # not abstract | ||
found = False | ||
for j, previous_shape in enumerate(abstract_shapes): | ||
if s is previous_shape: | ||
l_shape[i] = ascii_lowercase[j] | ||
found = True | ||
continue | ||
if not found: | ||
l_shape[i] = ascii_lowercase[len(abstract_shapes)] | ||
abstract_shapes.append(s) | ||
abstracted_axes.append(l_shape) | ||
|
||
if not abstract_shapes: | ||
return None, () | ||
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes) | ||
return abstracted_axes, abstract_shapes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to add your name to the changelog 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added. thanks for the reminder.