Skip to content

Commit

Permalink
Implementation of Common Subexpression Elimination for TIR (#9482)
Browse files Browse the repository at this point in the history
* Initial implementation of Common Subexpression Elimination for TIR (#703)

The goal of this PR is to implement a Common Subexpression Elimination (CSE) pass for TIR, which aims at identifying redundant computations (both within statements and within expressions), and to replace them by a new fresh variable, introduced before the first occurrence of the redundant computation.

Note that it does not only try to do commoning on full expressions, but it is also able to do it on subexpressions. For instance, if the program computes the expression (w+x) + (y+z) and the expression (w+x)+u, it will introduce the subexpression (w+x) into a new variable.

If we want so, it will be easily possible in the future to make the notion of equivalence between terms more flexible, allowing for instance to identify expressions modulo commutativity (identifying for instance (x+y) with (y+x)), modulo associativity (identifying for instance (x+y)+z with x+(y+z)), etc. Replacing only the function bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) will be the only thing needed in order to do that. The typical way to rewrite it for such extensions would be to compute a canonical representant of a and a canonical representant of b and to then compare them with the strict syntactical equality.

The main CSE pass is declared and implemented respectively in the files common_subexpr_elim.h and common_subexpr_elim.cc.
The function Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) is a good entry point as it contains many comments about what the pass is doing.

The general idea of this pass is that it tries to introduce at the current level (the current root) the computations that are redundant and which are possible to introduce there (they should only contain variables that are in scope). This notion of variables in scope is implemented with a context, which is a vector of pairs (var, MaybeValue). The context is not only used for checking that variables that appear in candidate computations are known at this point, but also for checking if a computation has already been introduced into a variable.

For a greater flexibility in the future, there is a strong distinction already in place between :

    - Syntactic computations, which are maintained in a hashtable which associates expressions (the computations already seen) to size_int (the number of times the computation has been seen).
    - Semantic entities, which are obtained from the syntactic computations by merging equivalent computations (where this notion of "equivalent" is customizable). Semantic entities are stored into a vector of pairs (expr, size_int) where, again, the number is the number of times that expr or equivalent computations have been seen.

The VisitStmt() method starts by computing the syntactic computations (implemented in an auxiliary analysis), then it merges equivalent computations to obtain the semantic computations. Then it sorts these semantic computations from biggest to smallest in order to always consider first the biggest computations. The rest will essentially be a loop over all these candidates, which will stay sorted.

When dealing with a candidate computation, there are three cases that can happen:

    1 - Rare case A variable in the context already contains this computation. This variable can't have been introduced by the CSE, as we would have performed the replacements at the same time (see case 2). So this is the case where the user himself (or the previous TIR passes) has written something like "let x = A in ...A...A...)"
    -> In this case, we simply perform the replacements of A with x in the current result. These replacements are done by an auxiliary transform/Mutator, declared and implemented in replace_expr_selected.h and in replace_expr_selected.cc.

    2 - Case where we need to introduce the current computation inside a new variable This is the case where all the variables used by the current computation are within scope (i.e. are present in the context) and where our internal heuristic/predicate tells us to introduce this computation into a new variable.
    -> In this case, a new variable new_var_i is generated, all the locations that use this computation in result are replaced by this fresh variable (using the same auxiliary Mutator mentioned in 1.), and the current result is replaced by let new_var_i = currentComputation in result.

    3 - Case where we can't or don't want to introduce this computation inside a new variable This is the case where we either can't introduce the current computation inside a new variable (because it contains variables that are not yet in scope there) or because our internal heuristic/predicate did not want to introduce it.
    -> In this case, we will compute the direct sub-expressions of the current computation (implemented by an auxiliary analysis), and we will add them to the vector of semantic computations so that they have a chance to be considered later. Note that they are added while still preserving the order.
    Note that we do not add all the sub-expressions of the current expression but only its direct subexpressions given the fact that we always consider them from biggest to smallest, and given that some candidates are mutually exclusive. Otherwise it would be computationally more intensive and it would pose the problem of cleaning the vector of candidate computations when one of them gets introduced into a variable. Evaluating them lazily by only looking at the direct sub-expressions is at the same time more efficient and simpler.

