diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 1065ad3bf1e0..6c63c8d693b4 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -473,8 +473,9 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { void VisitAccess(const BufferAccess& node, BufferTouch::AccessType touch_type, PrimExpr known_value_expr) { auto& current_block = out_->control_flow_.back(); - BufferTouch buffer_touch = current_block.MakeBufferTouch(out_, node->buffer, node->indices, - touch_type, known_value_expr); + PrimExpr current_predicate = CurrentScopePredicate(); + BufferTouch buffer_touch = current_block.MakeBufferTouch( + out_, node->buffer, node->indices, touch_type, known_value_expr, current_predicate); current_block.touch_points.push_back(buffer_touch); } @@ -637,7 +638,8 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { std::pair> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( const tir::Buffer& buf, Array index_variables, Array indices, - BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { + BufferTouch::AccessType touch_type, PrimExpr known_value_expr, + PrimExpr current_predicate) const { const auto& current_block = *this; Analyzer local_analyzer; @@ -797,9 +799,16 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make known_value_expr = local_analyzer.Simplify(normalize_expr(known_value_expr)); // Deliberately use an analyzer without scope-based information, - // to avoid simplifying `scope_predicate` to True. - PrimExpr predicate_expr = local_analyzer.Simplify(transform_predicate && scope_predicate); - + // to avoid simplifying `scope_predicate` or `current_predicate` to True. + PrimExpr scope_additional_predicate; + if (touch_type == BufferTouch::AccessType::Assume) { + // Consider the expression (additional_predicate) in T.assume to be included in `predicate_expr` + scope_additional_predicate = normalize_expr(current_predicate); + } else { + scope_additional_predicate = scope_predicate; + } + PrimExpr predicate_expr = + local_analyzer.Simplify(transform_predicate && scope_additional_predicate); BufferTouch buffer_touch = {buf, predicate_expr, known_value_expr, loop_var_expressions, touch_type}; @@ -810,10 +819,12 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph const tir::Buffer& buf, const Array& indices, BufferTouch::AccessType touch_type, - PrimExpr known_value_expr) const { + PrimExpr known_value_expr, + PrimExpr current_predicate) const { ICHECK(graph); - auto [buffer_touch, free_params] = MakeBufferTouch(buf, graph->GetIndexVariables(buf, indices), - indices, touch_type, known_value_expr); + auto [buffer_touch, free_params] = + MakeBufferTouch(buf, graph->GetIndexVariables(buf, indices), indices, touch_type, + known_value_expr, current_predicate); for (const auto& pair : free_params) { graph->free_predicate_parameters_.Set(pair.first, pair.second); } diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h index 543feeecfea1..e3491d772a0c 100644 --- a/src/tir/analysis/control_flow_graph.h +++ b/src/tir/analysis/control_flow_graph.h @@ -574,11 +574,15 @@ class ControlFlowGraph { * * \param known_expr_value The value being written to the buffer * + * \param current_predicate The aggregate of scope predicate and + * additional predicate introduced due to assume statement + * * \returns The newly generated BufferTouch */ BufferTouch MakeBufferTouch(ControlFlowGraph* graph, const Buffer& buf, const Array& indices, BufferTouch::AccessType touch_type, - PrimExpr known_value_expr) const; + PrimExpr known_value_expr, + PrimExpr current_predicate = Bool(true)) const; /* \brief Construct a BufferTouch instance as if it occurred in * this ControlFlowBlock @@ -598,15 +602,17 @@ class ControlFlowGraph { * * \param known_expr_value The value being written to the buffer * + * \param current_predicate The aggregate of scope predicate and + * additional predicate introduced due to assume statement + * * \returns The newly generated BufferTouch, and a map specifying * all free parameters that may occur in the BufferTouch's * predicate. */ - std::pair> MakeBufferTouch(const Buffer& buf, - Array index_variables, - Array indices, - BufferTouch::AccessType touch_type, - PrimExpr known_value_expr) const; + std::pair> MakeBufferTouch( + const Buffer& buf, Array index_variables, Array indices, + BufferTouch::AccessType touch_type, PrimExpr known_value_expr, + PrimExpr current_predicate = Bool(true)) const; }; friend std::ostream& operator<<(std::ostream& os, const ControlFlowBlock& pattern); diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py b/tests/python/tir-transform/test_tir_transform_simplify.py index 6bad817c4955..4889e2a59d08 100644 --- a/tests/python/tir-transform/test_tir_transform_simplify.py +++ b/tests/python/tir-transform/test_tir_transform_simplify.py @@ -17,6 +17,7 @@ import tvm import tvm.testing +import pytest from tvm import te from tvm.script import tir as T @@ -1325,6 +1326,27 @@ def expected(A: T.Buffer(16, "int32")): A[i] = 42 +class TestAssumeMayContainAdditionalPredicate(BaseBeforeAfter): + """An assumption about buffer contents may apply to only part of a buffer + Like TestSimplifyUsingPartiallyKnownBufferConditional, but the + conditional is expressed as part of T.assume, instead of in the + control flow. + """ + + propagate_knowns_to_simplify_expressions = True + + def before(A: T.Buffer(16, "int32")): + for i in T.serial(16): + T.evaluate(T.assume(i < 14 or A[i] == 0)) + + for i in T.serial(16): + if i < 14: + if A[i] == 0: + A[i] = 42 + + expected = before + + class TestNoSimplificationIfPredicateNotMet(BaseBeforeAfter): """Assumptions about buffer contents must apply to all cases to be used @@ -1625,6 +1647,7 @@ def expected(A: T.Buffer(24, "int32"), B: T.Buffer(24, "int32"), F: T.Buffer(3, T.evaluate(0) +@pytest.mark.skip("Skipping because this test will hang") class TestSimplifyUsingPartiallyProvenBufferValueScatter(BaseBeforeAfter): """Propagate known buffer values in part of buffer.