Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, analysis] Add CommutativeDeepEqual to handle commutativity in expression comparison #12761

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zhangyicole
Copy link

@zhangyicole zhangyicole commented Sep 13, 2022

Refer to issue #10211. The CSE pass can't handle commutativity because the arith system may not be able to do the commutativity。

The determination of the equality of two expressions (PrimExpr) is to use the method of structured determination, that is, to traverse the hierarchical structure of the two expressions, and to judge while traversing, if the structures of the two expressions are the same, and the smallest If the child nodes (nodes that cannot be traversed, generally such as Var) are the same, the expressions are considered to be the same.

Before performing PrimExpr comparison, a series of rewrite rules will be used to rewrite expressions to solve some operational problems, such as x * y + x * z will be rewritten as x * (y + z), so that to deal with distributivity.
However, commutativity cannot be rewritten due to the characteristics of rewrite (it will fall into an infinite loop). This makes it impossible to compare the equality of some expressions, such as a * b *c != a * c * b, (a * b) * c != a * (b * c).

To solve this problem, one solution is to sort and rewrite the expressions according to the StructuralHash of the Var nodes in the expressions before comparing the expressions. If two expressions are equivalent if they satisfy the commutativity, then they will definitely produce equivalent expressions of the same structure after sorting.

Under the assumption that the two expressions have the same structure, the determination condition that the two expressions satisfying the commutativity are the same can be further refined to the same set of all elements in the sub-expressions satisfying the commutativity. The sub-expressions satisfying the commutativity in the expression can be grasped by constructing the expression syntax tree. The sub-expressions satisfying the commutativity are the sub-trees of the expression tree whose child nodes are identical (the child nodes are OP).

An example:

    Var: a, b, c, d, e

    StructuralHash(a > b >c >d > e)

    cse_var_1  = a * b *c + d + e

    cse_var_2 = b * a * c + e + d

Sort and rewrite cse_var_1 and cse_var_2, first extract their first subexpressions a * b *c and b * a * c, and rewrite them as a * b * c, and then extract the sub-expressions again ( a * b *c) + d + e and (a *b * c) + e + d, rewriting the sort as (a * b *c) + d + e, get the same expression.

cc @masahi @Hzfengsy @tqchen @FranckQC

@zhangyicole zhangyicole changed the title [TIR, analysis] Add expr hash sort in ExprDeepEqual. [TIR, analysis] Add expr hash sort in ExprDeepEqual Sep 13, 2022
@zhangyicole zhangyicole force-pushed the new_deep_equal branch 3 times, most recently from b0dd74f to 73dea99 Compare September 14, 2022 03:52
@masahi
Copy link
Member

masahi commented Sep 14, 2022

I can take a look at this tomorrow.

@tqchen
Copy link
Member

tqchen commented Sep 14, 2022

