Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): builder ops separate from serialised ops #1140

Merged
merged 5 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 13 additions & 53 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing_extensions import Self

from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.ops import BaseOp, OpType as SerialOp
import hugr.serialization.ops as sops
from hugr.serialization.tys import Type
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.tys import Type, FunctionType
from hugr._ops import Op, Input, Output, DFG, Command
from hugr.utils import BiMap


Expand Down Expand Up @@ -101,35 +101,6 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort:
return self.out(offset)


class Op(Protocol):
def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ...

@classmethod
def from_serial(cls, serial: SerialOp) -> Self: ...


T = TypeVar("T", bound=BaseOp)


@dataclass()
class DummyOp(Op, Generic[T]):
_serial_op: T

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=self._serial_op.model_copy()) # type: ignore

@classmethod
def from_serial(cls, serial: SerialOp) -> DummyOp:
return DummyOp(serial.root)


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[Wire]: ...
def num_out(self) -> int | None:
return None


@dataclass()
class NodeData:
op: Op
Expand All @@ -139,10 +110,9 @@ class NodeData:
# TODO children field?

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
o = self.op.to_serial(node, hugr)
o.root.parent = self.parent.idx if self.parent else node.idx
o = self.op.to_serial(node, self.parent if self.parent else node, hugr)

return o
return SerialOp(root=o) # type: ignore[arg-type]


P = TypeVar("P", InPort, OutPort)
Expand Down Expand Up @@ -372,7 +342,7 @@ def from_serial(cls, serial: SerialHugr) -> Hugr:
hugr.root = Node(idx)
parent = None
serial_node.root.parent = -1
hugr._nodes.append(NodeData(DummyOp.from_serial(serial_node), parent))
hugr._nodes.append(NodeData(serial_node.root.deserialize(), parent))

for (src_node, src_offset), (dst_node, dst_offset) in serial.edges:
if src_offset is None or dst_offset is None:
Expand All @@ -396,43 +366,33 @@ def __init__(
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DummyOp(sops.DFG(parent=-1))
root_op._serial_op.signature.input = input_types
root_op._serial_op.signature.output = output_types
root_op = DFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
DummyOp(sops.Input(parent=0, types=input_types)),
self.root,
len(input_types),
)
self.output_node = self.hugr.add_node(
DummyOp(sops.Output(parent=0, types=output_types)), self.root
Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> DummyOp[sops.Input]:
def _input_op(self) -> Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, DummyOp)
assert isinstance(dop._serial_op, sops.Input)
assert isinstance(dop, Input)
return dop

def inputs(self) -> list[OutPort]:
return [
self.input_node.out(i)
for i in range(len(self._input_op()._serial_op.types))
]
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
return self.add_op(com.op(), *com.incoming(), num_outs=com.num_out())
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
Expand Down
135 changes: 135 additions & 0 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
import hugr.serialization.tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node, Wire

Check warning on line 10 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L10

Added line #L10 was not covered by tests


class Op(Protocol):
@property
def num_out(self) -> int | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unset for many op definitions (Input, SerWrap, Custom, DFG)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can't be known in general - but can be for many of those! (which I've added). Am unsure about Custom but I've added it for now. It should be left as the default None for SerWrap.

return None

Check warning on line 16 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L16

Added line #L16 was not covered by tests

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ...

def __call__(self, *args) -> Command:
return Command(self, list(args))


@dataclass(frozen=True)
class Command:
op: Op
incoming: list[Wire]


T = TypeVar("T", bound=BaseOp)


@dataclass()
class SerWrap(Op, Generic[T]):
# catch all for serial ops that don't have a corresponding Op class
_serial_op: T

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T:
root = self._serial_op.model_copy()
root.parent = parent.idx
return root

Check warning on line 41 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L39-L41

Added lines #L39 - L41 were not covered by tests


@dataclass()
class Input(Op):
types: list[tys.Type]

@property
def num_out(self) -> int | None:
return len(self.types)

Check warning on line 50 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L50

