diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index e6d876cf5aa8..90c448f4ea5c 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1644,6 +1644,11 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1); TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1); TVM_TRY_REWRITE(x - c1 < 0, x < c1); + + TVM_TRY_RECURSIVE_REWRITE(x - 1 < y, x <= y); + TVM_TRY_RECURSIVE_REWRITE(x < y + 1, x <= y); + TVM_TRY_RECURSIVE_REWRITE(x + (-1) < y, x <= y); + TVM_TRY_RECURSIVE_REWRITE(x < y - (-1), x <= y); // clang-format on } return std::move(ret); @@ -1886,6 +1891,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { TVM_TRY_REWRITE(x <= y || y < x, ctrue); TVM_TRY_REWRITE(y < x || x <= y, ctrue); + TVM_TRY_REWRITE(x < y || y < x, x != y); + TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value); TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value); diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 42c5c8bb82d5..2e537450d232 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -819,10 +820,30 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph return buffer_touch; } -ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) { +ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) + : max_revisits_(max_revisits) { ControlFlowGraphBuilder::Build(this, stmt); - ForwardPropagateKnownValues(max_revisits); - BackwardPropagateUnusedValues(max_revisits); + ForwardPropagateKnownValues(); + BackwardPropagateUnusedValues(); +} + +void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) { + size_t context_index = [&]() { + auto it = control_flow_lookup_.find(store.get()); + ICHECK(it != control_flow_lookup_.end()) + << "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor"; + return it->second; + }(); + + auto& touch_points = control_flow_[context_index].touch_points; + + touch_points.erase(std::remove_if(touch_points.begin(), touch_points.end(), + [](const BufferTouch& touch) { + return touch.touch_type == BufferTouch::AccessType::Write; + }), + touch_points.end()); + ForwardPropagateKnownValues(context_index); + BackwardPropagateUnusedValues(context_index); } std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) { @@ -1327,33 +1348,38 @@ Array ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array flow_from) { // Values to visit when searching. Using a std::set to // preferentially visit nodes near the start of the control flow. std::set to_visit; - // Map from a block's index - std::unordered_map visit_count_lookup; - - // Initiatize the locations to search from, propagating values - // forward from all locations that have a known value. - for (size_t i = 0; i < control_flow_.size(); i++) { - bool has_known_value = false; - for (const auto& touch : control_flow_[i].touch_points) { - if (!HasBufferLoad(touch.value)) { - has_known_value = true; - break; + if (flow_from.has_value()) { + to_visit.insert(flow_from.value()); + } else { + // Initiatize the locations to search from, propagating values + // forward from all locations that have a known value. + for (size_t i = 0; i < control_flow_.size(); i++) { + bool has_known_value = false; + for (const auto& touch : control_flow_[i].touch_points) { + if (!HasBufferLoad(touch.value)) { + has_known_value = true; + break; + } } - } - if (has_known_value) { - to_visit.insert(i); + if (has_known_value) { + to_visit.insert(i); + } } } + // Map from a block's index + std::unordered_map visit_count_lookup; + Analyzer analyzer; analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( arith::RewriteSimplifier::kTransitivelyProveInequalities | + arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); analyzer.Bind(iterator_ranges_); @@ -1369,7 +1395,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) { // Step 1: Collect known values provided from each predecessor block.known_at_block_start = [&]() -> BufferState { - if (num_previous_visits >= max_revisits) { + if (num_previous_visits >= max_revisits_) { return BufferState(); } @@ -1437,7 +1463,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) { // Step 2: Collect knowns provided as a result of executing this block auto post_state = [&]() { - if (num_previous_visits >= max_revisits) { + if (num_previous_visits >= max_revisits_) { return BufferState(); } auto post_state = block.known_at_block_start; @@ -1459,29 +1485,35 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) { } } -void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) { +void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional flow_from) { // Values to visit when searching. Using a std::set to // preferentially visit nodes near the end of the control flow. std::set to_visit; - // Map from a block's index - std::unordered_map visit_count_lookup; - - // Initiatize the locations to search from, propagating values - // backward from anywhere that performs a write. - for (size_t i = 0; i < control_flow_.size(); i++) { - const auto& touch_points = control_flow_[i].touch_points; - bool performs_write = std::any_of( - touch_points.begin(), touch_points.end(), - [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; }); - if (performs_write) { - to_visit.insert(i); + if (flow_from.has_value()) { + to_visit.insert(flow_from.value()); + } else { + // Initiatize the locations to search from, propagating values + // backward from anywhere that performs a write. + for (size_t i = 0; i < control_flow_.size(); i++) { + const auto& touch_points = control_flow_[i].touch_points; + bool performs_write = std::any_of( + touch_points.begin(), touch_points.end(), + [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; }); + if (performs_write) { + to_visit.insert(i); + } } } + // Map from a block's index + std::unordered_map visit_count_lookup; + Analyzer analyzer; - analyzer.rewrite_simplify.SetEnabledExtensions( - arith::RewriteSimplifier::kTransitivelyProveInequalities); + analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( + arith::RewriteSimplifier::kTransitivelyProveInequalities | + arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | + arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); analyzer.Bind(iterator_ranges_); analyzer.Bind(free_predicate_parameters_); @@ -1496,7 +1528,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) { // Step 1: Collect known unused indices provided by each successor block.unused_at_block_end = [&]() -> BufferState { - if (num_previous_visits >= max_revisits) { + if (num_previous_visits >= max_revisits_) { return BufferState(); } ICHECK_LE(block.successors.size(), 2) @@ -1561,7 +1593,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) { // Step 2: Collect knowns provided as a result of executing this block auto unused_at_block_start = [&]() { - if (num_previous_visits >= max_revisits) { + if (num_previous_visits >= max_revisits_) { return BufferState(); } auto prior_state = block.unused_at_block_end; @@ -1603,8 +1635,10 @@ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, local_analyzer.Bind(free_predicate_parameters_); local_analyzer.Bind(iterator_ranges_); local_analyzer.Bind(free_params); - local_analyzer.rewrite_simplify.SetEnabledExtensions( - RewriteSimplifier::kTransitivelyProveInequalities); + local_analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( + arith::RewriteSimplifier::kTransitivelyProveInequalities | + arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | + arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration(); @@ -1630,13 +1664,16 @@ PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& con return it->second; }(); + const auto& control_flow_block = control_flow_[context_index]; + PrimExpr constraint = Bool(true); for (const auto& known : non_buffer_assumptions_) { constraint = constraint && known; } With constraint_context(analyzer, constraint); + With control_flow_scope(analyzer, control_flow_block.scope_predicate); - expr = control_flow_[context_index].known_at_block_start.SubstituteKnownBufferValues( + expr = control_flow_block.known_at_block_start.SubstituteKnownBufferValues( std::move(expr), axis_var_lookup_, analyzer); expr = analyzer->Simplify(std::move(expr)); diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h index aa9023ba29dd..590392cf658a 100644 --- a/src/tir/analysis/control_flow_graph.h +++ b/src/tir/analysis/control_flow_graph.h @@ -29,6 +29,7 @@ #include #include +#include #include #include #include @@ -474,13 +475,17 @@ class ControlFlowGraph { /*! \brief Propagate known values from known BufferStore/assume * subsequent control flow blocks + * + * \param flow_from If specified, re-flow only from that block. */ - void ForwardPropagateKnownValues(size_t max_revisits); + void ForwardPropagateKnownValues(std::optional flow_from = std::nullopt); /*! \brief Propagate overwritten/unused indices to preceding control * flow blocks + * + * \param flow_from If specified, re-flow only from that block. */ - void BackwardPropagateUnusedValues(size_t max_revisits); + void BackwardPropagateUnusedValues(std::optional flow_from = std::nullopt); struct ControlFlowEdge { /* \brief The source block of the control flow edge @@ -646,6 +651,9 @@ class ControlFlowGraph { std::vector non_buffer_assumptions_; friend class ControlFlowGraphBuilder; + + /*! \brief The maximum number of revisits while flowing constraints */ + size_t max_revisits_; }; } // namespace tir diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 41250408a7f2..71faca9308d2 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -29,21 +29,71 @@ #include #include +#include #include #include "../../arith/const_fold.h" +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../analysis/control_flow_graph.h" #include "ir_utils.h" namespace tvm { namespace tir { +struct RemoveNoOpConfigNode : public tvm::AttrsNode { + bool use_dataflow_analysis; + + TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") { + TVM_ATTR_FIELD(use_dataflow_analysis) + .describe( + "If true, known buffer values are propagated and used " + "to statically prove statements as no-ops.") + .set_default(false); + } +}; + +class RemoveNoOpConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RemoveNoOpConfig, Attrs, RemoveNoOpConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(RemoveNoOpConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.RemoveNoOp", RemoveNoOpConfig); + // Mark the statement of each stage. -class NoOpRemover : public StmtMutator { +class NoOpRemover : public arith::IRMutatorWithAnalyzer { public: + static Stmt Apply(Stmt stmt, arith::Analyzer* analyzer, + std::optional touch_pattern, const StmtNode* context) { + NoOpRemover visitor(analyzer, touch_pattern, context); + return visitor(std::move(stmt)); + } + + private: + using Parent = IRMutatorWithAnalyzer; + using Parent::VisitStmt; + using Parent::VisitStmt_; + + NoOpRemover(arith::Analyzer* analyzer, std::optional touch_pattern, + const StmtNode* context) + : Parent(analyzer), touch_pattern_(touch_pattern), context_(context) {} + Stmt VisitStmt_(const LetStmtNode* op) final { - Stmt stmt = StmtMutator::VisitStmt_(op); + Stmt stmt = Parent::VisitStmt_(op); op = stmt.as(); - return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; + if (is_no_op(op->body)) { + return MakeEvaluate(op->value); + } + + bool body_uses_bound_variable = + !UsesVar(op->body, [&](const VarNode* var) { return var == op->var.get(); }); + if (body_uses_bound_variable && HasSideEffect(op->value)) { + return SeqStmt({MakeEvaluate(op->value), op->body}); + } else if (body_uses_bound_variable) { + return op->body; + } else { + return stmt; + } } Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_debug_skip_region") { @@ -58,24 +108,26 @@ class NoOpRemover : public StmtMutator { // We assume that such wait is a nop. auto inner = op->body.as(); ICHECK(inner); - return StmtMutator::VisitStmt(inner->body); + return Parent::VisitStmt(inner->body); } } - Stmt stmt = StmtMutator::VisitStmt_(op); + Stmt stmt = Parent::VisitStmt_(op); op = stmt.as(); return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; } Stmt VisitStmt_(const IfThenElseNode* op) final { - Stmt stmt = StmtMutator::VisitStmt_(op); + Stmt stmt = Parent::VisitStmt_(op); op = stmt.as(); if (op->else_case) { - if (is_no_op(op->else_case.value())) { - if (is_no_op(op->then_case)) { - return MakeEvaluate(op->condition); - } else { - return IfThenElse(op->condition, op->then_case); - } + bool no_op_else = is_no_op(op->else_case.value()); + bool no_op_then = is_no_op(op->then_case); + if (no_op_else && no_op_then) { + return MakeEvaluate(op->condition); + } else if (no_op_else) { + return IfThenElse(op->condition, op->then_case); + } else if (no_op_then) { + return IfThenElse(!op->condition, op->else_case.value()); } else { return stmt; } @@ -91,10 +143,10 @@ class NoOpRemover : public StmtMutator { var_range_map_[op->loop_var.get()] = arith::IntSet::FromMinExtent(op->min, op->extent); auto extent_range = arith::EvalSet(op->extent, var_range_map_); if (!arith::is_neg_inf(extent_range.max()) && !arith::is_pos_inf(extent_range.max()) && - analyzer_.CanProve(extent_range.max() <= 0)) { + analyzer_->CanProve(extent_range.max() <= 0)) { return Evaluate(0); } - Stmt stmt = StmtMutator::VisitStmt_(op); + Stmt stmt = Parent::VisitStmt_(op); var_range_map_.erase(op->loop_var.get()); op = stmt.as(); if (is_zero(op->extent)) { @@ -114,42 +166,104 @@ class NoOpRemover : public StmtMutator { return is_no_op(op->body) ? op->body : stmt; } Stmt VisitStmt_(const EvaluateNode* op) final { - if (SideEffect(op->value) > CallEffectKind::kReadState) return GetRef(op); - return Evaluate(0); + if (HasSideEffect(op->value)) { + return GetRef(op); + } else { + return Evaluate(0); + } } Stmt VisitStmt_(const SeqStmtNode* op) final { - Stmt ret = StmtMutator::VisitSeqStmt_(op, true); - op = ret.as(); - ICHECK(op != nullptr); - bool need_compact = false; - for (size_t i = 0; i < op->size(); ++i) { - if (is_no_op(op->seq[i])) need_compact = true; - } + auto ret = Downcast(StmtMutator::VisitSeqStmt_(op, true)); + + bool need_compact = std::any_of(ret->seq.begin(), ret->seq.end(), + [](const auto& stmt) { return is_no_op(stmt); }); + if (need_compact) { - auto n = CopyOnWrite(op); - size_t top = 0; - for (size_t i = 0; i < n->seq.size(); ++i) { - if (!is_no_op(n->seq[i])) { - n->seq.Set(top++, n->seq[i]); + Array filtered; + for (Stmt stmt : ret->seq) { + if (!is_no_op(stmt)) { + filtered.push_back(std::move(stmt)); } } - if (top == 1) { - return n->seq[0]; - } else { - n->seq.resize(top); - return Stmt(n); - } + ret = SeqStmt(filtered); + } + + if (ret->size() == 0) { + return Evaluate(0); + } else if (ret->size() == 1) { + return ret->seq[0]; } else { - if (op->size() == 1) { - return op->seq[0]; - } else { - return ret; + return std::move(ret); + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = GetRef(op); + + // Helper function that returns a statement containing only the + // side effects of evaluating this BufferStore, but not the store + // itself. + auto only_side_effects = [&]() { + Array statements; + statements.push_back(MakeEvaluate(store->value)); + for (const auto& index : store->indices) { + statements.push_back(MakeEvaluate(index)); + } + return this->VisitStmt(SeqStmt(statements)); + }; + + if (touch_pattern_.has_value()) { + // A write that is later overwritten is a no-op. + Stmt context = context_ ? GetRef(context_) : store; + if (touch_pattern_->IsOverwrittenWithoutEffect(store, context)) { + touch_pattern_->RemoveStore(store); + return only_side_effects(); + } + + // A write whose destination is known to already contain the + // values to be written is a no-op. + PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); + + PrimExpr simplified = + touch_pattern_->SimplifyInContext(stores_existing_value, context, analyzer_); + if (auto* as_int = as_const_int(simplified); as_int && *as_int) { + return only_side_effects(); + } + } + + // If the stored value is a load from the same location, the + // statement is a no-op, regardless of contextual information. + if (const BufferLoadNode* load = store->value.as()) { + if (load->buffer->data.same_as(store->buffer->data) && + analyzer_->CanProveEqual(load->buffer->elem_offset, store->buffer->elem_offset) && + ArrayValueEqual(load->buffer->shape, store->buffer->shape) && + ArrayValueEqual(load->buffer->strides, store->buffer->strides) && + ArrayValueEqual(load->indices, store->indices)) { + return only_side_effects(); } } + + return std::move(store); } private: + bool ArrayValueEqual(const Array& a, const Array& b) { + if (a.size() != b.size()) { + return false; + } + for (size_t i = 0; i < a.size(); i++) { + if (!analyzer_->CanProveEqual(a[i], b[i])) { + return false; + } + } + return true; + } + + bool HasSideEffect(const PrimExpr& value) { + return SideEffect(value) > CallEffectKind::kReadState; + } + Stmt MakeEvaluate(PrimExpr value) { if (SideEffect(value) > CallEffectKind::kReadState) { return Evaluate(value); @@ -158,31 +272,47 @@ class NoOpRemover : public StmtMutator { } } Stmt MakeEvaluate(const Array& values) { - Stmt stmt; + Array stmts; for (PrimExpr e : values) { if (SideEffect(e) > CallEffectKind::kReadState) { - if (stmt.defined()) { - stmt = SeqStmt({stmt, Evaluate(e)}); - } else { - stmt = Evaluate(e); - } + stmts.push_back(Evaluate(e)); } } - return stmt.defined() ? stmt : Evaluate(0); + + if (stmts.size() == 0) { + return Evaluate(0); + } else if (stmts.size() == 1) { + return stmts[0]; + } else { + return SeqStmt(stmts); + } } std::unordered_map var_range_map_; - arith::Analyzer analyzer_; + std::optional touch_pattern_; + const StmtNode* context_; }; -Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); } - namespace transform { Pass RemoveNoOp() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + std::optional touch_pattern = std::nullopt; + + RemoveNoOpConfig config = ctx->GetConfig("tir.RemoveNoOp") + .value_or(AttrsWithDefaultValues()); + if (config->use_dataflow_analysis) { + touch_pattern.emplace(f->body); + } + + arith::Analyzer analyzer; + analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( + arith::RewriteSimplifier::kTransitivelyProveInequalities | + arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | + arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); + auto* n = f.CopyOnWrite(); - n->body = NoOpRemover()(std::move(n->body)); + n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py index 820e32eb7e72..ce37329b7ed3 100644 --- a/tests/python/unittest/test_tir_transform_remove_no_op.py +++ b/tests/python/unittest/test_tir_transform_remove_no_op.py @@ -19,6 +19,8 @@ from tvm.script import tir as T import tvm.testing +import pytest + def nop(): return tvm.tir.Evaluate(0) @@ -82,5 +84,524 @@ def main(A: T.Buffer[(16), "int32"], B: T.Buffer[(16), "int32"]) -> None: assert isinstance(ret, tvm.tir.Evaluate) +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + use_dataflow_analysis = False + + def transform(self): + def inner(mod): + config = { + "tir.RemoveNoOp": { + "use_dataflow_analysis": self.use_dataflow_analysis, + } + } + with tvm.transform.PassContext(config=config): + mod = tvm.tir.transform.RemoveNoOp()(mod) + return mod + + return inner + + +class TestRemoveEmptyForLoop(BaseBeforeAfter): + """A for-loop whose body is a no-op is itself a no-op.""" + + def before(): + for i in T.serial(16): + T.evaluate(0) + + def expected(): + T.evaluate(0) + + +class TestRemoveZeroExtentLoop(BaseBeforeAfter): + """A for-loop with no extent is a no-op.""" + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(0): + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + T.evaluate(0) + + +class TestRemoveUnusedLet(BaseBeforeAfter): + """A let statement that is never used is a no-op.""" + + def before(A: T.Buffer[16, "int32"]): + x = 5 + for i in T.serial(16): + A[i] = 0 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 0 + + +class TestRemoveLetUsedOnlyInNoOp(BaseBeforeAfter): + """A let statement that is never used is a no-op. + + Similar to TestRemoveUnusedLet, but the usage of the let binding + may have been removed by an earlier removal of another no-op. + """ + + def before(A: T.Buffer[16, "int32"]): + x = 5 + for i in T.serial(0): + A[i] = x + + def expected(A: T.Buffer[16, "int32"]): + T.evaluate(0) + + +class TestKeepSideEffectsOfLet(BaseBeforeAfter): + """The side effects of a no-op let must be kept.""" + + def before(): + x = T.call_extern("extern_func", dtype="int32") + T.evaluate(0) + + def expected(): + T.evaluate(T.call_extern("extern_func", dtype="int32")) + + +class TestRemoveEmptyThenCase(BaseBeforeAfter): + """A no-op then_case can be removed.""" + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 8: + T.evaluate(0) + else: + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if not (i < 8): + A[i] = 42 + + +class TestRemoveEmptyElseCase(BaseBeforeAfter): + """A no-op else_case can be removed.""" + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 8: + A[i] = 42 + else: + T.evaluate(0) + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 8: + A[i] = 42 + + +class TestRemoveUnusedWrite(BaseBeforeAfter): + """For two sequential writes, the first is a no-op""" + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 100 + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 42 + + +class TestSuppressRemovalOfUnusedWrite(BaseBeforeAfter): + """Dataflow analysis requires the config to opt-in + + Like TestRemoveUnusedWrite, but dataflow analysis isn't enabled. + """ + + use_dataflow_analysis = False + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 100 + A[i] = 42 + + expected = before + + +class TestKeepSideEffectsOfUnusedWrite(BaseBeforeAfter): + """For two sequential writes, the first value may have side effects""" + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = T.call_extern("extern_func", dtype="int32") + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + T.evaluate(T.call_extern("extern_func", dtype="int32")) + A[i] = 42 + + +class TestKeepFirstWriteWhenUsed(BaseBeforeAfter): + """For two sequential writes, keep the first if it is used""" + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 100 + A[i] = A[i] + 1 + + expected = before + + +class TestRemoveOverwrittenLoop(BaseBeforeAfter): + """Remove repeated writes to the same region + + If two loops write to the same region, the first is a no-op. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 100 + + for i in T.serial(16): + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 42 + + +class TestRemoveOverwrittenSubloop(BaseBeforeAfter): + """Remove repeated writes to the same region + + If the first loop writes to a subset of the region, the first loop + is a no-op. Similar to TestRemoveOverwrittenLoop, but the first + loop's extents are a subset of the second loop. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(4, 12): + A[i] = 100 + + for i in T.serial(16): + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 42 + + +class TestKeepPartiallyOverwrittenLoop(BaseBeforeAfter): + """Keep partially overwritten regions + + If the second loop doesn't entirely overwrite the first, the first + may not be removed be kept. + """ + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 100 + + for i in T.serial(16): + if i < 12: + A[i] = 42 + + expected = before + + +class TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition(BaseBeforeAfter): + """Remove repeated writes to the same predicated region. + + Similar to TestKeepPartiallyOverwrittenLoop, except the first loop + has the same predicate as the second, and can therefore be + removed. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 12: + A[i] = 100 + + for i in T.serial(16): + if i < 12: + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 12: + A[i] = 42 + + +class TestRemoveOverwrittenPredicatedLoopWithProvableCondition(BaseBeforeAfter): + """Remove repeated writes to the same predicated region. + + Similar to + TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition, except + the first loop's predicate is not a precise match for the second + loop's predicate. So long as the regions written in the first + loop are a subset of those written in the second loop, they can be + removed. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 10: + A[i] = 100 + + for i in T.serial(16): + if i // 4 < 3: + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i // 4 < 3: + A[i] = 42 + + +class TestRemoveSeparatedOverwrites(BaseBeforeAfter): + """Remove repeated writes to the same predicated region. + + Similar to TestRemoveOverwrittenLoopRegion, but with an + independent loop between the first and second write of the buffer. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 100 + + for i in T.serial(16): + B[i] = 0 + + for i in T.serial(16): + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]): + for i in T.serial(16): + B[i] = 0 + + for i in T.serial(16): + A[i] = 42 + + +@pytest.mark.xfail(reason="Not implemented yet") +class TestRemoveSeparatedOverwriteOfPredicatedLoop(BaseBeforeAfter): + """Remove repeated writes to the same predicated region. + + Similar to TestRemoveSeparatedOverwrites, but the independent loop + between the first and second writes writes to a different subset + of the same buffer. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i < 12: + A[i] = 100 + + for i in T.serial(16): + if i > 12: + A[i] = 15 + + for i in T.serial(16): + if i < 12: + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i > 12: + A[i] = 15 + + for i in T.serial(16): + if i < 12: + A[i] = 42 + + +class TestRemoveReadWrite(BaseBeforeAfter): + """Writing a value to the same location as was just read is a no-op.""" + + def before(A: T.Buffer[1, "int32"]): + A[0] = A[0] + + def expected(A: T.Buffer[1, "int32"]): + T.evaluate(0) + + +class TestKeepReadWriteToDifferentIndices(BaseBeforeAfter): + """Writing a value to a different index should not be removed""" + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(15): + A[i] = A[i + 1] + + expected = before + + +class TestRemoveReadWriteSameIndexDifferentExpression(BaseBeforeAfter): + """Writing a value to the same location as the read is a no-op. + + If the value of the index can be proven to be the same, then the + no-op can be removed, even if they have different forms of the + expression. + """ + + def before(A: T.Buffer[16, "int32"]): + for io, ii in T.grid(4, 4): + i = 4 * io + ii + A[4 * io + ii] = A[i] + + def expected(A: T.Buffer[16, "int32"]): + T.evaluate(0) + + +class TestRemoveReadWriteSameIndexUsingConstraint(BaseBeforeAfter): + """Writing a value to the same location as the read is a no-op. + + If the value of the index can be proven to be the same, then the + no-op can be removed. This may require using the a constraint + that is known from a conditional containing the read/write. + """ + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i != 0: + A[i] = A[i - 1] + else: + A[i] = A[0] + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i != 0: + A[i] = A[i - 1] + + +class TestRemoveWritingOfKnownValue(BaseBeforeAfter): + """Writing a value that already exists at that index is a no-op""" + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = i + + A[4] = 4 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = i + + +class TestKeepOneOfDuplicateLoops(BaseBeforeAfter): + """Must not reason based on a touch point after removing it. + + If the first loop is removed because it is overwritten by the + second loop, and the second loop is removed because it writes the + same value as the first loop, the overall transformation is no + longer valid. In this case, only one of the two should be + removed. + """ + + use_dataflow_analysis = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = i + + for i in T.serial(16): + A[i] = i + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = i + + +class TestRemoveEmptyTemporary(BaseBeforeAfter): + """An allocation with a no-op body is a no-op.""" + + def before(): + A = T.allocate([16], "int32", "local") + T.evaluate(0) + + def expected(): + T.evaluate(0) + + +@pytest.mark.xfail(reason="Not implemented yet") +class TestRemoveUnusedTemporary(BaseBeforeAfter): + """An unused allocation is a no-op.""" + + def before(A: T.Buffer[16, "int32"]): + B = T.allocate([16], "int32", "local") + for i in T.serial(16): + A[i] = 1 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 1 + + +@pytest.mark.xfail(reason="Not implemented yet") +class TestRemoveUnusedWriteIntoTemporary(BaseBeforeAfter): + """A write that only impacts a temporary allocation is a no-op.""" + + def before(): + A = T.decl_buffer([16], "int32", scope="local") + for i in T.serial(16): + A[i] = 0 + + def expected(): + T.evaluate(0) + + +class TestKeepUsedWriteIntoTemporary(BaseBeforeAfter): + """A write into a temporary that is used later must be kept.""" + + def before(B: T.Buffer[16, "int32"]): + A = T.decl_buffer([16], "int32", scope="local") + for i in T.serial(16): + A[i] = 0 + + for i in T.serial(16): + B[i] = A[i] + + expected = before + + +@pytest.mark.xfail(reason="Not implemented yet") +class TestRemoveWriteIntoTemporary(BaseBeforeAfter): + """A write that only impacts a temporary allocation is a no-op.""" + + def before(A: T.Buffer[16, "int32"], C: T.Buffer[1, "int32"]): + B = T.decl_buffer([16], "int32", scope="local") + for i in T.serial(16): + B[i] = A[i] + + C[0] = 0 + for i in T.serial(16): + C[0] = C[0] + B[i] + + for i in T.serial(16): + B[i] = 0 + + def expected(A: T.Buffer[16, "int32"], C: T.Buffer[1, "int32"]): + B = T.decl_buffer([16], "int32", scope="local") + for i in T.serial(16): + B[i] = A[i] + + C[0] = 0 + for i in T.serial(16): + C[0] = C[0] + B[i] + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index fd98b715a4bc..1ddc0e50d98f 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -1267,6 +1267,7 @@ class TestSimplifyUsingPartiallyKnownBufferConditional(BaseBeforeAfter): """An assumption about buffer contents may apply to only part of a buffer""" propagate_knowns_to_prove_conditional = True + apply_constraints_to_boolean_branches = True def before(A: T.Buffer[16, "int32"]): for i in T.serial(16):