diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 6fb7236157b9..a45bbbb91fd8 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -21,8 +21,21 @@ namespace tvm { namespace ir { -using Halide::Internal::equal; -using Halide::Internal::simplify; +inline bool Equal(Expr a, Expr b) { + return Halide::Internal::equal(a, b); +} + +inline bool Equal(Stmt a, Stmt b) { + return Halide::Internal::equal(a, b); +} + +inline Expr Simplify(Expr a) { + return Halide::Internal::simplify(a); +} + +inline Stmt Simplify(Stmt a) { + return Halide::Internal::simplify(a); +} /*! * \brief Schedule s' dependent operations. diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index c667069ce189..10ffe95f653d 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -17,20 +17,20 @@ TVM_REGISTER_API(_pass_Simplify) .set_body([](const ArgStack& args, RetValue *ret) { CHECK(args.at(0).type_id == kNodeHandle); if (dynamic_cast(args.at(0).sptr.get())) { - *ret = simplify(args.at(0).operator Expr()); + *ret = Simplify(args.at(0).operator Expr()); } else { - *ret = simplify(args.at(0).operator Stmt()); + *ret = Simplify(args.at(0).operator Stmt()); } }); -TVM_REGISTER_API(_pass_equal) +TVM_REGISTER_API(_pass_Equal) .set_body([](const ArgStack& args, RetValue *ret) { CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(1).type_id == kNodeHandle); if (dynamic_cast(args.at(0).sptr.get())) { - *ret = equal(args.at(0).operator Expr(), args.at(1).operator Expr()); + *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); } else { - *ret = equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); + *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); } }); diff --git a/tests/python/test_pass_basic.py b/tests/python/test_pass_basic.py index 25ff5ea717c2..ebffc58805f3 100644 --- a/tests/python/test_pass_basic.py +++ b/tests/python/test_pass_basic.py @@ -3,11 +3,11 @@ def test_simplify(): x = tvm.Var('x') e1 = tvm.ir_pass.Simplify(x + 2 + 1) - assert(tvm.ir_pass.equal(e1, x + 3)) + assert(tvm.ir_pass.Equal(e1, x + 3)) e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x) - assert(tvm.ir_pass.equal(e2, x * 8)) + assert(tvm.ir_pass.Equal(e2, x * 8)) e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) - assert(tvm.ir_pass.equal(e3, tvm.make.Mod(x, 3))) + assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) def test_verify_ssa():