Skip to content

Commit

Permalink
[TIR][Transform] Optional data-flow analysis in RemoveNoOp
Browse files Browse the repository at this point in the history
Previously, `RemoveNoOp` would remove statements that could be locally
analyzed as having no effect (e.g. `For` with empty loop extents).
This commit adds opt-in use of data-flow analysis to identify
two types of statements that are no-ops based on their context:

* Buffer stores that are overwritten without ever being read.

  ```python
  buf[i] = 5 # Overwritten by next statement
  buf[i] = 10
  ```

* Storing a value that is already known to be present.

  ```python
  buf[0:16] = T.ramp(0, 16, 1)
  buf[5] = 5 # Previous load already stored this value
  ```
  • Loading branch information
Lunderberg committed Nov 22, 2022
1 parent e662970 commit aa4206d
Show file tree
Hide file tree
Showing 6 changed files with 795 additions and 91 deletions.
7 changes: 7 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,11 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1);
TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1);
TVM_TRY_REWRITE(x - c1 < 0, x < c1);

TVM_TRY_RECURSIVE_REWRITE(x - 1 < y, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x < y + 1, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x + (-1) < y, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x < y - (-1), x <= y);
// clang-format on
}
return std::move(ret);
Expand Down Expand Up @@ -1886,6 +1891,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
TVM_TRY_REWRITE(x <= y || y < x, ctrue);
TVM_TRY_REWRITE(y < x || x <= y, ctrue);

TVM_TRY_REWRITE(x < y || y < x, x != y);

TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);

Expand Down
117 changes: 77 additions & 40 deletions src/tir/analysis/control_flow_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <algorithm>
#include <numeric>
#include <optional>
#include <queue>
Expand Down Expand Up @@ -819,10 +820,30 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph
return buffer_touch;
}

ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) {
ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits)
: max_revisits_(max_revisits) {
ControlFlowGraphBuilder::Build(this, stmt);
ForwardPropagateKnownValues(max_revisits);
BackwardPropagateUnusedValues(max_revisits);
ForwardPropagateKnownValues();
BackwardPropagateUnusedValues();
}

void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) {
size_t context_index = [&]() {
auto it = control_flow_lookup_.find(store.get());
ICHECK(it != control_flow_lookup_.end())
<< "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor";
return it->second;
}();

auto& touch_points = control_flow_[context_index].touch_points;

touch_points.erase(std::remove_if(touch_points.begin(), touch_points.end(),
[](const BufferTouch& touch) {
return touch.touch_type == BufferTouch::AccessType::Write;
}),
touch_points.end());
ForwardPropagateKnownValues(context_index);
BackwardPropagateUnusedValues(context_index);
}

std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) {
Expand Down Expand Up @@ -1327,33 +1348,38 @@ Array<Var> ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array<Pr
return vars;
}

