Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Allow composition of DFPattern replacements #16732

Merged

Conversation

Lunderberg
Copy link
Contributor

The rewrite_call function accepts a DFPattern, and a function to
rewrite expressions matching that pattern. Often, the rewriting
function will perform additional validation that cannot be expressed
within the DFPattern itself. If this additional validation fails,
the rewriter function will return the matched expression unmodified.

Prior to this commit, an OrPattern that matches on the first branch,
but whose rewriter function does not apply a modification, would
prevent the second branch from being checked. This commit updates the
ExprPatternRewriter to check both branches of a OrPattern, if the
rewriter function of the first branch does not modify the result.

@Lunderberg
Copy link
Contributor Author

This PR is currently marked as a draft, because it depends on the refactor done in #16730.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes seem easy to understand given the PR this is built on and I think the new behavior is intuitive as well: If the first branch of the OR does not result in a rewrite, check the second. In principle, I think this should be noted in whatever top-level documentation there is for the pattern matching grammar.

ICHECK(matches_top_level);

// Special handling if the user-supplied pattern is a `OrPattern`.
// While the `ExtractMatchedExpr` can handle match the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a typo. I assume it's supposed to be "handle matching," correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, this was a typo. It should be "handle matching", and I've updated the PR with the correction.

The `rewrite_call` function accepts a `DFPattern`, and a function to
rewrite expressions matching that pattern.  Often, the rewriting
function will perform additional validation that cannot be expressed
within the `DFPattern` itself.  If this additional validation fails,
the rewriter function will return the matched expression unmodified.

Prior to this commit, an `OrPattern` that matches on the first branch,
but whose rewriter function does not apply a modification, would
prevent the second branch from being checked.  This commit updates the
`ExprPatternRewriter` to check both branches of a `OrPattern`, if the
rewriter function of the first branch does not modify the result.
@Lunderberg Lunderberg force-pushed the relax_composable_dataflow_rewriter branch from 7e9054b to 5e0a54f Compare March 26, 2024 13:47
@Lunderberg Lunderberg marked this pull request as ready for review March 26, 2024 13:47
@Lunderberg
Copy link
Contributor Author

The pre-requisite PR #16730 has landed, so this PR is now rebased on top of main and marked as ready. Thank you @slyubomirsky for the review, and so I think it's just waiting on CI now.

@Lunderberg Lunderberg merged commit 86b5a13 into apache:main Mar 27, 2024
19 of 20 checks passed
@Lunderberg Lunderberg deleted the relax_composable_dataflow_rewriter branch March 27, 2024 17:50
@MasterJH5574
Copy link
Contributor

Hi @Lunderberg, it seems that this PR introduces new regression. The original issue came from CI in MLC-LLM https://ci.mlc.ai/blue/organizations/jenkins/mlc-llm/detail/PR-2068/1/pipeline/. And I spent some time adapting your test test_backtrack_if_rewriter_returns_no_op and reducing it to the minimal I can get.

Running the test below

def test_backtrack_if_rewriter_returns_no_op():
    pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())

    pat_arg = wildcard()
    pat_zeros = is_op("relax.zeros")(wildcard())
    pat_add = is_op("relax.add")(pat_arg, pat_zeros)

    pat = pat_match_no_rewrite | pat_add

    def rewriter(expr, matches):
        print(f"matching {pat} to {matches[pat]}")
        assert isinstance(matches[pat], rx.Call)
        if pat_match_no_rewrite in matches:
            return expr   ## <<== return the expr
        elif pat_add in matches:
            return expr   ## <<== also return the expr
        else:
            raise RuntimeError("Pattern matched, but neither branch matched")

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.ones([64, 128], "int32")
            B = R.zeros([64, 128], "int32")
            C = R.add(A, B)

            R.output(C)
        return C

    after = rewrite_call(pat, rewriter, before)

will yield the following output:

