Skip to content

Commit

Permalink
feat: Add nat type
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jun 19, 2024
1 parent 13154dd commit b664d2c
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 1 deletion.
4 changes: 4 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
int_type_def,
linst_type_def,
list_type_def,
nat_type_def,
none_type_def,
tuple_type_def,
)
Expand Down Expand Up @@ -70,6 +71,7 @@ def default() -> "Globals":
tuple_type_def,
none_type_def,
bool_type_def,
nat_type_def,
int_type_def,
float_type_def,
list_type_def,
Expand All @@ -94,6 +96,8 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
match kind:
case NumericType.Kind.Bool:
type_defn = bool_type_def
case NumericType.Kind.Nat:
type_defn = nat_type_def
case NumericType.Kind.Int:
type_defn = int_type_def
case NumericType.Kind.Float:
Expand Down
20 changes: 20 additions & 0 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,26 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
return self._get_func().check_call(args, ty, self.node, self.ctx)


class NatTruedivCompiler(CustomCallCompiler):
"""Compiler for the `nat.__truediv__` method."""

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
from .builtins import Float, Nat

# Compile `truediv` using float arithmetic
[left, right] = args
[left] = Nat.__float__.compile_call(
[left], [], self.dfg, self.graph, self.globals, self.node
)
[right] = Nat.__float__.compile_call(
[right], [], self.dfg, self.graph, self.globals, self.node
)
[out] = Float.__truediv__.compile_call(
[left, right], [], self.dfg, self.graph, self.globals, self.node
)
return [out]


class IntTruedivCompiler(CustomCallCompiler):
"""Compiler for the `int.__truediv__` method."""

Expand Down
150 changes: 150 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
FloatFloordivCompiler,
FloatModCompiler,
IntTruedivCompiler,
NatTruedivCompiler,
ReversingChecker,
UnsupportedChecker,
float_op,
Expand All @@ -34,6 +35,7 @@
int_type_def,
linst_type_def,
list_type_def,
nat_type_def,
)

builtins = GuppyModule("builtins", import_builtins=False)
Expand All @@ -51,6 +53,10 @@ def py(*_args: Any) -> Any:
raise GuppyError("`py` can only by used in a Guppy context")


class nat:
"""Class to import in order to use nats."""


@guppy.extend_type(builtins, bool_type_def)
class Bool:
@guppy.custom(builtins, NoopCompiler())
Expand Down Expand Up @@ -110,6 +116,9 @@ def __mod__(self: bool, other: bool) -> bool: ...
@guppy.custom(builtins, checker=BoolArithChecker())
def __mul__(self: bool, other: bool) -> bool: ...

@guppy.hugr_op(builtins, DummyOp("ifrombool")) # TODO: Widen to INT_WIDTH
def __nat__(self: bool) -> nat: ...

@guppy.custom(builtins, checker=BoolArithChecker())
def __ne__(self: bool, other: bool) -> bool: ...

Expand Down Expand Up @@ -194,6 +203,141 @@ def __trunc__(self: bool) -> int: ...
def __xor__(self: bool, other: bool) -> bool: ...


@guppy.extend_type(builtins, nat_type_def)
class Nat:
@guppy.custom(builtins, NoopCompiler())
def __abs__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("iadd"))
def __add__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("iand"))
def __and__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, DummyOp("itobool")) # TODO: Only works with width 1 ints
def __bool__(self: nat) -> bool: ...

@guppy.custom(builtins, NoopCompiler())
def __ceil__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("idivmod_u", num_params=2))
def __divmod__(self: nat, other: nat) -> tuple[nat, nat]: ...

@guppy.hugr_op(builtins, int_op("ieq"))
def __eq__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, int_op("convert_u", "arithmetic.conversions"))
def __float__(self: nat) -> float: ...

@guppy.custom(builtins, NoopCompiler())
def __floor__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("idiv_u", num_params=2))
def __floordiv__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ige_u"))
def __ge__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, int_op("igt_u"))
def __gt__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, DummyOp("iu_to_s")) # TODO
def __int__(self: nat) -> int: ...

@guppy.hugr_op(builtins, int_op("inot"))
def __invert__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ile_u"))
def __le__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, int_op("ishl", num_params=2))
def __lshift__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ilt_u"))
def __lt__(self: nat, other: nat) -> bool: ...

@guppy.hugr_op(builtins, int_op("imod_u", num_params=2))
def __mod__(self: nat, other: nat) -> int: ...

@guppy.hugr_op(builtins, int_op("imul"))
def __mul__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NoopCompiler())
def __nat__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ine"))
def __ne__(self: nat, other: nat) -> bool: ...

