Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix array execution bugs #731

Merged
merged 18 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ inkwell = "0.5.0"
[patch.crates-io]

# Uncomment these to test the latest dependency version during development
# hugr = { git = "https://github.com/CQCL/hugr", rev = "861183e" }
# hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "861183e" }
# hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "1091755" }
hugr = { git = "https://github.com/CQCL/hugr", rev = "ab94518" }
hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "ab94518" }
hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "ab94518" }
12 changes: 10 additions & 2 deletions execute_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ fn find_funcdef_node(hugr: impl HugrView, fn_name: &str) -> PyResult<hugr::Node>
}
}

fn guppy_pass(hugr: Hugr) -> Hugr {
let hugr = hugr::algorithms::monomorphize(hugr);
hugr::algorithms::remove_polyfuncs(hugr)
}

fn compile_module<'a>(
hugr: &'a hugr::Hugr,
ctx: &'a Context,
Expand All @@ -47,6 +52,7 @@ fn compile_module<'a>(
// TODO: Handle tket2 codegen extension
let extensions = hugr::llvm::custom::CodegenExtsBuilder::default()
.add_int_extensions()
.add_logic_extensions()
.add_default_prelude_extensions()
.add_default_array_extensions()
.add_float_extensions()
Expand All @@ -64,9 +70,10 @@ fn compile_module<'a>(

