Skip to content

Commit

Permalink
decorator for easily registering operations
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Aug 12, 2024
1 parent 502f3de commit fa33e1f
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 57 deletions.
48 changes: 45 additions & 3 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar

from semver import Version

import hugr.serialization.extension as ext_s
from hugr import tys, val
from hugr import ops, tys, val
from hugr.utils import ser_it

__all__ = [
Expand All @@ -25,7 +25,7 @@
]

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence

from hugr.hugr import Hugr
from hugr.tys import ExtensionId
Expand Down Expand Up @@ -151,6 +151,9 @@ def to_serial(self) -> ext_s.ExtensionValue:
)


T = TypeVar("T", bound=ops.RegisteredOp)


@dataclass
class Extension: # noqa: D101
name: ExtensionId
Expand Down Expand Up @@ -221,6 +224,45 @@ def get_value(self, name: str) -> ExtensionValue:
except KeyError as e:
raise self.ValueNotFound(name) from e

T = TypeVar("T", bound=ops.RegisteredOp)

def register_op(
self,
name: str | None = None,
signature: OpDefSig | tys.PolyFuncType | tys.FunctionType | None = None,
description: str | None = None,
misc: dict[str, Any] | None = None,
lower_funcs: list[FixedHugr] | None = None,
) -> Callable[[type[T]], type[T]]:
"""Register a class as corresponding to an operation definition.
If `name` is not provided, the class name is used.
If `signature` is not provided, a binary signature is assumed.
If `description` is not provided, the class docstring is used.
See :class:`OpDef` for other parameters.
"""
if not isinstance(signature, OpDefSig):
binary = signature is None
signature = OpDefSig(signature, binary)

def _inner(cls: type[T]) -> type[T]:
new_description = cls.__doc__ if description is None and cls.__doc__ else ""
new_name = cls.__name__ if name is None else name
op_def = self.add_op_def(
OpDef(
new_name,
signature,
new_description,
misc or {},
lower_funcs or [],
)
)
cls.const_op_def = op_def
return cls

return _inner


@dataclass
class ExtensionRegistry:
Expand Down
16 changes: 15 additions & 1 deletion hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from typing_extensions import Self

import hugr.serialization.ops as sops
from hugr import ext, tys, val
from hugr import tys, val
from hugr.node_port import Direction, InPort, Node, OutPort, PortOffset, Wire
from hugr.utils import ser_it

if TYPE_CHECKING:
from collections.abc import Sequence

from hugr import ext
from hugr.serialization.ops import BaseOp


Expand Down Expand Up @@ -379,6 +380,19 @@ def outer_signature(self) -> tys.FunctionType:
return poly_func.body


class RegisteredOp(AsExtOp):
"""Base class for operations that are registered with an extension using
:meth:`Extension.register_op <hugr.ext.Extension.register_op>`.
"""

#: Known operation definition.
const_op_def: ext.OpDef # must be initialised by register_op

def op_def(self) -> ext.OpDef:
# override for AsExtOp.op_def
return self.const_op_def


@dataclass()
class MakeTuple(DataflowOp, _PartialOp):
"""Operation to create a tuple from a sequence of wires."""
Expand Down
26 changes: 9 additions & 17 deletions hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import Self

from hugr import ext, tys, val
from hugr.ops import AsExtOp, DataflowOp, ExtOp
from hugr.ops import AsExtOp, DataflowOp, ExtOp, RegisteredOp

if TYPE_CHECKING:
from hugr.ops import Command, ComWire
Expand Down Expand Up @@ -67,27 +67,19 @@ def to_value(self) -> val.Extension:

OPS_EXTENSION = ext.Extension("arithmetic.int", ext.Version(0, 1, 0))

_DivMod = OPS_EXTENSION.add_op_def(
ext.OpDef(
name="idivmod_u",
description="Unsigned integer division and modulo.",
signature=ext.OpDefSig(
tys.FunctionType([_int_tv(0), _int_tv(1)], [_int_tv(0), _int_tv(1)])
),
)
)


