Skip to content

Commit

Permalink
working rn
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart committed Aug 20, 2020
1 parent 6398f0e commit 348573e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 33 deletions.
8 changes: 8 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ TVM_DLL Pass FastMath();
*/
TVM_DLL Pass InferType();

/*!
* \brief Infer the type of all functions in a module.
*
* This pass should be used when typechecking modules
* with mutually recursive functions.
*
* \return The pass.
*/
TVM_DLL Pass InferTypeAll();

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ IRModule InferTypeAll(const IRModule& mod) {
mod->AddUnchecked(var, func);
}

TypeInferencer ti = TypeInferencer(mod, GlobalVar("all"));
TypeInferencer ti = TypeInferencer(mod, GlobalVar("dummy"));

// second pass, fill in constraints
for (const auto& var : globalvars) {
Expand Down
36 changes: 4 additions & 32 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def test_let_polymorphism():
int32 = relay.TensorType((), "int32")
tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))

def test_mutual_recursion2():
def test_mutual_recursion():
# f(x) = if x > 0 then g(x - 1) else 0
# g(y) = if y > 0 then f(y - 1) else 0
tensortype = relay.TensorType((), 'float32')
Expand Down Expand Up @@ -399,7 +399,7 @@ def body(var, call_func):
tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)

def test_mutual_recursion_list_sum():
def test_mutual_recursion_adt():
# f[A](x: List[A]) = match x {
# Cons(a, Nil) => a
# Cons(_, b) => g(b)
Expand Down Expand Up @@ -446,34 +446,6 @@ def body(var, call_func, type_param):
tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected)
tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected)


def test_id_mutual():
# f(x) = g(x)
# g(y) = f(y)

tx = relay.TypeVar("x")
ty = relay.TypeVar("y")

x = relay.Var("x", tx)
y = relay.Var("y", ty)

f_gv = relay.GlobalVar('f')
g_gv = relay.GlobalVar('g')

def body(var, call_func, type_param):
body = relay.Call(call_func, [var])
func = relay.Function([var], body, type_params=type_param)
return func

f = body(x, g_gv, [tx])
g = body(y, f_gv, [ty])

mod = tvm.IRModule()
mod.add_unchecked(f_gv, f)
mod.add_unchecked(g_gv, g)
mod = transform.InferTypeAll()(mod)


def test_if():
choice_t = relay.FuncType([], relay.scalar_type('bool'))
f = relay.Var('f', choice_t)
Expand Down Expand Up @@ -502,5 +474,5 @@ def test_if():
# test_adt_match()
# test_let_polymorphism()
# test_if()
# test_id_mutual()
test_mutual_recursion_list_sum()
test_mutual_recursion()
test_mutual_recursion_adt()

0 comments on commit 348573e

Please sign in to comment.