Skip to content

Commit

Permalink
support output assignment for function calls (#70)
Browse files Browse the repository at this point in the history
* support return assignment for function calls

* update test

---------

Co-authored-by: Phil Reinhold <[email protected]>
  • Loading branch information
yitchen-tim and PhilReinhold authored Sep 25, 2023
1 parent 9bee4c7 commit f0885b7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
17 changes: 12 additions & 5 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,18 @@ def barrier(self, qubits_or_frames: Iterable[AstConvertible]) -> Program:
self._add_statement(ast.QuantumBarrier(ast_qubits_or_frames))
return self

def function_call(self, name: str, args: Iterable[AstConvertible]) -> None:
"""Add a function call."""
self._add_statement(
ast.ExpressionStatement(ast.FunctionCall(ast.Identifier(name), map_to_ast(self, args)))
)
def function_call(
self,
name: str,
args: Iterable[AstConvertible],
assigns_to: AstConvertible = None,
) -> None:
"""Add a function call with an optional output assignment."""
function_call_node = ast.FunctionCall(ast.Identifier(name), map_to_ast(self, args))
if assigns_to is None:
self.do_expression(function_call_node)
else:
self._do_assignment(to_ast(self, assigns_to), "=", function_call_node)

def play(self, frame: AstConvertible, waveform: AstConvertible) -> Program:
"""Play a waveform on a particular frame."""
Expand Down
23 changes: 23 additions & 0 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,3 +2357,26 @@ def test_qubit_array():
prog_with_errors.gate(q0, "h")
with pytest.raises(ValueError):
prog_with_errors.to_qasm()


@pytest.mark.parametrize(
"args,assigns_to,expected",
[
([], None, "OPENQASM 3.0;\nmy_function();"),
(
[oqpy.BitVar(name="a0"), oqpy.BitVar(name="a1")],
None,
"OPENQASM 3.0;\nbit a0;\nbit a1;\nmy_function(a0, a1);",
),
(
[oqpy.BitVar(name="a0")],
oqpy.BitVar(name="b0"),
"OPENQASM 3.0;\nbit a0;\nbit b0;\nb0 = my_function(a0);",
),
],
)
def test_function_call(args, assigns_to, expected):
prog = Program()
prog.function_call("my_function", args, assigns_to)
assert prog.to_qasm() == expected
_check_respects_type_hints(prog)

0 comments on commit f0885b7

Please sign in to comment.