matching OrPattern(Op(relax.add)(*, *) | Op(relax.add)(*, Op(relax.zeros)(*))) to R.add(A, B)
matching OrPattern(Op(relax.add)(*, *) | Op(relax.add)(*, Op(relax.zeros)(*))) to R.add(A, B)
matching OrPattern(Op(relax.add)(*, *) | Op(relax.add)(*, Op(relax.zeros)(*))) to C
[22:36:29] /home/ruihang/Workspace/tvm/src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!
Traceback (most recent call last):
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_dataflow_pattern.py", line 1949, in <module>
    test_backtrack_if_rewriter_returns_no_op()
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_dataflow_pattern.py", line 1945, in test_backtrack_if_rewriter_returns_no_op
    after = rewrite_call(pat, rewriter, before)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruihang/Workspace/tvm/python/tvm/relax/dpl/rewrite.py", line 59, in rewrite_call
    return ffi.rewrite_call(pattern, rewriter, func)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/home/ruihang/Workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_dataflow_pattern.py", line 1921, in rewriter
    assert isinstance(matches[pat], rx.Call)
AssertionError

As you can see in the error message, with this PR, the pattern matcher mapped an add to a single Var. This breaks our assumption that what matched is always an add. Would you mind helping dig and fix this issue? Thanks in ahead.

@MasterJH5574
Copy link
Contributor

@Lunderberg
Copy link
Contributor Author

Ooh, that's a weird one, and thank you for the test case. I can reproduce it on my side, and will look in more detail tomorrow. For now, the first thing I'm noticing is that the failure occurs on the third time that rewriter is called, where I'd expect it to only be called twice (once for each branch).

matches[pat_match_no_rewrite] is filled during the first call, and matches[pat_add] is filled during the second call. During the third call, which fails the assert, matches[pat_match_no_rewrite] is filled again, which seems rather odd.

@Lunderberg
Copy link
Contributor Author

Lunderberg commented Apr 1, 2024

This is an interesting one.

  1. The DFPatternMatcher will unwrap variables when attempting to find a match (link), as part of ExtractMatchedExpr.
  2. The ExprPatternMatcher will save the value of each variable binding is saved in order to (link).
  3. The ExprPatternMatcher checks for a match in the override for Expr VisitExpr(const Expr&).

So the call to rewriter with the assert fails is made when ExprPatternMatcher is checking the variable usage in the body of the SeqExpr. This gets unwrapped by DFPatternMatcher, and identified as an expression whose value is equal to the match itself.

This potential bug existed all the way back to #15578. However, until #16732, the relax::Var body of a SeqExpr would never be the first match encountered. Any time a relax::Var would have matched, the expression to which it was bound would have matched first, hiding the later match.

I think the resolution is to update condition (1). Where currently, we unwrap all variable bindings as part of a match, we should instead unwrap all variable bindings except for the top-level of the match. This could be most easily implemented by having the ExprPatternMatcher::VisitExpr skip the call to TryRewrite for relax::Var. I'll try this out and see if it resolves the issue.

Edit: Or even better, this should be handled when filling in the OrPattern branches for the match. I had this step to match sure that the matches[pat] was still populated, even when the OrPattern was unwrapped. However, where previously this would be populated by ExtractMatchedExpr, and would contain the expression after TryGetValOfVar, I incorrectly filled it with the expression before applying TryGetValOfVar.

Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Apr 1, 2024
This resolves a bug that was introduced in
apache#16732.  If a rewriter function
returned a no-op, and the pattern-match continued, then the `matches`
provided to the rewriter function in subsequent calls would contain
a variable to which the matched expression was bound, not the matched
expression itself.  (e.g. For a match of `C = R.add(A,B)`, passing `C`
to the rewriter instead of `R.add(A,B)`.)

This bug was caused by incorrect re-wrapping of `OrPattern` in
`ExprPatternRewriter`.  Prior to
apache#16732, all pattern-match results
were populated by `ExtractMatchExpr`, and contained the result after
applying `TryGetValOfVar`.  When re-wrapping the result of an
`OrPattern`, apache#16732 populated the
additional matches with the result before applying `TryGetValOfVar`.
This commit fixes the bug by applying `TryGetValOfVar`.
@Lunderberg
Copy link
Contributor Author

@MasterJH5574 I have a fix implemented in #16828. Can you try it on your side and verify that it resolves the issue in the end-to-end flow?

@MasterJH5574
Copy link
Contributor

@Lunderberg Thanks for the swift fix! It works for e2e compilation now.