@OPS_EXTENSION.register_op(
signature=ext.OpDefSig(
tys.FunctionType([_int_tv(0), _int_tv(1)], [_int_tv(0), _int_tv(1)])
),
)
@dataclass(frozen=True)
class _DivModDef(AsExtOp):
class idivmod_u(RegisteredOp):
"""DivMod operation, has two inputs and two outputs."""

arg1: int = 5
arg2: int = 5

def op_def(self) -> ext.OpDef:
return _DivMod

def type_args(self) -> list[tys.TypeArg]:
return [tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)]

Expand All @@ -97,7 +89,7 @@ def cached_signature(self) -> tys.FunctionType | None:

@classmethod
def from_ext(cls, custom: ExtOp) -> Self | None:
if custom.op_def() != _DivMod:
if custom.op_def() != cls.const_op_def:
return None
match custom.args:
case [tys.BoundedNatArg(n=a1), tys.BoundedNatArg(n=a2)]:
Expand All @@ -111,4 +103,4 @@ def __call__(self, a: ComWire, b: ComWire) -> Command:


#: DivMod operation.
DivMod = _DivModDef()
DivMod = idivmod_u()
19 changes: 7 additions & 12 deletions hugr-py/src/hugr/std/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,22 @@
from typing import TYPE_CHECKING

from hugr import ext, tys
from hugr.ops import AsExtOp, Command, DataflowOp
from hugr.ops import Command, DataflowOp, RegisteredOp

if TYPE_CHECKING:
from hugr.ops import ComWire


EXTENSION = ext.Extension("logic", ext.Version(0, 1, 0))

_NotDef = EXTENSION.add_op_def(
ext.OpDef(
name="Not",
description="Logical NOT operation.",
signature=ext.OpDefSig(tys.FunctionType.endo([tys.Bool])),
)
)


@EXTENSION.register_op(
name="Not",
signature=ext.OpDefSig(tys.FunctionType.endo([tys.Bool])),
)
@dataclass(frozen=True)
class _NotOp(AsExtOp):
def op_def(self) -> ext.OpDef:
return _NotDef
class _NotOp(RegisteredOp):
"""Logical NOT operation."""

def __call__(self, a: ComWire) -> Command:
return DataflowOp.__call__(self, a)
Expand Down
35 changes: 11 additions & 24 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from hugr import ext, tys
from hugr.hugr import Hugr
from hugr.ops import AsExtOp, Command, DataflowOp, ExtOp
from hugr.ops import AsExtOp, Command, DataflowOp, ExtOp, RegisteredOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.std.float import FLOAT_T

Expand All @@ -36,21 +36,6 @@
)
)

EXTENSION.add_op_def(
ext.OpDef(
name="Measure",
description="Measurement operation",
signature=ext.OpDefSig(tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool])),
)
)

EXTENSION.add_op_def(
ext.OpDef(
name="Rz",
description="Rotation around the z-axis",
signature=ext.OpDefSig(tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit])),
)
)

E = TypeVar("E", bound=Enum)

Expand Down Expand Up @@ -106,23 +91,25 @@ def __call__(self, q0: ComWire, q1: ComWire) -> Command:
CX = TwoQbGate(TwoQbGate._Enum.CX)


@EXTENSION.register_op(
"Measure",
signature=tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]),
)
@dataclass(frozen=True)
class MeasureDef(AsExtOp):
def op_def(self) -> ext.OpDef:
return EXTENSION.operations["Measure"]

class MeasureDef(RegisteredOp):
def __call__(self, q: ComWire) -> Command:
return super().__call__(q)


Measure = MeasureDef()


@EXTENSION.register_op(
"Rz",
signature=tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]),
)
@dataclass(frozen=True)
class RzDef(AsExtOp):
def op_def(self) -> ext.OpDef:
return EXTENSION.operations["Rz"]

class RzDef(RegisteredOp):
def __call__(self, q: ComWire, fl_wire: ComWire) -> Command:
return super().__call__(q, fl_wire)

Expand Down

0 comments on commit fa33e1f

Please sign in to comment.