diff --git a/oqpy/subroutines.py b/oqpy/subroutines.py index fe0aa37..869b17b 100644 --- a/oqpy/subroutines.py +++ b/oqpy/subroutines.py @@ -97,6 +97,8 @@ def wrapper( for input_val in inputs.values(): inner_prog._mark_var_declared(input_val) output = func(inner_prog, **inputs) + inner_prog.autodeclare() + inner_prog._state.finalize_if_clause() body = inner_prog._state.body if isinstance(output, OQPyExpression): return_type = output.type @@ -115,6 +117,9 @@ def wrapper( raise ValueError( "Output type of subroutine {name} was neither oqpy expression nor None." ) + program.defcals.update(inner_prog.defcals) + program.subroutines.update(inner_prog.subroutines) + program.externs.update(inner_prog.externs) stmt = ast.SubroutineDefinition( identifier, arguments=arguments, diff --git a/tests/test_directives.py b/tests/test_directives.py index 4aa5bdc..360e63f 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -1999,3 +1999,41 @@ def test_io_declaration(): ).strip() assert prog.to_qasm() == expected _check_respects_type_hints(prog) + + +def test_nested_subroutines(): + @oqpy.subroutine + def f(prog: oqpy.Program) -> oqpy.IntVar: + i = oqpy.IntVar(name="i", init_expression=1) + with oqpy.If(prog, i == 1): + prog.increment(i, 1) + return i + + @oqpy.subroutine + def g(prog: oqpy.Program) -> oqpy.IntVar: + return f(prog) + + + prog = oqpy.Program() + x = oqpy.IntVar(name="x") + prog.set(x, g(prog)) + + expected = textwrap.dedent( + """ + OPENQASM 3.0; + def f() -> int[32] { + int[32] i = 1; + if (i == 1) { + i += 1; + } + return i; + } + def g() -> int[32] { + return f(); + } + int[32] x; + x = g(); + """ + ).strip() + + assert prog.to_qasm() == expected