Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[Pass] Support Function and If in Normalize pass. #268

Merged
merged 4 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/relax/transform/normalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,38 @@ class NormalizeMutator : public ExprMutatorBase {
return builder_->Normalize(ExprMutatorBase::VisitExpr(expr));
}

Expr VisitExpr_(const FunctionNode* op) {
Expr body = this->VisitWithNewScope(op->body);

if (body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Function(op->params, body, op->ret_type, op->ret_shape, op->attrs);
}
}

Expr VisitExpr_(const IfNode* op) {
Expr guard = this->VisitExpr(op->cond);
Expr true_b = this->VisitWithNewScope(op->true_branch);
Expr false_b = this->VisitWithNewScope(op->false_branch);
if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) &&
op->false_branch.same_as(false_b)) {
return GetRef<Expr>(op);
} else {
return If(guard, true_b, false_b, op->span);
}
}

Expr VisitWithNewScope(const Expr& expr) {
builder_->BeginBindingBlock();
Expr ret = this->VisitExpr(expr);
BindingBlock prologue = builder_->EndBlock();
if (!prologue->bindings.empty()) {
ret = SeqExpr({prologue}, ret);
}
return ret;
}

Expr VisitExpr_(const SeqExprNode* op) final {
bool all_blocks_unchanged = true;
Array<BindingBlock> blocks;
Expand Down
19 changes: 19 additions & 0 deletions tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,5 +871,24 @@ def k(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor:
check_shape(gv2_bind.var, ("n", "n"))


def test_class_normalize():
@tvm.script.ir_module
class InputModule:
@R.function
def mul_add(x: Tensor) -> Tensor:
return R.multiply(R.add(x, x), R.add(x, x))

# The parser automatically normalizes the input AST to the following ANF form
@tvm.script.ir_module
class OutputModule:
@R.function
def mul_add(x: Tensor) -> Tensor:
gv = relax.add(x, x)
gv1 = relax.add(x, x)
return R.multiply(gv, gv1)

assert_structural_equal(InputModule, OutputModule)


if __name__ == "__main__":
pytest.main([__file__])
97 changes: 90 additions & 7 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,24 +492,107 @@ def foo(x: Tensor((d,), "float32")):
assert cast_expr.dtype == "int64"


def test_to_anf():
def test_normalize_function():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
type_anno = relax.DynTensorType(ndim=2, dtype="float16")
x = relax.Var("x", [m, n], type_anno)

# Note: the parser automatically normalize the IR written in TVMScript,
# so we manually construct the function here.
mul_add = relax.Function(
[x],
relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
ret_type=type_anno,
ret_shape=relax.RuntimeDepShape(),
)
mul_add = mul_add.with_attr("global_symbol", "mul_add")
before_mod = tvm.IRModule.from_expr(mul_add)

after_mod = relax.transform.Normalize()(before_mod)

@tvm.script.ir_module
class Expected:
@R.function
def mul_add(x: Tensor((m, n), "float16")) -> Tensor(None, "float16", ndim=2):
gv = R.add(x, x)
gv1 = relax.add(x, x)
return R.multiply(gv, gv1)

assert_structural_equal(after_mod, Expected)


def test_normalize_if():
cond = relax.Var("cond", [], type_annotation=relax.DynTensorType(0, "bool"))
x = relax.Var("x", [tir.IntImm("int64", 1)], type_annotation=relax.DynTensorType(1, "float32"))
# TODO(relax-team): add type and shape inference for IfNode
y = relax.Var("y")

# Note: the parser automatically normalize the IR written in TVMScript,
# so we manually construct the function and If here.
f = relax.Function(
[cond, x],
relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(
y,
relax.If(
cond,
relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
relax.op.add(relax.op.multiply(x, x), relax.op.multiply(x, x)),
),
)
]
)
],
y,
),
ret_type=relax.DynTensorType(1, "float32"),
ret_shape=relax.RuntimeDepShape(),
)

f = f.with_attr("global_symbol", "f")
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)

@tvm.script.ir_module
class TestNormalizeInputModule:
class Expected:
@R.function
def f(
cond: Tensor((), "bool"), x: Tensor((1,), "float32")
) -> Tensor(None, "float32", ndim=1):
if cond:
gv = R.add(x, x)
gv1 = R.add(x, x)
y = R.multiply(gv, gv1)
else:
gv = R.multiply(x, x)
gv1 = R.multiply(x, x)
y = R.add(gv, gv1)
return y

assert_structural_equal(after_mod, Expected)


def test_normalize_no_op():
# the normalize pass should be no-op for IR in ANF
@tvm.script.ir_module
class ANFMod1:
@R.function
def f(x: Tensor(_, "float32")):
gv = relax.add(x, x)
gv1 = relax.add(gv, gv)
gv2 = relax.add(gv, gv1)
return (gv, gv2)

before_mod = TestNormalizeInputModule
before_mod = ANFMod1
after_mod = relax.transform.Normalize()(before_mod)
assert_structural_equal(before_mod, after_mod, map_free_vars=True)


def test_to_anf_no_op():
@tvm.script.ir_module
class TestANFNoOp:
class ANFMod2:
@R.function
def foo(x: Tensor((m, n), "float32")):
with relax.dataflow():
Expand All @@ -518,7 +601,7 @@ def foo(x: Tensor((m, n), "float32")):
relax.output(gv0)
return gv0

mod = TestANFNoOp
mod = ANFMod2
mod_post = relax.transform.Normalize()(mod)

assert_structural_equal(mod, mod_post)
Expand Down