Skip to content

Commit

Permalink
fix: Use places in BB signatures for Hugr generation (#342)
Browse files Browse the repository at this point in the history
Fixes  #337.

The fix for the linearity checker is easy: When checking assignments,
only complain if a *local* unused linear variable is shadowed.

However, correctly lowering the program in #337 to Hugr requires some
additional changes to Hugr codegen. Namely, we need to consider if a BB
uses a *whole* struct `s` or only some of its fields. In the latter
case, we should only feed those used values into the BB. This requires
specifying the signature of BBs in the form of `Place`s rather than
`Variable`s.

Concretely, this PR does the following:

* Make `Signature` generic over the representation of program variables
so it can capture both `Variable` and `Place`
* Similarly, make `CheckedBB` and `CheckedCFG` generic
* During checking, we first construct a `CheckedCFG[Variable]` which
then gets turned into a `CheckedCFG[Place]` during linearity checking
* Add a test case
  • Loading branch information
mark-koch authored Jul 29, 2024
1 parent 528c443 commit 48b0e35
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 56 deletions.
54 changes: 30 additions & 24 deletions guppylang/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,48 @@

import collections
from collections.abc import Iterator, Sequence
from dataclasses import dataclass
from typing import TypeVar
from dataclasses import dataclass, field
from typing import Generic, TypeVar

from guppylang.ast_util import line_col
from guppylang.cfg.bb import BB
from guppylang.cfg.cfg import CFG, BaseCFG
from guppylang.checker.core import Context, Globals, Locals, Variable
from guppylang.checker.core import Context, Globals, Locals, Place, V, Variable
from guppylang.checker.expr_checker import ExprSynthesizer, to_bool
from guppylang.checker.linearity_checker import check_cfg_linearity
from guppylang.checker.stmt_checker import StmtChecker
from guppylang.error import GuppyError
from guppylang.tys.ty import Type

VarRow = Sequence[Variable]
Row = Sequence[V]


@dataclass(frozen=True)
class Signature:
class Signature(Generic[V]):
"""The signature of a basic block.
Stores the input/output variables with their types.
Stores the input/output variables with their types. Generic over the representation
of program variables.
"""

input_row: VarRow
output_rows: Sequence[VarRow] # One for each successor
input_row: Row[V]
output_rows: Sequence[Row[V]] # One for each successor

@staticmethod
def empty() -> "Signature":
def empty() -> "Signature[V]":
return Signature([], [])


@dataclass(eq=False) # Disable equality to recover hash from `object`
class CheckedBB(BB):
"""Basic block annotated with an input and output type signature."""
class CheckedBB(BB, Generic[V]):
"""Basic block annotated with an input and output type signature.
sig: Signature = Signature.empty() # noqa: RUF009
The signature is generic over the representation of program variables.
"""

sig: Signature[V] = field(default_factory=Signature.empty)


class CheckedCFG(BaseCFG[CheckedBB]):
class CheckedCFG(BaseCFG[CheckedBB[V]], Generic[V]):
input_tys: list[Type]
output_ty: Type

Expand All @@ -55,19 +58,20 @@ def __init__(self, input_tys: list[Type], output_ty: Type) -> None:


def check_cfg(
cfg: CFG, inputs: VarRow, return_ty: Type, globals: Globals
) -> CheckedCFG:
cfg: CFG, inputs: Row[Variable], return_ty: Type, globals: Globals
) -> CheckedCFG[Place]:
"""Type checks a control-flow graph.
Annotates the basic blocks with input and output type signatures and removes
unreachable blocks.
unreachable blocks. Note that the inputs/outputs are annotated in the form of
*places* rather than just variables.
"""
# First, we need to run program analysis
ass_before = {v.name for v in inputs}
cfg.analyze(ass_before, ass_before)

# We start by compiling the entry BB
checked_cfg = CheckedCFG([v.ty for v in inputs], return_ty)
checked_cfg: CheckedCFG[Variable] = CheckedCFG([v.ty for v in inputs], return_ty)
checked_cfg.entry_bb = check_bb(
cfg.entry_bb, checked_cfg, inputs, return_ty, globals
)
Expand Down Expand Up @@ -117,18 +121,20 @@ def check_cfg(
}