#[pyfunction]
fn compile_module_to_string(hugr_json: &str) -> PyResult<String> {
let hugr = parse_hugr(hugr_json)?;
let mut hugr = parse_hugr(hugr_json)?;
let ctx = Context::create();

hugr = guppy_pass(hugr);
let module = compile_module(&hugr, &ctx, Default::default())?;

Ok(module.print_to_string().to_str().unwrap().to_string())
Expand All @@ -77,7 +84,8 @@ fn run_function<T>(
fn_name: &str,
parse_result: impl FnOnce(&Context, GenericValue) -> PyResult<T>,
) -> PyResult<T> {
let hugr = parse_hugr(hugr_json)?;
let mut hugr = parse_hugr(hugr_json)?;
hugr = guppy_pass(hugr);
let ctx = Context::create();

let namer = hugr::llvm::emit::Namer::default();
Expand Down
57 changes: 29 additions & 28 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, TypeGuard, TypeVar

import hugr
import hugr.std.collections.array
import hugr.std.float
import hugr.std.int
import hugr.std.logic
Expand All @@ -21,7 +22,7 @@
from guppylang.checker.errors.generic import UnsupportedError
from guppylang.checker.linearity_checker import contains_subscript
from guppylang.compiler.core import CompilerBase, DFContainer
from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp
from guppylang.compiler.hugr_extension import PartialOp
from guppylang.definition.custom import CustomFunctionDef
from guppylang.definition.value import (
CallReturnWires,
Expand All @@ -46,6 +47,7 @@
TensorCall,
TypeApply,
)
from guppylang.std._internal.compiler.arithmetic import convert_ifromusize
from guppylang.std._internal.compiler.array import array_repeat
from guppylang.std._internal.compiler.list import (
list_new,
Expand Down Expand Up @@ -123,7 +125,7 @@ def _new_dfcontainer(
def _new_loop(
self,
loop_vars: list[PlaceNode],
branch: PlaceNode,
continue_predicate: PlaceNode,
) -> Iterator[None]:
"""Context manager to build a graph inside a new `TailLoop` node.

Expand All @@ -134,13 +136,12 @@ def _new_loop(
loop = self.builder.add_tail_loop([], loop_inputs)
with self._new_dfcontainer(loop_vars, loop):
yield
# Output the branch predicate and the inputs for the next iteration
loop.set_loop_outputs(
# Note that we have to do fresh calls to `self.visit` here since we're
# in a new context
self.visit(branch),
*(self.visit(name) for name in loop_vars),
)
# Output the branch predicate and the inputs for the next iteration. Note
# that we have to do fresh calls to `self.visit` here since we're in a new
# context
do_continue = self.visit(continue_predicate)
do_break = loop.add_op(hugr.std.logic.Not, do_continue)
loop.set_loop_outputs(do_break, *(self.visit(name) for name in loop_vars))
# Update the DFG with the outputs from the loop
for node, wire in zip(loop_vars, loop, strict=True):
self.dfg[node.place] = wire
Expand Down Expand Up @@ -172,12 +173,12 @@ def _if_true(self, cond: ast.expr, inputs: list[PlaceNode]) -> Iterator[None]:
conditional = self.builder.add_conditional(
self.visit(cond), *(self.visit(inp) for inp in inputs)
)
# If the condition is true, we enter the `with` block
with self._new_case(inputs, inputs, conditional, 0):
yield
# If the condition is false, output the inputs as is
with self._new_case(inputs, inputs, conditional, 1):
with self._new_case(inputs, inputs, conditional, 0):
pass
# If the condition is true, we enter the `with` block
with self._new_case(inputs, inputs, conditional, 1):
yield
# Update the DFG with the outputs from the Conditional node
for node, wire in zip(inputs, conditional, strict=True):
self.dfg[node.place] = wire
Expand Down Expand Up @@ -206,11 +207,16 @@ def visit_GlobalName(self, node: GlobalName) -> Wire:
return defn.load(self.dfg, self.globals, node)

def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
# TODO: We need a way to look up the concrete value of a generic type arg in
# Hugr. For example, a new op that captures the value during monomorphisation
return self.builder.add_op(
UnsupportedOp("load_type_param", [], [node.param.ty.to_hugr()]).ext_op
)
match node.param.ty:
case NumericType(NumericType.Kind.Nat):
arg = node.param.to_bound().to_hugr()
load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate(
[arg], ht.FunctionType([], [ht.USize()])
)
usize = self.builder.add_op(load_nat)
return self.builder.add_op(convert_ifromusize(), usize)
case _:
raise NotImplementedError

def visit_Name(self, node: ast.Name) -> Wire:
raise InternalGuppyError("Node should have been removed during type checking.")
Expand Down Expand Up @@ -604,17 +610,12 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None:
return hv.Tuple(*vs)
case list(elts):
assert is_array_type(exp_ty)
vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts]
elem_ty = get_element_type(exp_ty)
vs = [python_value_to_hugr(elt, elem_ty) for elt in elts]
if doesnt_contain_none(vs):
# TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1497
return hv.Extension(
name="ArrayValue",
typ=exp_ty.to_hugr(),
# The value list must be serialized at this point, otherwise the
# `Extension` value would not be serializable.
val=[v._to_serial_root() for v in vs],
extensions=["unsupported"],
)
opt_ty = ht.Option(elem_ty.to_hugr())
opt_vs: list[hv.Value] = [hv.Some(v) for v in vs]
return hugr.std.collections.array.ArrayVal(opt_vs, opt_ty)
case _:
# TODO replace with hugr protocol handling: https://github.com/CQCL/guppylang/issues/563
# Pytket conversion is an experimental feature
Expand Down
19 changes: 15 additions & 4 deletions guppylang/compiler/stmt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,26 @@ def pop(
array: Wire, length: int, pats: list[ast.expr], from_left: bool
) -> tuple[Wire, int]:
err = "Internal error: unpacking of iterable failed"
for pat in pats:
num_pats = len(pats)
# Pop the number of requested elements from the array
elts = []
for i in range(num_pats):
res = self.builder.add_op(
array_pop(opt_elt_ty, length, from_left), array
array_pop(opt_elt_ty, length - i, from_left), array
)
[elt_opt, array] = build_unwrap(self.builder, res, err)
[elt] = build_unwrap(self.builder, elt_opt, err)
elts.append(elt)
# Assign elements to the given patterns
for pat, elt in zip(
pats,
# Assignments are evaluated from left to right, so we need to assign in
# reverse order if we popped from the right
elts if from_left else reversed(elts),
strict=True,
):
self._assign(pat, elt)
length -= 1
return array, length
return array, length - num_pats

self.dfg[lhs.rhs_var.place] = port
array = self.expr_compiler.visit_DesugaredArrayComp(lhs.compr)
Expand Down
43 changes: 30 additions & 13 deletions guppylang/std/_internal/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from hugr import tys as ht
from hugr.std.collections.array import EXTENSION

from guppylang.compiler.hugr_extension import UnsupportedOp
from guppylang.definition.custom import CustomCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.error import InternalGuppyError
Expand Down Expand Up @@ -92,24 +91,42 @@ def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp:
)


def array_scan(
elem_ty: ht.Type,
length: ht.TypeArg,
new_elem_ty: ht.Type,
accumulators: list[ht.Type],
) -> ops.ExtOp:
"""Returns an operation that maps and folds a function across an array."""
ty_args = [
length,
ht.TypeTypeArg(elem_ty),
ht.TypeTypeArg(new_elem_ty),
ht.SequenceArg([ht.TypeTypeArg(acc) for acc in accumulators]),
ht.ExtensionsArg([]),
]
ins = [
array_type(elem_ty, length),
ht.FunctionType([elem_ty, *accumulators], [new_elem_ty, *accumulators]),
*accumulators,
]
outs = [array_type(new_elem_ty, length), *accumulators]
return EXTENSION.get_op("scan").instantiate(ty_args, ht.FunctionType(ins, outs))


def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp:
"""Returns an operation that maps a function across an array."""
# TODO
return UnsupportedOp(
op_name="array_map",
inputs=[array_type(elem_ty, length), ht.FunctionType([elem_ty], [new_elem_ty])],
outputs=[array_type(new_elem_ty, length)],
).ext_op
return array_scan(elem_ty, length, new_elem_ty, accumulators=[])


def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
"""Returns an array `repeat` operation."""
# TODO
return UnsupportedOp(
op_name="array.repeat",
inputs=[ht.FunctionType([], [elem_ty])],
outputs=[array_type(elem_ty, length)],
).ext_op
return EXTENSION.get_op("repeat").instantiate(
[length, ht.TypeTypeArg(elem_ty), ht.ExtensionsArg([])],
ht.FunctionType(
[ht.FunctionType([], [elem_ty])], [array_type(elem_ty, length)]
),
)


# ------------------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion guppylang/std/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def __ge__(self: nat, other: nat) -> bool: ...
@guppy.hugr_op(int_op("igt_u"))
def __gt__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(int_op("iu_to_s"))
# TODO: Use "iu_to_s" once we have lowering:
# https://github.com/CQCL/hugr/issues/1806
@guppy.custom(NoopCompiler())
def __int__(self: nat) -> int: ...

@guppy.hugr_op(int_op("inot"))
Expand Down
3 changes: 1 addition & 2 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:
elem_ty = ht.Option(ty_arg.ty.to_hugr())
hugr_arg = len_arg.to_hugr()

# TODO remove type ignore after Array type annotation fixed to include VariableArg
return hugr.std.collections.array.Array(elem_ty, hugr_arg) # type:ignore[arg-type]
return hugr.std.collections.array.Array(elem_ty, hugr_arg)


def _sized_iter_to_hugr(args: Sequence[Argument]) -> ht.Type:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ members = ["execute_llvm"]
execute-llvm = { workspace = true }

# Uncomment these to test the latest dependency version during development
# hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "861183e" }
hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "e40b6c7" }
# tket2-exts = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-exts", rev = "eb7cc63"}
# tket2 = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-py", rev = "eb7cc63"}

Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ def main() -> int:
package = module.compile()
validate(package)

# TODO: Enable execution once lowering for missing ops is implemented
# run_int_fn(package, expected=9)
run_int_fn(package, expected=9)


def test_mem_swap(validate):
Expand Down
Loading
Loading