Skip to content

Commit

Permalink
Improved the CSE by not commoning at the toplevel redundant computati…
Browse files Browse the repository at this point in the history
…ons that only appear in one of the possible execution path (for instance, only in the then/else branch of an IF statement). Redundant computations that appear only in a specific execution path are now being commoned at the entrance of their specific execution path instead of earlier at the toplevel. Introducing them at the toplevel was an anti-optimization as the redundant computation might not have been comptued at all. Added two additional tests for this too.
  • Loading branch information
FranckQC committed Jan 20, 2022
1 parent 08c3cb1 commit 7afac3f
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/tir/transforms/common_subexpr_elim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
});

// For each computation done (considering them from biggest to smallest)
for (int i = 0; i < semantic_comp_done_by_expr.size(); i++) {
for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];

// The predicate later used (when doing replacements) to select expressions that are
Expand Down Expand Up @@ -377,7 +377,7 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {
});

// For each computation done (considering them from biggest to smallest)
for (int i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i];

// The predicate later used (when doing replacements) to select expressions that are
Expand Down
270 changes: 254 additions & 16 deletions src/tir/transforms/common_subexpr_elim_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ CacheOfComputations ComputationsDoneBy::cache_;
/* ********************************** Class ComputationsDoneBy **********************************
*********************************************************************************************** */

/* This utility class of the CSE pass offers a way of knowing the computations done by a given
/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a
statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr.
This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which
is the number of time that this computation is being computed).
is the number of time that this computation is being seen).
This analysis is used by the CSE pass in order to find potential candidates for being introduced
into new variables (after having merged semantically equivalent computations).
Expand All @@ -65,10 +65,15 @@ CacheOfComputations ComputationsDoneBy::cache_;
analysis can recurse). The user of the class must define these notions of "eligible computation"
and of "nodes that can contain eligibile computations" for his own use case.
- On an statement, this analysis returns the union of all the computations that appear in its
child nodes (ie, the union of the results of the recursive calls).
For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will return (x+y),
(i1+i2) and (a+b) when used with typical predicates.
- On an statement, this analysis often returns the union of all the computations that appear in
its child nodes (ie, the union of the results of the recursive calls).
For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y)
seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates.
On some nodes, it will return something more complicated that uses the intersection of the
computations done by the children nodes.
For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return
(x+y) seen twice but it won't report b-x as is it seen only the else branch.
- On an expression, this analysis returns the expression itself, except if it is not eligible
for being introduced by the CSE pass into a variable according to `is_eligible_computation_`
(often because it's a load node or a function call node for instance), in which case it will
Expand Down Expand Up @@ -100,18 +105,140 @@ CacheOfComputations ComputationsDoneBy::cache_;
*/