Once the entire vector of semantic computations has been tried, the main function VisitStmt() calls the general dispatcher , which will in turn call the appropriate handlers. The only specific task of overridden handlers will be to update the context appropriately as new variables are introduced into scope (via Let-In, via For loop, etc) or leave the current scope. Thus, they will update the context appropriately before and after the calls to VisitStmt() and VisitExpr() on the child nodes.

* Added empty newline at the end of every new file

* Rolled-back the pointer to the submodule vta-hw

* Improved the CSE by not commoning at the toplevel redundant computations that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too.

* Spelling and comment

* Improved the CSE by not commoning at the toplevel redundant computations that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too.

* Revert "Improved the CSE by not commoning at the toplevel redundant computations that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too."

This reverts commit c4138d9.

* Fixed reference used for no reason instead of normal variable.

* Added comment explaning why we do not need the union/intersection over N tables at the moment (because we would only use it for N=3)

* Did most of the changes suggested by upstream

* Continued to work on the remarks given on the public repo.

* Final remarks addressed, small formatting things, and fixing things reported by the linter

* Last linter fix.

* Fixing newline

* Adding newline missing.

* Minor commit for style fo conform with clang-format

* Removed trailing space at end of line

* And more minor style changes

* Fixing style of the python test files

* And one more for style in python tests!

* This linter is very annoying to force the style of indentation in a comment, in a test file. It makes it harder to read in this case! And that incitates people to not write comments

* Deactivate the CSE pass for the lowering tests as it would otherwise do some commoning, and improve the way the CSE recurse + test added for cascade commonings

* Fixing new lint offenses

* Removing debug statement

* Restore other test file to its previous state

* One more for the linter...

* Linter again, this time for the new test...

* again

* again...

* Deactivating the CSE pass for another lowering test as it does some commoning

* Disabling the CSE for the a test for GPU too

* Trying to fix a VTA test by disabling the CSE pass for it, as it probably does some commoning

* Complying with the linter

* Restarting the CI 1/2

* Restarting the CI 2/2

* Restarting CI 1/2

* Restarting CI 2/2

* Slightly reduce size of large pretty printer test, copied from ae98f9e

* Trying to resolve the problems on the weird tests

* Linter.

* Restarting CI which has skipped the MacOS build for no reason 1/2

* Restarting CI which has skipped the MacOS build for no reason 2/2

* Commented buggy tests

* Linter...

* Restore the VTA tests, and use trick kindly given  by Masa to disable the CSE pass for the VTA tests, as vta.build() overwrittes the config

* New fix, which this time does not break the doc (VTA uses a set with {} for the disabled passes instead of a list with [] for some reason

* More VTA fixes

* vta tutorial fix

Co-authored-by: Masahiro Masuda <[email protected]>
  • Loading branch information
FranckQC and masahi authored Feb 10, 2022
1 parent 222152b commit 09f7be2
Show file tree
Hide file tree
Showing 25 changed files with 2,544 additions and 65 deletions.
8 changes: 8 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,14 @@ TVM_DLL Pass FlattenBuffer();
*/
TVM_DLL Pass TextureFlatten();

/*!
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
* which introduces let-in bindings for duplicated sub-expressions.
* \param enable_cse_tir Whether common subexpression elimination is enabled.
* \return The pass.
*/
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true);

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
* "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,17 @@ def BF16TypeLowering():
return _ffi_api.BF16TypeLowering() # type: ignore


def CommonSubexprElimTIR(enable_cse_tir: bool = True):
"""Replace redundant computations by new variables.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.CommonSubexprElimTIR(enable_cse_tir) # type: ignore


def RewriteUnsafeSelect():
"""Detect and rewrite unsafe select that contains memory access.
Expand Down
5 changes: 5 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
Expand Down Expand Up @@ -196,6 +197,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_ctx->GetConfig<Bool>("tir.disable_storage_rewrite", Bool(false)).value();
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();

