Skip to content

Commit

Permalink
Revert "[Unity] Avoid trivial var2 = var1 bindings in pattern match…
Browse files Browse the repository at this point in the history
…er (#15578)"

This reverts commit 567848e.
  • Loading branch information
masahi committed Sep 26, 2023
1 parent 6c5a435 commit 0a6a617
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 35 deletions.
33 changes: 9 additions & 24 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -951,31 +951,16 @@ class PatternRewriter : ExprMutator {
}
}

Expr VisitExpr(const Expr& expr) final {
auto node = ExprMutator::VisitExpr(expr);
if (pattern_) {
if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), node, bindings_)) {
Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
if (!rewritten_expr.same_as(node)) {
rewritten_expr = builder_->Normalize(rewritten_expr);

// If the rewriter returns a variable (e.g. when rewriting
// from `R.add(x, R.const(0.0))` to `x`), the variable
// should be dereferenced to avoid trivial `var_2 = var_1`
// bindings. This lookup is done using the builder_ instead
// of the bindings_, as the previous `builder_->Normalize`
// call may have introduced variable bindings.
if (auto opt_var = rewritten_expr.as<Var>()) {
if (auto binding = builder_->LookupBinding(opt_var.value())) {
rewritten_expr = binding.value();
}
}
memo_[expr.get()] = rewritten_expr;
return rewritten_expr;
}
}
Expr VisitExpr_(const CallNode* call_node) final {
auto call = ExprMutator::VisitExpr_(call_node);
if (!pattern_) {
return call;
} else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call, bindings_)) {
auto rewriten_expr = rewriter_func_(call, matches_opt.value());
memo_[call_node] = rewriten_expr;
return rewriten_expr;
}
return node;
return call;
}

BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final {
Expand Down
15 changes: 4 additions & 11 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,8 +1287,7 @@ def before(
x: R.Tensor((1024,)),
):
with R.dataflow():
y = R.add(x, x)
out = R.add(R.const(1.0), y)
out = R.add(R.const(1.0), x)
R.output(out)
return out

Expand All @@ -1297,10 +1296,8 @@ def expected(
x: R.Tensor((1024,)),
):
with R.dataflow():
y = R.add(x, x)
out = R.add(y, R.const(2.0))
out = R.add(x, R.const(2.0))
R.output(out)

return out

pattern_add = is_op("relax.add")
Expand All @@ -1311,14 +1308,10 @@ def expected(

pattern = pattern_op(pattern_arg, pattern_const)

def rewriter(expr, matches):
def rewriter(_expr, matches):
op = matches[pattern_op]
arg = matches[pattern_arg]
const = matches[pattern_const].data.numpy()
if const.shape == tuple() and const[()] == 1.0:
return rx.Call(op, [arg, rx.const(2.0)])
else:
return expr
return rx.Call(op, [arg, rx.const(2.0)])

after = rewrite_call(pattern, rewriter, before)
tvm.ir.assert_structural_equal(after, expected)
Expand Down

0 comments on commit 0a6a617

Please sign in to comment.