Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for arrays #28

Merged
merged 12 commits into from
Mar 21, 2023
82 changes: 80 additions & 2 deletions oqpy/classical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

__all__ = [
"pi",
"ArrayVar",
"BoolVar",
"IntVar",
"UintVar",
Expand All @@ -48,6 +49,7 @@
"ComplexVar",
"DurationVar",
"OQFunctionCall",
"OQIndexExpression",
"StretchVar",
"_ClassicalVar",
"duration",
Expand Down Expand Up @@ -272,18 +274,20 @@ class ComplexVar(_ClassicalVar):
"""An oqpy variable with bit type."""

type_cls = ast.ComplexType
base_type: ast.FloatType = float64

def __class_getitem__(cls, item: Type[ast.FloatType]) -> Callable[..., ComplexVar]:
def __class_getitem__(cls, item: ast.FloatType) -> Callable[..., ComplexVar]:
return functools.partial(cls, base_type=item)

def __init__(
self,
init_expression: AstConvertible | None = None,
*args: Any,
base_type: Type[ast.FloatType] = float64,
base_type: ast.FloatType = float64,
**kwargs: Any,
) -> None:
assert isinstance(base_type, ast.FloatType)
self.base_type = base_type

if not isinstance(init_expression, (complex, type(None), OQPyExpression)):
init_expression = complex(init_expression) # type: ignore[arg-type]
Expand Down Expand Up @@ -313,6 +317,80 @@ class StretchVar(_ClassicalVar):
type_cls = ast.StretchType


AllowedArrayTypes = Union[_SizedVar, DurationVar, BoolVar, ComplexVar]


class ArrayVar(_ClassicalVar):
"""An oqpy array variable."""

type_cls = ast.ArrayType
dimensions: list[int]
base_type: type[AllowedArrayTypes]

def __class_getitem__(
cls, item: tuple[type[AllowedArrayTypes], int] | type[AllowedArrayTypes]
) -> Callable[..., ArrayVar]:
# Allows usage like ArrayVar[FloatVar, 32](...) or ArrayVar[FloatVar]
if isinstance(item, tuple):
base_type = item[0]
dimensions = list(item[1:])
return functools.partial(cls, dimensions=dimensions, base_type=base_type)
else:
return functools.partial(cls, base_type=item)

def __init__(
self,
*args: Any,
dimensions: list[int],
base_type: type[AllowedArrayTypes] = IntVar,
**kwargs: Any,
) -> None:
self.dimensions = dimensions
self.base_type = base_type

# Creating a dummy variable supports IntVar[64] etc.
base_type_instance = base_type()
if isinstance(base_type_instance, _SizedVar):
array_base_type = base_type_instance.type_cls(
size=ast.IntegerLiteral(base_type_instance.size)
jcjaskula-aws marked this conversation as resolved.
Show resolved Hide resolved
)
elif isinstance(base_type_instance, ComplexVar):
array_base_type = base_type_instance.type_cls(base_type=base_type_instance.base_type)
else:
array_base_type = base_type_instance.type_cls()

# Automatically handle Duration array.
if base_type is DurationVar and kwargs["init_expression"]:
kwargs["init_expression"] = (make_duration(i) for i in kwargs["init_expression"])

super().__init__(
*args,
**kwargs,
dimensions=[ast.IntegerLiteral(dimension) for dimension in dimensions],
base_type=array_base_type,
)

def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
jcjaskula-aws marked this conversation as resolved.
Show resolved Hide resolved
return OQIndexExpression(collection=self, index=index)


class OQIndexExpression(OQPyExpression):
"""An oqpy expression corresponding to an index expression."""

def __init__(self, collection: AstConvertible, index: AstConvertible):
self.collection = collection
self.index = index

if isinstance(collection, ArrayVar):
jcjaskula-aws marked this conversation as resolved.
Show resolved Hide resolved
self.type = collection.base_type().type_cls()

def to_ast(self, program: Program) -> ast.IndexExpression:
"""Converts this oqpy index expression into an ast node."""
return ast.IndexExpression(
collection=to_ast(program, self.collection), index=[to_ast(program, self.index)]
)


class OQFunctionCall(OQPyExpression):
"""An oqpy expression corresponding to a function call."""

Expand Down
6 changes: 5 additions & 1 deletion oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,11 @@ def _do_assignment(self, var: AstConvertible, op: str, value: AstConvertible) ->
)
)

def set(self, var: classical_types._ClassicalVar, value: AstConvertible) -> Program:
def set(
self,
var: classical_types._ClassicalVar | classical_types.OQIndexExpression,
value: AstConvertible,
) -> Program:
"""Set a variable value."""
self._do_assignment(var, "=", value)
return self
Expand Down
101 changes: 101 additions & 0 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,87 @@ def test_complex_numbers_declaration():

assert prog.to_qasm() == expected

