Skip to content

Commit

Permalink
add prelude test
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart committed Aug 20, 2020
1 parent bc8383d commit 6398f0e
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,54 @@ def body(var, call_func):
tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)

def test_mutual_recursion_list_sum():
# f[A](x: List[A]) = match x {
# Cons(a, Nil) => a
# Cons(_, b) => g(b)
# }
# g[B](y: List[B]) = match y {
# Cons(a, Nil) => a
# Cons(_, b) => f(b)
# }
p = Prelude()
l = p.l
A = relay.TypeVar("x")
B = relay.TypeVar("y")

x = relay.Var("x", l(A))
y = relay.Var("y", l(B))

f_gv = relay.GlobalVar('f')
g_gv = relay.GlobalVar('g')

def body(var, call_func, type_param):
a = relay.Var("a")
b = relay.Var("b")
body = relay.Match(
var,
[
relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternVar(a), relay.PatternConstructor(p.nil)]), a),
relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternWildcard(), relay.PatternVar(b)]), relay.Call(call_func, [b]))
],
complete=False
)
func = relay.Function([var], body, type_params=type_param)
return func

f = body(x, g_gv, [A])
g = body(y, f_gv, [B])

mod = p.mod
mod.add_unchecked(f_gv, f)
mod.add_unchecked(g_gv, g)
mod = transform.InferTypeAll()(mod)

tv = relay.TypeVar("test")
expected = relay.FuncType([l(tv)], tv, [tv])
tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)


def test_id_mutual():
# f(x) = g(x)
# g(y) = f(y)
Expand Down Expand Up @@ -455,4 +503,4 @@ def test_if():
# test_let_polymorphism()
# test_if()
# test_id_mutual()
test_mutual_recursion2()
test_mutual_recursion_list_sum()

0 comments on commit 6398f0e

Please sign in to comment.