@guppy.custom(builtins, checker=DunderChecker("__nat__"), higher_order_value=False)
def __new__(x): ...

@guppy.hugr_op(builtins, int_op("ior"))
def __or__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NoopCompiler())
def __pos__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, DummyOp("ipow")) # TODO
def __pow__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("iadd"), ReversingChecker())
def __radd__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("rand"), ReversingChecker())
def __rand__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("idivmod_u", num_params=2), ReversingChecker())
def __rdivmod__(self: nat, other: nat) -> tuple[nat, nat]: ...

@guppy.hugr_op(builtins, int_op("idiv_u", num_params=2), ReversingChecker())
def __rfloordiv__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ishl", num_params=2), ReversingChecker())
def __rlshift__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("imod_u", num_params=2), ReversingChecker())
def __rmod__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("imul"), ReversingChecker())
def __rmul__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ior"), ReversingChecker())
def __ror__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NoopCompiler())
def __round__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, DummyOp("ipow"), ReversingChecker()) # TODO
def __rpow__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ishr", num_params=2), ReversingChecker())
def __rrshift__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ishr", num_params=2))
def __rshift__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("isub"), ReversingChecker())
def __rsub__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NatTruedivCompiler(), ReversingChecker())
def __rtruediv__(self: nat, other: nat) -> float: ...

@guppy.hugr_op(builtins, int_op("ixor"), ReversingChecker())
def __rxor__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("isub"))
def __sub__(self: nat, other: nat) -> nat: ...

@guppy.custom(builtins, NatTruedivCompiler())
def __truediv__(self: nat, other: nat) -> float: ...

@guppy.custom(builtins, NoopCompiler())
def __trunc__(self: nat) -> nat: ...

@guppy.hugr_op(builtins, int_op("ixor"))
def __xor__(self: nat, other: nat) -> nat: ...


@guppy.extend_type(builtins, int_type_def)
class Int:
@guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!)
Expand Down Expand Up @@ -253,6 +397,9 @@ def __mod__(self: int, other: int) -> int: ...
@guppy.hugr_op(builtins, int_op("imul"))
def __mul__(self: int, other: int) -> int: ...

@guppy.hugr_op(builtins, DummyOp("is_to_u")) # TODO
def __nat__(self: int) -> nat: ...

@guppy.hugr_op(builtins, int_op("ine"))
def __ne__(self: int, other: int) -> bool: ...

Expand Down Expand Up @@ -385,6 +532,9 @@ def __mod__(self: float, other: float) -> float: ...
@guppy.hugr_op(builtins, float_op("fmul"), CoercingChecker())
def __mul__(self: float, other: float) -> float: ...

@guppy.hugr_op(builtins, float_op("trunc_u", "arithmetic.conversions"))
def __nat__(self: float) -> nat: ...

@guppy.hugr_op(builtins, float_op("fne"), CoercingChecker())
def __ne__(self: float, other: float) -> bool: ...

Expand Down
3 changes: 3 additions & 0 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type:
bool_type_def = _NumericTypeDef(
DefId.fresh(), "bool", None, NumericType(NumericType.Kind.Bool)
)
nat_type_def = _NumericTypeDef(
DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat)
)
int_type_def = _NumericTypeDef(
DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
)
Expand Down
3 changes: 2 additions & 1 deletion guppylang/tys/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class Kind(Enum):
"""The different kinds of numeric types."""

Bool = "bool"
Nat = "nat"
Int = "int"
Float = "float"

Expand All @@ -260,7 +261,7 @@ def to_hugr(self) -> tys.Type:
match self.kind:
case NumericType.Kind.Bool:
return SumType([NoneType(), NoneType()]).to_hugr()
case NumericType.Kind.Int:
case NumericType.Kind.Nat | NumericType.Kind.Int:
return tys.Type(
tys.Opaque(
extension="arithmetic.int.types",
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from guppylang.prelude.builtins import nat
from tests.util import compile_guppy


Expand Down Expand Up @@ -34,6 +35,19 @@ def add(x: bool, y: bool) -> int:
validate(add)


def test_nat(validate):
@compile_guppy
def foo(
a: nat, b: nat, c: bool, d: int, e: float
) -> tuple[nat, bool, int, float, float]:
b, c, d, e = nat(b), nat(c), nat(d), nat(e)
x = a + b * c // d - e
y = e / b
return x, bool(x), int(x), float(x), y

validate(foo)


def test_float_coercion(validate):
@compile_guppy
def coerce(x: int, y: float, z: bool) -> float:
Expand Down

0 comments on commit b664d2c

Please sign in to comment.