Skip to content

Commit

Permalink
common up static input wiring
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 4, 2024
1 parent 4f99254 commit 8a13c64
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions hugr-py/src/hugr/build/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,19 @@ def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(
self, op: ops.DataflowOp, /, *args: Wire, metadata: dict[str, Any] | None = None
self,
op: ops.DataflowOp,
/,
*args: Wire,
static_in: Iterable[Wire] | None = None,
metadata: dict[str, Any] | None = None,
) -> Node:
"""Add a dataflow operation to the graph, wiring in input ports.
Args:
op: The operation to add.
args: The input wires to the operation.
static_in: Any static input wires to the command.
metadata: Metadata to attach to the function definition. Defaults to None.
Returns:
Expand All @@ -191,7 +197,7 @@ def add_op(
Node(3)
"""
new_n = self.hugr.add_node(op, self.parent_node, metadata=metadata)
self._wire_up(new_n, args)
self._wire_up(new_n, args, static_in=static_in)

return replace(new_n, _num_out_ports=op.num_out)

Expand Down Expand Up @@ -225,15 +231,9 @@ def raise_no_ints():
wires = (
(w if not isinstance(w, int) else raise_no_ints()) for w in com.incoming
)
node = self.add_op(com.op, *wires, metadata=metadata)
return self.add_op(com.op, *wires, metadata=metadata, static_in=static_in)


# wire up static inputs
static_inputs = list(static_in or [])
dataflow_in = self.hugr.num_incoming(node)
for i, w in enumerate(static_inputs):
# static inputs always come after dataflow inputs
self.hugr.add_link(w.out_port(), node.inp(dataflow_in + i))
return node

def extend(self, *coms: ops.Command) -> list[Node]:
"""Add a series of commands to the DFG.
Expand Down Expand Up @@ -627,10 +627,19 @@ def _fn_sig(self, func: ToNode) -> tys.PolyFuncType:
raise ValueError(msg)
return signature

def _wire_up(self, node: Node, ports: Iterable[Wire]) -> tys.TypeRow:
def _wire_up(
self, node: Node, ports: Iterable[Wire], static_in: Iterable[Wire] | None = None
) -> tys.TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops._PartialOp):
op._set_in_types(tys)

# wire up static inputs
static_inputs = list(static_in or [])
dataflow_in = self.hugr.num_incoming(node)
for i, w in enumerate(static_inputs):
# static inputs always come after dataflow inputs
self.hugr.add_link(w.out_port(), node.inp(dataflow_in + i))
return tys

def _get_dataflow_type(self, wire: Wire) -> tys.Type:
Expand Down

0 comments on commit 8a13c64

Please sign in to comment.