From cc2c8a4f09f43f8913c49ff0dfe0da601a85b7c6 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Tue, 17 Sep 2024 10:22:30 +0100 Subject: [PATCH] feat: Only lower definitions to Hugr if they are used (#496) Closes #434 and closes #470. --------- Co-authored-by: Alan Lawrence Co-authored-by: Douglas Wilson --- guppylang/compiler/core.py | 41 ++++++++++++++++++++++++++--- guppylang/compiler/func_compiler.py | 31 +++++++++++----------- guppylang/module.py | 32 +++++++--------------- tests/integration/test_array.py | 6 +---- tests/integration/test_basic.py | 6 +---- tests/integration/test_extern.py | 10 ++----- 6 files changed, 67 insertions(+), 59 deletions(-) diff --git a/guppylang/compiler/core.py b/guppylang/compiler/core.py index d5f96e57..0a75d1d9 100644 --- a/guppylang/compiler/core.py +++ b/guppylang/compiler/core.py @@ -3,17 +3,52 @@ from typing import cast from hugr import Wire, ops -from hugr.build.dfg import DP, DfBase +from hugr.build.dfg import DP, DefinitionBuilder, DfBase from guppylang.checker.core import FieldAccess, Place, PlaceId, Variable -from guppylang.definition.common import CompiledDef, DefId +from guppylang.definition.common import CheckedDef, CompilableDef, CompiledDef, DefId from guppylang.error import InternalGuppyError from guppylang.tys.ty import StructType -CompiledGlobals = dict[DefId, CompiledDef] CompiledLocals = dict[PlaceId, Wire] +class CompiledGlobals: + """Compilation context containing all available definitions. + + Maintains a `worklist` of definitions which have been used by other compiled code + (i.e. `compile_outer` has been called) but have not yet been compiled/lowered + themselves (i.e. `compile_inner` has not yet been called). + """ + + module: DefinitionBuilder[ops.Module] + checked: dict[DefId, CheckedDef] + compiled: dict[DefId, CompiledDef] + worklist: set[DefId] + + def __init__( + self, + checked: dict[DefId, CheckedDef], + module: DefinitionBuilder[ops.Module], + ) -> None: + self.module = module + self.checked = checked + self.worklist = set() + self.compiled = {} + + def __getitem__(self, def_id: DefId) -> CompiledDef: + if def_id not in self.compiled: + defn = self.checked[def_id] + self.compiled[def_id] = self._compile(defn) + self.worklist.add(def_id) + return self.compiled[def_id] + + def _compile(self, defn: CheckedDef) -> CompiledDef: + if isinstance(defn, CompilableDef): + return defn.compile_outer(self.module) + return defn + + @dataclass class DFContainer: """A dataflow graph under construction. diff --git a/guppylang/compiler/func_compiler.py b/guppylang/compiler/func_compiler.py index faa349e0..23d307c5 100644 --- a/guppylang/compiler/func_compiler.py +++ b/guppylang/compiler/func_compiler.py @@ -62,26 +62,25 @@ def compile_local_func_def( call_args.append(partial) func.cfg.input_tys.append(func.ty) + + # Compile the CFG + cfg = compile_cfg(func.cfg, func_builder, call_args, globals) + func_builder.set_outputs(*cfg) else: # Otherwise, we treat the function like a normal global variable from guppylang.definition.function import CompiledFunctionDef - globals = globals | { - func.def_id: CompiledFunctionDef( - func.def_id, - func.name, - func, - func.ty, - {}, - None, - func.cfg, - func_builder, - ) - } - - # Compile the CFG - cfg = compile_cfg(func.cfg, func_builder, call_args, globals) - func_builder.set_outputs(*cfg) + globals.compiled[func.def_id] = CompiledFunctionDef( + func.def_id, + func.name, + func, + func.ty, + {}, + None, + func.cfg, + func_builder, + ) + globals.worklist.add(func.def_id) # will compile the CFG later # Finally, load the function into the local data-flow graph loaded = dfg.builder.load_function(func_builder, closure_ty) diff --git a/guppylang/module.py b/guppylang/module.py index fc7d7537..ed6822ff 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -11,11 +11,10 @@ import guppylang.compiler.hugr_extension from guppylang.checker.core import Globals, PyScope +from guppylang.compiler.core import CompiledGlobals from guppylang.definition.common import ( CheckableDef, CheckedDef, - CompilableDef, - CompiledDef, DefId, Definition, ParsableDef, @@ -267,18 +266,6 @@ def _check_defs( for def_id, defn in parsed.items() } - @staticmethod - def _compile_defs( - checked_defs: Mapping[DefId, CheckedDef], hugr_module: Module - ) -> dict[DefId, CompiledDef]: - """Helper method to compile checked definitions to Hugr.""" - return { - def_id: defn.compile_outer(hugr_module) - if isinstance(defn, CompilableDef) - else defn - for def_id, defn in checked_defs.items() - } - def check(self) -> None: """Type-checks the module.""" if self.checked: @@ -329,19 +316,20 @@ def compile(self) -> Package: return self._compiled_hugr self.check() + checked_defs = self._imported_checked_defs | self._checked_defs # Prepare Hugr for this module graph = Module() graph.metadata["name"] = self.name - # Compile definitions to Hugr - compiled_defs = self._compile_defs(self._imported_checked_defs, graph) - compiled_defs |= self._compile_defs(self._checked_defs, graph) - - # Finally, compile the definition contents to Hugr. For example, this compiles - # the bodies of functions. - for defn in compiled_defs.values(): - defn.compile_inner(compiled_defs) + # Lower definitions to Hugr + required = set(self._checked_defs.keys()) + ctx = CompiledGlobals(checked_defs, graph) + _request_compilation = [ctx[def_id] for def_id in required] + while ctx.worklist: + next_id = ctx.worklist.pop() + next_def = ctx[next_id] + next_def.compile_inner(ctx) hugr = graph.hugr diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 685c7db1..38ca00d3 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -1,4 +1,3 @@ -import pytest from hugr import ops from hugr.std.int import IntVal @@ -22,10 +21,7 @@ def main(xs: array[float, 42]) -> int: validate(package) hg = package.modules[0] - vals = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)] - if len(vals) > 1: - pytest.xfail(reason="hugr-includes-whole-stdlib") - [val] = vals + [val] = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)] assert isinstance(val, ops.Const) assert isinstance(val.val, IntVal) assert val.val.v == 42 diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py index 9fe3b11d..6a7108d0 100644 --- a/tests/integration/test_basic.py +++ b/tests/integration/test_basic.py @@ -1,4 +1,3 @@ -import pytest from hugr import ops from guppylang.decorator import guppy @@ -69,14 +68,11 @@ def test_func_def_name(): def func_name() -> None: return - defs = [ + [def_op] = [ data.op for n, data in func_name.modules[0].nodes() if isinstance(data.op, ops.FuncDefn) ] - if len(defs) > 1: - pytest.xfail(reason="hugr-includes-whole-stdlib") - [def_op] = defs assert isinstance(def_op, ops.FuncDefn) assert def_op.f_name == "func_name" diff --git a/tests/integration/test_extern.py b/tests/integration/test_extern.py index 83379546..e8e56e30 100644 --- a/tests/integration/test_extern.py +++ b/tests/integration/test_extern.py @@ -18,10 +18,7 @@ def main() -> float: validate(package) hg = package.modules[0] - consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] - if len(consts) > 1: - pytest.xfail(reason="hugr-includes-whole-stdlib") - [c] = consts + [c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] assert isinstance(c.val, val.Extension) assert c.val.val["symbol"] == "ext" @@ -39,10 +36,7 @@ def main() -> int: validate(package) hg = package.modules[0] - consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] - if len(consts) > 1: - pytest.xfail(reason="hugr-includes-whole-stdlib") - [c] = consts + [c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)] assert isinstance(c.val, val.Extension) assert c.val.val["symbol"] == "foo"