Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Transform] Optional data-flow analysis in RemoveNoOp #13217

Merged
merged 2 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,11 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1);
TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1);
TVM_TRY_REWRITE(x - c1 < 0, x < c1);

TVM_TRY_RECURSIVE_REWRITE(x - 1 < y, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x < y + 1, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x + (-1) < y, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x < y - (-1), x <= y);
// clang-format on
}
return std::move(ret);
Expand Down Expand Up @@ -1886,6 +1891,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
TVM_TRY_REWRITE(x <= y || y < x, ctrue);
TVM_TRY_REWRITE(y < x || x <= y, ctrue);

TVM_TRY_REWRITE(x < y || y < x, x != y);

TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);

Expand Down
117 changes: 77 additions & 40 deletions src/tir/analysis/control_flow_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <algorithm>
#include <numeric>
#include <optional>
#include <queue>
Expand Down Expand Up @@ -819,10 +820,30 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph
return buffer_touch;
}

ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) {
ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits)
: max_revisits_(max_revisits) {
ControlFlowGraphBuilder::Build(this, stmt);
ForwardPropagateKnownValues(max_revisits);
BackwardPropagateUnusedValues(max_revisits);
ForwardPropagateKnownValues();
BackwardPropagateUnusedValues();
}

void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) {
size_t context_index = [&]() {
auto it = control_flow_lookup_.find(store.get());
ICHECK(it != control_flow_lookup_.end())
<< "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor";
return it->second;
}();

auto& touch_points = control_flow_[context_index].touch_points;

touch_points.erase(std::remove_if(touch_points.begin(), touch_points.end(),
[](const BufferTouch& touch) {
return touch.touch_type == BufferTouch::AccessType::Write;
}),
touch_points.end());
ForwardPropagateKnownValues(context_index);
BackwardPropagateUnusedValues(context_index);
}

std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) {
Expand Down Expand Up @@ -1327,33 +1348,38 @@ Array<Var> ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array<Pr
return vars;
}

