-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of Common Subexpression Elimination for TIR #9482
Conversation
…pache#703) The goal of this PR is to implement a Common Subexpression Elimination (CSE) pass for TIR, which aims at identifying redundant computations (both within statements and within expressions), and to replace them by a new fresh variable, introduced before the first occurrence of the redundant computation. Note that it does not only try to do commoning on full expressions, but it is also able to do it on subexpressions. For instance, if the program computes the expression (w+x) + (y+z) and the expression (w+x)+u, it will introduce the subexpression (w+x) into a new variable. If we want so, it will be easily possible in the future to make the notion of equivalence between terms more flexible, allowing for instance to identify expressions modulo commutativity (identifying for instance (x+y) with (y+x)), modulo associativity (identifying for instance (x+y)+z with x+(y+z)), etc. Replacing only the function bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) will be the only thing needed in order to do that. The typical way to rewrite it for such extensions would be to compute a canonical representant of a and a canonical representant of b and to then compare them with the strict syntactical equality. The main CSE pass is declared and implemented respectively in the files common_subexpr_elim.h and common_subexpr_elim.cc. The function Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) is a good entry point as it contains many comments about what the pass is doing. The general idea of this pass is that it tries to introduce at the current level (the current root) the computations that are redundant and which are possible to introduce there (they should only contain variables that are in scope). This notion of variables in scope is implemented with a context, which is a vector of pairs (var, MaybeValue). The context is not only used for checking that variables that appear in candidate computations are known at this point, but also for checking if a computation has already been introduced into a variable. For a greater flexibility in the future, there is a strong distinction already in place between : - Syntactic computations, which are maintained in a hashtable which associates expressions (the computations already seen) to size_int (the number of times the computation has been seen). - Semantic entities, which are obtained from the syntactic computations by merging equivalent computations (where this notion of "equivalent" is customizable). Semantic entities are stored into a vector of pairs (expr, size_int) where, again, the number is the number of times that expr or equivalent computations have been seen. The VisitStmt() method starts by computing the syntactic computations (implemented in an auxiliary analysis), then it merges equivalent computations to obtain the semantic computations. Then it sorts these semantic computations from biggest to smallest in order to always consider first the biggest computations. The rest will essentially be a loop over all these candidates, which will stay sorted. When dealing with a candidate computation, there are three cases that can happen: 1 - Rare case A variable in the context already contains this computation. This variable can't have been introduced by the CSE, as we would have performed the replacements at the same time (see case 2). So this is the case where the user himself (or the previous TIR passes) has written something like "let x = A in ...A...A...)" -> In this case, we simply perform the replacements of A with x in the current result. These replacements are done by an auxiliary transform/Mutator, declared and implemented in replace_expr_selected.h and in replace_expr_selected.cc. 2 - Case where we need to introduce the current computation inside a new variable This is the case where all the variables used by the current computation are within scope (i.e. are present in the context) and where our internal heuristic/predicate tells us to introduce this computation into a new variable. -> In this case, a new variable new_var_i is generated, all the locations that use this computation in result are replaced by this fresh variable (using the same auxiliary Mutator mentioned in 1.), and the current result is replaced by let new_var_i = currentComputation in result. 3 - Case where we can't or don't want to introduce this computation inside a new variable This is the case where we either can't introduce the current computation inside a new variable (because it contains variables that are not yet in scope there) or because our internal heuristic/predicate did not want to introduce it. -> In this case, we will compute the direct sub-expressions of the current computation (implemented by an auxiliary analysis), and we will add them to the vector of semantic computations so that they have a chance to be considered later. Note that they are added while still preserving the order. Note that we do not add all the sub-expressions of the current expression but only its direct subexpressions given the fact that we always consider them from biggest to smallest, and given that some candidates are mutually exclusive. Otherwise it would be computationally more intensive and it would pose the problem of cleaning the vector of candidate computations when one of them gets introduced into a variable. Evaluating them lazily by only looking at the direct sub-expressions is at the same time more efficient and simpler. Once the entire vector of semantic computations has been tried, the main function VisitStmt() calls the general dispatcher , which will in turn call the appropriate handlers. The only specific task of overridden handlers will be to update the context appropriately as new variables are introduced into scope (via Let-In, via For loop, etc) or leave the current scope. Thus, they will update the context appropriately before and after the calls to VisitStmt() and VisitExpr() on the child nodes.
Hi @FranckQC, I wanted TIR-level CSE for a long time, so very excited to see this! What I wanted to do is to eliminate common expressions that span across the host and GPU - for example, in GPU Is this PR going to solve my problem? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A quick round of review.
the changes of vta-hw is not necessary
Hi @masahi Unfortunately, on the TIR code snippet that you have uploaded, it seems that the subexpression that is redundant is just a function call. These can't be commoned out into a new variable by the CSE pass as it does not have any guarantee that the function has no side effects, meaning that it will always produce the same outputs for the same inputs. Without this guarantee, commoning out such function calls could change the program's semantics, so it's not done as preserving the semantics of the program is vital. I can imagine that for functions that are guaranteed to not do any side effects (and which are therefore "functions" in the mathematical sense of the term), we could relax this restriction, but that would be an extension to implement in the future. And it would rely on some "NoSideEffect" tag on functions. However, please note that if you had some other redundancies, this CSE pass would common out whatever redundancies you have that are eligible. Does that answers your question? Kind regards. |
Thanks a lot for the review! I'll rollback the pointer to the sub-module in another commit. Many thanks again! |
The pointer to the submodule vta-hw has been rolled-back to its previous state. |
Hi @FranckQC, I am also very interested in the CSE pass. In our circumstances we suffer from duplicate index computations produced by loop unroll, like
We have to depend on target backend optimize abilities which may or maynot optimize them out. It would be great if CSE can handle part of these things in TIR level. |
I have another three questions about the pass.
|
BTW current split host device machinary seems work quite well with common expression bindings! import tvm
from tvm.script import tir as T
@T.prim_func
def func(a: T.handle, b: T.handle, n: T.int32) -> None:
threadIdx_x = T.env_thread("threadIdx.x")
A = T.match_buffer(a, [256], dtype="int32")
B = T.match_buffer(b, [256], dtype="int32")
common_expr = T.var("int32")
# for common_expr in range(n // 8, n // 8 + 1):
with T.let(common_expr, n // 8):
for i in T.serial(0, common_expr):
T.launch_thread(threadIdx_x, 8)
T.store(B.data, i * 8 + threadIdx_x, common_expr + T.load("int32", A.data, i * 8 + threadIdx_x), True)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr({"global_symbol": "main", "target": tvm.target.Target("cuda")}))(mod)
mod = tvm.tir.transform.SplitHostDevice()(mod)
print(mod.script())
# script for result mod
@tvm.script.ir_module
class Module:
@T.prim_func
def main(a: T.handle, b: T.handle, n: T.int32) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "target": None})
A = T.match_buffer(a, [256], dtype="int32")
B = T.match_buffer(b, [256], dtype="int32")
# body
for common_expr in T.serial(n // 8, n // 8 + 1):
for i in T.serial(0, common_expr):
T.evaluate(T.tvm_call_packed("main_kernel0", B.data, A.data, common_expr, i, 8, dtype="int32"))
@T.prim_func
def main_kernel0(B_1: T.Ptr[global T.int32], A_1: T.Ptr[global T.int32], common_expr: T.int32, i: T.int32) -> None:
# function attr dict
T.func_attr({"target": cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32, "tir.noalias": 1, "global_symbol": "main_kernel0", "tir.device_thread_axis": [T.iter_var(threadIdx_x, [0:8], "ThreadIndex", "threadIdx.x")], "tir.is_global_func": 1, "calling_conv": 2})
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
# body
T.launch_thread(threadIdx_x, 8)
T.store(B_1, i * 8 + threadIdx_x, common_expr + T.load("int32", A_1, i * 8 + threadIdx_x), True) |
@FranckQC Thanks, yes absolutely. I can work on extending this pass to support my use case, after this is merged. |
Hi @wrongtest Yes, these kind of redundancies should definitely be commoned out by this new CSE pass.
Do not hesitate to try the pass out, and to let me know if it does what we hope. I'd be happy to help of course if that's needed. Best regards. |
…the CSE pass for the VTA tests, as vta.build() overwrittes the config
The only remaining failures are the VTA ones (both for python3: i386 and for unittest: CPU), which should now work with the trick kindly given by @masahi. |
…{} for the disabled passes instead of a list with [] for some reason
Please have a look again @Hzfengsy @wrongtest I'm merging this this week unless there are other comments @tqchen @junrushao1994 @vinx13 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the clear and complete comments, very much appreciate that.
Echo @Hzfengsy suggestion to switch to TVMScript for your tests -- perhaps just pretty printing what you already constructs would be a short cut.
The code duplication across PrimExpr & Stmnt in CommonSubexpressionElimintor is unfortunate but suspect removing that would only obscure things.
In your experience is this mostly firing on the affine index sub-expressions, or do you see cse over actual data sub-expressions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me and we should merge as is, @masahi do you mind adding some follow up issues to track the TODOs generated during review?
Sure! In addition to #9482 (comment), I think we can refactor our |
I would be very happy to look into that as soon as this is merged. Hopefully the current run of the CI should be the last one! |
Thank you so much for the compliment, I really appreciate it. It makes me happy to know that the code is easy to read! If I recall well I saw quite a lot of indices (mostly from loop unrolling), just like what @wrongtest had here #9482 (comment). Also some indices due to lowering of memory accesses, for instance: And I also recall a lot of random commoning, like: I'll post more if I can find more notes about more interesting commonings performed in test files and models. |
Sorry I forgot to answer to the other parts of your message @mbs-octoml . Many thanks for it by the way!
Yes I didn't know about TVMScript before. When writing the tests, I initially stared at some other test files and got inspiration from them. Unfortunately the ones I've been looking at might not have been the most up-to-date way of doing things, sorry for that! :-(
Yes, I agree that this duplication is a little bit unfortunate. @masahi did pointed it out here. I was also a little bit annoyed with it at the beginning. So I tried to factorize it out a few times, including an attempt described in my answer here. But all my attempt ended up with something much too complicated for what we would gain. In fact, we just happen to want to do almost exactly the same treatment for an expression and for a statement from an algorithmic point of view, but from a data-type point of view, quite a things are still different type-wise. That's a pretty rare situation. In the end, I decided to not force things, and to leave it like that. Many thanks again! |
Thanks @FranckQC @Hzfengsy @wrongtest @mbs-octoml @jroesch this is merged!! |
Follow-up items in #10211 |
* Initial implementation of Common Subexpression Elimination for TIR (apache#703) The goal of this PR is to implement a Common Subexpression Elimination (CSE) pass for TIR, which aims at identifying redundant computations (both within statements and within expressions), and to replace them by a new fresh variable, introduced before the first occurrence of the redundant computation. Note that it does not only try to do commoning on full expressions, but it is also able to do it on subexpressions. For instance, if the program computes the expression (w+x) + (y+z) and the expression (w+x)+u, it will introduce the subexpression (w+x) into a new variable. If we want so, it will be easily possible in the future to make the notion of equivalence between terms more flexible, allowing for instance to identify expressions modulo commutativity (identifying for instance (x+y) with (y+x)), modulo associativity (identifying for instance (x+y)+z with x+(y+z)), etc. Replacing only the function bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) will be the only thing needed in order to do that. The typical way to rewrite it for such extensions would be to compute a canonical representant of a and a canonical representant of b and to then compare them with the strict syntactical equality. The main CSE pass is declared and implemented respectively in the files common_subexpr_elim.h and common_subexpr_elim.cc. The function Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) is a good entry point as it contains many comments about what the pass is doing. The general idea of this pass is that it tries to introduce at the current level (the current root) the computations that are redundant and which are possible to introduce there (they should only contain variables that are in scope). This notion of variables in scope is implemented with a context, which is a vector of pairs (var, MaybeValue). The context is not only used for checking that variables that appear in candidate computations are known at this point, but also for checking if a computation has already been introduced into a variable. For a greater flexibility in the future, there is a strong distinction already in place between : - Syntactic computations, which are maintained in a hashtable which associates expressions (the computations already seen) to size_int (the number of times the computation has been seen). - Semantic entities, which are obtained from the syntactic computations by merging equivalent computations (where this notion of "equivalent" is customizable). Semantic entities are stored into a vector of pairs (expr, size_int) where, again, the number is the number of times that expr or equivalent computations have been seen. The VisitStmt() method starts by computing the syntactic computations (implemented in an auxiliary analysis), then it merges equivalent computations to obtain the semantic computations. Then it sorts these semantic computations from biggest to smallest in order to always consider first the biggest computations. The rest will essentially be a loop over all these candidates, which will stay sorted. When dealing with a candidate computation, there are three cases that can happen: 1 - Rare case A variable in the context already contains this computation. This variable can't have been introduced by the CSE, as we would have performed the replacements at the same time (see case 2). So this is the case where the user himself (or the previous TIR passes) has written something like "let x = A in ...A...A...)" -> In this case, we simply perform the replacements of A with x in the current result. These replacements are done by an auxiliary transform/Mutator, declared and implemented in replace_expr_selected.h and in replace_expr_selected.cc. 2 - Case where we need to introduce the current computation inside a new variable This is the case where all the variables used by the current computation are within scope (i.e. are present in the context) and where our internal heuristic/predicate tells us to introduce this computation into a new variable. -> In this case, a new variable new_var_i is generated, all the locations that use this computation in result are replaced by this fresh variable (using the same auxiliary Mutator mentioned in 1.), and the current result is replaced by let new_var_i = currentComputation in result. 3 - Case where we can't or don't want to introduce this computation inside a new variable This is the case where we either can't introduce the current computation inside a new variable (because it contains variables that are not yet in scope there) or because our internal heuristic/predicate did not want to introduce it. -> In this case, we will compute the direct sub-expressions of the current computation (implemented by an auxiliary analysis), and we will add them to the vector of semantic computations so that they have a chance to be considered later. Note that they are added while still preserving the order. Note that we do not add all the sub-expressions of the current expression but only its direct subexpressions given the fact that we always consider them from biggest to smallest, and given that some candidates are mutually exclusive. Otherwise it would be computationally more intensive and it would pose the problem of cleaning the vector of candidate computations when one of them gets introduced into a variable. Evaluating them lazily by only looking at the direct sub-expressions is at the same time more efficient and simpler. Once the entire vector of semantic computations has been tried, the main function VisitStmt() calls the general dispatcher , which will in turn call the appropriate handlers. The only specific task of overridden handlers will be to update the context appropriately as new variables are introduced into scope (via Let-In, via For loop, etc) or leave the current scope. Thus, they will update the context appropriately before and after the calls to VisitStmt() and VisitExpr() on the child nodes. * Added empty newline at the end of every new file * Rolled-back the pointer to the submodule vta-hw * Improved the CSE by not commoning at the toplevel redundant computations that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too. * Spelling and comment * Improved the CSE by not commoning at the toplevel redundant computations that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too. * Revert "Improved the CSE by not commoning at the toplevel redundant computations that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too." This reverts commit c4138d9. * Fixed reference used for no reason instead of normal variable. * Added comment explaning why we do not need the union/intersection over N tables at the moment (because we would only use it for N=3) * Did most of the changes suggested by upstream * Continued to work on the remarks given on the public repo. * Final remarks addressed, small formatting things, and fixing things reported by the linter * Last linter fix. * Fixing newline * Adding newline missing. * Minor commit for style fo conform with clang-format * Removed trailing space at end of line * And more minor style changes * Fixing style of the python test files * And one more for style in python tests! * This linter is very annoying to force the style of indentation in a comment, in a test file. It makes it harder to read in this case! And that incitates people to not write comments * Deactivate the CSE pass for the lowering tests as it would otherwise do some commoning, and improve the way the CSE recurse + test added for cascade commonings * Fixing new lint offenses * Removing debug statement * Restore other test file to its previous state * One more for the linter... * Linter again, this time for the new test... * again * again... * Deactivating the CSE pass for another lowering test as it does some commoning * Disabling the CSE for the a test for GPU too * Trying to fix a VTA test by disabling the CSE pass for it, as it probably does some commoning * Complying with the linter * Restarting the CI 1/2 * Restarting the CI 2/2 * Restarting CI 1/2 * Restarting CI 2/2 * Slightly reduce size of large pretty printer test, copied from apache@ae98f9e * Trying to resolve the problems on the weird tests * Linter. * Restarting CI which has skipped the MacOS build for no reason 1/2 * Restarting CI which has skipped the MacOS build for no reason 2/2 * Commented buggy tests * Linter... * Restore the VTA tests, and use trick kindly given by Masa to disable the CSE pass for the VTA tests, as vta.build() overwrittes the config * New fix, which this time does not break the doc (VTA uses a set with {} for the disabled passes instead of a list with [] for some reason * More VTA fixes * vta tutorial fix Co-authored-by: Masahiro Masuda <[email protected]>
Hi everyone,
We would like to upstream some work that we did at Qualcomm.
The goal of this PR is to implement a Common Subexpression Elimination (CSE) pass for TIR, which aims at identifying redundant computations (both within statements and within expressions), and to replace them by a new fresh variable, introduced before the first occurrence of the redundant computation.
Note that it does not only try to do commoning on full expressions, but it is also able to do it on subexpressions. For instance, if the program computes the expression
(w+x) + (y+z)
and the expression(w+x)+u
, it will introduce the subexpression(w+x)
into a new variable.If we want so, it will be easily possible in the future to make the notion of equivalence between terms more flexible, allowing for instance to identify expressions modulo commutativity (identifying for instance (x+y) with (y+x)), modulo associativity (identifying for instance (x+y)+z with x+(y+z)), etc. Replacing only the function
bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b)
will be the only thing needed in order to do that. The typical way to rewrite this function for such extensions would be to compute a canonical representant ofa
and a canonical representant ofb
and to then compare them with the strict syntactical equality.The main CSE pass is declared and implemented respectively in the files common_subexpr_elim.h and common_subexpr_elim.cc.
The function
Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt)
is a good entry point as it contains many comments about what the pass is doing.The general idea of this pass is that it tries to introduce at the current level (the current root) the computations that are redundant and which are possible to introduce there (they should only contain variables that are in scope). This notion of variables in scope is implemented with a context, which is a vector of pairs
(var, MaybeValue)
. The context is not only used for checking that variables that appear in candidate computations are known at this point, but also for checking if a computation has already been introduced into a variable.For a greater flexibility in the future, there is a strong distinction already in place between :
(expr, size_int)
where, again, the number is the number of times thatexpr
or equivalent computations have been seen.The
VisitStmt()
method starts by computing the syntactic computations (implemented in an auxiliary analysis), then it merges equivalent computations to obtain the semantic computations. Then it sorts these semantic computations from biggest to smallest in order to always consider first the biggest computations. The rest will essentially be a loop over all these candidates, which will stay sorted.When dealing with a candidate computation, there are three cases that can happen:
1 - Rare case where a variable in the context already contains this computation. This variable can't have been introduced by the CSE, as we would have performed the replacements at the same time (see case 2). So this is the case where the user himself (or the previous TIR passes) has written something like "let x = A in ...A...A...)"
-> In this case, we simply perform the replacements of A with x in the current result. These replacements are done by an auxiliary transform/Mutator, declared and implemented in replace_expr_selected.h and in replace_expr_selected.cc.
2 - Case where we need to introduce the current computation inside a new variable This is the case where all the variables used by the current computation are within scope (i.e. are present in the context) and where our internal heuristic/predicate tells us to introduce this computation into a new variable.
-> In this case, a new variable
new_var_i
is generated, all the locations that use this computation in result are replaced by this fresh variable (using the same auxiliary Mutator mentioned in 1.), and the current result is replaced by let "new_var_i = currentComputation in result".3 - Case where we can't or don't want to introduce this computation inside a new variable This is the case where we either can't introduce the current computation inside a new variable (because it contains variables that are not yet in scope there) or because our internal heuristic/predicate did not want to introduce it.
-> In this case, we will compute the direct sub-expressions of the current computation (implemented by an auxiliary analysis), and we will add them to the vector of semantic computations so that they have a chance to be considered later. Note that they are added while still preserving the order.
Note that we do not add all the sub-expressions of the current expression but only its direct subexpressions given the fact that we always consider them from biggest to smallest, and given that some candidates are mutually exclusive. Otherwise it would be computationally more intensive and it would pose the problem of cleaning the vector of candidate computations when one of them gets introduced into a variable. Evaluating them lazily by only looking at the direct sub-expressions is at the same time more efficient and simpler.
Once the entire vector of semantic computations has been tried, the main function
VisitStmt()
calls the general dispatcher, which will in turn call the appropriate handlers. The only specific task of the overridden handlers will be to update the context appropriately as new variables are introduced into scope (via Let-In, via For loops, etc) or leave the current scope. Thus, they will update the context appropriately before and after the calls toVisitStmt()
andVisitExpr()
on the child nodes.For more details, this new pass has been presented at the TVM Con 2021 ( https://www.tvmcon.org/events/common-subexpression-elimination-for-tir/ ) and thanks to the organizers of the conference the video is now available on Youtube here : https://www.youtube.com/watch?v=Iuio-oJSOv0 .
Please do not hesitate if you have any question.
Thank you.