# Finally, run the linearity check
check_cfg_linearity(checked_cfg)
from guppylang.checker.linearity_checker import check_cfg_linearity

linearity_checked_cfg = check_cfg_linearity(checked_cfg)

return checked_cfg
return linearity_checked_cfg


def check_bb(
bb: BB,
checked_cfg: CheckedCFG,
inputs: VarRow,
checked_cfg: CheckedCFG[Variable],
inputs: Row[Variable],
return_ty: Type,
globals: Globals,
) -> CheckedBB:
) -> CheckedBB[Variable]:
cfg = bb.containing_cfg

# For the entry BB we have to separately check that all used variables are
Expand Down Expand Up @@ -180,7 +186,7 @@ def check_bb(
return checked_bb


def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None:
def check_rows_match(row1: Row[Variable], row2: Row[Variable], bb: BB) -> None:
"""Checks that the types of two rows match up.
Otherwise, an error is thrown, alerting the user that a variable has different
Expand Down
4 changes: 2 additions & 2 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from guppylang.cfg.bb import BB
from guppylang.cfg.builder import CFGBuilder
from guppylang.checker.cfg_checker import CheckedCFG, check_cfg
from guppylang.checker.core import Context, Globals, Variable
from guppylang.checker.core import Context, Globals, Place, Variable
from guppylang.definition.common import DefId
from guppylang.error import GuppyError
from guppylang.nodes import CheckedNestedFunctionDef, NestedFunctionDef
Expand All @@ -25,7 +25,7 @@

def check_global_func_def(
func_def: ast.FunctionDef, ty: FunctionType, globals: Globals
) -> CheckedCFG:
) -> CheckedCFG[Place]:
"""Type checks a top-level function definition."""
args = func_def.args.args
returns_none = isinstance(ty.output, NoneType)
Expand Down
47 changes: 40 additions & 7 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import ast
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING

from guppylang.ast_util import AstNode, find_nodes, get_type
from guppylang.cfg.analysis import LivenessAnalysis
from guppylang.cfg.bb import BB, VariableStats
from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG, Signature
from guppylang.checker.core import (
FieldAccess,
Locals,
Expand All @@ -28,9 +28,6 @@
)
from guppylang.tys.ty import StructType

if TYPE_CHECKING:
from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG


class Scope(Locals[PlaceId, Place]):
"""Scoped collection of assigned places indexed by their id.
Expand Down Expand Up @@ -89,7 +86,7 @@ class BBLinearityChecker(ast.NodeVisitor):
scope: Scope
stats: VariableStats[PlaceId]

def check(self, bb: "CheckedBB", is_entry: bool) -> Scope:
def check(self, bb: "CheckedBB[Variable]", is_entry: bool) -> Scope:
# Manufacture a scope that holds all places that are live at the start
# of this BB
input_scope = Scope()
Expand Down Expand Up @@ -179,7 +176,9 @@ def _check_assign_targets(self, targets: list[ast.expr]) -> None:
assert isinstance(tgt, PlaceNode)
for tgt_place in leaf_places(tgt.place):
x = tgt_place.id
if x in self.scope and not self.scope.used(x):
# Only check for overrides of places locally defined in this BB. Global
# checks are handled by dataflow analysis.
if x in self.scope.vars and x not in self.scope.used_local:
place = self.scope[x]
if place.ty.linear:
raise GuppyError(
Expand Down Expand Up @@ -278,10 +277,13 @@ def leaf_places(place: Place) -> Iterator[Place]:
yield place


def check_cfg_linearity(cfg: "CheckedCFG") -> None:
def check_cfg_linearity(cfg: "CheckedCFG[Variable]") -> "CheckedCFG[Place]":
"""Checks whether a CFG satisfies the linearity requirements.
Raises a user-error if linearity violations are found.
Returns a new CFG with refined basic block signatures in terms of *places* rather
than just variables.
"""
bb_checker = BBLinearityChecker()
scopes: dict[BB, Scope] = {
Expand All @@ -292,6 +294,10 @@ def check_cfg_linearity(cfg: "CheckedCFG") -> None:
stats = {bb: scope.stats() for bb, scope in scopes.items()}
live_before = LivenessAnalysis(stats).run(cfg.bbs)

# Construct a CFG that tracks places instead of just variables
result_cfg: CheckedCFG[Place] = CheckedCFG(cfg.input_tys, cfg.output_ty)
checked: dict[BB, CheckedBB[Place]] = {}

for bb, scope in scopes.items():
# We have to check that used linear variables are not being outputted
for succ in bb.successors:
Expand Down Expand Up @@ -323,3 +329,30 @@ def check_cfg_linearity(cfg: "CheckedCFG") -> None:
# more precise location
scope[x].defined_at,
)

assert isinstance(bb, CheckedBB)
sig = Signature(
input_row=[scope[x] for x in live_before[bb]]
if bb not in (cfg.entry_bb, cfg.exit_bb)
else bb.sig.input_row,
output_rows=[
[scope[x] for x in live_before[succ]] for succ in bb.successors
],
)
checked[bb] = CheckedBB(
bb.idx, result_cfg, bb.statements, branch_pred=bb.branch_pred, sig=sig
)

# Fill in missing fields of the result CFG
result_cfg.bbs = list(checked.values())
result_cfg.entry_bb = checked[cfg.entry_bb]
result_cfg.exit_bb = checked[cfg.exit_bb]
result_cfg.live_before = {checked[bb]: cfg.live_before[bb] for bb in cfg.bbs}
result_cfg.ass_before = {checked[bb]: cfg.ass_before[bb] for bb in cfg.bbs}
result_cfg.maybe_ass_before = {
checked[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs
}
for bb in cfg.bbs:
checked[bb].predecessors = [checked[pred] for pred in bb.predecessors]
checked[bb].successors = [checked[succ] for succ in bb.successors]
return result_cfg
39 changes: 22 additions & 17 deletions guppylang/compiler/cfg_compiler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import functools
from typing import TYPE_CHECKING

from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG, Signature, VarRow
from guppylang.checker.core import Variable
from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG, Row, Signature
from guppylang.checker.core import Place, Variable
from guppylang.compiler.core import (
CompiledGlobals,
DFContainer,
Expand All @@ -20,12 +20,12 @@


def compile_cfg(
cfg: CheckedCFG, graph: Hugr, parent: Node, globals: CompiledGlobals
cfg: CheckedCFG[Place], graph: Hugr, parent: Node, globals: CompiledGlobals
) -> None:
"""Compiles a CFG to Hugr."""
insert_return_vars(cfg)

blocks: dict[CheckedBB, CFNode] = {}
blocks: dict[CheckedBB[Place], CFNode] = {}
for bb in cfg.bbs:
blocks[bb] = compile_bb(bb, graph, parent, bb == cfg.entry_bb, globals)
for bb in cfg.bbs:
Expand All @@ -34,7 +34,11 @@ def compile_cfg(


def compile_bb(
bb: CheckedBB, graph: Hugr, parent: Node, is_entry: bool, globals: CompiledGlobals
bb: CheckedBB[Place],
graph: Hugr,
parent: Node,
is_entry: bool,
globals: CompiledGlobals,
) -> CFNode:
"""Compiles a single basic block to Hugr."""
inputs = bb.sig.input_row if is_entry else sort_vars(bb.sig.input_row)
Expand Down Expand Up @@ -65,16 +69,16 @@ def compile_bb(
).out_port(0)

# Finally, we have to add the block output.
outputs: Sequence[Variable]
outputs: Sequence[Place]
if len(bb.successors) == 1:
# The easy case is if we don't branch: We just output all variables that are
# specified by the signature
[outputs] = bb.sig.output_rows
else:
# If we branch and the branches use the same variables, then we can use a
# If we branch and the branches use the same places, then we can use a
# regular output
first, *rest = bb.sig.output_rows
if all({v.name for v in first} == {v.name for v in r} for r in rest):
if all({p.id for p in first} == {p.id for p in r} for r in rest):
outputs = first
else:
# Otherwise, we have to output a TupleSum: We put all non-linear variables
Expand All @@ -89,21 +93,21 @@ def compile_bb(
[
v
for v in sort_vars(row)
if not v.ty.linear or is_return_var(v.name)
if not v.ty.linear or is_return_var(str(v))
]
for row in bb.sig.output_rows
],
dfg=dfg,
)
outputs = [v for v in first if v.ty.linear and not is_return_var(v.name)]
outputs = [v for v in first if v.ty.linear and not is_return_var(str(v))]

graph.add_output(
inputs=[branch_port] + [dfg[v] for v in sort_vars(outputs)], parent=block
)
return block


def insert_return_vars(cfg: CheckedCFG) -> None:
def insert_return_vars(cfg: CheckedCFG[Place]) -> None:
"""Patches a CFG by annotating dummy return variables in the BB signatures.
The statement compiler turns `return` statements into assignments of dummy variables
Expand All @@ -126,7 +130,7 @@ def insert_return_vars(cfg: CheckedCFG) -> None:


def choose_vars_for_tuple_sum(
graph: Hugr, unit_sum: OutPortV, output_vars: list[VarRow], dfg: DFContainer
graph: Hugr, unit_sum: OutPortV, output_vars: list[Row[Place]], dfg: DFContainer
) -> OutPortV:
"""Selects an output based on a TupleSum.
Expand All @@ -149,20 +153,21 @@ def choose_vars_for_tuple_sum(
return conditional.add_out_port(SumType([row_to_type(row) for row in tys]))


def compare_var(x: Variable, y: Variable) -> int:
def compare_var(p1: Place, p2: Place) -> int:
"""Defines a `<` order on variables.
We use this to determine in which order variables are outputted from basic blocks.
We need to output linear variables at the end, so we do a lexicographic ordering of
linearity and name. The only exception are return vars which must be outputted in
order.
"""
if is_return_var(x.name) and is_return_var(y.name):
return -1 if x.name < y.name else 1
return -1 if (x.ty.linear, x.name) < (y.ty.linear, y.name) else 1
x, y = str(p1), str(p2)
if is_return_var(x) and is_return_var(y):
return -1 if x < y else 1
return -1 if (p1.ty.linear, x) < (p2.ty.linear, y) else 1


def sort_vars(row: VarRow) -> list[Variable]:
def sort_vars(row: Row[Place]) -> list[Place]:
"""Sorts a row of variables.
This determines the order in which they are outputted from a BB.
Expand Down
4 changes: 2 additions & 2 deletions guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from guppylang.ast_util import AstNode, annotate_location, with_loc
from guppylang.checker.cfg_checker import CheckedCFG
from guppylang.checker.core import Context, Globals, PyScope
from guppylang.checker.core import Context, Globals, Place, PyScope
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import (
check_global_func_def,
Expand Down Expand Up @@ -111,7 +111,7 @@ class CheckedFunctionDef(ParsedFunctionDef, CompilableDef):
graph for the function body.
"""

cfg: CheckedCFG
cfg: CheckedCFG[Place]

def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledFunctionDef":
"""Adds a Hugr `FuncDefn` node for this function to the Hugr.
Expand Down
4 changes: 2 additions & 2 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self, cfg: "CFG", ty: FunctionType, *args: Any, **kwargs: Any) -> N

class CheckedNestedFunctionDef(ast.FunctionDef):
def_id: "DefId"
cfg: "CheckedCFG"
cfg: "CheckedCFG[Place]"
ty: FunctionType

#: Mapping from names to variables captured by this function, together with an AST
Expand All @@ -221,7 +221,7 @@ class CheckedNestedFunctionDef(ast.FunctionDef):
def __init__(
self,
def_id: "DefId",
cfg: "CheckedCFG",
cfg: "CheckedCFG[Place]",
ty: FunctionType,
captured: Mapping[str, tuple["Variable", AstNode]],
*args: Any,
Expand Down
Loading

0 comments on commit 48b0e35

Please sign in to comment.