void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_from) {
// Values to visit when searching. Using a std::set to
// preferentially visit nodes near the start of the control flow.
std::set<size_t> to_visit;

// Map from a block's index
std::unordered_map<size_t, size_t> visit_count_lookup;

// Initiatize the locations to search from, propagating values
// forward from all locations that have a known value.
for (size_t i = 0; i < control_flow_.size(); i++) {
bool has_known_value = false;
for (const auto& touch : control_flow_[i].touch_points) {
if (!HasBufferLoad(touch.value)) {
has_known_value = true;
break;
if (flow_from.has_value()) {
to_visit.insert(flow_from.value());
} else {
// Initiatize the locations to search from, propagating values
// forward from all locations that have a known value.
for (size_t i = 0; i < control_flow_.size(); i++) {
bool has_known_value = false;
for (const auto& touch : control_flow_[i].touch_points) {
if (!HasBufferLoad(touch.value)) {
has_known_value = true;
break;
}
}
}

if (has_known_value) {
to_visit.insert(i);
if (has_known_value) {
to_visit.insert(i);
}
}
}

// Map from a block's index
std::unordered_map<size_t, size_t> visit_count_lookup;

Analyzer analyzer;
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));

analyzer.Bind(iterator_ranges_);
Expand All @@ -1369,7 +1395,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {

// Step 1: Collect known values provided from each predecessor
block.known_at_block_start = [&]() -> BufferState {
if (num_previous_visits >= max_revisits) {
if (num_previous_visits >= max_revisits_) {
return BufferState();
}

Expand Down Expand Up @@ -1437,7 +1463,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {

// Step 2: Collect knowns provided as a result of executing this block
auto post_state = [&]() {
if (num_previous_visits >= max_revisits) {
if (num_previous_visits >= max_revisits_) {
return BufferState();
}
auto post_state = block.known_at_block_start;
Expand All @@ -1459,29 +1485,35 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
}
}

void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {
void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_from) {
// Values to visit when searching. Using a std::set to
// preferentially visit nodes near the end of the control flow.
std::set<size_t> to_visit;

// Map from a block's index
std::unordered_map<size_t, size_t> visit_count_lookup;

// Initiatize the locations to search from, propagating values
// backward from anywhere that performs a write.
for (size_t i = 0; i < control_flow_.size(); i++) {
const auto& touch_points = control_flow_[i].touch_points;
bool performs_write = std::any_of(
touch_points.begin(), touch_points.end(),
[](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
if (performs_write) {
to_visit.insert(i);
if (flow_from.has_value()) {
to_visit.insert(flow_from.value());
} else {
// Initiatize the locations to search from, propagating values
// backward from anywhere that performs a write.
for (size_t i = 0; i < control_flow_.size(); i++) {
const auto& touch_points = control_flow_[i].touch_points;
bool performs_write = std::any_of(
touch_points.begin(), touch_points.end(),
[](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
if (performs_write) {
to_visit.insert(i);
}
}
}

// Map from a block's index
std::unordered_map<size_t, size_t> visit_count_lookup;

Analyzer analyzer;
analyzer.rewrite_simplify.SetEnabledExtensions(
arith::RewriteSimplifier::kTransitivelyProveInequalities);
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));

analyzer.Bind(iterator_ranges_);
analyzer.Bind(free_predicate_parameters_);
Expand All @@ -1496,7 +1528,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {

// Step 1: Collect known unused indices provided by each successor
block.unused_at_block_end = [&]() -> BufferState {
if (num_previous_visits >= max_revisits) {
if (num_previous_visits >= max_revisits_) {
return BufferState();
}
ICHECK_LE(block.successors.size(), 2)
Expand Down Expand Up @@ -1561,7 +1593,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {

// Step 2: Collect knowns provided as a result of executing this block
auto unused_at_block_start = [&]() {
if (num_previous_visits >= max_revisits) {
if (num_previous_visits >= max_revisits_) {
return BufferState();
}
auto prior_state = block.unused_at_block_end;
Expand Down Expand Up @@ -1603,8 +1635,10 @@ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store,
local_analyzer.Bind(free_predicate_parameters_);
local_analyzer.Bind(iterator_ranges_);
local_analyzer.Bind(free_params);
local_analyzer.rewrite_simplify.SetEnabledExtensions(
RewriteSimplifier::kTransitivelyProveInequalities);
local_analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));

PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration();

Expand All @@ -1630,13 +1664,16 @@ PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& con
return it->second;
}();

const auto& control_flow_block = control_flow_[context_index];

PrimExpr constraint = Bool(true);
for (const auto& known : non_buffer_assumptions_) {
constraint = constraint && known;
}
With<ConstraintContext> constraint_context(analyzer, constraint);
With<ConstraintContext> control_flow_scope(analyzer, control_flow_block.scope_predicate);

expr = control_flow_[context_index].known_at_block_start.SubstituteKnownBufferValues(
expr = control_flow_block.known_at_block_start.SubstituteKnownBufferValues(
std::move(expr), axis_var_lookup_, analyzer);

expr = analyzer->Simplify(std::move(expr));
Expand Down
12 changes: 10 additions & 2 deletions src/tir/analysis/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/var.h>

#include <optional>
#include <unordered_map>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -474,13 +475,17 @@ class ControlFlowGraph {

/*! \brief Propagate known values from known BufferStore/assume
* subsequent control flow blocks
*
* \param flow_from If specified, re-flow only from that block.
*/
void ForwardPropagateKnownValues(size_t max_revisits);
void ForwardPropagateKnownValues(std::optional<size_t> flow_from = std::nullopt);

/*! \brief Propagate overwritten/unused indices to preceding control
* flow blocks
*
* \param flow_from If specified, re-flow only from that block.
*/
void BackwardPropagateUnusedValues(size_t max_revisits);
void BackwardPropagateUnusedValues(std::optional<size_t> flow_from = std::nullopt);

struct ControlFlowEdge {
/* \brief The source block of the control flow edge
Expand Down Expand Up @@ -646,6 +651,9 @@ class ControlFlowGraph {
std::vector<PrimExpr> non_buffer_assumptions_;

friend class ControlFlowGraphBuilder;

/*! \brief The maximum number of revisits while flowing constraints */
size_t max_revisits_;
};

} // namespace tir
Expand Down
Loading