/*!
* \brief Does the union of two table of computations.
* \param tableMain One of the two tables. The union will be written into it.
* \param tableAux The other table, which won't change.
* \brief Does the union of two tables of computations.
* \param table_main One of the two tables. The union will be written into it.
* \param table_aux The other table, which won't change.
* \note Does it directly in the first argument A for efficiency, as the union of A and B
* necessarily gives something which contains A, so we avoid its copy.
*/
void UnionOfTablesOfComputations(TableOfComputations& table_main,
const TableOfComputations& table_aux) {
void UnionOf2TablesOfComputations(TableOfComputations& table_main,
const TableOfComputations& table_aux) {
// Adds each element of the second table to the first one
for (const auto& current : table_aux) {
table_main[current.first] += current.second;
}
}

/*!
* \brief Does the union of three tables of computations.
* \param table1 One of the three tables, which won't change.
* \param table2 One of the three tables, which won't change.
* \param table3 One of the three tables, which won't change.
*/
TableOfComputations UnionOf3TablesOfComputations(const TableOfComputations& table1,
const TableOfComputations& table2, const TableOfComputations& table3) {
TableOfComputations result = table1; // Copy needed as the union of 2 writes into its first arg
UnionOf2TablesOfComputations(result, table2);
UnionOf2TablesOfComputations(result, table3);

return result;
}

/*!
* \brief Does the intersection of two tables of computations.
* \param table1 One of the two tables, which won't change.
* \param table2 The other table, which also won't change.
*/
TableOfComputations IntersectionOf2TablesOfComputations(const TableOfComputations& table1,
const TableOfComputations& table2) {
TableOfComputations result;
for (const auto& current : table1) {
auto it = table2.find(current.first);
if (it != table2.end()) {
result[current.first] = current.second + it->second;
}
}
return result;
}

/*!
* \brief Does the intersection of three tables of computations.
* \param table1 One of the three tables, which won't change.
* \param table2 One of the three tables, which won't change.
* \param table3 One of the three tables, which won't change.
*/
TableOfComputations IntersectionOf3TablesOfComputations(const TableOfComputations& table1,
const TableOfComputations& table2, const TableOfComputations& table3) {
TableOfComputations result = IntersectionOf2TablesOfComputations(table1, table2);
result = IntersectionOf2TablesOfComputations(result, table3);
return result;
}

/*!
* \brief Recompute the number of times that each computation in table_main
is being seen in table_bloc1, table_bloc2 and table_bloc3. It sets
each element to the sum of the times it is seen in each individual bloc.
* \param table_main The main table, for which we recompute the counters.
* \param table1 One of the three tables, which won't change.
* \param table2 One of the three tables, which won't change.
* \param table3 One of the three tables, which won't change.
* \note This function is needed because both the intersection (A Inter B) and the union
* (A U B U C) adds the individual counters found in A, B and C. So when we treat for
* instance an If (which contains a Cond, a Then branch and an Else branch),
* it will compute (Then Inter Else) U (Cond Inter Then) U (Cond Inter Else).
* In order to get back to the appripate number (for instance, 3 if seen one time in each
* bloc), it is therefore necessary to recompute the counters afterwards, which is what this
* function does.
*/
void RecomputeNbTimesSeenInThreeBlocs(TableOfComputations& table_main,
const TableOfComputations& table_bloc1, const TableOfComputations& table_bloc2,
const TableOfComputations& table_bloc3) {
// For each element in the main table
for(auto current : table_main) {
// Try to find it in the first bloc
auto it1 = table_bloc1.find(current.first);
if (it1 != table_bloc1.end()) {
// If found, init the counter with the value found in the first bloc
current.second = it1->second;
}

// Try to find it in the second bloc
auto it2 = table_bloc2.find(current.first);
if (it2 != table_bloc2.end()) {
// If found, increase its value by the value found in the second bloc
current.second += it2->second;
}

auto it3 = table_bloc3.find(current.first);
if (it3 != table_bloc3.end()) {
// If found, increase its value by the value found in the third bloc
current.second += it3->second;
}
}
}

/*!
* \brief Builds a table for a node that has three children. A computation will be reported
as being computed if it appears in at least two of the children, i.e. if it will aways be
computed, regardless of the execution path.
* \param table_child1 The table of computations done by the first child.
* \param table_child2 The table of computations done by the second child.
* \param table_child3 The table of computations done by the third child.
* \note This function will be used for obtaining the computations done by If nodes and by For
* nodes, which both have three children.
*/
TableOfComputations BuildTableForThreeChildrenNode(const TableOfComputations& table_child1,
const TableOfComputations& table_child2, const TableOfComputations& table_child3) {
TableOfComputations result;
// We look at what the children have in common
TableOfComputations child2_inter_child3 =
IntersectionOf2TablesOfComputations(table_child2, table_child3);
TableOfComputations child1_inter_child2 =
IntersectionOf2TablesOfComputations(table_child1, table_child2);
TableOfComputations child1_inter_child3 =
IntersectionOf2TablesOfComputations(table_child1, table_child3);

// We do the union of all the things they have in common
result = UnionOf3TablesOfComputations(child2_inter_child3, child1_inter_child2,
child1_inter_child3);

// Now we need to recompute the numbers associated with each computation, because both the
// intersections and the union might have increased the counters which can now be wrong.
RecomputeNbTimesSeenInThreeBlocs(result, table_child1, table_child2, table_child3);

return result;
}

/*!
* \brief Toplevel (static) method for a PrimExpr
* \param expr The expr for which we want to know the computations done
Expand Down Expand Up @@ -215,7 +342,7 @@ void ComputationsDoneBy::VisitExpr(const PrimExpr& expr) {
// We need to do the union with `table_of_computations_` instead of just writing into it,
// because some other childs might have added things into it too. The reason for that is
// that `table_of_computations_` is shared between the child nodes of a given expression.
UnionOfTablesOfComputations(table_of_computations_, it_table_expr->second);
UnionOf2TablesOfComputations(table_of_computations_, it_table_expr->second);
return;
}

Expand All @@ -238,7 +365,7 @@ void ComputationsDoneBy::VisitExpr(const PrimExpr& expr) {
// We need to do the union with `table_of_computations_` instead of just writing into it,
// because some other childs might have added things into it too. The reason for that is
// that `table_of_computations_` is shared between the child nodes of a given expression.
UnionOfTablesOfComputations(table_of_computations_, temp);
UnionOf2TablesOfComputations(table_of_computations_, temp);
return;
}

Expand All @@ -257,7 +384,7 @@ void ComputationsDoneBy::VisitStmt(const Stmt& stmt) {
// We need to do the union with `table_of_computations_` instead of just writing into it,
// because some other childs might have added things into it too. The reason for that is
// that `table_of_computations_` is shared between the child nodes of a given statement.
UnionOfTablesOfComputations(table_of_computations_, it_table_stmt->second);
UnionOf2TablesOfComputations(table_of_computations_, it_table_stmt->second);
return;
}

Expand All @@ -270,7 +397,109 @@ void ComputationsDoneBy::VisitStmt(const Stmt& stmt) {
// We need to do the union with `table_of_computations_` instead of just writing into it,
// because some other childs might have added things into it too. The reason for that is
// that `table_of_computations_` is shared between the child nodes of a given expression.
UnionOfTablesOfComputations(table_of_computations_, temp);
UnionOf2TablesOfComputations(table_of_computations_, temp);
}

/*!
* \brief The method which overrides the specific treatment for an IfThenElseNode
*/
void ComputationsDoneBy::VisitStmt_(const IfThenElseNode* op) {
// We build the computations done by each of its child, but unlike the overridden method we will
// remember each table of computations so that we can at the end compute the needed intersections

// Calls the VisitExpr() method on the `condition` child
VisitExpr(op->condition);
TableOfComputations computations_done_by_cond = table_of_computations_;
// Clear it for not importing the computations of the condition in the computations of the then
table_of_computations_.clear();

// Then calls the VisitStmt() method on the `then_case` child
VisitStmt(op->then_case);
TableOfComputations computations_done_by_then = table_of_computations_;
// Clear it for not importing the computations of the then in the computations of the else
table_of_computations_.clear();

TableOfComputations computations_done_by_else;
if (op->else_case.defined()) {
// And finally calls the VisitStmt() method on the `then_case` child
VisitStmt(op->else_case);
computations_done_by_else = table_of_computations_;
table_of_computations_.clear();
}

// Build a table of computations for this node with three children
table_of_computations_ = BuildTableForThreeChildrenNode(computations_done_by_cond,
computations_done_by_then, computations_done_by_else);

// Copy the `table_of_computations_` into the cache
// for the future queries
const Stmt& ref_to_op = GetRef<Stmt>(op);
cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_;
}

/*!
* \brief The method which overrides the specific treatment for a ForNode
*/
void ComputationsDoneBy::VisitStmt_(const ForNode* op) {
// We build the computations done by each of its child, but unlike the overridden method we will
// remember each table of computations so that we can at the end compute the needed intersections

// Calls the VisitExpr() method on the `min` child
VisitExpr(op->min);
TableOfComputations computations_done_by_min = table_of_computations_;
// Clear it for not importing the computations of the min in the computations of the extent
table_of_computations_.clear();

// Then calls the VisitStmt() method on the `extent` child
VisitExpr(op->extent);
TableOfComputations computations_done_by_extent = table_of_computations_;
// Clear it for not importing the computations of the extent in the computations of the body
table_of_computations_.clear();

TableOfComputations computations_done_by_body;
// And finally calls the VisitStmt() method on the `body` child
VisitStmt(op->body);
computations_done_by_body = table_of_computations_;
table_of_computations_.clear();

// Build a table of computations for this node with three children
table_of_computations_ = BuildTableForThreeChildrenNode(computations_done_by_min,
computations_done_by_extent, computations_done_by_body);

// Copy the `table_of_computations_` into the cache
// for the future queries
const Stmt& ref_to_op = GetRef<Stmt>(op);
cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_;
}

/*!
* \brief The method which overrides the specific treatment for a WhileNode
*/
void ComputationsDoneBy::VisitStmt_(const WhileNode* op) {
// We build the computations done by each of its child, but unlike the overridden method we will
// remember each table of computations so that we can at the end compute the needed intersection

// Calls the VisitExpr() method on the `condition` child
VisitExpr(op->condition);
TableOfComputations computations_done_by_condition = table_of_computations_;
// Clear it for not importing the computations of the min in the computations of the extent
table_of_computations_.clear();

// Then calls the VisitStmt() method on the `body` child
VisitStmt(op->body);
TableOfComputations computations_done_by_body = table_of_computations_;
// Clear it for not importing the computations of the extent in the computations of the body
table_of_computations_.clear();

// Build a table of computations for this node with two children by computing what is
// is common between the two child, i.e. computing their intersection
table_of_computations_ = IntersectionOf2TablesOfComputations(computations_done_by_condition,
computations_done_by_body);

// Copy the `table_of_computations_` into the cache
// for the future queries
const Stmt& ref_to_op = GetRef<Stmt>(op);
cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_;
}

/*!
Expand Down Expand Up @@ -316,7 +545,7 @@ TableOfComputations ComputationsDoneBy::ComputationsDoneByChildrenOf(
// Calls the *dispatcher* (not the overriden method)
computations_done_by.StmtExprVisitor::VisitStmt(stmt);
// So now we can copy table_of_computations_ into the cache for the future queries
// Note : in the table, the computations done by `stmt` is set the the computations done by its
// Note : in the table, the computations done by `stmt` is set to the computations done by its
// children, because that's exactly what we mean by "the computations of a statement".
cache_.cache_stmt_table_computations_[stmt] = computations_done_by.table_of_computations_;

Expand Down Expand Up @@ -459,6 +688,15 @@ void UsesVarName::VisitStmt(const Stmt& stmt) {
/* ********************************** Utility functions for CSE *********************************
*********************************************************************************************** */

void PrintTableOfComputations(const TableOfComputations& table)
{
std::cout << "{" << std::endl;
for(const auto& current : table) {
std::cout << "(" << current.first << ", " << current.second << ")" << std::endl;
}
std::cout << "}" << std::endl;
}

/*!
* \brief Decides if two terms are equal syntactically
*/
Expand Down
Loading

0 comments on commit 7afac3f

Please sign in to comment.