Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix array execution bugs #731

Merged
merged 18 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ inkwell = "0.5.0"
[patch.crates-io]

# Uncomment these to test the latest dependency version during development
# hugr = { git = "https://github.com/CQCL/hugr", rev = "861183e" }
# hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "861183e" }
# hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "1091755" }
hugr = { git = "https://github.com/CQCL/hugr", rev = "ab94518" }
hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "ab94518" }
hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "ab94518" }
1 change: 0 additions & 1 deletion execute_llvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ crate-type = ["cdylib"]

[dependencies]
hugr = {workspace = true, features = ["llvm"]}
hugr-passes = "0.14.0"
inkwell.workspace = true
pyo3.workspace = true
serde_json.workspace = true
8 changes: 5 additions & 3 deletions execute_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
use hugr::llvm::utils::fat::FatExt;
use hugr::Hugr;
use hugr::{self, ops, std_extensions, HugrView};
use hugr_passes;
use inkwell::{context::Context, module::Module, values::GenericValue};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
Expand Down Expand Up @@ -40,7 +39,8 @@ fn find_funcdef_node(hugr: impl HugrView, fn_name: &str) -> PyResult<hugr::Node>
}

fn guppy_pass(hugr: Hugr) -> Hugr {
hugr_passes::monomorphize(hugr)
let hugr = hugr::algorithms::monomorphize(hugr);
hugr::algorithms::remove_polyfuncs(hugr)
}

fn compile_module<'a>(
Expand All @@ -52,6 +52,7 @@ fn compile_module<'a>(
// TODO: Handle tket2 codegen extension
let extensions = hugr::llvm::custom::CodegenExtsBuilder::default()
.add_int_extensions()
.add_logic_extensions()
.add_default_prelude_extensions()
.add_default_array_extensions()
.add_float_extensions()
Expand All @@ -69,9 +70,10 @@ fn compile_module<'a>(

