diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 1cdffd808..e8fd828fc 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -17,9 +17,9 @@ from typing_extensions import Self from hugr.serialization.serial_hugr import SerialHugr -from hugr.serialization.ops import BaseOp, OpType as SerialOp -import hugr.serialization.ops as sops -from hugr.serialization.tys import Type +from hugr.serialization.ops import OpType as SerialOp +from hugr.serialization.tys import Type, FunctionType +from hugr._ops import Op, Input, Output, DFG from hugr.utils import BiMap @@ -43,6 +43,13 @@ class Wire(Protocol): def out_port(self) -> OutPort: ... +class Command(Protocol): + def op(self) -> Op: ... + def incoming(self) -> Iterable[Wire]: ... + def num_out(self) -> int | None: + return None + + @dataclass(frozen=True, eq=True, order=True) class OutPort(_Port, Wire): direction: ClassVar[Direction] = Direction.OUTGOING @@ -100,35 +107,6 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort: return self.out(offset) -class Op(Protocol): - def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ... - - @classmethod - def from_serial(cls, serial: SerialOp) -> Self: ... - - -T = TypeVar("T", bound=BaseOp) - - -@dataclass() -class DummyOp(Op, Generic[T]): - _serial_op: T - - def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: - return SerialOp(root=self._serial_op.model_copy()) # type: ignore - - @classmethod - def from_serial(cls, serial: SerialOp) -> DummyOp: - return DummyOp(serial.root) - - -class Command(Protocol): - def op(self) -> Op: ... - def incoming(self) -> Iterable[Wire]: ... - def num_out(self) -> int | None: - return None - - @dataclass() class NodeData: op: Op @@ -138,8 +116,7 @@ class NodeData: # TODO children field? def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: - o = self.op.to_serial(node, hugr) - o.root.parent = self.parent.idx if self.parent else node.idx + o = self.op.to_serial(node, self.parent if self.parent else node, hugr) return o @@ -371,7 +348,7 @@ def from_serial(cls, serial: SerialHugr) -> Hugr: hugr.root = Node(idx) parent = None serial_node.root.parent = -1 - hugr._nodes.append(NodeData(DummyOp.from_serial(serial_node), parent)) + hugr._nodes.append(NodeData(Op.from_serial(serial_node), parent)) for (src_node, src_offset), (dst_node, dst_offset) in serial.edges: if src_offset is None or dst_offset is None: @@ -395,35 +372,25 @@ def __init__( ) -> None: input_types = list(input_types) output_types = list(output_types) - root_op = DummyOp(sops.DFG(parent=-1)) - root_op._serial_op.signature.input = input_types - root_op._serial_op.signature.output = output_types + root_op = DFG(FunctionType(input=input_types, output=output_types)) self.hugr = Hugr(root_op) self.root = self.hugr.root self.input_node = self.hugr.add_node( - DummyOp(sops.Input(parent=0, types=input_types)), - self.root, - len(input_types), - ) - self.output_node = self.hugr.add_node( - DummyOp(sops.Output(parent=0, types=output_types)), self.root + Input(input_types), self.root, len(input_types) ) + self.output_node = self.hugr.add_node(Output(output_types), self.root) @classmethod def endo(cls, types: Sequence[Type]) -> Dfg: return Dfg(types, types) - def _input_op(self) -> DummyOp[sops.Input]: + def _input_op(self) -> Input: dop = self.hugr[self.input_node].op - assert isinstance(dop, DummyOp) - assert isinstance(dop._serial_op, sops.Input) + assert isinstance(dop, Input) return dop def inputs(self) -> list[OutPort]: - return [ - self.input_node.out(i) - for i in range(len(self._input_op()._serial_op.types)) - ] + return [self.input_node.out(i) for i in range(len(self._input_op().types))] def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node: new_n = self.hugr.add_node(op, self.root, num_outs=num_outs) diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/_ops.py similarity index 100% rename from hugr-py/src/hugr/ops.py rename to hugr-py/src/hugr/_ops.py diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 522da06f6..3f1971083 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -3,10 +3,11 @@ import subprocess import os import pathlib -from hugr._hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op +from hugr._hugr import Dfg, Hugr, Node, Command, Wire +from hugr._ops import Op, Custom +import hugr._ops as ops from hugr.serialization import SerialHugr import hugr.serialization.tys as stys -import hugr.serialization.ops as sops import pytest import json @@ -22,14 +23,11 @@ ) ) -NOT_OP = DummyOp( - # TODO get from YAML - sops.CustomOp( - parent=-1, - extension="logic", - op_name="Not", - signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]), - ) +# TODO get from YAML +NOT_OP = Custom( + extension="logic", + op_name="Not", + signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]), ) @@ -59,14 +57,11 @@ def num_out(self) -> int | None: return 2 def op(self) -> Op: - return DummyOp( - sops.CustomOp( - parent=-1, - extension="arithmetic.int", - op_name="idivmod_u", - signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2), - args=[ARG_5, ARG_5], - ) + return Custom( + extension="arithmetic.int", + op_name="idivmod_u", + signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2), + args=[ARG_5, ARG_5], ) @@ -82,7 +77,7 @@ def num_out(self) -> int | None: return 1 def op(self) -> Op: - return DummyOp(sops.MakeTuple(parent=-1, tys=self.types)) + return ops.MakeTuple(self.types) @dataclass @@ -97,7 +92,7 @@ def num_out(self) -> int | None: return len(self.types) def op(self) -> Op: - return DummyOp(sops.UnpackTuple(parent=-1, tys=self.types)) + return ops.UnpackTuple(self.types) def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): @@ -117,7 +112,7 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): def test_stable_indices(): - h = Hugr(DummyOp(sops.DFG(parent=-1))) + h = Hugr(ops.DFG()) nodes = [h.add_node(NOT_OP) for _ in range(3)] assert len(h) == 4 @@ -201,8 +196,8 @@ def test_tuple(): h1 = Dfg.endo(row) a, b = h1.inputs() - mt = h1.add_op(DummyOp(sops.MakeTuple(parent=-1, tys=row)), a, b) - a, b = h1.add_op(DummyOp(sops.UnpackTuple(parent=-1, tys=row)), mt)[0, 1] + mt = h1.add_op(ops.MakeTuple(row), a, b) + a, b = h1.add_op(ops.UnpackTuple(row), mt)[0, 1] h1.set_outputs(a, b) assert h.hugr.to_serial() == h1.hugr.to_serial() @@ -224,7 +219,7 @@ def test_insert(): assert len(h1.hugr) == 4 - new_h = Hugr(DummyOp(sops.DFG(parent=-1))) + new_h = Hugr(ops.DFG()) mapping = h1.hugr.insert_hugr(new_h, h1.hugr.root) assert mapping == {new_h.root: Node(4)}