Skip to content

Commit

Permalink
[TIR] CSE pass : Restrict the equivalence to be decided by a normal f…
Browse files Browse the repository at this point in the history
…orm - avoids comparison of terms (apache#11574)

The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion of being equivalent was customizable, and no assumption was made about it**. That means that the implementation of the equivalence test function `EquivalentTerms()` - which was at the moment just calling the syntactical equality test `EqualTerms()` - could be replaced later by a cleverer equality test.

However, having such a generic way of comparing elements meant that in the function `SyntacticToSemanticComputations()`, where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), **the only way was to compare each pair**.
That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them.

**This PR essentially does the following:**

- When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : **instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them**.
That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc).
It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs.

- Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms.
    As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that.

To run it ON, simply do:
`with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):`
before calling `build()`

- When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses `arith::Analyzer::Simplify` as noted by in apache#10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc.

- Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in apache#10544 but which did not make it through.

- Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from apache#10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them).
  • Loading branch information
FranckQC authored and Kathryn-cat committed Jun 10, 2022
1 parent 1023870 commit 363a72b
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 123 deletions.
3 changes: 2 additions & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,10 @@ TVM_DLL Pass LowerVtcmAlloc();
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
* which introduces let-in bindings for duplicated sub-expressions.
* \param enable_cse_tir Whether common subexpression elimination is enabled.
* \param identify_equiv_terms Whether equivalent terms should be identified.
* \return The pass.
*/
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true);
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ def BF16TypeLowering():
return _ffi_api.BF16TypeLowering() # type: ignore


