diff --git a/oqpy/classical_types.py b/oqpy/classical_types.py index 7bc216f..40ce8f6 100644 --- a/oqpy/classical_types.py +++ b/oqpy/classical_types.py @@ -39,6 +39,7 @@ __all__ = [ "pi", + "ArrayVar", "BoolVar", "IntVar", "UintVar", @@ -48,6 +49,7 @@ "ComplexVar", "DurationVar", "OQFunctionCall", + "OQIndexExpression", "StretchVar", "_ClassicalVar", "duration", @@ -272,18 +274,20 @@ class ComplexVar(_ClassicalVar): """An oqpy variable with bit type.""" type_cls = ast.ComplexType + base_type: ast.FloatType = float64 - def __class_getitem__(cls, item: Type[ast.FloatType]) -> Callable[..., ComplexVar]: + def __class_getitem__(cls, item: ast.FloatType) -> Callable[..., ComplexVar]: return functools.partial(cls, base_type=item) def __init__( self, init_expression: AstConvertible | None = None, *args: Any, - base_type: Type[ast.FloatType] = float64, + base_type: ast.FloatType = float64, **kwargs: Any, ) -> None: assert isinstance(base_type, ast.FloatType) + self.base_type = base_type if not isinstance(init_expression, (complex, type(None), OQPyExpression)): init_expression = complex(init_expression) # type: ignore[arg-type] @@ -313,6 +317,80 @@ class StretchVar(_ClassicalVar): type_cls = ast.StretchType +AllowedArrayTypes = Union[_SizedVar, DurationVar, BoolVar, ComplexVar] + + +class ArrayVar(_ClassicalVar): + """An oqpy array variable.""" + + type_cls = ast.ArrayType + dimensions: list[int] + base_type: type[AllowedArrayTypes] + + def __class_getitem__( + cls, item: tuple[type[AllowedArrayTypes], int] | type[AllowedArrayTypes] + ) -> Callable[..., ArrayVar]: + # Allows usage like ArrayVar[FloatVar, 32](...) or ArrayVar[FloatVar] + if isinstance(item, tuple): + base_type = item[0] + dimensions = list(item[1:]) + return functools.partial(cls, dimensions=dimensions, base_type=base_type) + else: + return functools.partial(cls, base_type=item) + + def __init__( + self, + *args: Any, + dimensions: list[int], + base_type: type[AllowedArrayTypes] = IntVar, + **kwargs: Any, + ) -> None: + self.dimensions = dimensions + self.base_type = base_type + + # Creating a dummy variable supports IntVar[64] etc. + base_type_instance = base_type() + if isinstance(base_type_instance, _SizedVar): + array_base_type = base_type_instance.type_cls( + size=ast.IntegerLiteral(base_type_instance.size) + ) + elif isinstance(base_type_instance, ComplexVar): + array_base_type = base_type_instance.type_cls(base_type=base_type_instance.base_type) + else: + array_base_type = base_type_instance.type_cls() + + # Automatically handle Duration array. + if base_type is DurationVar and kwargs["init_expression"]: + kwargs["init_expression"] = (make_duration(i) for i in kwargs["init_expression"]) + + super().__init__( + *args, + **kwargs, + dimensions=[ast.IntegerLiteral(dimension) for dimension in dimensions], + base_type=array_base_type, + ) + + def __getitem__(self, index: AstConvertible) -> OQIndexExpression: + return OQIndexExpression(collection=self, index=index) + + +class OQIndexExpression(OQPyExpression): + """An oqpy expression corresponding to an index expression.""" + + def __init__(self, collection: AstConvertible, index: AstConvertible): + self.collection = collection + self.index = index + + if isinstance(collection, ArrayVar): + self.type = collection.base_type().type_cls() + + def to_ast(self, program: Program) -> ast.IndexExpression: + """Converts this oqpy index expression into an ast node.""" + return ast.IndexExpression( + collection=to_ast(program, self.collection), index=[to_ast(program, self.index)] + ) + + class OQFunctionCall(OQPyExpression): """An oqpy expression corresponding to a function call.""" diff --git a/oqpy/program.py b/oqpy/program.py index f877cee..5434232 100644 --- a/oqpy/program.py +++ b/oqpy/program.py @@ -459,7 +459,11 @@ def _do_assignment(self, var: AstConvertible, op: str, value: AstConvertible) -> ) ) - def set(self, var: classical_types._ClassicalVar, value: AstConvertible) -> Program: + def set( + self, + var: classical_types._ClassicalVar | classical_types.OQIndexExpression, + value: AstConvertible, + ) -> Program: """Set a variable value.""" self._do_assignment(var, "=", value) return self diff --git a/tests/test_directives.py b/tests/test_directives.py index 5df84d5..92ad181 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -138,6 +138,87 @@ def test_complex_numbers_declaration(): assert prog.to_qasm() == expected +def test_array_declaration(): + b = ArrayVar(name="b", init_expression=[True, False], dimensions=[2], base_type=BoolVar) + i = ArrayVar(name="i", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar) + i55 = ArrayVar(name="i55", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar[55]) + u = ArrayVar(name="u", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=UintVar) + x = ArrayVar(name="x", init_expression=[0e-9, 1e-9, 2e-9], dimensions=[3], base_type=DurationVar) + y = ArrayVar(name="y", init_expression=[0.0, 1.0, 2.0, 3.0], dimensions=[4], base_type=FloatVar) + ang = ArrayVar(name="ang", init_expression=[0.0, 1.0, 2.0, 3.0], dimensions=[4], base_type=AngleVar) + comp = ArrayVar(name="comp", init_expression=[0, 1 + 1j], dimensions=[2], base_type=ComplexVar) + comp55 = ArrayVar(name="comp55", init_expression=[0, 1 + 1j], dimensions=[2], base_type=ComplexVar[float_(55)]) + ang_partial = ArrayVar[AngleVar, 2](name="ang_part", init_expression=[oqpy.pi, oqpy.pi/2]) + simple = ArrayVar[FloatVar](name="no_init", dimensions=[5]) + multidim = ArrayVar[FloatVar[32], 3, 2](name="multiDim", init_expression=[[1.1, 1.2], [2.1, 2.2], [3.1, 3.2]]) + + vars = [b, i, i55, u, x, y, ang, comp, comp55, ang_partial, simple, multidim] + + prog = oqpy.Program(version=None) + prog.declare(vars) + prog.set(i[1], 0) # Set with literal values + idx = IntVar(name="idx", init_expression=5) + val = IntVar(name="val", init_expression=10) + prog.set(i[idx], val) + + expected = textwrap.dedent( + """ + int[32] idx = 5; + int[32] val = 10; + array[bool, 2] b = {true, false}; + array[int[32], 5] i = {0, 1, 2, 3, 4}; + array[int[55], 5] i55 = {0, 1, 2, 3, 4}; + array[uint[32], 5] u = {0, 1, 2, 3, 4}; + array[duration, 3] x = {0.0ns, 1.0ns, 2.0ns}; + array[float[64], 4] y = {0.0, 1.0, 2.0, 3.0}; + array[angle[32], 4] ang = {0.0, 1.0, 2.0, 3.0}; + array[complex[float[64]], 2] comp = {0, 1.0 + 1.0im}; + array[complex[float[55]], 2] comp55 = {0, 1.0 + 1.0im}; + array[angle[32], 2] ang_part = {pi, pi / 2}; + array[float[64], 5] no_init; + array[float[32], 3, 2] multiDim = {{1.1, 1.2}, {2.1, 2.2}, {3.1, 3.2}}; + i[1] = 0; + i[idx] = val; + """ + ).strip() + + assert prog.to_qasm() == expected + +def test_non_trivial_array_access(): + prog = oqpy.Program() + port = oqpy.PortVar(name="my_port") + frame = oqpy.FrameVar(name="my_frame", port=port, frequency=1e9, phase=0) + + zero_to_one = oqpy.ArrayVar( + name='duration_array', + init_expression=[0.0, 0.25, 0.5, 0.75, 1], + dimensions=[5], + base_type=oqpy.DurationVar + ) + one_second = oqpy.DurationVar(init_expression=1, name="one_second") + + one = oqpy.IntVar(name="one", init_expression=1) + + with oqpy.ForIn(prog, range(4), "idx") as idx: + prog.delay(zero_to_one[idx + one] + one_second, frame) + prog.set(zero_to_one[idx], 5) + + expected = textwrap.dedent( + """ + OPENQASM 3.0; + port my_port; + array[duration, 5] duration_array = {0.0ns, 250000000.0ns, 500000000.0ns, 750000000.0ns, 1000000000.0ns}; + int[32] one = 1; + duration one_second = 1000000000.0ns; + frame my_frame = newframe(my_port, 1000000000.0, 0); + for int idx in [0:3] { + delay[duration_array[idx + one] + one_second] my_frame; + duration_array[idx] = 5; + } + """ + ).strip() + + assert prog.to_qasm() == expected def test_non_trivial_variable_declaration(): prog = Program() @@ -389,6 +470,26 @@ def test_for_in_var_types(): """ ).strip() + # Test indexing over an ArrayVar + program = oqpy.Program() + pyphases = [0] + [oqpy.pi / i for i in range(10, 1, -2)] + phases = ArrayVar(name="phases", dimensions=[len(pyphases)], init_expression=pyphases, base_type=AngleVar) + + with oqpy.ForIn(program, range(len(pyphases)), "idx") as idx: + program.shift_phase(phases[idx], frame) + + expected = textwrap.dedent( + """ + OPENQASM 3.0; + port my_port; + array[angle[32], 6] phases = {0, pi / 10, pi / 8, pi / 6, pi / 4, pi / 2}; + frame my_frame = newframe(my_port, 3000000000.0, 0); + for int idx in [0:5] { + shift_phase(phases[idx], my_frame); + } + """ + ).strip() + assert program.to_qasm() == expected