MasterJH5574 pushed a commit to MasterJH5574/tvm that referenced this pull request Apr 1, 2024
This resolves a bug that was introduced in
apache#16732.  If a rewriter function
returned a no-op, and the pattern-match continued, then the `matches`
provided to the rewriter function in subsequent calls would contain
a variable to which the matched expression was bound, not the matched
expression itself.  (e.g. For a match of `C = R.add(A,B)`, passing `C`
to the rewriter instead of `R.add(A,B)`.)

This bug was caused by incorrect re-wrapping of `OrPattern` in
`ExprPatternRewriter`.  Prior to
apache#16732, all pattern-match results
were populated by `ExtractMatchExpr`, and contained the result after
applying `TryGetValOfVar`.  When re-wrapping the result of an
`OrPattern`, apache#16732 populated the
additional matches with the result before applying `TryGetValOfVar`.
This commit fixes the bug by applying `TryGetValOfVar`.
MasterJH5574 pushed a commit to MasterJH5574/tvm that referenced this pull request Apr 1, 2024
This resolves a bug that was introduced in
apache#16732.  If a rewriter function
returned a no-op, and the pattern-match continued, then the `matches`
provided to the rewriter function in subsequent calls would contain
a variable to which the matched expression was bound, not the matched
expression itself.  (e.g. For a match of `C = R.add(A,B)`, passing `C`
to the rewriter instead of `R.add(A,B)`.)

This bug was caused by incorrect re-wrapping of `OrPattern` in
`ExprPatternRewriter`.  Prior to
apache#16732, all pattern-match results
were populated by `ExtractMatchExpr`, and contained the result after
applying `TryGetValOfVar`.  When re-wrapping the result of an
`OrPattern`, apache#16732 populated the
additional matches with the result before applying `TryGetValOfVar`.
This commit fixes the bug by applying `TryGetValOfVar`.
tqchen pushed a commit that referenced this pull request Apr 1, 2024
* [Relax][Bugfix] Provide the full Expr to pattern-match rewriter

This resolves a bug that was introduced in
#16732.  If a rewriter function
returned a no-op, and the pattern-match continued, then the `matches`
provided to the rewriter function in subsequent calls would contain
a variable to which the matched expression was bound, not the matched
expression itself.  (e.g. For a match of `C = R.add(A,B)`, passing `C`
to the rewriter instead of `R.add(A,B)`.)

This bug was caused by incorrect re-wrapping of `OrPattern` in
`ExprPatternRewriter`.  Prior to
#16732, all pattern-match results
were populated by `ExtractMatchExpr`, and contained the result after
applying `TryGetValOfVar`.  When re-wrapping the result of an
`OrPattern`, #16732 populated the
additional matches with the result before applying `TryGetValOfVar`.
This commit fixes the bug by applying `TryGetValOfVar`.

* Update with PR link of bugfix
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
[Relax] Allow composition of DFPattern replacements

The `rewrite_call` function accepts a `DFPattern`, and a function to
rewrite expressions matching that pattern.  Often, the rewriting
function will perform additional validation that cannot be expressed
within the `DFPattern` itself.  If this additional validation fails,
the rewriter function will return the matched expression unmodified.

Prior to this commit, an `OrPattern` that matches on the first branch,
but whose rewriter function does not apply a modification, would
prevent the second branch from being checked.  This commit updates the
`ExprPatternRewriter` to check both branches of a `OrPattern`, if the
rewriter function of the first branch does not modify the result.
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
…he#16828)

* [Relax][Bugfix] Provide the full Expr to pattern-match rewriter

This resolves a bug that was introduced in
apache#16732.  If a rewriter function
returned a no-op, and the pattern-match continued, then the `matches`
provided to the rewriter function in subsequent calls would contain
a variable to which the matched expression was bound, not the matched
expression itself.  (e.g. For a match of `C = R.add(A,B)`, passing `C`
to the rewriter instead of `R.add(A,B)`.)

This bug was caused by incorrect re-wrapping of `OrPattern` in
`ExprPatternRewriter`.  Prior to
apache#16732, all pattern-match results
were populated by `ExtractMatchExpr`, and contained the result after
applying `TryGetValOfVar`.  When re-wrapping the result of an
`OrPattern`, apache#16732 populated the
additional matches with the result before applying `TryGetValOfVar`.
This commit fixes the bug by applying `TryGetValOfVar`.

* Update with PR link of bugfix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants