Skip to content

Commit

Permalink
implemented reset for catalyst backend
Browse files Browse the repository at this point in the history
  • Loading branch information
positr0nium committed Dec 30, 2024
1 parent d281b86 commit a4b72f0
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@

from jax import make_jaxpr, jit
from jax.core import ClosedJaxpr
from jax.lax import fori_loop
from jax.lax import fori_loop, cond, while_loop
from jax._src.linear_util import wrap_init
import jax.numpy as jnp

from catalyst.jax_primitives import (AbstractQreg, qinst_p, qmeasure_p,
qextract_p, qinsert_p, while_p, cond_p, func_p)

from qrisp.jasp import (QuantumPrimitive, OperationPrimitive, AbstractQuantumCircuit, AbstractQubitArray,
AbstractQubit, eval_jaxpr, Jaspr, extract_invalues, insert_outvalues)
AbstractQubit, eval_jaxpr, Jaspr, extract_invalues, insert_outvalues, Measurement_p, get_qubit_p,
get_size_p, delete_qubits_p)


# Name translator from Qrisp gate naming to Catalyst gate naming
Expand Down Expand Up @@ -98,6 +99,8 @@ def catalyst_eqn_evaluator(eqn, context_dic):
elif eqn.primitive.name == "jasp.delete_qubits":
# Not available in Catalyst
context_dic[outvars[0]] = context_dic[invars[0]]
elif eqn.primitive.name == "jasp.reset":
process_reset(eqn, context_dic)
elif isinstance(eqn.primitive, OperationPrimitive):
process_op(eqn.primitive, invars, outvars, context_dic)
else:
Expand Down Expand Up @@ -417,7 +420,6 @@ def process_cond(eqn, context_dic):
else:
unflattened_outvalues.append(outvalues.pop(0))


insert_outvalues(eqn, context_dic, unflattened_outvalues)

@lru_cache(maxsize = int(1E5))
Expand Down Expand Up @@ -464,4 +466,53 @@ def process_pjit(eqn, context_dic):
unflattened_outvalues.append(outvalues.pop(0))

insert_outvalues(eqn, context_dic, unflattened_outvalues)

# Function to reset and delete a qubit array
def reset_qubit_array(abs_qc, qb_array):
from qrisp.circuit import XGate

def body_func(arg_tuple):

abs_qc, qb_array, i = arg_tuple

abs_qb = get_qubit_p.bind(qb_array, i)
abs_qc, meas_bl = Measurement_p.bind(abs_qc, abs_qb)

def true_fun(arg_tuple):
abs_qc, qb = arg_tuple
abs_qc = OperationPrimitive(XGate()).bind(abs_qc, qb)
return (abs_qc, qb)

def false_fun(arg_tuple):
return arg_tuple

abs_qc, qb = cond(meas_bl, true_fun, false_fun, (abs_qc, abs_qb))

i += 1

return (abs_qc, qb_array, i)

def cond_fun(arg_tuple):
return arg_tuple[-1] < get_size_p.bind(arg_tuple[1])


abs_qc, qb_array, i = while_loop(cond_fun,
body_func,
(abs_qc, qb_array, jnp.array(0, dtype = jnp.int32))
)

abs_qc = delete_qubits_p.bind(abs_qc, qb_array)

return abs_qc

reset_jaxpr = make_jaxpr(reset_qubit_array)(AbstractQuantumCircuit(), AbstractQubitArray())

def process_reset(eqn, context_dic):

invalues = extract_invalues(eqn, context_dic)
outvalues = eval_jaxpr(reset_jaxpr.jaxpr, eqn_evaluator = catalyst_eqn_evaluator)(*invalues)
insert_outvalues(eqn, context_dic, outvalues)




1 change: 0 additions & 1 deletion src/qrisp/jasp/rus.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def extract_boolean_digit(integer, digit):
@jax.jit
def reset_qubit_array(abs_qc, qb_array):


def body_func(arg_tuple):

abs_qc, qb_array, i = arg_tuple
Expand Down

0 comments on commit a4b72f0

Please sign in to comment.