Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add nat type #254

Merged
merged 9 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
int_type_def,
linst_type_def,
list_type_def,
nat_type_def,
none_type_def,
tuple_type_def,
)
Expand Down Expand Up @@ -70,6 +71,7 @@ def default() -> "Globals":
tuple_type_def,
none_type_def,
bool_type_def,
nat_type_def,
int_type_def,
float_type_def,
list_type_def,
Expand All @@ -92,6 +94,8 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
return None
case NumericType(kind):
match kind:
case NumericType.Kind.Nat:
type_defn = nat_type_def
case NumericType.Kind.Int:
type_defn = int_type_def
case NumericType.Kind.Float:
Expand Down
20 changes: 20 additions & 0 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,26 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
return args, subst


class NatTruedivCompiler(CustomCallCompiler):
"""Compiler for the `nat.__truediv__` method."""

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
from .builtins import Float, Nat

# Compile `truediv` using float arithmetic
[left, right] = args
[left] = Nat.__float__.compile_call(
[left], [], self.dfg, self.graph, self.globals, self.node
)
[right] = Nat.__float__.compile_call(
[right], [], self.dfg, self.graph, self.globals, self.node
)
[out] = Float.__truediv__.compile_call(
[left, right], [], self.dfg, self.graph, self.globals, self.node
)
return [out]


class IntTruedivCompiler(CustomCallCompiler):
"""Compiler for the `int.__truediv__` method."""

Expand Down
150 changes: 150 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FloatFloordivCompiler,
FloatModCompiler,
IntTruedivCompiler,
NatTruedivCompiler,
ReversingChecker,
UnsupportedChecker,
float_op,
Expand All @@ -33,6 +34,7 @@
int_type_def,
linst_type_def,
list_type_def,
nat_type_def,
)

builtins = GuppyModule("builtins", import_builtins=False)
Expand All @@ -50,6 +52,10 @@ def py(*_args: Any) -> Any:
raise GuppyError("`py` can only by used in a Guppy context")


class nat:
"""Class to import in order to use nats."""


@guppy.extend_type(builtins, bool_type_def)
class Bool:
@guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))]))
Expand All @@ -61,13 +67,151 @@ def __bool__(self: bool) -> bool: ...
@guppy.hugr_op(builtins, int_op("ifrombool"))
def __int__(self: bool) -> int: ...

@guppy.hugr_op(builtins, DummyOp("ifrombool")) # TODO: Widen to INT_WIDTH
def __nat__(self: bool) -> nat: ...

@guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False)
def __new__(x): ...

@guppy.hugr_op(builtins, logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))]))
def __or__(self: bool, other: bool) -> bool: ...


@guppy.extend_type(builtins, nat_type_def)
class Nat:
@guppy.custom(builtins, NoopCompiler())
def __abs__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("iadd"))
def __add__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("iand"))
def __and__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, DummyOp("itobool")) # TODO: Only works with width 1 ints
def __bool__(self: nat) -> bool: ...

@guppy.custom(builtins, NoopCompiler())
def __ceil__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("idivmod_u", num_params=2))
def __divmod__(self: nat, other: nat) -> tuple[nat, nat]: ...

@guppy.hugr_op(builtins, int_op("ieq"))
def __eq__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, int_op("convert_u", "arithmetic.conversions"))
def __float__(self: nat) -> float: ...

@guppy.custom(builtins, NoopCompiler())
def __floor__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("idiv_u", num_params=2))
def __floordiv__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ige_u"))
def __ge__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, int_op("igt_u"))
def __gt__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, DummyOp("iu_to_s")) # TODO
def __int__(self: nat) -> int: ...

@guppy.hugr_op(builtins, int_op("inot"))
def __invert__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ile_u"))
def __le__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, int_op("ishl", num_params=2))
def __lshift__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ilt_u"))
def __lt__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, int_op("imod_u", num_params=2))
def __mod__(self: nat, other: nat) -> int: ...

@guppy.hugr_op(builtins, int_op("imul"))
def __mul__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NoopCompiler())
def __nat__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ine"))
def __ne__(self: nat, other: nat) -> bool: ...

@guppy.custom(builtins, checker=DunderChecker("__nat__"), higher_order_value=False)
def __new__(x): ...

@guppy.hugr_op(builtins, int_op("ior"))
def __or__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NoopCompiler())
def __pos__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, DummyOp("ipow")) # TODO
def __pow__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker())
def __radd__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("rand"), ReversingChecker())
def __rand__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("idivmod_u", num_params=2), ReversingChecker())
def __rdivmod__(self: nat, other: nat) -> tuple[nat, nat]: ...

