Skip to content

Commit

Permalink
Merge branch 'master' into pause-capture-context
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro authored Feb 4, 2025
2 parents a27311d + d1de032 commit a823f80
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 27 deletions.
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@

<h3>Bug fixes 🐛</h3>

* `qml.capture.PlxprInterpreter` now correctly handles propagation of constants when interpreting higher-order primitives
[(#6913)](https://github.com/PennyLaneAI/pennylane/pull/6913)

* `qml.capture.PlxprInterpreter` now uses `Primitive.get_bind_params` to resolve primitive calling signatures before binding
primitives.
[(#6913)](https://github.com/PennyLaneAI/pennylane/pull/6913)

* The interface is now detected from the data in the circuit, not the arguments to the `QNode`. This allows
interface data to be strictly passed as closure variables and still be detected.
[(#6892)](https://github.com/PennyLaneAI/pennylane/pull/6892)
Expand Down
87 changes: 63 additions & 24 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def jaxpr_to_jaxpr(

f = partial(interpreter.eval, jaxpr, consts)

return jax.make_jaxpr(f)(*args).jaxpr
return jax.make_jaxpr(f)(*args)


class PlxprInterpreter:
Expand Down Expand Up @@ -368,7 +368,8 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
outvals = self.interpret_measurement_eqn(eqn)
else:
invals = [self.read(invar) for invar in eqn.invars]
outvals = primitive.bind(*invals, **eqn.params)
subfuns, params = primitive.get_bind_params(eqn.params)
outvals = primitive.bind(*subfuns, *invals, **params)

if not primitive.multiple_results:
outvals = [outvals]
Expand Down Expand Up @@ -454,9 +455,11 @@ def handle_adjoint_transform(self, *invals, jaxpr, lazy, n_consts):
"""Interpret an adjoint transform primitive."""
consts = invals[:n_consts]
args = invals[n_consts:]

jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)
return adjoint_transform_prim.bind(*invals, jaxpr=jaxpr, lazy=lazy, n_consts=n_consts)

return adjoint_transform_prim.bind(
*jaxpr.consts, *args, jaxpr=jaxpr.jaxpr, lazy=lazy, n_consts=len(jaxpr.consts)
)


# pylint: disable=too-many-arguments
Expand All @@ -468,12 +471,14 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_
jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)

return ctrl_transform_prim.bind(
*invals,
*jaxpr.consts,
*args,
*invals[-n_control:],
n_control=n_control,
jaxpr=jaxpr,
jaxpr=jaxpr.jaxpr,
control_values=control_values,
work_wires=work_wires,
n_consts=n_consts,
n_consts=len(jaxpr.consts),
)


Expand All @@ -482,19 +487,24 @@ def handle_for_loop(
self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle a for loop primitive."""
consts = args[consts_slice]
init_state = args[args_slice]
abstract_shapes = args[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, args[consts_slice], *abstract_shapes, start, *init_state
copy(self), jaxpr_body_fn, consts, *abstract_shapes, start, *init_state
)

consts_slice = slice(0, len(new_jaxpr_body_fn.consts))
abstract_shapes_slice = slice(consts_slice.stop, consts_slice.stop + len(abstract_shapes))
args_slice = slice(abstract_shapes_slice.stop, None)
return for_loop_prim.bind(
start,
stop,
step,
*args,
jaxpr_body_fn=new_jaxpr_body_fn,
*new_jaxpr_body_fn.consts,
*abstract_shapes,
*init_state,
jaxpr_body_fn=new_jaxpr_body_fn.jaxpr,
consts_slice=consts_slice,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
Expand All @@ -507,15 +517,30 @@ def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
args = invals[args_slice]

new_jaxprs = []
new_consts = []
new_consts_slices = []
end_const_ind = len(jaxpr_branches)

for const_slice, jaxpr in zip(consts_slices, jaxpr_branches):
consts = invals[const_slice]
if jaxpr is None:
new_jaxprs.append(None)
new_consts_slices.append(slice(0, 0))
else:
new_jaxprs.append(jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args))
new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)
new_jaxprs.append(new_jaxpr.jaxpr)
new_consts.extend(new_jaxpr.consts)
new_consts_slices.append(slice(end_const_ind, end_const_ind + len(new_jaxpr.consts)))
end_const_ind += len(new_jaxpr.consts)

new_args_slice = slice(end_const_ind, None)
return cond_prim.bind(
*invals, jaxpr_branches=new_jaxprs, consts_slices=consts_slices, args_slice=args_slice
*invals[: len(jaxpr_branches)],
*new_consts,
*args,
jaxpr_branches=new_jaxprs,
consts_slices=new_consts_slices,
args_slice=new_args_slice,
)


Expand Down Expand Up @@ -543,12 +568,20 @@ def handle_while_loop(
copy(self), jaxpr_cond_fn, consts_cond, *abstract_shapes, *init_state
)

body_consts = slice(0, len(new_jaxpr_body_fn.consts))
cond_consts = slice(body_consts.stop, body_consts.stop + len(new_jaxpr_cond_fn.consts))
abstract_shapes_slice = slice(cond_consts.stop, cond_consts.stop + len(abstract_shapes))
args_slice = slice(abstract_shapes_slice.stop, None)

return while_loop_prim.bind(
*invals,
jaxpr_body_fn=new_jaxpr_body_fn,
jaxpr_cond_fn=new_jaxpr_cond_fn,
body_slice=body_slice,
cond_slice=cond_slice,
*new_jaxpr_body_fn.consts,
*new_jaxpr_cond_fn.consts,
*abstract_shapes,
*init_state,
jaxpr_body_fn=new_jaxpr_body_fn.jaxpr,
jaxpr_cond_fn=new_jaxpr_cond_fn.jaxpr,
body_slice=body_consts,
cond_slice=cond_consts,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)
Expand All @@ -559,17 +592,19 @@ def handle_while_loop(
def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts):
"""Handle a qnode primitive."""
consts = invals[:n_consts]
args = invals[n_consts:]

new_qfunc_jaxpr = jaxpr_to_jaxpr(copy(self), qfunc_jaxpr, consts, *invals[n_consts:])
new_qfunc_jaxpr = jaxpr_to_jaxpr(copy(self), qfunc_jaxpr, consts, *args)

return qnode_prim.bind(
*invals,
*new_qfunc_jaxpr.consts,
*args,
shots=shots,
qnode=qnode,
device=device,
qnode_kwargs=qnode_kwargs,
qfunc_jaxpr=new_qfunc_jaxpr,
n_consts=n_consts,
qfunc_jaxpr=new_qfunc_jaxpr.jaxpr,
n_consts=len(new_qfunc_jaxpr.consts),
)


Expand All @@ -579,7 +614,9 @@ def handle_grad(self, *invals, jaxpr, n_consts, **params):
consts = invals[:n_consts]
args = invals[n_consts:]
new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)
return grad_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params)
return grad_prim.bind(
*new_jaxpr.consts, *args, jaxpr=new_jaxpr.jaxpr, n_consts=len(new_jaxpr.consts), **params
)


@PlxprInterpreter.register_primitive(jacobian_prim)
Expand All @@ -588,7 +625,9 @@ def handle_jacobian(self, *invals, jaxpr, n_consts, **params):
consts = invals[:n_consts]
args = invals[n_consts:]
new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)
return jacobian_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params)
return jacobian_prim.bind(
*new_jaxpr.consts, *args, jaxpr=new_jaxpr.jaxpr, n_consts=len(new_jaxpr.consts), **params
)


def flatten_while_loop(
Expand Down
3 changes: 2 additions & 1 deletion pennylane/transforms/optimization/cancel_inverses.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:
if getattr(eqn.primitive, "prim_type", "") == "transform":
self.interpret_all_previous_ops()
invals = [self.read(invar) for invar in eqn.invars]
outvals = eqn.primitive.bind(*invals, **eqn.params)
subfuns, params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **params)

if not eqn.primitive.multiple_results:
outvals = [outvals]
Expand Down
Loading

0 comments on commit a823f80

Please sign in to comment.