From d554072d0266a7a584ef1c03e6fd78c9d4167933 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Fri, 2 Aug 2024 17:02:44 +0100 Subject: [PATCH] feat(hugr-py): Allow defining functions, consts, and aliases inside DFGs (#1394) Move `define_function` and `add_alias_defn` from `Module` to a common root for both it and `DFG`. Move both `add_const` definitions to that common base. I had to move `Function` to `dfg.py` due to circular deps, but I added a reexport to avoid breaking changes. --- hugr-py/src/hugr/dfg.py | 99 +++++++++++++++++++++++++++++------- hugr-py/src/hugr/function.py | 78 +++------------------------- 2 files changed, 88 insertions(+), 89 deletions(-) diff --git a/hugr-py/src/hugr/dfg.py b/hugr-py/src/hugr/dfg.py index 5882eb4d9..141e2e15f 100644 --- a/hugr-py/src/hugr/dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field, replace from typing import ( TYPE_CHECKING, + Generic, TypeVar, ) @@ -22,13 +23,66 @@ from .cfg import Cfg from .cond_loop import Conditional, If, TailLoop from .node_port import Node, OutPort, PortOffset, ToNode, Wire + from .tys import Type, TypeParam, TypeRow + +OpVar = TypeVar("OpVar", bound=ops.Op) + + +@dataclass() +class _DefinitionBuilder(Generic[OpVar]): + """Base class for builders that can define functions, constants, and aliases. + + As this class may be a root node, it does not extend `ParentBuilder`. + """ + + hugr: Hugr[OpVar] + + def define_function( + self, + name: str, + input_types: TypeRow, + type_params: list[TypeParam] | None = None, + ) -> Function: + """Start building a function definition in the graph. + + Args: + name: The name of the function. + input_types: The input types for the function. + type_params: The type parameters for the function, if polymorphic. + + Returns: + The new function builder. + """ + parent_op = ops.FuncDefn(name, input_types, type_params or []) + return Function.new_nested(parent_op, self.hugr) + + def add_const(self, value: val.Value) -> Node: + """Add a static constant to the graph. + + Args: + value: The constant value to add. + + Returns: + The node holding the :class:`Const ` operation. + + Example: + >>> dfg = Dfg() + >>> const_n = dfg.add_const(val.TRUE) + >>> dfg.hugr[const_n].op + Const(TRUE) + """ + return self.hugr.add_node(ops.Const(value), self.hugr.root) + + def add_alias_defn(self, name: str, ty: Type) -> Node: + """Add a type alias definition.""" + return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.root) DP = TypeVar("DP", bound=ops.DfParentOp) @dataclass() -class _DfBase(ParentBuilder[DP], AbstractContextManager): +class _DfBase(ParentBuilder[DP], _DefinitionBuilder, AbstractContextManager): """Base class for dataflow graph builders. Args: @@ -428,23 +482,6 @@ def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges self.hugr.add_link(src.out(-1), dst.inp(-1)) - def add_const(self, val: val.Value) -> Node: - """Add a static constant to the graph. - - Args: - val: The value to add. - - Returns: - The node holding the :class:`Const ` operation. - - Example: - >>> dfg = Dfg() - >>> const_n = dfg.add_const(val.TRUE) - >>> dfg.hugr[const_n].op - Const(TRUE) - """ - return self.hugr.add_const(val, self.parent_node) - def load(self, const: ToNode | val.Value) -> Node: """Load a constant into the graph as a dataflow value. @@ -594,3 +631,29 @@ def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: tgt = tgt_parent return None + + +@dataclass +class Function(_DfBase[ops.FuncDefn]): + """Build a function definition as a HUGR dataflow graph. + + Args: + name: The name of the function. + input_types: The input types for the function (output types are + computed by propagating types from input node through the graph). + type_params: The type parameters for the function, if polymorphic. + + Examples: + >>> f = Function("f", [tys.Bool]) + >>> f.parent_op + FuncDefn(name='f', inputs=[Bool], params=[]) + """ + + def __init__( + self, + name: str, + input_types: TypeRow, + type_params: list[TypeParam] | None = None, + ) -> None: + root_op = ops.FuncDefn(name, input_types, type_params or []) + super().__init__(root_op) diff --git a/hugr-py/src/hugr/function.py b/hugr-py/src/hugr/function.py index bdc388ccf..66be9c58f 100644 --- a/hugr-py/src/hugr/function.py +++ b/hugr-py/src/hugr/function.py @@ -5,45 +5,19 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from hugr import ops, val - -from .dfg import _DfBase +from . import ops +from .dfg import Function, _DefinitionBuilder from .hugr import Hugr if TYPE_CHECKING: - from hugr.node_port import Node - - from .tys import PolyFuncType, Type, TypeBound, TypeParam, TypeRow - - -@dataclass -class Function(_DfBase[ops.FuncDefn]): - """Build a function definition as a HUGR dataflow graph. - - Args: - name: The name of the function. - input_types: The input types for the function (output types are - computed by propagating types from input node through the graph). - type_params: The type parameters for the function, if polymorphic. - - Examples: - >>> f = Function("f", [tys.Bool]) - >>> f.parent_op - FuncDefn(name='f', inputs=[Bool], params=[]) - """ + from .node_port import Node + from .tys import PolyFuncType, TypeBound, TypeRow - def __init__( - self, - name: str, - input_types: TypeRow, - type_params: list[TypeParam] | None = None, - ) -> None: - root_op = ops.FuncDefn(name, input_types, type_params or []) - super().__init__(root_op) +__all__ = ["Function", "Module"] @dataclass -class Module: +class Module(_DefinitionBuilder[ops.Module]): """Build a top-level HUGR module. Examples: @@ -57,25 +31,6 @@ class Module: def __init__(self) -> None: self.hugr = Hugr(ops.Module()) - def define_function( - self, - name: str, - input_types: TypeRow, - type_params: list[TypeParam] | None = None, - ) -> Function: - """Start building a function definition in the module. - - Args: - name: The name of the function. - input_types: The input types for the function. - type_params: The type parameters for the function, if polymorphic. - - Returns: - The new function builder. - """ - parent_op = ops.FuncDefn(name, input_types, type_params or []) - return Function.new_nested(parent_op, self.hugr) - def define_main(self, input_types: TypeRow) -> Function: """Define the 'main' function in the module. See :meth:`define_function`.""" return self.define_function("main", input_types) @@ -91,6 +46,7 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node: The node representing the function declaration. Examples: + >>> from hugr.function import Module >>> m = Module() >>> sig = tys.PolyFuncType([], tys.FunctionType.empty()) >>> m.declare_function("f", sig) @@ -98,26 +54,6 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node: """ return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.root) - def add_const(self, value: val.Value) -> Node: - """Add a static constant to the module. - - Args: - value: The constant value to add. - - Returns: - The node holding the constant. - - Examples: - >>> m = Module() - >>> m.add_const(val.FALSE) - Node(1) - """ - return self.hugr.add_node(ops.Const(value), self.hugr.root) - - def add_alias_defn(self, name: str, ty: Type) -> Node: - """Add a type alias definition.""" - return self.hugr.add_node(ops.AliasDefn(name, ty), self.hugr.root) - def add_alias_decl(self, name: str, bound: TypeBound) -> Node: """Add a type alias declaration.""" return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.root)