def CommonSubexprElimTIR(enable_cse_tir: bool = True):
def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):
"""Replace redundant computations by new variables.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.CommonSubexprElimTIR(enable_cse_tir) # type: ignore
return _ffi_api.CommonSubexprElimTIR(enable_cse_tir, identify_equiv_terms) # type: ignore


def RewriteUnsafeSelect():
Expand Down
6 changes: 5 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
Expand Down Expand Up @@ -198,6 +199,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
bool enable_equiv_terms_in_cse_tir =
pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value();

// Get any user-added passes
Array<Array<ObjectRef>> add_lower_pass =
Expand Down Expand Up @@ -289,7 +292,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InstrumentBoundCheckers());
}

pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir));
pass_list.push_back(
tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir));

return pass_list;
}
Expand Down
96 changes: 70 additions & 26 deletions src/tir/transforms/common_subexpr_elim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ namespace tir {
to collect them for the CSE pass, but we also won't even want to collect computations
that contain them.
The reason is that reusing such computations would change the semantics of the program,
and therefore before doing any introduction of variable or any reuse of already introduced
and therefore before doing any introduction of var or any reuse of already introduced
variables, we will make sure that the computation being considered is not forbidden, and
that it does not even contain a forbidden computation.
* \param expr The expression to check
Expand Down Expand Up @@ -120,6 +120,42 @@ bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExp
return true;
}

/*!
* \brief Implements an order on pairs (expression,frequency). First attempts to compare them
using the size of the expression. If it is the same, decides something else still
deterministic.
* \param a The first pair
* \param b The second pair
* \return A boolean telling if the first pair `a` comes before the second pair `b`
* \note We need this order to be deterministic in order to have a fully deterministic pass,
* as we will deal with elements that are coming from a hashtable, but the order in which
* they appeared in the hashtable was based on some runtime addresses, so it can potentially
* change with every execution.
*/
bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a,
std::pair<PrimExpr, size_t> b) {
size_t a_size = CalculateExprComplexity(a.first);
size_t b_size = CalculateExprComplexity(b.first);

// Criteria 1 - Size of the expression comes first
// `a` comes before `b` if the size of `a` is bigger
if (a_size > b_size) {
return true;
}
// `a` does NOT come before `b` if the size of `b` is bigger
if (b_size > a_size) {
return false;
}

// Criteria 2 - If they had the same size, use the lexicographic order as a last resort
// as we need a deterministic order
std::stringstream a_stream;
std::stringstream b_stream;
a_stream << a.first;
b_stream << b.first;
return (a_stream.str().compare(b_stream.str()) < 0);
}

/*!
* \brief Generates a new fresh variable, whose name will be cse_var_i.
* \param type_annotation The type of the new variable to generate
Expand Down Expand Up @@ -166,10 +202,12 @@ int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
of the function being analyzed
* \return A new statement where CSE has been performed
*/
Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init,
bool identify_equiv_terms) {
// As this function is being called for each PrimFunc definition, we create a new instance
// for the one we are having now.
CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init,
identify_equiv_terms);
return common_subexpression_eliminator.VisitStmt(stmt);
}

Expand All @@ -179,8 +217,9 @@ Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context&
formal parameters of the function that will be analyzed
*/
CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
const Context& context_init)
: initial_body_(stmt), context_(context_init) {}
const Context& context_init,
bool identify_equiv_terms)
: initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) {}

/*!
* \brief The method which overrides the generic dispatcher of StmtExprMutator.
Expand All @@ -200,39 +239,40 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
// containing *semantic* entities, i.e. where equivalent computations are merged.
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_);

// Sort the vector of semantic entities by decreasing size
std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
});
OrderOnExprAndFrequency);

// For each computation done (considering them from biggest to smallest)
for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];

bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"

// The predicate later used (when doing replacements) to select expressions that are
// equivalent to the current computation (`computation_and_nb.first`)
std::function<bool(const PrimExpr&)> predicate_selector =
[computation_and_nb](const PrimExpr& current_expr) {
[computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
// `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
// that `current_expr` is an eligible computation even if we know that
// `computation_and_nb.first` is eligible by construction, in case that one day the
// equivalence relation would not preserve the eligibility any more (even though that
// would probably be a very weird equivalence).
return (EquivalentTerms(current_expr, computation_and_nb.first) &&
return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
IsEligibleComputation(current_expr));
};

// See if there is a pair (`var`, `value`) in the context where `value` is semantically
// equivalent to `computation_and_nb.first`
auto it_on_var = std::find_if(
context_.begin(), context_.end(),
[computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
[computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
// Note : safe to call value() as we check has_value() just before
return (var_and_value.second.has_value() &&
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
ident_equiv_terms));
});

// Case where we have a perfectly equivalent computation already available in a variable
Expand Down Expand Up @@ -298,7 +338,8 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
// The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
// decreasing size/complexity), and it will only insert at locations > i as the
// direct subexprs are necessarily smaller than the current computation.
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs);
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs,
identify_equiv_terms_);
}
}
// Note : we do not remove the current element, as we never look back in the local vector
Expand Down Expand Up @@ -378,39 +419,40 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
// containing *semantic* entities, i.e. where equivalent computations are merged.
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt);
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_);

// Sort the vector of semantic entities by decreasing size
std::sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(),
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
});
OrderOnExprAndFrequency);

// For each computation done (considering them from biggest to smallest)
for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i];

bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"

// The predicate later used (when doing replacements) to select expressions that are
// equivalent to the current computation (`computation_and_nb.first`)
std::function<bool(const PrimExpr&)> predicate_selector =
[computation_and_nb](const PrimExpr& current_expr) {
[computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
// `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
// that `current_expr` is an eligible computation even if we know that
// `computation_and_nb.first` is eligible by construction, in case that one day the
// equivalence relation would not preserve the eligibility any more (even though that
// would probably be a very weird equivalence).
return (EquivalentTerms(current_expr, computation_and_nb.first) &&
return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
IsEligibleComputation(current_expr));
};

// See if there is a pair (`var`, `value`) in the context where `value` is semantically
// equivalent to `computation_and_nb.first`
auto it_on_var = std::find_if(
context_.begin(), context_.end(),
[computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
[computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
// Note : safe to call value() as we check has_value() just before
return (var_and_value.second.has_value() &&
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
ident_equiv_terms));
});

// Case where we have a perfectly equivalent computation already available in a variable
Expand Down Expand Up @@ -477,7 +519,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {
// The following insertion will maintain `semantic_comp_done_by_stmt` sorted (by
// decreasing size/complexity), and it will only insert at locations > i as the
// direct subexprs are necessarily smaller than the current computation.
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs);
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs,
identify_equiv_terms_);
}
}
// Note : we do not remove the current element, as we never look back in the local vector
Expand Down Expand Up @@ -587,8 +630,8 @@ namespace transform {
* \brief The function which returns the pass for the Common Subexpression Elimination.
* \return The pass for performing CSE.
*/
Pass CommonSubexprElimTIR(bool enable_cse_tir) {
auto pass_func = [enable_cse_tir](PrimFunc f, IRModule m, PassContext ctx) {
Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) {
auto pass_func = [enable_cse_tir, identify_equiv_terms](PrimFunc f, IRModule m, PassContext ctx) {
if (enable_cse_tir) {
auto* n = f.CopyOnWrite();
Context context_init;
Expand All @@ -603,7 +646,8 @@ Pass CommonSubexprElimTIR(bool enable_cse_tir) {

// Do the Common Subexpression Elimination on the body of the function, with the initial
// context that we have prepared
n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init);
n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init,
identify_equiv_terms);
}

return f;
Expand Down
8 changes: 6 additions & 2 deletions src/tir/transforms/common_subexpr_elim.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ using Context = std::vector<std::pair<Var, MaybeValue>>;
class CommonSubexpressionEliminator : public StmtExprMutator {
public:
// Toplevel (static) function
static Stmt PerformCSE(const Stmt& stmt, const Context& context_init);
static Stmt PerformCSE(const Stmt& stmt, const Context& context_init, bool identify_equiv_terms);

PrimExpr VisitExpr(const PrimExpr& expr) override;
Stmt VisitStmt(const Stmt& stmt) override;
Expand All @@ -64,7 +64,8 @@ class CommonSubexpressionEliminator : public StmtExprMutator {

protected:
// Constructor
CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init);
CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init,
bool identify_equiv_terms);

PrimExpr VisitExpr_(const LetNode* op) override;

Expand All @@ -77,9 +78,12 @@ class CommonSubexpressionEliminator : public StmtExprMutator {
int num_last_try_ = 0; // Number of the last variable tried
int nb_var_ = 0; // Number of variables introduced by the CSE pass

bool identify_equiv_terms_ = false;

static bool ForbiddenComputation(const PrimExpr& expr);
static bool IsEligibleComputation(const PrimExpr& expr);
static bool CanContainEligibleComputations(const PrimExpr& expr);
static bool OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b);
Var GenerateNewVar(DataType type_annotation);
};

Expand Down
Loading

0 comments on commit 363a72b

Please sign in to comment.