From 50f860b51a33cc7de2960d7e6c2fca48e061a6dd Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Sat, 8 Oct 2022 22:49:34 -0700 Subject: [PATCH 1/4] Support Function and If in Normalize pass. --- src/relax/transform/normalize.cc | 32 ++++++++++++++++++++++ tests/python/relax/test_parser.py | 19 +++++++++++++ tests/python/relax/test_transform.py | 40 +++++++++++++++++++++++----- 3 files changed, 84 insertions(+), 7 deletions(-) diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 5dcc003d39..bc5faf96b1 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -40,6 +40,38 @@ class NormalizeMutator : public ExprMutatorBase { return builder_->Normalize(ExprMutatorBase::VisitExpr(expr)); } + Expr VisitExpr_(const FunctionNode* op) { + Expr body = this->VisitWithNewScope(op->body); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, body, op->ret_type, op->attrs); + } + } + + Expr VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } + } + + Expr VisitWithNewScope(const Expr& expr) { + builder_->BeginBindingBlock(); + Expr ret = this->VisitExpr(expr); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + ret = SeqExpr({prologue}, ret); + } + return ret; + } + Expr VisitExpr_(const SeqExprNode* op) final { bool all_blocks_unchanged = true; Array blocks; diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index aec36836b1..b7f196d399 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -871,5 +871,24 @@ def k(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: check_shape(gv2_bind.var, ("n", "n")) +def test_class_normalize(): + @tvm.script.ir_module + class InputModule: + @R.function + def mul_add(x: Tensor) -> Tensor: + return R.multiply(R.add(x, x), R.add(x, x)) + + # The parser automatically normalizes the input AST to the following ANF form + @tvm.script.ir_module + class OutputModule: + @R.function + def mul_add(x: Tensor) -> Tensor: + gv = relax.add(x, x) + gv1 = relax.add(x, x) + return R.multiply(gv, gv1) + + assert_structural_equal(InputModule, OutputModule) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 7faba3ad67..ed3cb04148 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -492,9 +492,37 @@ def foo(x: Tensor((d,), "float32")): assert cast_expr.dtype == "int64" -def test_to_anf(): +def test_normalize(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + type_anno = relax.DynTensorType(ndim=2, dtype="float16") + x = relax.Var("x", [m, n], type_anno) + + mul_add = relax.Function( + [x], + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + ret_type=type_anno, + ) + mul_add = mul_add.with_attr("global_symbol", "mul_add") + before_mod = tvm.IRModule.from_expr(mul_add) + + after_mod = relax.transform.Normalize()(before_mod) + @tvm.script.ir_module - class TestNormalizeInputModule: + class Expected: + @R.function + def mul_add(x: Tensor((m, n), "float16")) -> Tensor(None, "float16", ndim=2): + gv = relax.add(x, x) + gv1 = relax.add(x, x) + return R.multiply(gv, gv1) + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_no_op(): + # the normalize pass should be no-op for IR in ANF + @tvm.script.ir_module + class ANFMod1: @R.function def f(x: Tensor(_, "float32")): gv = relax.add(x, x) @@ -502,14 +530,12 @@ def f(x: Tensor(_, "float32")): gv2 = relax.add(gv, gv1) return (gv, gv2) - before_mod = TestNormalizeInputModule + before_mod = ANFMod1 after_mod = relax.transform.Normalize()(before_mod) assert_structural_equal(before_mod, after_mod, map_free_vars=True) - -def test_to_anf_no_op(): @tvm.script.ir_module - class TestANFNoOp: + class ANFMod2: @R.function def foo(x: Tensor((m, n), "float32")): with relax.dataflow(): @@ -518,7 +544,7 @@ def foo(x: Tensor((m, n), "float32")): relax.output(gv0) return gv0 - mod = TestANFNoOp + mod = ANFMod2 mod_post = relax.transform.Normalize()(mod) assert_structural_equal(mod, mod_post) From a734d66b08e6f7402cd75337d41b16177978a744 Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Mon, 10 Oct 2022 08:30:23 -0700 Subject: [PATCH 2/4] Use structural equality for expr_memo_. --- src/relax/ir/block_builder.cc | 2 +- tests/python/relax/test_parser.py | 3 +-- tests/python/relax/test_transform.py | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 5ceca53a11..3398266c13 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -355,7 +355,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { private: std::unordered_map var_memo_; - std::unordered_map expr_memo_; + std::unordered_map expr_memo_; }; // Helper function to check if a ShapeExpr is constant shape or tuple of constant shape diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index b7f196d399..380bb0e630 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -884,8 +884,7 @@ class OutputModule: @R.function def mul_add(x: Tensor) -> Tensor: gv = relax.add(x, x) - gv1 = relax.add(x, x) - return R.multiply(gv, gv1) + return R.multiply(gv, gv) assert_structural_equal(InputModule, OutputModule) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index ed3cb04148..7204becf65 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -513,8 +513,7 @@ class Expected: @R.function def mul_add(x: Tensor((m, n), "float16")) -> Tensor(None, "float16", ndim=2): gv = relax.add(x, x) - gv1 = relax.add(x, x) - return R.multiply(gv, gv1) + return R.multiply(gv, gv) assert_structural_equal(after_mod, Expected) From 8f3eb44bcfd47e7eac32988854bc808e0b406091 Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Mon, 10 Oct 2022 12:50:27 -0700 Subject: [PATCH 3/4] Change back to pointer equality for expr_memo_; Add more tests. --- src/relax/ir/block_builder.cc | 2 +- tests/python/relax/test_parser.py | 3 +- tests/python/relax/test_transform.py | 62 ++++++++++++++++++++++++++-- 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 3398266c13..5ceca53a11 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -355,7 +355,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { private: std::unordered_map var_memo_; - std::unordered_map expr_memo_; + std::unordered_map expr_memo_; }; // Helper function to check if a ShapeExpr is constant shape or tuple of constant shape diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 380bb0e630..b7f196d399 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -884,7 +884,8 @@ class OutputModule: @R.function def mul_add(x: Tensor) -> Tensor: gv = relax.add(x, x) - return R.multiply(gv, gv) + gv1 = relax.add(x, x) + return R.multiply(gv, gv1) assert_structural_equal(InputModule, OutputModule) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 7204becf65..a28d60baa1 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -492,12 +492,14 @@ def foo(x: Tensor((d,), "float32")): assert cast_expr.dtype == "int64" -def test_normalize(): +def test_normalize_function(): m = tir.Var("m", "int64") n = tir.Var("n", "int64") type_anno = relax.DynTensorType(ndim=2, dtype="float16") x = relax.Var("x", [m, n], type_anno) + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function here. mul_add = relax.Function( [x], relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), @@ -512,8 +514,62 @@ def test_normalize(): class Expected: @R.function def mul_add(x: Tensor((m, n), "float16")) -> Tensor(None, "float16", ndim=2): - gv = relax.add(x, x) - return R.multiply(gv, gv) + gv = R.add(x, x) + gv1 = relax.add(x, x) + return R.multiply(gv, gv1) + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_if(): + cond = relax.Var("cond", [], type_annotation=relax.DynTensorType(0, "bool")) + x = relax.Var("x", [tir.IntImm("int64", 1)], type_annotation=relax.DynTensorType(1, "float32")) + # TODO(relax-team): add type and shape inference for IfNode + y = relax.Var("y") + + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function and If here. + f = relax.Function( + [cond, x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + y, + relax.If( + cond, + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + relax.op.add(relax.op.multiply(x, x), relax.op.multiply(x, x)), + ), + ) + ] + ) + ], + y, + ), + ret_type=relax.DynTensorType(1, "float32"), + ) + + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + cond: Tensor((), "bool"), x: Tensor((1,), "float32") + ) -> Tensor(None, "float32", ndim=1): + if cond: + gv = R.add(x, x) + gv1 = R.add(x, x) + y = R.multiply(gv, gv1) + else: + gv = R.multiply(x, x) + gv1 = R.multiply(x, x) + y = R.add(gv, gv1) + return y assert_structural_equal(after_mod, Expected) From 1171e37a00bdca0deb606b3a263ba1ea04c5639a Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Mon, 10 Oct 2022 17:51:06 -0700 Subject: [PATCH 4/4] rebase. --- src/relax/transform/normalize.cc | 2 +- tests/python/relax/test_transform.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index bc5faf96b1..8beb2b6b5a 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -46,7 +46,7 @@ class NormalizeMutator : public ExprMutatorBase { if (body.same_as(op->body)) { return GetRef(op); } else { - return Function(op->params, body, op->ret_type, op->attrs); + return Function(op->params, body, op->ret_type, op->ret_shape, op->attrs); } } diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index a28d60baa1..0a3272bfe8 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -504,6 +504,7 @@ def test_normalize_function(): [x], relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), ret_type=type_anno, + ret_shape=relax.RuntimeDepShape(), ) mul_add = mul_add.with_attr("global_symbol", "mul_add") before_mod = tvm.IRModule.from_expr(mul_add) @@ -549,6 +550,7 @@ def test_normalize_if(): y, ), ret_type=relax.DynTensorType(1, "float32"), + ret_shape=relax.RuntimeDepShape(), ) f = f.with_attr("global_symbol", "f")