Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge latest update from openqasm:main #1

Merged
merged 5 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

All members of this project agree to adhere to the [Qiskit Code of Conduct][qiskit-coc].

[qiskit-coc]: https://github.com/Qiskit/qiskit/blob/main/CODE_OF_CONDUCT.md
[qiskit-coc]: https://github.com/Qiskit/qiskit/blob/master/CODE_OF_CONDUCT.md
43 changes: 43 additions & 0 deletions oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@ def __pow__(self, other: AstConvertible) -> OQPyBinaryExpression:
def __rpow__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("**", other, self)

def __lshift__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("<<", self, other)

def __rlshift__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("<<", other, self)

def __rshift__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary(">>", self, other)

def __rrshift__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary(">>", other, self)

def __and__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("&", self, other)

def __rand__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("&", other, self)

def __or__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("|", self, other)

def __ror__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("|", other, self)

def __xor__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("^", self, other)

def __rxor__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("^", other, self)

def __eq__(self, other: AstConvertible) -> OQPyBinaryExpression: # type: ignore[override]
return self._to_binary("==", self, other)

Expand All @@ -125,13 +155,26 @@ def __ge__(self, other: AstConvertible) -> OQPyBinaryExpression:
def __le__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("<=", self, other)

def __invert__(self) -> OQPyUnaryExpression:
return self._to_unary("~", self)

def __bool__(self) -> bool:
raise RuntimeError(
"OQPy expressions cannot be converted to bool. This can occur if you try to check "
"the equality of expressions using == instead of expr_matches."
)


def logical_and(first: AstConvertible, second: AstConvertible) -> OQPyBinaryExpression:
"""Logical AND."""
return OQPyBinaryExpression(ast.BinaryOperator["&&"], first, second)


def logical_or(first: AstConvertible, second: AstConvertible) -> OQPyBinaryExpression:
"""Logical OR."""
return OQPyBinaryExpression(ast.BinaryOperator["||"], first, second)


def expr_matches(a: Any, b: Any) -> bool:
"""Check equality of the given objects.

Expand Down
82 changes: 80 additions & 2 deletions oqpy/classical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

__all__ = [
"pi",
"ArrayVar",
"BoolVar",
"IntVar",
"UintVar",
Expand All @@ -48,6 +49,7 @@
"ComplexVar",
"DurationVar",
"OQFunctionCall",
"OQIndexExpression",
"StretchVar",
"_ClassicalVar",
"duration",
Expand Down Expand Up @@ -272,18 +274,20 @@ class ComplexVar(_ClassicalVar):
"""An oqpy variable with bit type."""

type_cls = ast.ComplexType
base_type: ast.FloatType = float64

def __class_getitem__(cls, item: Type[ast.FloatType]) -> Callable[..., ComplexVar]:
def __class_getitem__(cls, item: ast.FloatType) -> Callable[..., ComplexVar]:
return functools.partial(cls, base_type=item)

def __init__(
self,
init_expression: AstConvertible | None = None,
*args: Any,
base_type: Type[ast.FloatType] = float64,
base_type: ast.FloatType = float64,
**kwargs: Any,
) -> None:
assert isinstance(base_type, ast.FloatType)
self.base_type = base_type

if not isinstance(init_expression, (complex, type(None), OQPyExpression)):
init_expression = complex(init_expression) # type: ignore[arg-type]
Expand Down Expand Up @@ -313,6 +317,80 @@ class StretchVar(_ClassicalVar):
type_cls = ast.StretchType


AllowedArrayTypes = Union[_SizedVar, DurationVar, BoolVar, ComplexVar]


class ArrayVar(_ClassicalVar):
"""An oqpy array variable."""

type_cls = ast.ArrayType
dimensions: list[int]
base_type: type[AllowedArrayTypes]

def __class_getitem__(
cls, item: tuple[type[AllowedArrayTypes], int] | type[AllowedArrayTypes]
) -> Callable[..., ArrayVar]:
# Allows usage like ArrayVar[FloatVar, 32](...) or ArrayVar[FloatVar]
if isinstance(item, tuple):
base_type = item[0]
dimensions = list(item[1:])
return functools.partial(cls, dimensions=dimensions, base_type=base_type)
else:
return functools.partial(cls, base_type=item)

def __init__(
self,
*args: Any,
dimensions: list[int],
base_type: type[AllowedArrayTypes] = IntVar,
**kwargs: Any,
) -> None:
self.dimensions = dimensions
self.base_type = base_type

# Creating a dummy variable supports IntVar[64] etc.
base_type_instance = base_type()
if isinstance(base_type_instance, _SizedVar):
array_base_type = base_type_instance.type_cls(
size=ast.IntegerLiteral(base_type_instance.size)
)
elif isinstance(base_type_instance, ComplexVar):
array_base_type = base_type_instance.type_cls(base_type=base_type_instance.base_type)
else:
array_base_type = base_type_instance.type_cls()

# Automatically handle Duration array.
if base_type is DurationVar and kwargs["init_expression"]:
kwargs["init_expression"] = (make_duration(i) for i in kwargs["init_expression"])

super().__init__(
*args,
**kwargs,
dimensions=[ast.IntegerLiteral(dimension) for dimension in dimensions],
base_type=array_base_type,
)

def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
return OQIndexExpression(collection=self, index=index)


class OQIndexExpression(OQPyExpression):
"""An oqpy expression corresponding to an index expression."""

def __init__(self, collection: AstConvertible, index: AstConvertible):
self.collection = collection
self.index = index

if isinstance(collection, ArrayVar):
self.type = collection.base_type().type_cls()

def to_ast(self, program: Program) -> ast.IndexExpression:
"""Converts this oqpy index expression into an ast node."""
return ast.IndexExpression(
collection=to_ast(program, self.collection), index=[to_ast(program, self.index)]
)


class OQFunctionCall(OQPyExpression):
"""An oqpy expression corresponding to a function call."""

Expand Down
6 changes: 5 additions & 1 deletion oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,11 @@ def _do_assignment(self, var: AstConvertible, op: str, value: AstConvertible) ->
)
)

def set(self, var: classical_types._ClassicalVar, value: AstConvertible) -> Program:
def set(
self,
var: classical_types._ClassicalVar | classical_types.OQIndexExpression,
value: AstConvertible,
) -> Program:
"""Set a variable value."""
self._do_assignment(var, "=", value)
return self
Expand Down
Loading