def test_array_declaration():
b = ArrayVar(name="b", init_expression=[True, False], dimensions=[2], base_type=BoolVar)
i = ArrayVar(name="i", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar)
i55 = ArrayVar(name="i55", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar[55])
u = ArrayVar(name="u", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=UintVar)
x = ArrayVar(name="x", init_expression=[0e-9, 1e-9, 2e-9], dimensions=[3], base_type=DurationVar)
y = ArrayVar(name="y", init_expression=[0.0, 1.0, 2.0, 3.0], dimensions=[4], base_type=FloatVar)
ang = ArrayVar(name="ang", init_expression=[0.0, 1.0, 2.0, 3.0], dimensions=[4], base_type=AngleVar)
comp = ArrayVar(name="comp", init_expression=[0, 1 + 1j], dimensions=[2], base_type=ComplexVar)
comp55 = ArrayVar(name="comp55", init_expression=[0, 1 + 1j], dimensions=[2], base_type=ComplexVar[float_(55)])
ang_partial = ArrayVar[AngleVar, 2](name="ang_part", init_expression=[oqpy.pi, oqpy.pi/2])
simple = ArrayVar[FloatVar](name="no_init", dimensions=[5])
multidim = ArrayVar[FloatVar[32], 3, 2](name="multiDim", init_expression=[[1.1, 1.2], [2.1, 2.2], [3.1, 3.2]])

vars = [b, i, i55, u, x, y, ang, comp, comp55, ang_partial, simple, multidim]

prog = oqpy.Program(version=None)
prog.declare(vars)
prog.set(i[1], 0) # Set with literal values
idx = IntVar(name="idx", init_expression=5)
val = IntVar(name="val", init_expression=10)
prog.set(i[idx], val)

expected = textwrap.dedent(
"""
int[32] idx = 5;
int[32] val = 10;
array[bool, 2] b = {true, false};
array[int[32], 5] i = {0, 1, 2, 3, 4};
array[int[55], 5] i55 = {0, 1, 2, 3, 4};
array[uint[32], 5] u = {0, 1, 2, 3, 4};
array[duration, 3] x = {0.0ns, 1.0ns, 2.0ns};
array[float[64], 4] y = {0.0, 1.0, 2.0, 3.0};
array[angle[32], 4] ang = {0.0, 1.0, 2.0, 3.0};
array[complex[float[64]], 2] comp = {0, 1.0 + 1.0im};
array[complex[float[55]], 2] comp55 = {0, 1.0 + 1.0im};
array[angle[32], 2] ang_part = {pi, pi / 2};
array[float[64], 5] no_init;
array[float[32], 3, 2] multiDim = {{1.1, 1.2}, {2.1, 2.2}, {3.1, 3.2}};
i[1] = 0;
i[idx] = val;
"""
).strip()

assert prog.to_qasm() == expected

anuragm marked this conversation as resolved.
Show resolved Hide resolved
def test_non_trivial_array_access():
prog = oqpy.Program()
port = oqpy.PortVar(name="my_port")
frame = oqpy.FrameVar(name="my_frame", port=port, frequency=1e9, phase=0)

zero_to_one = oqpy.ArrayVar(
name='duration_array',
init_expression=[0.0, 0.25, 0.5, 0.75, 1],
dimensions=[5],
base_type=oqpy.DurationVar
)
one_second = oqpy.DurationVar(init_expression=1, name="one_second")

one = oqpy.IntVar(name="one", init_expression=1)

with oqpy.ForIn(prog, range(4), "idx") as idx:
prog.delay(zero_to_one[idx + one] + one_second, frame)
prog.set(zero_to_one[idx], 5)

expected = textwrap.dedent(
"""
OPENQASM 3.0;
port my_port;
array[duration, 5] duration_array = {0.0ns, 250000000.0ns, 500000000.0ns, 750000000.0ns, 1000000000.0ns};
int[32] one = 1;
duration one_second = 1000000000.0ns;
frame my_frame = newframe(my_port, 1000000000.0, 0);
for int idx in [0:3] {
delay[duration_array[idx + one] + one_second] my_frame;
duration_array[idx] = 5;
}
"""
).strip()

assert prog.to_qasm() == expected

def test_non_trivial_variable_declaration():
prog = Program()
Expand Down Expand Up @@ -389,6 +470,26 @@ def test_for_in_var_types():
"""
).strip()

# Test indexing over an ArrayVar
program = oqpy.Program()
pyphases = [0] + [oqpy.pi / i for i in range(10, 1, -2)]
phases = ArrayVar(name="phases", dimensions=[len(pyphases)], init_expression=pyphases, base_type=AngleVar)

with oqpy.ForIn(program, range(len(pyphases)), "idx") as idx:
program.shift_phase(phases[idx], frame)

expected = textwrap.dedent(
"""
OPENQASM 3.0;
port my_port;
array[angle[32], 6] phases = {0, pi / 10, pi / 8, pi / 6, pi / 4, pi / 2};
frame my_frame = newframe(my_port, 3000000000.0, 0);
for int idx in [0:5] {
shift_phase(phases[idx], my_frame);
}
"""
).strip()

assert program.to_qasm() == expected


Expand Down