// Get any user-added passes
Array<Array<ObjectRef>> add_lower_pass =
Expand Down Expand Up @@ -283,6 +285,9 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
if (instrument_bound_checkers) {
pass_list.push_back(tir::transform::InstrumentBoundCheckers());
}

pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir));

return pass_list;
}

Expand Down
98 changes: 98 additions & 0 deletions src/tir/analysis/check_contains.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file check_contains.cc
* \brief Implementation of the analysis that tells if an expression contains
a node that satisfies a given predicate.
*/

#include "check_contains.h"

#include <tvm/tir/expr.h>

#include <vector>

namespace tvm {
namespace tir {

/*!
* \brief Toplevel (static) function that tells if an expression contains a subexpression that
satisfies a given predicate.
* \param expr The expression to check
* \param predicate The predicate that must be satisfied
* \return Whether `expr` contains a subexpression that satisfies `predicate`
*/
bool CheckContains::ExprContains(const PrimExpr& expr,
std::function<bool(const PrimExpr&)> predicate) {
CheckContains check_contains(predicate);
check_contains.VisitExpr(expr);
return check_contains.contains_it_;
}

/*!
* \brief Toplevel (static) function that tells if a statement contains a subexpression that
satisfies a given predicate.
* \param stmt The statement to check
* \param predicate The predicate that must be satisfied
* \return Whether `stmt` contains a subexpression that satisfies `predicate`
*/
bool CheckContains::StmtContains(const Stmt& stmt, std::function<bool(const PrimExpr&)> predicate) {
CheckContains check_contains(predicate);
check_contains.VisitStmt(stmt);
return check_contains.contains_it_;
}

/*!
* \brief Protected constructor of CheckContains.
* \param predicate The predicate that must be satisfied
*/
CheckContains::CheckContains(std::function<bool(const PrimExpr&)> predicate)
: predicate_(predicate) {}

/*!
* \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions.
* \param expr The expression to visit
*/
void CheckContains::VisitExpr(const PrimExpr& expr) {
// If the predicate holds on `expr`, we know `expr` contains something which makes
// the predicate hold
if (predicate_(expr)) {
contains_it_ = true;
} else {
// Otherwise we continue to look for it recursively by calling the dispatcher
StmtExprVisitor::VisitExpr(expr);
}
}

/*!
* \brief The method which overrides the generic dispatcher of StmtExprVisitor for statements.
* \param stmt The statement to visit
*/
void CheckContains::VisitStmt(const Stmt& stmt) {
// We keep exploring only if `contains_it_` is false
if (!contains_it_) {
// and in order to do that we call the general dispatcher
StmtExprVisitor::VisitStmt(stmt);
}
// As otherwise we already have our answer
}

} // namespace tir
} // namespace tvm
60 changes: 60 additions & 0 deletions src/tir/analysis/check_contains.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file check_contains.h
* \brief Interface of the analysis that tells if an expression contains
a node that satisfies a given predicate.
*/

#ifndef TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_
#define TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_

#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h> // For the class StmtExprVisitor

namespace tvm {
namespace tir {

/*!
* \brief Visitor which tells if a given expression or statement contains a subexpression
that satisfies a given predicate
*/
class CheckContains : public StmtExprVisitor {
public:
// Toplevel (static) functions
static bool ExprContains(const PrimExpr& expr, std::function<bool(const PrimExpr&)> predicate);
static bool StmtContains(const Stmt& stmt, std::function<bool(const PrimExpr&)> predicate);

protected:
// Constructor
explicit CheckContains(std::function<bool(const PrimExpr&)> predicate);

void VisitExpr(const PrimExpr& expr) override;
void VisitStmt(const Stmt& stmt) override;

private:
std::function<bool(const PrimExpr&)> predicate_;
bool contains_it_ = false;
};

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_
Loading

0 comments on commit 09f7be2

Please sign in to comment.