-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR, analysis] Add CommutativeDeepEqual to handle commutativity in expression comparison #12761
base: main
Are you sure you want to change the base?
Conversation
bfd9cdb
to
4442813
Compare
b0dd74f
to
73dea99
Compare
I can take a look at this tomorrow. |
It would be useful to make it as a different equality comparator (rather than change DeepEqual's behavior, e.g. we can do CommunicativeDeepEqual as a subclass), as communicative rewrite is something that goes deeper. Another possibility is add a canonicalization pass to canonicalize the expressions before CSE |
Hello! Thank you for the discussion and the PR.
Although I designed the pass in a way that it can potentially identify terms that are equivalent according to any equivalence relation (instead of just the syntactical equality
Thanks! |
Hi,@FranckQC,thanks for the questions and comments。 The source of this submission is that in my previous use, there was a size judgment for the input and output size of the operator containing the reshape operation, and I rewrote deep equal to meet my needs. After seeing your issue, I think this part of the rewrite is helpful, so submit this PR. |
Hi,@tqchen,Thanks a lot for the review! |
961107f
to
9be7e7c
Compare
Hi,@tqchen. I have implemented the function as a subclass,would you like to take a look and merge it if everything looks good? |
9be7e7c
to
c21d856
Compare
of course |
c21d856
to
f600819
Compare
f600819
to
4369818
Compare
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment. Generated by tvm-bot |
python/tvm/tir/analysis/analysis.py
Outdated
@@ -331,3 +331,32 @@ def OOBChecker(): | |||
The result pass | |||
""" | |||
return _ffi_api.OOBChecker() # type: ignore | |||
|
|||
|
|||
def communicative_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "communicative" a word? Do you mean "commutative"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review!It was a typo, I have corrected it.
Please update the PR title and description to reflect the current status. In particular, please make it more concise and explain what the goal is clearly. For example, rather than "Add expr hash sort in ExprDeepEqual", explain why you want to do this. |
python/tvm/tir/analysis/analysis.py
Outdated
---- | ||
|
||
This function is an extension of py:func:`tvm.ir.expr_deep_equal`, it can | ||
handle commutativity. The function will not return true for (x + y) vs (y + x). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By "The function", which function you are talking about here? If you start by saying "This function" and continue with "The function" in the next sentence, people would think that they refer to the same function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review!The description here is indeed misleading, I have replaced "The function" with the exact function name to avoid misunderstanding.
…xpression comparison
4369818
to
3d9e540
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try making each VisitExpr_
more readable?
|
||
class SortExprByHashMutator : public StmtExprMutator { | ||
public: | ||
void Init() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace with constructor
sort_lhs = sort.Rewrite(sort_lhs); | ||
} | ||
sort.pre_max_tree_idx = INT32_MAX; | ||
auto sort_rhs = sort.Rewrite(rhs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace with SortExprByHashMutator::Rewrite(...)
and avoid using the same sorter twice.
auto sort_lhs = sort.Rewrite(lhs); | ||
while (sort.pre_max_tree_idx != -1) { | ||
sort_lhs = sort.Rewrite(sort_lhs); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this loop do? Why not do it inside Rewrite
?
PrimExpr a; \ | ||
PrimExpr b; \ | ||
int cur_tree_idx_temp = cur_tree_idx; \ | ||
GetRef<PrimExpr>(op); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this line necessary?
private: | ||
std::string pre_bin_op = "null"; | ||
int stack_idx = 0; | ||
int cur_tree_idx = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document these params. Otherwise it's impossible to understand your code.
std::string pre_bin_op = "null"; | ||
int stack_idx = 0; | ||
int full_stack_size = 0; | ||
int cur_tree_idx = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document them.
I was finally able to take some time this week to have a closer look at this PR. It is definitely better than it was before, as it leaves the current DeepEqual unchanged, which was very important for me. That was my initial point 3 in my earlier comment, which I consider being addressed, thank you :). However, I agree with @masahi comments. The implementation could have a lot more comments and documentation about the variables being used, what the function do, and what parts of the algorithm do. It would help a lot reading the code, which increases the confidence one can have in the implementation. It's a great thing that there is quite a lot of tests, thanks for taking the time to write many of them. However, I'd also like to be able to see a real usage for new equivalence relation (that was point 2 in my earlier comment). TVM is a compiler, not a tool for just doing algebraic manipulations of mathematical terms like Matlab, so I would really like to see some natural use cases for this, where this get used/integrated into a pass, or into something else that ultimately lead to improvements in the code produced by the compilation of some ML models. More minor thing: Finally, I'd like to know how one is supposed to use this The most important thing for me at this stage for this PR are to add comments to the code, and to show some real use case/integration for this. Thank you for your work! |
Refer to issue #10211. The CSE pass can't handle commutativity because the arith system may not be able to do the commutativity。
The determination of the equality of two expressions (PrimExpr) is to use the method of structured determination, that is, to traverse the hierarchical structure of the two expressions, and to judge while traversing, if the structures of the two expressions are the same, and the smallest If the child nodes (nodes that cannot be traversed, generally such as Var) are the same, the expressions are considered to be the same.
Before performing PrimExpr comparison, a series of rewrite rules will be used to rewrite expressions to solve some operational problems, such as x * y + x * z will be rewritten as x * (y + z), so that to deal with distributivity.
However, commutativity cannot be rewritten due to the characteristics of rewrite (it will fall into an infinite loop). This makes it impossible to compare the equality of some expressions, such as a * b *c != a * c * b, (a * b) * c != a * (b * c).
To solve this problem, one solution is to sort and rewrite the expressions according to the
StructuralHash
of the Var nodes in the expressions before comparing the expressions. If two expressions are equivalent if they satisfy the commutativity, then they will definitely produce equivalent expressions of the same structure after sorting.Under the assumption that the two expressions have the same structure, the determination condition that the two expressions satisfying the commutativity are the same can be further refined to the same set of all elements in the sub-expressions satisfying the commutativity. The sub-expressions satisfying the commutativity in the expression can be grasped by constructing the expression syntax tree. The sub-expressions satisfying the commutativity are the sub-trees of the expression tree whose child nodes are identical (the child nodes are OP).
An example:
Sort and rewrite cse_var_1 and cse_var_2, first extract their first subexpressions a * b *c and b * a * c, and rewrite them as a * b * c, and then extract the sub-expressions again ( a * b *c) + d + e and (a *b * c) + e + d, rewriting the sort as (a * b *c) + d + e, get the same expression.
cc @masahi @Hzfengsy @tqchen @FranckQC