Added line #L50 was not covered by tests

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input:
return sops.Input(parent=parent.idx, types=self.types)

def __call__(self) -> Command:
return super().__call__()

Check warning on line 56 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L56

Added line #L56 was not covered by tests


@dataclass()
class Output(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=self.types)


@dataclass()
class Custom(Op):
op_name: str
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)
description: str = ""
extension: tys.ExtensionId = ""
args: list[tys.TypeArg] = field(default_factory=list)

@property
def num_out(self) -> int | None:
return len(self.signature.output)

Check warning on line 77 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L77

Added line #L77 was not covered by tests

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
return sops.CustomOp(
parent=parent.idx,
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
description=self.description,
args=self.args,
)


@dataclass()
class MakeTuple(Op):
types: list[tys.Type]
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
return sops.MakeTuple(
parent=parent.idx,
tys=self.types,
)

def __call__(self, *elements: Wire) -> Command:
return super().__call__(*elements)


@dataclass()
class UnpackTuple(Op):
types: list[tys.Type]

@property
def num_out(self) -> int | None:
return len(self.types)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple:
return sops.UnpackTuple(
parent=parent.idx,
tys=self.types,
)

def __call__(self, tuple_: Wire) -> Command:
return super().__call__(tuple_)


@dataclass()
class DFG(Op):
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)

@property
def num_out(self) -> int | None:
return len(self.signature.output)

Check warning on line 129 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L129

Added line #L129 was not covered by tests

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG:
return sops.DFG(
parent=parent.idx,
signature=self.signature,
)
30 changes: 30 additions & 0 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import inspect
import sys
from abc import ABC
Expand Down Expand Up @@ -41,6 +42,10 @@
"""Name of the op for visualisation"""
return self.__class__.__name__

def deserialize(self) -> _ops.Op:
"""Deserializes the model into the corresponding Op."""
return _ops.SerWrap(self)

Check warning on line 47 in hugr-py/src/hugr/serialization/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/serialization/ops.py#L47

Added line #L47 was not covered by tests


# ----------------------------------------------------------
# --------------- Module level operations ------------------
Expand Down Expand Up @@ -209,6 +214,9 @@
assert len(in_types) == 0
self.types = list(out_types)

def deserialize(self) -> _ops.Input:
return _ops.Input(types=self.types)


class Output(DataflowOp):
"""An output node. The inputs are the outputs of the function."""
Expand All @@ -220,6 +228,9 @@
assert len(out_types) == 0
self.types = list(in_types)

def deserialize(self) -> _ops.Output:
return _ops.Output(types=self.types)


class Call(DataflowOp):
"""
Expand Down Expand Up @@ -292,6 +303,9 @@
input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([])
)

def deserialize(self) -> _ops.DFG:
return _ops.DFG(self.signature)


# ------------------------------------------------
# --------------- ControlFlowOp ------------------
Expand Down Expand Up @@ -388,6 +402,14 @@
def display_name(self) -> str:
return self.op_name

def deserialize(self) -> _ops.Custom:
return _ops.Custom(
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
args=self.args,
)

model_config = ConfigDict(
# Needed to avoid random '\n's in the pydantic description
json_schema_extra={
Expand Down Expand Up @@ -424,6 +446,9 @@
in_types = []
self.tys = list(in_types)

def deserialize(self) -> _ops.MakeTuple:
return _ops.MakeTuple(self.tys)


class UnpackTuple(DataflowOp):
"""An operation that packs all its inputs into a tuple."""
Expand All @@ -434,6 +459,9 @@
def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.tys = list(out_types)

def deserialize(self) -> _ops.UnpackTuple:
return _ops.UnpackTuple(self.tys)


class Tag(DataflowOp):
"""An operation that creates a tagged sum value from one of its variants."""
Expand Down Expand Up @@ -529,3 +557,5 @@
)

tys_model_rebuild(dict(classes))

from hugr import _ops # noqa: E402 # needed to avoid circular imports
Loading