void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_from) {
// Values to visit when searching. Using a std::set to
// preferentially visit nodes near the start of the control flow.
std::set<size_t> to_visit;

// Map from a block's index
std::unordered_map<size_t, size_t> visit_count_lookup;

// Initiatize the locations to search from, propagating values
// forward from all locations that have a known value.
for (size_t i = 0; i < control_flow_.size(); i++) {
bool has_known_value = false;
for (const auto& touch : control_flow_[i].touch_points) {
if (!HasBufferLoad(touch.value)) {
has_known_value = true;
break;
if (flow_from.has_value()) {
to_visit.insert(flow_from.value());
} else {
// Initiatize the locations to search from, propagating values
// forward from all locations that have a known value.
for (size_t i = 0; i < control_flow_.size(); i++) {
bool has_known_value = false;
for (const auto& touch : control_flow_[i].touch_points) {
if (!HasBufferLoad(touch.value)) {
has_known_value = true;
break;
}
}
}

if (has_known_value) {
to_visit.insert(i);
if (has_known_value) {
to_visit.insert(i);
}
}
}

// Map from a block's index
std::unordered_map<size_t, size_t> visit_count_lookup;

Analyzer analyzer;
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));

analyzer.Bind(iterator_ranges_);
Expand All @@ -1369,7 +1395,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {

// Step 1: Collect known values provided from each predecessor
block.known_at_block_start = [&]() -> BufferState {
if (num_previous_visits >= max_revisits) {
if (num_previous_visits >= max_revisits_) {
return BufferState();
}

Expand Down Expand Up @@ -1437,7 +1463,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {

// Step 2: Collect knowns provided as a result of executing this block
auto post_state = [&]() {
if (num_previous_visits >= max_revisits) {
if (num_previous_visits >= max_revisits_) {
return BufferState();
}
auto post_state = block.known_at_block_start;
Expand All @@ -1459,29 +1485,35 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
}
}

void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {
void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_from) {
// Values to visit when searching. Using a std::set to
// preferentially visit nodes near the end of the control flow.
std::set<size_t> to_visit;

// Map from a block's index
std::unordered_map<size_t, size_t> visit_count_lookup;

// Initiatize the locations to search from, propagating values
// backward from anywhere that performs a write.
for (size_t i = 0; i < control_flow_.size(); i++) {
const auto& touch_points = control_flow_[i].touch_points;
bool performs_write = std::any_of(
touch_points.begin(), touch_points.end(),
[](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
if (performs_write) {
to_visit.insert(i);
if (flow_from.has_value()) {
to_visit.insert(flow_from.value());
} else {
// Initiatize the locations to search from, propagating values
// backward from anywhere that performs a write.
for (size_t i = 0; i < control_flow_.size(); i++) {
const auto& touch_points = control_flow_[i].touch_points;
bool performs_write = std::any_of(
touch_points.begin(), touch_points.end(),
[](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
if (performs_write) {
to_visit.insert(i);
}
}
}

// Map from a block's index
std::unordered_map<size_t, size_t> visit_count_lookup;

Analyzer analyzer;
analyzer.rewrite_simplify.SetEnabledExtensions(
arith::RewriteSimplifier::kTransitivelyProveInequalities);
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));

analyzer.Bind(iterator_ranges_);
analyzer.Bind(free_predicate_parameters_);
Expand All @@ -1496,7 +1528,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {

// Step 1: Collect known unused indices provided by each successor
block.unused_at_block_end = [&]() -> BufferState {
if (num_previous_visits >= max_revisits) {
if (num_previous_visits >= max_revisits_) {
return BufferState();
}
ICHECK_LE(block.successors.size(), 2)
Expand Down Expand Up @@ -1561,7 +1593,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {

// Step 2: Collect knowns provided as a result of executing this block
auto unused_at_block_start = [&]() {
if (num_previous_visits >= max_revisits) {
if (num_previous_visits >= max_revisits_) {
return BufferState();
}
auto prior_state = block.unused_at_block_end;
Expand Down Expand Up @@ -1603,8 +1635,10 @@ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store,
local_analyzer.Bind(free_predicate_parameters_);
local_analyzer.Bind(iterator_ranges_);
local_analyzer.Bind(free_params);
local_analyzer.rewrite_simplify.SetEnabledExtensions(
RewriteSimplifier::kTransitivelyProveInequalities);
local_analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));

PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration();

Expand All @@ -1630,13 +1664,16 @@ PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& con
return it->second;
}();

const auto& control_flow_block = control_flow_[context_index];

PrimExpr constraint = Bool(true);
for (const auto& known : non_buffer_assumptions_) {
constraint = constraint && known;
}
With<ConstraintContext> constraint_context(analyzer, constraint);
With<ConstraintContext> control_flow_scope(analyzer, control_flow_block.scope_predicate);

expr = control_flow_[context_index].known_at_block_start.SubstituteKnownBufferValues(
expr = control_flow_block.known_at_block_start.SubstituteKnownBufferValues(
std::move(expr), axis_var_lookup_, analyzer);

expr = analyzer->Simplify(std::move(expr));
Expand Down
12 changes: 10 additions & 2 deletions src/tir/analysis/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/var.h>

#include <optional>
#include <unordered_map>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -474,13 +475,17 @@ class ControlFlowGraph {

/*! \brief Propagate known values from known BufferStore/assume
* subsequent control flow blocks
*
* \param flow_from If specified, re-flow only from that block.
*/
void ForwardPropagateKnownValues(size_t max_revisits);
void ForwardPropagateKnownValues(std::optional<size_t> flow_from = std::nullopt);

/*! \brief Propagate overwritten/unused indices to preceding control
* flow blocks
*
* \param flow_from If specified, re-flow only from that block.
*/
void BackwardPropagateUnusedValues(size_t max_revisits);
void BackwardPropagateUnusedValues(std::optional<size_t> flow_from = std::nullopt);

struct ControlFlowEdge {
/* \brief The source block of the control flow edge
Expand Down Expand Up @@ -646,6 +651,9 @@ class ControlFlowGraph {
std::vector<PrimExpr> non_buffer_assumptions_;

friend class ControlFlowGraphBuilder;

/*! \brief The maximum number of revisits while flowing constraints */
size_t max_revisits_;
};

} // namespace tir
Expand Down
Loading

0 comments on commit aa4206d

Please sign in to comment.