#[pyfunction]
fn compile_module_to_string(hugr_json: &str) -> PyResult<String> {
let hugr = parse_hugr(hugr_json)?;
let mut hugr = parse_hugr(hugr_json)?;
let ctx = Context::create();

hugr = guppy_pass(hugr);
let module = compile_module(&hugr, &ctx, Default::default())?;

Ok(module.print_to_string().to_str().unwrap().to_string())
Expand Down
39 changes: 17 additions & 22 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, TypeGuard, TypeVar

import hugr
import hugr.std.collections.array
import hugr.std.float
import hugr.std.int
import hugr.std.logic
Expand Down Expand Up @@ -124,7 +125,7 @@ def _new_dfcontainer(
def _new_loop(
self,
loop_vars: list[PlaceNode],
branch: PlaceNode,
continue_predicate: PlaceNode,
) -> Iterator[None]:
"""Context manager to build a graph inside a new `TailLoop` node.

Expand All @@ -135,13 +136,12 @@ def _new_loop(
loop = self.builder.add_tail_loop([], loop_inputs)
with self._new_dfcontainer(loop_vars, loop):
yield
# Output the branch predicate and the inputs for the next iteration
loop.set_loop_outputs(
# Note that we have to do fresh calls to `self.visit` here since we're
# in a new context
self.visit(branch),
*(self.visit(name) for name in loop_vars),
)
# Output the branch predicate and the inputs for the next iteration. Note
# that we have to do fresh calls to `self.visit` here since we're in a new
# context
do_continue = self.visit(continue_predicate)
do_break = loop.add_op(hugr.std.logic.Not, do_continue)
loop.set_loop_outputs(do_break, *(self.visit(name) for name in loop_vars))
# Update the DFG with the outputs from the loop
for node, wire in zip(loop_vars, loop, strict=True):
self.dfg[node.place] = wire
Expand Down Expand Up @@ -173,12 +173,12 @@ def _if_true(self, cond: ast.expr, inputs: list[PlaceNode]) -> Iterator[None]:
conditional = self.builder.add_conditional(
self.visit(cond), *(self.visit(inp) for inp in inputs)
)
# If the condition is true, we enter the `with` block
with self._new_case(inputs, inputs, conditional, 0):
yield
# If the condition is false, output the inputs as is
with self._new_case(inputs, inputs, conditional, 1):
with self._new_case(inputs, inputs, conditional, 0):
pass
# If the condition is true, we enter the `with` block
with self._new_case(inputs, inputs, conditional, 1):
yield
# Update the DFG with the outputs from the Conditional node
for node, wire in zip(inputs, conditional, strict=True):
self.dfg[node.place] = wire
Expand Down Expand Up @@ -610,17 +610,12 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None:
return hv.Tuple(*vs)
case list(elts):
assert is_array_type(exp_ty)
vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts]
elem_ty = get_element_type(exp_ty)
vs = [python_value_to_hugr(elt, elem_ty) for elt in elts]
if doesnt_contain_none(vs):
# TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1771
return hv.Extension(
name="ArrayValue",
typ=exp_ty.to_hugr(),
# The value list must be serialized at this point, otherwise the
# `Extension` value would not be serializable.
val=[v._to_serial_root() for v in vs],
extensions=["unsupported"],
)
opt_ty = ht.Option(elem_ty.to_hugr())
opt_vs: list[hv.Value] = [hv.Sum(1, opt_ty, [v]) for v in vs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
opt_vs: list[hv.Value] = [hv.Sum(1, opt_ty, [v]) for v in vs]
opt_vs: list[hv.Value] = [hv.Some(v) for v in vs]

return hugr.std.collections.array.ArrayVal(opt_vs, opt_ty)
case _:
# TODO replace with hugr protocol handling: https://github.com/CQCL/guppylang/issues/563
# Pytket conversion is an experimental feature
Expand Down
20 changes: 20 additions & 0 deletions guppylang/std/_internal/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,26 @@ def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
)


def array_pop(elem_ty: ht.Type, length: int, from_left: bool) -> ops.ExtOp:
"""Returns an operation that pops an element from the left of an array."""
assert length > 0
length_arg = ht.BoundedNatArg(length)
arr_ty = array_type(elem_ty, length_arg)
popped_arr_ty = array_type(elem_ty, ht.BoundedNatArg(length - 1))
op = "pop_left" if from_left else "pop_right"
return _instantiate_array_op(
op, elem_ty, length_arg, [arr_ty], [ht.Option(elem_ty, popped_arr_ty)]
)


def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp:
"""Returns an operation that discards an array of length zero."""
arr_ty = array_type(elem_ty, ht.BoundedNatArg(0))
return EXTENSION.get_op("discard_empty").instantiate(
[ht.TypeTypeArg(elem_ty)], ht.FunctionType([arr_ty], [])
)


def array_scan(
elem_ty: ht.Type,
length: ht.TypeArg,
Expand Down
4 changes: 3 additions & 1 deletion guppylang/std/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def __ge__(self: nat, other: nat) -> bool: ...
@guppy.hugr_op(int_op("igt_u"))
def __gt__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(int_op("iu_to_s"))
# TODO: Use "iu_to_s" once we have lowering:
# https://github.com/CQCL/hugr/issues/1806
@guppy.custom(NoopCompiler())
def __int__(self: nat) -> int: ...

@guppy.hugr_op(int_op("inot"))
Expand Down
3 changes: 1 addition & 2 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:
elem_ty = ht.Option(ty_arg.ty.to_hugr())
hugr_arg = len_arg.to_hugr()

# TODO remove type ignore after Array type annotation fixed to include VariableArg
return hugr.std.collections.array.Array(elem_ty, hugr_arg) # type:ignore[arg-type]
return hugr.std.collections.array.Array(elem_ty, hugr_arg)


def _sized_iter_to_hugr(args: Sequence[Argument]) -> ht.Type:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ members = ["execute_llvm"]
execute-llvm = { workspace = true }

# Uncomment these to test the latest dependency version during development
# hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "861183e" }
hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "e40b6c7" }
# tket2-exts = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-exts", rev = "eb7cc63"}
# tket2 = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-py", rev = "eb7cc63"}

Expand Down
68 changes: 55 additions & 13 deletions tests/integration/test_array_comprehension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@
from tests.util import compile_guppy


def test_basic(validate):
@compile_guppy
def test_basic_exec(validate, run_int_fn):
module = GuppyModule("test")

@guppy(module)
def test() -> array[int, 10]:
return array(i + 1 for i in range(10))

validate(test)
@guppy(module)
def main() -> int:
s = 0
for x in test():
s += x
return s

package = module.compile()
validate(package)
run_int_fn(package, expected=sum(i + 1 for i in range(10)))


def test_basic_linear(validate):
Expand All @@ -29,23 +40,42 @@ def test() -> array[qubit, 42]:
validate(module.compile())


def test_zero_length(validate):
@compile_guppy
def test_zero_length(validate, run_int_fn):
module = GuppyModule("test")

@guppy(module)
def test() -> array[float, 0]:
return array(i / 0 for i in range(0))

validate(test)
@guppy(module)
def main() -> int:
test()
return 0

package = module.compile()
validate(package)
run_int_fn(package, expected=0)

def test_capture(validate):
@compile_guppy

def test_capture(validate, run_int_fn):
module = GuppyModule("test")

@guppy(module)
def test(x: int) -> array[int, 42]:
return array(i + x for i in range(42))

validate(test)
@guppy(module)
def main() -> int:
s = 0
for x in test(3):
s += x
return s

package = module.compile()
validate(package)
run_int_fn(package, expected=sum(i + 3 for i in range(42)))


@pytest.mark.skip("See https://github.com/CQCL/hugr/issues/1625")
def test_capture_struct(validate):
module = GuppyModule("test")

Expand All @@ -71,12 +101,24 @@ def test() -> float:
validate(test)


def test_nested_left(validate):
@compile_guppy
def test_nested_left(validate, run_int_fn):
module = GuppyModule("test")

@guppy(module)
def test() -> array[array[int, 10], 20]:
return array(array(x + y for y in range(10)) for x in range(20))

validate(test)
@guppy(module)
def main() -> int:
s = 0
for xs in test():
for x in xs:
s += x
return s

package = module.compile()
validate(package)
run_int_fn(package, expected=sum(x + y for y in range(10) for x in range(20)))


def test_generic(validate):
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def main() -> int:

compiled = module.compile()
validate(compiled)
# TODO: Enable execution test once array lowering is fully supported
# run_int_fn(compiled, expected=9)
run_int_fn(compiled, expected=10)


def test_unpack_tuple_starred(validate, run_int_fn):
Expand Down
Loading
Loading