@guppy.hugr_op(builtins, int_op("idiv_u", num_params=2), ReversingChecker())
def __rfloordiv__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ishl", num_params=2), ReversingChecker())
def __rlshift__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("imod_u", num_params=2), ReversingChecker())
def __rmod__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("imul"), ReversingChecker())
def __rmul__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ior"), ReversingChecker())
def __ror__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NoopCompiler())
def __round__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, DummyOp("ipow"), ReversingChecker()) # TODO
def __rpow__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ishr", num_params=2), ReversingChecker())
def __rrshift__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ishr", num_params=2))
def __rshift__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("isub"), ReversingChecker())
def __rsub__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NatTruedivCompiler(), ReversingChecker())
def __rtruediv__(self: nat, other: nat) -> float: ...

@guppy.hugr_op(builtins, int_op("ixor"), ReversingChecker())
def __rxor__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("isub"))
def __sub__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NatTruedivCompiler())
def __truediv__(self: nat, other: nat) -> float: ...

@guppy.custom(builtins, NoopCompiler())
def __trunc__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ixor"))
def __xor__(self: nat, other: nat) -> nat: ...


@guppy.extend_type(builtins, int_type_def)
class Int:
@guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!)
Expand Down Expand Up @@ -127,6 +271,9 @@ def __mod__(self: int, other: int) -> int: ...
@guppy.hugr_op(builtins, int_op("imul"))
def __mul__(self: int, other: int) -> int: ...

@guppy.hugr_op(builtins, DummyOp("is_to_u")) # TODO
def __nat__(self: int) -> nat: ...

@guppy.hugr_op(builtins, int_op("ine"))
def __ne__(self: int, other: int) -> bool: ...

Expand Down Expand Up @@ -259,6 +406,9 @@ def __mod__(self: float, other: float) -> float: ...
@guppy.hugr_op(builtins, float_op("fmul"), CoercingChecker())
def __mul__(self: float, other: float) -> float: ...

@guppy.hugr_op(builtins, float_op("trunc_u", "arithmetic.conversions"))
def __nat__(self: float) -> nat: ...

@guppy.hugr_op(builtins, float_op("fne"), CoercingChecker())
def __ne__(self: float, other: float) -> bool: ...

Expand Down
7 changes: 7 additions & 0 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type:
always_linear=False,
to_hugr=lambda _: tys.Type(tys.SumType(tys.UnitSum(size=2))),
)
nat_type_def = _NumericTypeDef(
DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat)
)
int_type_def = _NumericTypeDef(
DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
)
Expand Down Expand Up @@ -175,6 +178,10 @@ def bool_type() -> OpaqueType:
return OpaqueType([], bool_type_def)


def int_type() -> NumericType:
return NumericType(NumericType.Kind.Int)


def list_type(element_ty: Type) -> OpaqueType:
return OpaqueType([TypeArg(element_ty)], list_type_def)

Expand Down
3 changes: 2 additions & 1 deletion guppylang/tys/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class NumericType(TypeBase):
class Kind(Enum):
"""The different kinds of numeric types."""

Nat = "nat"
Int = "int"
Float = "float"

Expand All @@ -257,7 +258,7 @@ def linear(self) -> bool:
def to_hugr(self) -> tys.Type:
"""Computes the Hugr representation of the type."""
match self.kind:
case NumericType.Kind.Int:
case NumericType.Kind.Nat | NumericType.Kind.Int:
return tys.Type(
tys.Opaque(
extension="arithmetic.int.types",
Expand Down
6 changes: 3 additions & 3 deletions tests/error/type_errors/invert_not_int.err
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6

4: @compile_guppy
5: def foo() -> int:
6: return ~True
^^^^
GuppyTypeError: Unary operator `~` not defined for argument of type `bool`
6: return ~()
^^
GuppyTypeError: Unary operator `~` not defined for argument of type `()`
2 changes: 1 addition & 1 deletion tests/error/type_errors/invert_not_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@compile_guppy
def foo() -> int:
return ~True
return ~()
6 changes: 3 additions & 3 deletions tests/error/type_errors/unary_not_arith.err
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6

4: @compile_guppy
5: def foo() -> int:
6: return -True
^^^^
GuppyTypeError: Unary operator `-` not defined for argument of type `bool`
6: return -()
^^
GuppyTypeError: Unary operator `-` not defined for argument of type `()`
2 changes: 1 addition & 1 deletion tests/error/type_errors/unary_not_arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

@compile_guppy
def foo() -> int:
return -True
return -()
14 changes: 14 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from guppylang.prelude.builtins import nat
from tests.util import compile_guppy


Expand Down Expand Up @@ -26,6 +27,19 @@ def add(x: int) -> int:
validate(add)


def test_nat(validate):
@compile_guppy
def foo(
a: nat, b: nat, c: bool, d: int, e: float
) -> tuple[nat, bool, int, float, float]:
b, c, d, e = nat(b), nat(c), nat(d), nat(e)
x = a + b * c // d - e
y = e / b
return x, bool(x), int(x), float(x), y

validate(foo)


def test_float_coercion(validate):
@compile_guppy
def coerce(x: int, y: float) -> float:
Expand Down
Loading