It would be useful to make it as a different equality comparator (rather than change DeepEqual's behavior, e.g. we can do CommunicativeDeepEqual as a subclass), as communicative rewrite is something that goes deeper.

Another possibility is add a canonicalization pass to canonicalize the expressions before CSE

@FranckQC
Copy link
Contributor

FranckQC commented Sep 14, 2022

Hello!

Thank you for the discussion and the PR.
Just a few thoughts here:

  • 1. Indeed, we knew that the existing Analyzer::Simplify() would not deal with commutativity, because it does not implement a normalization procedure that is guaranteed to converge towards the normal form (which would indeed imply sorting sub-terms, etc). Rather, all it does is it tries to "do its best" by rewriting some known patterns, with no guarantees to converge to a normal form. For this reason, they could not deal with commutativity in Simplify(), because that would indeed lead to non-terminating rewrite sequences (or more realistically return junk, as in practice they stop to rewrite after N rewrites are done). It often works fairly well in practice, but there is no guarantee of being complete (it behaves like a heuristic). However, it really must be correct (i.e, the result of simplify() must rely be equivalent to its input with algebraic laws).

  • 2. Before trying to make Analyzer::Simplify() able to deal with with commutativity, it could be useful to see if people are in practice facing the issue where a lot of redundant computations appear, but written differently due to the commutativity of some operators (like + and *). If so, it would be cool to see such concrete examples. To be honest, I don't even think that that many people are turning ON the already existing bool identify_equiv_terms of the CSE pass Pass CommonSubexprElimTIR, which uses Analyzer::Simplify(), which itself does what it can with associativity, neutral elements, etc. These things are probably pretty rare, and commutativity is probably too.

Although I designed the pass in a way that it can potentially identify terms that are equivalent according to any equivalence relation (instead of just the syntactical equality ExprDeepEqual), it might not necessary be often needed in practice. This design that allows to use a custom equivalence relation did not cost much more, so I went for it, in case it could become useful to someone one day for a particular case, so that it would not be needed to write another CSE pass for dealing with that. But I don't necessary think that we should do comparisons modulo equivalence often, and this is why by default bool identify_equiv_terms of the Pass CommonSubexprElimTIR is set to false.

  • 3. If we decide that Analyzer::Simplify() (or another new Analyzer!) should deal with commutativity, then it should itself deal with commutativity, rather than baking commutativity into ExprDeepEqual which is supposed to be just a deep syntactical equality, and is used as such in many many places of TVM's codebase (not just the CSE). So clearly it should not being changed (as @tqchen noted too).

  • 4. Remember that normalizing terms properly in order to deal with commutativity (which indeed includes sorting sub-terms) will likely be computationally expensive, which will make compilation of ML models longer. Actually, just the smaller work that Analyzer::Simplify() does is already time consuming, and that's probably why people leave the bool identify_equiv_terms of the CSE pass set to false (which is the default behavior, as I wrote earlier). It might make people want even less to turn ON this bool identify_equiv_terms. Perhaps the pseudo-normalization that Analyzer::Simplify() does is not too bad as a compromise: still usable in practice when needed (i.e does not take too long), and deals with most simplifications needed (although not commutativity).

  • 5. If all the other algebraic properties (associativity, simplification of neutral elements, etc) are still done by the pseudo-normalization Analyzer::Simplify() that is not guaranteed to find a normal form, I am not sure that a "normalizer for commutativity" built on top would be complete -even just in regard to commutativity. Is it worth it to then make Analyzer::Simplify() slower while still being incomplete?

Thanks!

@zhangyicole
Copy link
Author

Hello!

Thank you for the discussion and the PR. Just a few thoughts here:

  • 1. Indeed, we knew that the existing Analyzer::Simplify() would not deal with commutativity, because it does not implement a normalization procedure that is guaranteed to converge towards the normal form (which would indeed imply sorting sub-terms, etc). Rather, all it does is it tries to "do its best" by rewriting some known patterns, with no guarantees to converge to a normal form. For this reason, they could not deal with commutativity in Simplify(), because that would indeed lead to non-terminating rewrite sequences (or more realistically return junk, as in practice they stop to rewrite after N rewrites are done). It often works fairly well in practice, but there is no guarantee of being complete (it behaves like a heuristic). However, it really must be correct (i.e, the result of simplify() must rely be equivalent to its input with algebraic laws).
  • 2. Before trying to make Analyzer::Simplify() able to deal with with commutativity, it could be useful to see if people are in practice facing the issue where a lot of redundant computations appear, but written differently due to the commutativity of some operators (like + and *). If so, it would be cool to see such concrete examples. To be honest, I don't even think that that many people are turning ON the already existing bool identify_equiv_terms of the CSE pass Pass CommonSubexprElimTIR, which uses Analyzer::Simplify(), which itself does what it can with associativity, neutral elements, etc. These things are probably pretty rare, and commutativity is probably too.

Although I designed the pass in a way that it can potentially identify terms that are equivalent according to any equivalence relation (instead of just the syntactical equality ExprDeepEqual), it might not necessary be often needed in practice. This design that allows to use a custom equivalence relation did not cost much more, so I went for it, in case it could become useful to someone one day for a particular case, so that it would not be needed to write another CSE pass for dealing with that. But I don't necessary think that we should do comparisons modulo equivalence often, and this is why by default bool identify_equiv_terms of the Pass CommonSubexprElimTIR is set to false.

  • 3. If we decide that Analyzer::Simplify() (or another new Analyzer!) should deal with commutativity, then it should itself deal with commutativity, rather than baking commutativity into ExprDeepEqual which is supposed to be just a deep syntactical equality, and is used as such in many many places of TVM's codebase (not just the CSE). So clearly it should not being changed (as @tqchen noted too).
  • 4. Remember that normalizing terms properly in order to deal with commutativity (which indeed includes sorting sub-terms) will likely be computationally expensive, which will make compilation of ML models longer. Actually, just the smaller work that Analyzer::Simplify() does is already time consuming, and that's probably why people leave the bool identify_equiv_terms of the CSE pass set to false (which is the default behavior, as I wrote earlier). It might make people want even less to turn ON this bool identify_equiv_terms. Perhaps the pseudo-normalization that Analyzer::Simplify() does is not too bad as a compromise: still usable in practice when needed (i.e does not take too long), and deals with most simplifications needed (although not commutativity).
  • 5. If all the other algebraic properties (associativity, simplification of neutral elements, etc) are still done by the pseudo-normalization Analyzer::Simplify() that is not guaranteed to find a normal form, I am not sure that a "normalizer for commutativity" built on top would be complete -even just in regard to commutativity. Is it worth it to then make Analyzer::Simplify() slower while still being incomplete?

Thanks!

Hi,@FranckQC,thanks for the questions and comments。
I quite agree with what you said about the complexity of the Analyzer::Simplify(), the overhead of the Analyzer::Simplify() is very expensive, changes to the ExprDeepEqual will not improve the function and may slow down its performance, so, as suggested by tqchen, make it as A subclass, maybe a good approach.

The source of this submission is that in my previous use, there was a size judgment for the input and output size of the operator containing the reshape operation, and I rewrote deep equal to meet my needs. After seeing your issue, I think this part of the rewrite is helpful, so submit this PR.

@zhangyicole
Copy link
Author

It would be useful to make it as a different equality comparator (rather than change DeepEqual's behavior, e.g. we can do CommunicativeDeepEqual as a subclass), as communicative rewrite is something that goes deeper.

Another possibility is add a canonicalization pass to canonicalize the expressions before CSE

Hi,@tqchen,Thanks a lot for the review!
I agree that it would be useful to make it as a different equality comparator,
this will not have an impact on TVM's infrastructure. So how about making it a new subclass, as you suggested?

@zhangyicole zhangyicole force-pushed the new_deep_equal branch 3 times, most recently from 961107f to 9be7e7c Compare September 26, 2022 06:11
@zhangyicole
Copy link
Author

Hi,@tqchen. I have implemented the function as a subclass,would you like to take a look and merge it if everything looks good?

@tqchen
Copy link
Member

tqchen commented Sep 29, 2022

@FranckQC @masahi can you help to take a look

@FranckQC
Copy link
Contributor

@FranckQC @masahi can you help to take a look

Sure, will do on Monday if that's ok :)

@tqchen
Copy link
Member

tqchen commented Sep 29, 2022

of course

@zhangyicole
Copy link
Author

Hi, @FranckQC @masahi. Is there anything in the code that I need to update?

@zhangyicole
Copy link
Author

Hi, @FranckQC @masahi. If nothing needs to be changed in this PR, can you help merge it?

@tvm-bot
Copy link
Collaborator

tvm-bot commented Oct 18, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@@ -331,3 +331,32 @@ def OOBChecker():
The result pass
"""
return _ffi_api.OOBChecker() # type: ignore


def communicative_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "communicative" a word? Do you mean "commutative"?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!It was a typo, I have corrected it.

@masahi
Copy link
Member

masahi commented Oct 18, 2022

Please update the PR title and description to reflect the current status. In particular, please make it more concise and explain what the goal is clearly.

For example, rather than "Add expr hash sort in ExprDeepEqual", explain why you want to do this.

----

This function is an extension of py:func:`tvm.ir.expr_deep_equal`, it can
handle commutativity. The function will not return true for (x + y) vs (y + x).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By "The function", which function you are talking about here? If you start by saying "This function" and continue with "The function" in the next sentence, people would think that they refer to the same function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!The description here is indeed misleading, I have replaced "The function" with the exact function name to avoid misunderstanding.

@zhangyicole zhangyicole changed the title [TIR, analysis] Add expr hash sort in ExprDeepEqual [TIR, analysis] Add CommutativeDeepEqual to handle commutativity in expression comparison Oct 18, 2022
@zhangyicole zhangyicole requested review from masahi and removed request for tqchen and Hzfengsy October 18, 2022 08:45
@areusch areusch added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it and removed needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Oct 19, 2022
Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try making each VisitExpr_ more readable?


class SortExprByHashMutator : public StmtExprMutator {
public:
void Init() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with constructor

sort_lhs = sort.Rewrite(sort_lhs);
}
sort.pre_max_tree_idx = INT32_MAX;
auto sort_rhs = sort.Rewrite(rhs);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with SortExprByHashMutator::Rewrite(...) and avoid using the same sorter twice.

auto sort_lhs = sort.Rewrite(lhs);
while (sort.pre_max_tree_idx != -1) {
sort_lhs = sort.Rewrite(sort_lhs);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this loop do? Why not do it inside Rewrite?

PrimExpr a; \
PrimExpr b; \
int cur_tree_idx_temp = cur_tree_idx; \
GetRef<PrimExpr>(op); \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line necessary?

private:
std::string pre_bin_op = "null";
int stack_idx = 0;
int cur_tree_idx = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document these params. Otherwise it's impossible to understand your code.

std::string pre_bin_op = "null";
int stack_idx = 0;
int full_stack_size = 0;
int cur_tree_idx = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document them.

@FranckQC
Copy link
Contributor

I was finally able to take some time this week to have a closer look at this PR. It is definitely better than it was before, as it leaves the current DeepEqual unchanged, which was very important for me. That was my initial point 3 in my earlier comment, which I consider being addressed, thank you :).

However, I agree with @masahi comments. The implementation could have a lot more comments and documentation about the variables being used, what the function do, and what parts of the algorithm do. It would help a lot reading the code, which increases the confidence one can have in the implementation.

It's a great thing that there is quite a lot of tests, thanks for taking the time to write many of them. However, I'd also like to be able to see a real usage for new equivalence relation (that was point 2 in my earlier comment). TVM is a compiler, not a tool for just doing algebraic manipulations of mathematical terms like Matlab, so I would really like to see some natural use cases for this, where this get used/integrated into a pass, or into something else that ultimately lead to improvements in the code produced by the compilation of some ML models.

More minor thing: Finally, I'd like to know how one is supposed to use this CommutativeDeepEqual along the Analyzer::Simplify() function that performs other kind of simplification (simplification of neutral elements, applying distributivity, etc, but which unfortunately can't handle commutativity, as discussed earlier in the thread), in order to have a function that uses all the algebraic properties available. I imagine it would call Simplify() on both sides and then uses this new CommutativeDeepEqual. Would that be enough for being complete?
The reason behind that is the following: I believe most of the people who could be interested in equality-modulo-commutativity will be coming here after having discovered that Analyzer::Simplify() can't do all the simplifications for them. So when they will learn that there is this CommutativeDeepEqual equivalence relation for dealing with commutativity, their first question will likely be "how do I combine both?". So I think demonstrating that could be useful.

The most important thing for me at this stage for this PR are to add comments to the code, and to show some real use case/integration for this. Thank you for your work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants