Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Allow higher order primitives to accept dynamically shaped arrays #6786

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
7 changes: 6 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

<h3>Improvements 🛠</h3>

* The higher order primitives in program capture can now accept inputs with abstract shapes.
[(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786)

Comment on lines +9 to +11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to add your name to the changelog 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. thanks for the reminder.

<h3>Breaking changes 💔</h3>

<h3>Deprecations 👋</h3>
Expand All @@ -20,4 +23,6 @@
<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Diksha Dhawan

Diksha Dhawan
Christina Lee
1 change: 1 addition & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def _(*args, **kwargs):
)
from .flatfn import FlatFn
from .make_plxpr import make_plxpr, run_autograph
from .dynamic_shapes import determine_abstracted_axes

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
Expand Down
48 changes: 38 additions & 10 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,15 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_


@PlxprInterpreter.register_primitive(for_loop_prim)
def handle_for_loop(self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice):
def handle_for_loop(
self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle a for loop primitive."""
init_state = args[args_slice]
abstract_shapes = args[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, args[consts_slice], start, *init_state
copy(self), jaxpr_body_fn, args[consts_slice], *abstract_shapes, start, *init_state
)

return for_loop_prim.bind(
Expand All @@ -401,6 +404,7 @@ def handle_for_loop(self, start, stop, step, *args, jaxpr_body_fn, consts_slice,
jaxpr_body_fn=new_jaxpr_body_fn,
consts_slice=consts_slice,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)


Expand All @@ -424,15 +428,27 @@ def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):

@PlxprInterpreter.register_primitive(while_loop_prim)
def handle_while_loop(
self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice
self,
*invals,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
):
"""Handle a while loop primitive."""
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr(copy(self), jaxpr_body_fn, consts_body, *init_state)
new_jaxpr_cond_fn = jaxpr_to_jaxpr(copy(self), jaxpr_cond_fn, consts_cond, *init_state)
new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, consts_body, *abstract_shapes, *init_state
)
new_jaxpr_cond_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_cond_fn, consts_cond, *abstract_shapes, *init_state
)

return while_loop_prim.bind(
*invals,
Expand All @@ -441,6 +457,7 @@ def handle_while_loop(
body_slice=body_slice,
cond_slice=cond_slice,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)


Expand Down Expand Up @@ -482,16 +499,24 @@ def handle_jacobian(self, *invals, jaxpr, n_consts, **params):


def flatten_while_loop(
self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice
self,
*invals,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
):
"""Handle the while loop by a flattened python strategy."""
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

fn_res = init_state
while copy(self).eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *fn_res)
while copy(self).eval(jaxpr_cond_fn, consts_cond, *abstract_shapes, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *abstract_shapes, *fn_res)

return fn_res

Expand All @@ -515,14 +540,17 @@ def flattened_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
FlattenedHigherOrderPrimitives[cond_prim] = flattened_cond


def flattened_for(self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice):
def flattened_for(
self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle the for loop by a flattened python strategy."""
consts = invals[consts_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

res = init_state
for i in range(start, stop, step):
res = copy(self).eval(jaxpr_body_fn, consts, i, *res)
res = copy(self).eval(jaxpr_body_fn, consts, *abstract_shapes, i, *res)

return res

Expand Down
75 changes: 75 additions & 0 deletions pennylane/capture/dynamic_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a utility for handling inputs with dynamically shaped arrays.
"""
from string import ascii_lowercase

has_jax = True
try:
import jax
except ImportError: # pragma: no cover
has_jax = False # pragma: no cover


def determine_abstracted_axes(args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a hard time reviewing this function because I'm not confident that I understand what the expected behaviour is for various inputs. Maybe we could have a quick call/meeting and you could walk me through it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll put in some more examples for how abstracted_axes work.

"""Computed the abstracted axes and extracing the abstract shapes from the arguments.

Args:
args (tuple): the arguments for a higher order primitive

Returns:
tuple, tuple: the corresponding abstracted axes and dynamic shapes

See the ``intro_to_dynamic_shapes.md`` document for more information on how dynamic shapes work.

To make jaxpr from arguments with dynamic shapes, the ``abstracted_axes`` keyword argument must be set.
Then, when calling the jaxpr, variables for the dynamic shapes must be passed.

```
def f(n):
x = jax.numpy.ones((n,))
abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,))
jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x)
return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x)
```

"""
if not has_jax: # pragma: no cover
raise ImportError("jax must be installed to use determine_abstracted_axes")
if not jax.config.jax_dynamic_shapes:
return None, tuple()

args, structure = jax.tree_util.tree_flatten(args)
abstracted_axes = []
abstract_shapes = []
for l in args:
l_shape = {}
for i, s in enumerate(getattr(l, "shape", ())):
if not isinstance(s, int): # not abstract
found = False
for j, previous_shape in enumerate(abstract_shapes):
if s is previous_shape:
l_shape[i] = ascii_lowercase[j]
found = True
continue
if not found:
l_shape[i] = ascii_lowercase[len(abstract_shapes)]
abstract_shapes.append(s)
abstracted_axes.append(l_shape)

if not abstract_shapes:
return None, ()
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes)
return abstracted_axes, abstract_shapes
Loading
Loading