From 19978a1234ddd7515e82d5a8a449b65a86f5f9ae Mon Sep 17 00:00:00 2001 From: Jean-Christophe Jaskula Date: Fri, 11 Aug 2023 11:44:45 -0400 Subject: [PATCH] Handle cases with generators --- oqpy/program.py | 18 +++++++++++------- tests/test_directives.py | 9 +++++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/oqpy/program.py b/oqpy/program.py index 9323f90..42fd6df 100644 --- a/oqpy/program.py +++ b/oqpy/program.py @@ -389,25 +389,29 @@ def delay( qubits_or_frames: AstConvertible | Iterable[AstConvertible] | None = None, ) -> Program: """Apply a delay to a set of qubits or frames.""" - if isinstance(qubits_or_frames, Iterable) and not any(True for _ in qubits_or_frames): - return self - elif qubits_or_frames is None: + ast_duration = to_ast(self, convert_float_to_duration(time)) + + if qubits_or_frames is None: ast_qubits_or_frames = [] else: if not isinstance(qubits_or_frames, Iterable): qubits_or_frames = [qubits_or_frames] + else: + qubits_or_frames = list(qubits_or_frames) + if len(qubits_or_frames) == 0: + return self ast_qubits_or_frames = map_to_ast(self, qubits_or_frames) - ast_duration = to_ast(self, convert_float_to_duration(time)) self._add_statement(ast.DelayInstruction(ast_duration, ast_qubits_or_frames)) return self def barrier(self, qubits_or_frames: Iterable[AstConvertible] | None = None) -> Program: """Apply a barrier to a set of qubits or frames.""" - if isinstance(qubits_or_frames, Iterable) and not any(True for _ in qubits_or_frames): - return self - elif qubits_or_frames is None: + if qubits_or_frames is None: ast_qubits_or_frames = [] else: + qubits_or_frames = list(qubits_or_frames) + if len(qubits_or_frames) == 0: + return self ast_qubits_or_frames = map_to_ast(self, qubits_or_frames) self._add_statement(ast.QuantumBarrier(ast_qubits_or_frames)) return self diff --git a/tests/test_directives.py b/tests/test_directives.py index 08c757e..3c30e44 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -875,6 +875,13 @@ def test_barrier_delay_arguments(): prog.delay(3e-7, frame1) prog.delay(4e-7, [frame, frame1]) + def frame_generator(frames): + for frame in frames: + yield frame + + prog.barrier(frame_generator([frame, frame1])) + prog.delay(5e-7, frame_generator([frame, frame1])) + expected = textwrap.dedent( """ OPENQASM 3.0; @@ -886,6 +893,8 @@ def test_barrier_delay_arguments(): barrier frame0, frame1; delay[300.0ns] frame1; delay[400.0ns] frame0, frame1; + barrier frame0, frame1; + delay[500.0ns] frame0, frame1; """ ).strip()