From a80cdc26e291abc52bbd70c950023d9e0340464d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 16 Nov 2022 13:45:07 -0600 Subject: [PATCH] [TIR][Analysis][Arith] Implement basic data-flow analysis (#13130) An optional utility to track known buffer values through a TIR PrimFunc, allowing simplifications based on known values. * Updated documentation following review comments * Unit tests for rewrites, including negative numerators for div/mod * Fix linting error * Added brief description on what a control graph is * Updates based on review comments * Updated T.assume(expr) to T.evaluate(T.assume(expr)) --- include/tvm/tir/op_attr_types.h | 31 + src/arith/conjunctive_normal_form.cc | 26 +- src/arith/constraint_extract.cc | 39 +- src/arith/constraint_extract.h | 31 +- src/arith/ir_visitor_with_analyzer.h | 6 +- src/arith/rewrite_simplify.cc | 53 +- src/arith/transitive_comparison_analyzer.cc | 2 +- src/arith/unwrap_vector_expr.cc | 90 + src/arith/unwrap_vector_expr.h | 56 + src/tir/analysis/control_flow_graph.cc | 1647 +++++++++++++++++ src/tir/analysis/control_flow_graph.h | 653 +++++++ src/tir/transforms/simplify.cc | 105 +- .../unittest/test_arith_rewrite_simplify.py | 61 +- .../unittest/test_tir_transform_simplify.py | 645 ++++++- 14 files changed, 3403 insertions(+), 42 deletions(-) create mode 100644 src/arith/unwrap_vector_expr.cc create mode 100644 src/arith/unwrap_vector_expr.h create mode 100644 src/tir/analysis/control_flow_graph.cc create mode 100644 src/tir/analysis/control_flow_graph.h diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 6b5d6c48ddd0..fa409b27d12a 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -32,6 +32,8 @@ #include #include +#include + namespace tvm { namespace tir { /*! @@ -92,6 +94,35 @@ enum class CallEffectKind : int { kControlJump = 6, }; +inline std::ostream& operator<<(std::ostream& os, CallEffectKind side_effect) { + switch (side_effect) { + case CallEffectKind::kExprAnnotation: + return os << "kExprAnnotation"; + + case CallEffectKind::kPure: + return os << "kPure"; + + case CallEffectKind::kReadState: + return os << "kReadState"; + + case CallEffectKind::kUpdateState: + return os << "kUpdateState"; + + case CallEffectKind::kSpecialCallArg: + return os << "kSpecialCallArg"; + + case CallEffectKind::kEmbedInfo: + return os << "kEmbedInfo"; + + case CallEffectKind::kControlJump: + return os << "kControlJump"; + + default: + LOG(FATAL) << "Unknown CallEffectKind: " << static_cast(side_effect); + return os; + } +} + /*! \brief Use integer to record the kind. */ using TCallEffectKind = Integer; diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc index 19d6a234e6ad..1c5f31a913a1 100644 --- a/src/arith/conjunctive_normal_form.cc +++ b/src/arith/conjunctive_normal_form.cc @@ -248,14 +248,14 @@ void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { Key& a = *a_ptr; Key& b = *b_ptr; PrimExpr joint = GetExpr(a) || GetExpr(b); - PrimExpr simplified = analyzer->Simplify(joint); + PrimExpr simplified = analyzer->rewrite_simplify(joint); if (!ExprDeepEqual()(simplified, joint)) { if (auto* simplified_or = simplified.as()) { a = GetKey(simplified_or->a); b = GetKey(simplified_or->b); } else { - a = GetKey(simplified); - b = key_false_; + a = key_false_; + b = GetKey(simplified); } } } @@ -264,14 +264,14 @@ void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { Key& a = *a_ptr; Key& b = *b_ptr; PrimExpr joint = GetExpr(a) && GetExpr(b); - PrimExpr simplified = analyzer->Simplify(joint); + PrimExpr simplified = analyzer->rewrite_simplify(joint); if (!ExprDeepEqual()(simplified, joint)) { if (auto* simplified_and = simplified.as()) { a = GetKey(simplified_and->a); b = GetKey(simplified_and->b); } else { - a = GetKey(simplified); - b = key_true_; + a = key_true_; + b = GetKey(simplified); } } } @@ -362,6 +362,20 @@ void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) { // (A or B) and (A or C) => A or (B and C) auto& key_i = i_chunk[i_distinct_index.value()]; auto& key_j = j_chunk[j_distinct_index.value()]; + + // When attempting to simplify (B and C), the analyzer may + // assume that A is false. + PrimExpr known = [&]() { + PrimExpr known = Bool(true); + for (const auto& key : i_chunk) { + if (&key != &key_i) { + known = known && analyzer->Simplify(!GetExpr(key)); + } + } + return known; + }(); + + With context(analyzer, known); TrySimplifyAnd(&key_i, &key_j, analyzer); } } diff --git a/src/arith/constraint_extract.cc b/src/arith/constraint_extract.cc index d0bf57497e63..b873adcb5ca4 100644 --- a/src/arith/constraint_extract.cc +++ b/src/arith/constraint_extract.cc @@ -31,23 +31,42 @@ namespace tvm { namespace arith { -void CollectConstraints(const PrimExpr& expr, Analyzer* analyzer, std::vector* collect) { - collect->push_back(expr); +template +void CollectConstraints(PrimExpr expr, F callback, bool keep_composite_constraints) { + if (keep_composite_constraints) { + callback(expr); + } PVar x, y; if ((x && y).Match(expr)) { - CollectConstraints(x.Eval(), analyzer, collect); - CollectConstraints(y.Eval(), analyzer, collect); - } else if ((!(x || y)).Match(expr)) { - CollectConstraints(analyzer->rewrite_simplify(tir::Not(x.Eval())), analyzer, collect); - CollectConstraints(analyzer->rewrite_simplify(tir::Not(y.Eval())), analyzer, collect); + CollectConstraints(x.Eval(), callback, keep_composite_constraints); + CollectConstraints(y.Eval(), callback, keep_composite_constraints); + } else if (!keep_composite_constraints) { + callback(expr); + } +} + +std::vector ExtractConstraints(const PrimExpr& expr, bool keep_composite_constraints) { + std::vector out; + CollectConstraints( + expr, [&](const PrimExpr& part) { out.push_back(part); }, keep_composite_constraints); + return out; +} + +template +void CollectComponents(PrimExpr expr, F callback) { + PVar x, y; + if ((x || y).Match(expr)) { + CollectComponents(x.Eval(), callback); + CollectComponents(y.Eval(), callback); + } else { + callback(expr); } } -std::vector ExtractConstraints(const PrimExpr& expr) { +std::vector ExtractComponents(const PrimExpr& expr) { std::vector out; - Analyzer analyzer; - CollectConstraints(expr, &analyzer, &out); + CollectComponents(expr, [&](const PrimExpr& part) { out.push_back(part); }); return out; } diff --git a/src/arith/constraint_extract.h b/src/arith/constraint_extract.h index ea6e0a74419c..815eafeebd62 100644 --- a/src/arith/constraint_extract.h +++ b/src/arith/constraint_extract.h @@ -42,6 +42,35 @@ namespace arith { * Example: `i==5 || j==3` => `[i==5 || j==3]` * Example: `!(i>5 || j==3)` => `[!(i==5 || j==3), i<=5, j!=3]` * + * If `keep_composite_constraints` is true (default), a constraint + * that can be decomposed will be included in the output. If false, + * they will be excluded. + * + * Example, removing composite: `!(i>5 || j==3)` => `[i<=5, j!=3]` + * + * Intended for use in bounds analysis or simplification within a + * conditional, or identifying independent conditionals that may be + * hoisted. + * + * \param expr The expression to be analyzers + * + * \param keep_composite_constraints Whether to include composite + * constraints in the output. + * + * \returns A vector of independent constraints + */ +std::vector ExtractConstraints(const PrimExpr& expr, + bool keep_composite_constraints = true); + +/* \brief Returns components that are false if the expression is false. + * + * Utility to break up a boolean expression into independent + * components. + * + * Example: `i==5 || j==3` => `[i==5, j==3]` + * Example: `i==5 && j==3` => `[i==5 && j==3]` + * Example: `!(i>5 && j==3)` => `[i<=5, j!=3]` + * * Intended for use in bounds analysis or simplification within a * conditional, or identifying independent conditionals that may be * hoisted. @@ -50,7 +79,7 @@ namespace arith { * * \returns A vector of independent constraints */ -std::vector ExtractConstraints(const PrimExpr& expr); +std::vector ExtractComponents(const PrimExpr& expr); } // namespace arith } // namespace tvm diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index f41a628f3cc6..416b2af196bd 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -57,7 +57,11 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor { /*! \brief internal analyzer field. */ arith::Analyzer analyzer_; - private: + /*! \brief Extract a constraint from a conditional statement + * + * Intended for preparing argument for use in + * `With`. + */ PrimExpr ExtractRealCondition(PrimExpr condition) const; }; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index d0fb943334de..e6d876cf5aa8 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -292,7 +292,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c // we will compare the already simplified result with the constraint, // so simplify the constraint as well PrimExpr new_constraint = operator()(constraint); - for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) { + for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) { if (SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); PrimExpr negation; @@ -1734,7 +1734,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2, c3; PVar lanes; if (op->dtype.lanes() != 1) { @@ -1761,6 +1761,55 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2); TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2); + + TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) == c3, x == c1 * c2 + c3); + TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) == c3 && floordiv(x, c2) == c1, x == c1 * c2 + c3); + + TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x - y * c1 && + x - y * c1value> 0); + TVM_TRY_RECURSIVE_REWRITE_IF(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1), + c1.Eval()->value > 0); + + TVM_TRY_RECURSIVE_REWRITE(c1 < x - y * c1 && x - y * c1 <= 0, y == floordiv(x, c1)); + TVM_TRY_RECURSIVE_REWRITE(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1)); + TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x + y * c2 && x + y * c2 < c1, y == floordiv(x, c1), + c2.Eval()->value == -c1.Eval()->value); + TVM_TRY_RECURSIVE_REWRITE_IF(x + y * c2 < c1 && 0 <= x + y * c2, y == floordiv(x, c1), + c2.Eval()->value == -c1.Eval()->value); + + TVM_TRY_RECURSIVE_REWRITE_IF(x < c1 && floormod(x, c2) < c3, + x < c1 - c2 + c3 && floormod(x, c2) < c3, + c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + x < c1 && floormod(x, c2) < c3, x < c1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3, + (c1.Eval()->value % c2.Eval()->value + c2.Eval()->value) % c2.Eval()->value > + c3.Eval()->value); + + TVM_TRY_RECURSIVE_REWRITE_IF(x <= c1 && floormod(x, c2) < c3, + x < c1 + 1 - c2 + c3 && floormod(x, c2) < c3, + (c1.Eval()->value + 1) % c2.Eval()->value == 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + x <= c1 && floormod(x, c2) < c3, x < c1 + 1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3, + (((c1.Eval()->value + 1) % c2.Eval()->value) + c2.Eval()->value) % c2.Eval()->value > + c3.Eval()->value); + + TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) < c3, + c1 * c2 <= x && x < c1 * c2 + c3); + TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) < c3 && floordiv(x, c2) == c1, + c1 * c2 <= x && x < c1 * c2 + c3); + TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) <= c3, + c1 * c2 <= x && x <= c1 * c2 + c3); + TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) <= c3 && floordiv(x, c2) == c1, + c1 * c2 <= x && x <= c1 * c2 + c3); + + TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 <= floormod(x, c2), + c1 * c2 + c3 <= x && x < (c1 + 1) * c2); + TVM_TRY_RECURSIVE_REWRITE(c3 <= floormod(x, c2) && floordiv(x, c2) == c1, + c1 * c2 + c3 <= x && x < (c1 + 1) * c2); + TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 < floormod(x, c2), + c1 * c2 + c3 < x && x < (c1 + 1) * c2); + TVM_TRY_RECURSIVE_REWRITE(c3 < floormod(x, c2) && floordiv(x, c2) == c1, + c1 * c2 + c3 < x && x < (c1 + 1) * c2); return ret; } diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index b71096a479b5..36c2fb77074c 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -547,7 +547,7 @@ std::function TransitiveComparisonAnalyzer::EnterConstraint(const PrimEx void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, std::vector* vec) { - for (const auto& subexpr : ExtractConstraints(expr)) { + for (const auto& subexpr : ExtractConstraints(expr, false)) { if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) { if (auto cmp = FromExpr(subexpr)) { vec->push_back(cmp.value()); diff --git a/src/arith/unwrap_vector_expr.cc b/src/arith/unwrap_vector_expr.cc new file mode 100644 index 000000000000..6a3e8c3d434c --- /dev/null +++ b/src/arith/unwrap_vector_expr.cc @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unwrap_vector_expr.cc + * \brief Utility for tracking currently active constraints + */ + +#include "unwrap_vector_expr.h" + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace arith { + +using namespace tir; + +class Scalarizer : public ExprMutator { + public: + explicit Scalarizer(PrimExpr lane) : lane_(lane) {} + + PrimExpr VisitExpr_(const RampNode* op) final { return op->base + lane_ * op->stride; } + + PrimExpr VisitExpr_(const BroadcastNode* op) final { return op->value; } + + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + + auto it = let_var_remap_.find(op); + if (it != let_var_remap_.end()) { + return it->second; + } else { + return ExprMutator::VisitExpr_(op); + } + } + PrimExpr VisitExpr_(const LetNode* op) final { + if (op->value.dtype().lanes() == 1) { + return ExprMutator::VisitExpr_(op); + } + + auto it = let_var_remap_.find(op->var.get()); + ICHECK(it == let_var_remap_.end()) << "Duplicate binding of variable " << op->var; + + Var new_var(op->var->name_hint + "_scalar", op->var.dtype().element_of()); + let_var_remap_[op->var.get()] = new_var; + + PrimExpr value = this->VisitExpr(op->value); + PrimExpr body = this->VisitExpr(op->body); + + let_var_remap_.erase(op->var.get()); + return Let(op->var, value, body); + } + + private: + // The lane to extract + PrimExpr lane_; + + // Let binding + std::unordered_map let_var_remap_; +}; + +PrimExpr UnwrapVectorExpr(const PrimExpr& vector_expr, const PrimExpr& lane) { + return Scalarizer(lane)(vector_expr); +} + +} // namespace arith +} // namespace tvm diff --git a/src/arith/unwrap_vector_expr.h b/src/arith/unwrap_vector_expr.h new file mode 100644 index 000000000000..9f18964043ff --- /dev/null +++ b/src/arith/unwrap_vector_expr.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unwrap_vector_expr.h + * + * \brief Centralized location for extraction of constraints from a boolean expression. + */ + +#ifndef TVM_ARITH_UNWRAP_VECTOR_EXPR_H_ +#define TVM_ARITH_UNWRAP_VECTOR_EXPR_H_ + +#include + +#include + +namespace tvm { +namespace arith { + +/* \brief Unwraps a component of a vector expression + * + * Utility to break up a vector expression into a specific component + * of the expression. + * + * Example: `Ramp(start, stride, n)` => `start + stride*lane` + * Example: `Broadcast(value, n)` => `value` + * Example: `2*Ramp(start, stride, n) + Broadcast(value,n)` => `2*(start + stride*lane) + value` + * + * \param vector_expr The vectorized expression to examine + * + * \param lane Which lane of the vectorized expression to extract. + * + * \returns A scalar expression + */ +PrimExpr UnwrapVectorExpr(const PrimExpr& vector_expr, const PrimExpr& lane); + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITH_UNWRAP_VECTOR_EXPR_H_ diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc new file mode 100644 index 000000000000..42c5c8bb82d5 --- /dev/null +++ b/src/tir/analysis/control_flow_graph.cc @@ -0,0 +1,1647 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file control_flow_graph.cc + * \brief Utility to deduce bound of expression + */ + +#include "control_flow_graph.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../arith/conjunctive_normal_form.h" +#include "../../arith/constraint_extract.h" +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/ir_visitor_with_analyzer.h" +#include "../../arith/narrow_predicate_expression.h" +#include "../../arith/unwrap_vector_expr.h" + +namespace tvm { +namespace tir { + +using namespace arith; + +namespace { +bool HasBufferLoad(PrimExpr expr) { + struct Visitor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* node) override { found_buffer_load = true; } + bool found_buffer_load{false}; + }; + + Visitor visitor; + visitor(expr); + return visitor.found_buffer_load; +} + +Optional SubstituteParamValues(const Array& param_vars, + const Array& param_values, + const PrimExpr& expr) { + ICHECK_EQ(param_vars.size(), param_values.size()) + << "Expression was defined as having " << param_vars.size() << " parameters, but received " + << param_values.size() << " arguments."; + + Map var_map; + for (size_t i = 0; i < param_values.size(); i++) { + var_map.Set(param_vars[i], param_values[i]); + } + + return Substitute(expr, var_map); +} +} // namespace + +PrimExpr BufferTouch::BeforeLoopIteration() const { + PrimExpr loop_predicate = Bool(true); + for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { + const Var& loop_var = it->first; + const PrimExpr& loop_expr = it->second; + loop_predicate = (loop_var <= loop_expr) || ((loop_var == loop_expr) && loop_predicate); + } + return loop_predicate; +} + +PrimExpr BufferTouch::AtLoopIteration() const { + PrimExpr loop_predicate = Bool(true); + for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { + const Var& loop_var = it->first; + const PrimExpr& loop_expr = it->second; + loop_predicate = (loop_var == loop_expr) && loop_predicate; + } + return loop_predicate; +} + +PrimExpr BufferTouch::AfterLoopIteration() const { + PrimExpr loop_predicate = Bool(true); + for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { + const Var& loop_var = it->first; + const PrimExpr& loop_expr = it->second; + loop_predicate = (loop_var >= loop_expr) || ((loop_var == loop_expr) && loop_predicate); + } + return loop_predicate; +} + +bool BufferTouch::IsSubsetOf(const BufferTouch& other, Analyzer* analyzer) const { + if (this->buffer.same_as(other.buffer)) { + With constraint(analyzer, predicate); + + return analyzer->CanProve(other.predicate); + } else { + return false; + } +} + +bool BufferTouch::IsDistinctFrom(const BufferTouch& other, Analyzer* analyzer) const { + if (this->buffer.same_as(other.buffer)) { + With constraint(analyzer, predicate); + + return analyzer->CanProve(!other.predicate); + } else { + return true; + } +} + +std::ostream& operator<<(std::ostream& os, const BufferTouch& tp) { + auto touch_type = [&]() { + if (tp.touch_type == BufferTouch::AccessType::Read) { + return "read"; + } else if (tp.touch_type == BufferTouch::AccessType::Write) { + return "write"; + } else if (tp.touch_type == BufferTouch::AccessType::Assume) { + return "assume"; + } else { + return "???"; + } + }(); + + os << "BufferTouch(" << tp.buffer->name << ", " << touch_type << ", " << tp.predicate + << ", value = " << tp.value << ")"; + return os; +} + +class BufferConstraintApply : public IRMutatorWithAnalyzer { + public: + using Parent = IRMutatorWithAnalyzer; + + BufferConstraintApply(const Map>& axis_var_lookup, + const std::vector& knowns, Analyzer* analyzer) + : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {} + + using Parent::VisitExpr_; + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + for (const auto& known : knowns_) { + if (!op->buffer.same_as(known.buffer)) { + continue; + } + + Optional lane_var = NullOpt; + IntImm num_lanes; + + Array indices = op->indices.Map([&](const auto& index) { + if (index.dtype().lanes() == 1) { + return index; + } else { + ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; + lane_var = Var("lane", index.dtype().element_of()); + num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); + return UnwrapVectorExpr(index, lane_var.value()); + } + }); + + auto axis_vars = axis_var_lookup_.at(op->buffer); + PrimExpr predicate = SubstituteParamValues(axis_vars, indices, known.predicate).value(); + + std::optional> context; + if (lane_var.defined()) { + Var lanes = lane_var.value(); + PrimExpr known = (IntImm(lanes.dtype(), 0) <= lanes) && (lanes < num_lanes); + context.emplace(analyzer_, known); + } + + if (analyzer_->CanProve(predicate)) { + return SubstituteParamValues(axis_vars, op->indices, known.value).value(); + } + } + + return GetRef(op); + } + + private: + const Map>& axis_var_lookup_; + const std::vector& knowns_; +}; + +/*! \brief Extract the control-flow graph + * + * Walk through a statement, populating the control-flow graph. + */ +class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { + public: + static void Build(ControlFlowGraph* out, const Stmt& stmt) { + ControlFlowGraphBuilder extractor(out); + extractor.AppendControlBlock(); + extractor(stmt); + } + + private: + ControlFlowGraphBuilder(ControlFlowGraph* out) : out_(out) {} + + using Parent = IRVisitorWithAnalyzer; + using Parent::VisitExpr_; + using Parent::VisitStmt_; + + void VisitStmt(const Stmt& stmt) override { + // Update the lookup table to determine which control-flow block + // contains the start of the specified statement. This is used + // later to determine which set of known values should be used to + // simplify a statement. + out_->control_flow_lookup_[stmt.get()] = CurrentControlBlock(); + Stmt prev_stmt = current_stmt_; + current_stmt_ = stmt; + Parent::VisitStmt(stmt); + current_stmt_ = prev_stmt; + } + + void VisitStmt_(const EvaluateNode* op) override { + if (auto* call = op->value.as()) { + if (call->op.same_as(builtin::assume())) { + Assume(call->args[0], true); + return; + } + } + + Parent::VisitStmt_(op); + } + + void Assume(PrimExpr assumption, bool from_assume_statement) { + for (const auto& expr : ExtractConstraints(assumption, false)) { + AssumeConstraintComponent(expr, from_assume_statement); + } + } + + void AssumeConstraintComponent(PrimExpr assumption, bool from_assume_statement) { + PrimExpr additional_predicate = Bool(true); + + std::vector buffer_exprs; + for (const auto& expr : ExtractComponents(assumption)) { + auto side_effect = tir::SideEffect(expr); + if (side_effect <= tir::CallEffectKind::kPure) { + // Pulling out portions of the assumption that do not depend + // on a buffer value allows the following two forms to be + // treated identically. + // + // Option 1: if i < 3: T.assume(buf[i] == value) + // Option 2: T.assume(i>=3 or buf[i] == value) + additional_predicate = additional_predicate && logical_not(expr); + } else if (side_effect == tir::CallEffectKind::kReadState) { + buffer_exprs.push_back(expr); + } else { + LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr + << " with side-effect \'" << side_effect << "\'"; + } + } + + if (buffer_exprs.empty()) { + out_->non_buffer_assumptions_.push_back(!CurrentScopePredicate() || assumption); + return; + } + + CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; + + auto* as_equal_node = buffer_exprs[0].as(); + CHECK(as_equal_node || !from_assume_statement) + << "T.assume buffer constraint must be of the form 'buffer[indices] == " + "value', but received " + << assumption; + if (!as_equal_node) { + // This assumption is an inequality on a data-dependent + // conditional. Not an error for this to occur, but also not + // something that is currently supported. + return; + } + + tir::BufferLoad load; + PrimExpr value; + if (auto* as_load = as_equal_node->a.as()) { + load = GetRef(as_load); + value = as_equal_node->b; + } else if (auto* as_load = as_equal_node->b.as()) { + load = GetRef(as_load); + value = as_equal_node->a; + } else if (!from_assume_statement) { + return; + } else { + LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; + } + + auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; + CHECK(!has_side_effect || !from_assume_statement) + << "Buffer value in constraint must be pure expression, but was " << value; + if (has_side_effect) { + return; + } + + { + InternalConstraintContext context(this, additional_predicate); + VisitAccess(load, BufferTouch::AccessType::Assume, value); + } + // Appending a control block ensures that all control blocks have + // at most one statement that changes the known buffer contents. + auto prev_block = CurrentControlBlock(); + auto new_block = AppendControlBlock(); + MarkControlFlow(prev_block, new_block); + } + + void VisitExpr_(const LetNode* op) override { + std::optional binding; + if (UsesLoopVar(op->value)) { + binding.emplace(this, op->var, op->value); + } + Parent::VisitExpr_(op); + } + + void VisitStmt_(const LetStmtNode* op) override { + std::optional binding; + if (UsesLoopVar(op->value)) { + binding.emplace(this, op->var, op->value); + } + Parent::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode* op) override { + Parent::VisitExpr_(op); + BufferLoad load = GetRef(op); + VisitAccess(load, BufferTouch::AccessType::Read, load); + } + + void VisitStmt_(const BufferStoreNode* op) override { + Parent::VisitStmt_(op); + VisitAccess(GetRef(op), BufferTouch::AccessType::Write, op->value); + // Appending a control block ensures that all control blocks have + // at most one statement that changes the buffer contents. + auto prev_block = CurrentControlBlock(); + auto new_block = AppendControlBlock(); + MarkControlFlow(prev_block, new_block); + } + + void VisitStmt_(const ForNode* op) override { + out_->iterator_ranges_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + + auto before_loop = CurrentControlBlock(); + size_t loop_start = -1; + + { + BindActiveLoopVar binding(this, op->loop_var, op->min, op->extent); + loop_start = AppendControlBlock(); + Parent::VisitStmt_(op); + } + + auto loop_end = CurrentControlBlock(); + auto after_loop = AppendControlBlock(); + PrimExpr max_iterator_value = analyzer_.Simplify(op->min + op->extent - 1); + { + auto [forward, backward] = MarkControlFlow(before_loop, loop_start); + backward.post_condition = (op->loop_var == op->min); + forward.var_remap = {{op->loop_var, op->min}}; + } + { + auto [forward, backward] = MarkControlFlow(loop_end, after_loop); + backward.var_remap = {{op->loop_var, max_iterator_value}}; + forward.post_condition = (op->loop_var == max_iterator_value); + } + { + auto [forward, backward] = MarkControlFlow(loop_end, loop_start); + backward.var_remap = {{op->loop_var, op->loop_var - 1}}; + forward.var_remap = {{op->loop_var, op->loop_var + 1}}; + backward.post_condition = (op->loop_var > op->min); + forward.post_condition = (op->loop_var < max_iterator_value); + } + } + + void VisitStmt_(const IfThenElseNode* op) override { + this->VisitExpr(op->condition); + + PrimExpr real_condition = ExtractRealCondition(op->condition); + + auto before_branching = CurrentControlBlock(); + + auto branch_start = AppendControlBlock(); + MarkControlFlow(before_branching, branch_start); + + { + InternalConstraintContext context(this, real_condition); + auto then_start = AppendControlBlock(); + if (context.assume.defined()) { + Assume(context.assume.value(), false); + } + auto [forward, backward] = MarkControlFlow(branch_start, then_start); + backward.post_condition = real_condition; + forward.post_condition = real_condition; + this->VisitStmt(op->then_case); + } + auto then_end = CurrentControlBlock(); + + auto negation = analyzer_.rewrite_simplify(!real_condition); + { + InternalConstraintContext context(this, negation); + auto else_start = AppendControlBlock(); + if (context.assume.defined()) { + Assume(context.assume.value(), false); + } + auto [forward, backward] = MarkControlFlow(branch_start, else_start); + backward.post_condition = negation; + forward.post_condition = negation; + + if (op->else_case.defined()) { + this->VisitStmt(op->else_case.value()); + } + } + + auto else_end = CurrentControlBlock(); + auto after_branching = AppendControlBlock(); + + if (HasBufferLoad(real_condition)) { + // The buffer value may have changed during the body of the + // condition, so we can't provide it as a post-condition. + MarkControlFlow(then_end, after_branching); + MarkControlFlow(else_end, after_branching); + } else { + { + auto [forward, backward] = MarkControlFlow(then_end, after_branching); + backward.post_condition = real_condition; + forward.post_condition = real_condition; + } + { + auto [forward, backward] = MarkControlFlow(else_end, after_branching); + backward.post_condition = negation; + forward.post_condition = negation; + } + } + } + + /*! \brief Internal utility, returns true if the expression depends + * on a loop iterator + */ + bool UsesLoopVar(const PrimExpr& expr) { + return UsesVar(expr, [&](const VarNode* expr_var) { + return loop_dependent_vars_.find(expr_var) != loop_dependent_vars_.end(); + }); + } + + /*! \brief Record the interaction with the buffer. + * + * \param node The TIR node that accesses the buffer. Should be + * either a BufferLoad or BufferStore node. + * + * \param touch_type The type of buffer access being performed. A + * BufferStore should always use AccessType::Write. A BufferLoad + * may use either AccessType::Read or AccessType::Assume, depending + * on whether the BufferLoad occurs within `builtin::assume`. + * + * \param known_value_expr The value in the buffer following the access. + */ + template + void VisitAccess(const BufferAccess& node, BufferTouch::AccessType touch_type, + PrimExpr known_value_expr) { + auto& current_block = out_->control_flow_.back(); + BufferTouch buffer_touch = current_block.MakeBufferTouch(out_, node->buffer, node->indices, + touch_type, known_value_expr); + current_block.touch_points.push_back(buffer_touch); + } + + /*! \brief Return a predicate for having reached the current + * control-flow block + * + * For example, while inside an IfThenElse, will return the + * IfThenElse's condition. + */ + PrimExpr CurrentScopePredicate() const { + PrimExpr predicate = Bool(true); + for (const auto& condition : conditions_) { + predicate = predicate && condition; + } + return predicate; + } + + /* \brief Add a new control block, returning its index */ + size_t AppendControlBlock() { + size_t index = out_->control_flow_.size(); + auto& block = out_->control_flow_.emplace_back(); + block.active_loop_iterators = active_loop_iterators_; + block.let_bindings_using_loop = let_bindings_using_loop_; + block.scope_predicate = CurrentScopePredicate(); + return index; + } + + /* \brief The index of the current control block */ + size_t CurrentControlBlock() { return out_->control_flow_.size() - 1; } + + /* \brief Mark a possible control from one block to another + * + * \param from_block The block from which control leaves + * + * \param to_block The block to which control enters + * + * \param var_remap Variable replacements that should be made in + * known expression while traversing this edge. For example, + * replacing `i` with `i-1` when entering the next loop iteration, + * or replacing `i` with `n-1` when concluding a loop. + */ + std::pair MarkControlFlow( + size_t from_block, size_t to_block) { + ICHECK_LE(from_block, out_->control_flow_.size()); + ICHECK_LE(to_block, out_->control_flow_.size()); + + auto& forward = out_->control_flow_[from_block].successors.emplace_back( + ControlFlowGraph::ControlFlowEdge{to_block, {}, NullOpt}); + auto& backward = out_->control_flow_[to_block].predecessors.emplace_back( + ControlFlowGraph::ControlFlowEdge{from_block, {}, NullOpt}); + return {forward, backward}; + } + + // Internal utility, context manager for entering/leaving a scoped constraint + struct InternalConstraintContext { + InternalConstraintContext(ControlFlowGraphBuilder* self, PrimExpr constraint) + : self(self), analyzer_context(&self->analyzer_, constraint) { + old_num_constraints = self->conditions_.size(); + + auto side_effect = tir::SideEffect(constraint); + if (side_effect <= tir::CallEffectKind::kPure) { + self->conditions_.push_back(constraint); + } else if (side_effect <= tir::CallEffectKind::kReadState) { + assume = constraint; + } + + new_num_constraints = self->conditions_.size(); + } + ~InternalConstraintContext() { + ICHECK_EQ(self->conditions_.size(), new_num_constraints) + << "Internal error: Each condition should only be popped once."; + self->conditions_.erase(self->conditions_.begin() + old_num_constraints, + self->conditions_.end()); + } + + ControlFlowGraphBuilder* self{nullptr}; + With analyzer_context; + size_t old_num_constraints{0}; + size_t new_num_constraints{0}; + Optional assume{NullOpt}; + + // Disable default-generated copy/move assignment and constructors + InternalConstraintContext(const InternalConstraintContext&) = delete; + InternalConstraintContext& operator=(const InternalConstraintContext&) = delete; + InternalConstraintContext(InternalConstraintContext&&) = delete; + InternalConstraintContext& operator=(InternalConstraintContext&&) = delete; + }; + + // Internal utility, context manager for tracking a loop + struct BindActiveLoopVar { + BindActiveLoopVar(ControlFlowGraphBuilder* self, Var var, PrimExpr loop_min, + PrimExpr loop_extent) + : self(self), var(var) { + PrimExpr loop_max = loop_min + (loop_extent - 1); + auto loop_range = Range::FromMinExtent(loop_min, loop_extent); + self->active_loop_iterators_.push_back({var, loop_min, loop_max, loop_range}); + self->loop_dependent_vars_.insert(var.get()); + } + ~BindActiveLoopVar() { self->active_loop_iterators_.pop_back(); } + + ControlFlowGraphBuilder* self; + Var var; + + // Disable default-generated copy/move assignment and constructors + BindActiveLoopVar(const BindActiveLoopVar&) = delete; + BindActiveLoopVar& operator=(const BindActiveLoopVar&) = delete; + BindActiveLoopVar(BindActiveLoopVar&&) = delete; + BindActiveLoopVar& operator=(BindActiveLoopVar&&) = delete; + }; + + // Internal utility, context manager for tracking a variable binding + struct BindLetVar { + BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) : self(self), var(var) { + self->let_bindings_using_loop_.Set(var, value); + self->loop_dependent_vars_.insert(var.get()); + } + ~BindLetVar() { + self->loop_dependent_vars_.erase(var.get()); + self->let_bindings_using_loop_.erase(var); + } + ControlFlowGraphBuilder* self; + Var var; + + // Disable default-generated copy/move assignment and constructors + BindLetVar(const BindLetVar&) = delete; + BindLetVar& operator=(const BindLetVar&) = delete; + BindLetVar(BindLetVar&&) = delete; + BindLetVar& operator=(BindLetVar&&) = delete; + }; + + struct LoopEntry { + Var loop_var; + PrimExpr loop_min; + PrimExpr loop_max; + Range loop_range; + }; + + // Track in order to know which Vars to write in terms of the buffer + // indices and substitute out of the predicate. + std::vector active_loop_iterators_; + + // Track all loop iterators, along with values derived from loop iterators. + std::unordered_set loop_dependent_vars_; + + // Any let binding that depends, directly or indirectly, on a loop + // binding. When making a predicate in terms of the buffer indices, + // these need to be substituted out. + // std::unordered_map let_bindings_using_loop_; + Map let_bindings_using_loop_; + + // Track in order to know what conditions limit the buffer access + std::vector conditions_; + + // Track in order to know what statement initiated the buffer access + Stmt current_stmt_; + + // Output data structure + ControlFlowGraph* out_; +}; + +std::pair> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( + const tir::Buffer& buf, Array index_variables, Array indices, + BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { + const auto& current_block = *this; + + Analyzer local_analyzer; + + Optional lane_var = NullOpt; + IntImm num_lanes; + + Array index_expressions = indices.Map([&](const auto& index) { + if (index.dtype().lanes() == 1) { + return index; + } else { + ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; + lane_var = Var("lane", index.dtype().element_of()); + num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); + return UnwrapVectorExpr(index, lane_var.value()); + } + }); + + Array loop_vars; + + Map loop_ranges; + for (const auto& loop_entry : current_block.active_loop_iterators) { + loop_vars.push_back(loop_entry.loop_var); + loop_ranges.Set(loop_entry.loop_var, loop_entry.loop_range); + } + + // If the indices contain multiple lanes, treat the lane variable + // as an additional loop iterator to be solved for and substituted + // out. + if (lane_var) { + loop_vars.push_back(lane_var.value()); + loop_ranges.Set(lane_var.value(), Range::FromMinExtent(0, num_lanes)); + } + + IntConstraintsTransform transform = [&]() { + ICHECK_EQ(index_variables.size(), index_expressions.size()); + + Array relations; + + for (size_t i = 0; i < index_expressions.size(); i++) { + PrimExpr expr = index_expressions[i]; + Var var = index_variables[i]; + + expr = Substitute(expr, current_block.let_bindings_using_loop); + relations.push_back(var == expr); + } + + IntConstraints system(loop_vars, loop_ranges, relations); + return arith::SolveLinearEquations(system); + }(); + + Map loop_var_to_axis_var = transform->src_to_dst; + Map free_params = transform->dst->ranges; + PrimExpr transform_predicate = + std::accumulate(transform->dst->relations.begin(), transform->dst->relations.end(), + PrimExpr(Bool(true)), [](PrimExpr a, PrimExpr b) { return a && b; }); + + transform_predicate = SimplifyAsAndOfOrs(transform_predicate, &local_analyzer); + + auto find_removable_params = [&]() -> Map { + Map removable_params; + + // The arith::SolveLinearEquations is more general than the + // utilities in iter_affine_map.h, but can introduce free + // parameters that could later be determined with the known + // constraints. This step removes all such free parameters. + for (const auto& expr : ExtractConstraints(transform_predicate)) { + if (auto* as_equal = expr.as()) { + auto check_expr = [&](const PrimExpr& a, const PrimExpr& b) { + auto* var_ptr = a.as(); + if (!var_ptr) { + return; + } + + Var var = GetRef(var_ptr); + if (free_params.count(var) == 0) { + return; + } + + bool uses_free_param = + UsesVar(b, [&](const VarNode* v) { return free_params.count(GetRef(v)) > 0; }); + if (uses_free_param) { + return; + } + removable_params.Set(var, b); + }; + check_expr(as_equal->a, as_equal->b); + check_expr(as_equal->b, as_equal->a); + } + } + + // In addition, the arith::SolveLinearEquation can introduce + // free parameters with an extent of one. Filtering them out here + // avoids needing to track them through later simplifications. + for (const auto [var, range] : free_params) { + if (is_one(range->extent)) { + removable_params.Set(var, range->min); + } + } + + return removable_params; + }; + for (auto removable_params = find_removable_params(); removable_params.size() > 0; + removable_params = find_removable_params()) { + auto update = [&](const PrimExpr& expr) { + return local_analyzer.Simplify(Substitute(expr, removable_params)); + }; + + Map new_map; + for (const auto [loop_var, expr] : loop_var_to_axis_var) { + static_cast(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 + new_map.Set(loop_var, update(expr)); + } + loop_var_to_axis_var = new_map; + + transform_predicate = update(transform_predicate); + + for (const auto [var, expr] : removable_params) { + static_cast(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 + free_params.erase(var); + } + } + + // Normalization function, applied to both the predicate and the + // known value. Converts from an expression in terms of loop + // iterators to an expression in terms of buffer indices. + auto normalize_expr = [&](PrimExpr expr) -> PrimExpr { + expr = Substitute(expr, current_block.let_bindings_using_loop); + + if (lane_var) { + expr = UnwrapVectorExpr(expr, lane_var.value()); + } + expr = Substitute(expr, loop_var_to_axis_var); + + return expr; + }; + + // Collect the current loop variables, along with an expression for + // the loop variables in terms of the buffer axis variables. This + // is used during forward/backward propagation to generate predicate + // tracking whether a loop iteration has been reached. + std::vector> loop_var_expressions; + for (const auto& entry : current_block.active_loop_iterators) { + auto expr_it = loop_var_to_axis_var.find(entry.loop_var); + ICHECK(expr_it != loop_var_to_axis_var.end()); + loop_var_expressions.push_back({entry.loop_var, (*expr_it).second}); + } + + // The full predicate is composed of the values required to reach + // the scope of the BufferStore or builtin::assume(), any bounds + // implied by solving for the axis variables, and any additional + // statements resulting from unpacking the expression contained in + // builtin::assume(). + PrimExpr scope_predicate = normalize_expr(current_block.scope_predicate); + transform_predicate = normalize_expr(transform_predicate); + + known_value_expr = local_analyzer.Simplify(normalize_expr(known_value_expr)); + + // Deliberately use an analyzer without scope-based information, + // to avoid simplifying `scope_predicate` to True. + PrimExpr predicate_expr = local_analyzer.Simplify(transform_predicate && scope_predicate); + + BufferTouch buffer_touch = {buf, predicate_expr, known_value_expr, loop_var_expressions, + touch_type}; + + return {buffer_touch, free_params}; +} + +BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph* graph, + const tir::Buffer& buf, + const Array& indices, + BufferTouch::AccessType touch_type, + PrimExpr known_value_expr) const { + ICHECK(graph); + auto [buffer_touch, free_params] = MakeBufferTouch(buf, graph->GetIndexVariables(buf, indices), + indices, touch_type, known_value_expr); + for (const auto& pair : free_params) { + graph->free_predicate_parameters_.Set(pair.first, pair.second); + } + return buffer_touch; +} + +ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) { + ControlFlowGraphBuilder::Build(this, stmt); + ForwardPropagateKnownValues(max_revisits); + BackwardPropagateUnusedValues(max_revisits); +} + +std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) { + os << edge.index; + if (edge.var_remap.size()) { + os << " with remap " << edge.var_remap; + } + if (edge.post_condition) { + os << " with postcondition " << edge.post_condition; + } + + return os; +} + +std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowBlock& block) { + os << "Predecessors: ["; + for (size_t i = 0; i < block.predecessors.size(); i++) { + if (i) { + os << ", "; + } + os << block.predecessors[i]; + } + os << "]\n"; + + os << "Active loop iterators: ["; + for (size_t i = 0; i < block.active_loop_iterators.size(); i++) { + if (i) { + os << ", "; + } + os << block.active_loop_iterators[i].loop_var; + } + os << "]\n"; + + os << "Before block knowns: " << block.known_at_block_start << "\n"; + + os << "Before block unused: " << block.unused_at_block_start << "\n"; + + for (size_t i = 0; i < block.touch_points.size(); i++) { + os << "Touch[" << i << "] = " << block.touch_points[i] << "\n"; + } + os << "After block: " << block.known_at_block_end << "\n"; + + os << "After block unused: " << block.unused_at_block_end << "\n"; + + os << "Successors: ["; + for (size_t i = 0; i < block.successors.size(); i++) { + if (i) { + os << ", "; + } + os << block.successors[i]; + } + os << "]"; + return os; +} + +std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern) { + os << "Touch pattern contains " << pattern.control_flow_.size() << " control blocks." + << (pattern.control_flow_.size() ? "\n" : ""); + for (size_t i = 0; i < pattern.control_flow_.size(); i++) { + os << "\t" + << "ControlBlock[" << i << "] = " << pattern.control_flow_[i] << "\n"; + } + + return os; +} + +bool BufferTouch::IsEquivalentTo(const BufferTouch& other, Analyzer* analyzer) const { + // Constraints must apply to the same buffer to be equivalent + if (!buffer.same_as(other.buffer) || touch_type != other.touch_type) { + return false; + } + + ExprDeepEqual deep_equal; + + auto implies = [&](const PrimExpr& a, const PrimExpr& b) -> bool { + With context(analyzer, a); + return analyzer->CanProve(b); + }; + + // Predicates must be equivalent expressions, or must both be undefined + bool equivalent_predicates = + deep_equal(predicate, other.predicate) || + (implies(predicate, other.predicate) && implies(other.predicate, predicate)); + if (!equivalent_predicates) { + return false; + } + + // The known value must be equal + if (!deep_equal(value, other.value) && !analyzer->CanProveEqual(value, other.value)) { + return false; + } + + return true; +} + +std::ostream& operator<<(std::ostream& os, const BufferState& state) { + for (size_t i = 0; i < state.constraints_.size(); i++) { + os << "constraints[" << i << "] = " << state.constraints_[i] + << (i + 1 == state.constraints_.size() ? "" : "\n"); + } + return os; +} + +PrimExpr BufferState::SubstituteKnownBufferValues( + PrimExpr expr, const Map>& axis_var_lookup, + Analyzer* analyzer) const { + BufferConstraintApply mutator(axis_var_lookup, constraints_, analyzer); + return mutator(std::move(expr)); +} + +void BufferState::AddCondition(const PrimExpr& condition) { + for (auto& constraint : constraints_) { + constraint.predicate = constraint.predicate && condition; + } +} + +void BufferState::Substitute(const Map& var_remap, Analyzer* analyzer) { + if (var_remap.size()) { + for (auto& prior : constraints_) { + PrimExpr updated = tvm::tir::Substitute(prior.predicate, var_remap); + if (!updated.same_as(prior.predicate)) { + prior.predicate = SimplifyAsAndOfOrs(updated, analyzer); + } + } + } +} + +void BufferState::Simplify(Analyzer* analyzer) { + for (auto& constraint : constraints_) { + constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate, analyzer); + } +} + +void BufferState::Union(const BufferState& b, Analyzer* analyzer) { + for (const auto& b_constraint : b.constraints_) { + bool used = false; + for (auto& a_constraint : constraints_) { + if (a_constraint.buffer.same_as(b_constraint.buffer) && + analyzer->CanProveEqual(a_constraint.value, b_constraint.value)) { + a_constraint.predicate = + SimplifyAsAndOfOrs(a_constraint.predicate || b_constraint.predicate, analyzer); + used = true; + break; + } + } + if (!used) { + constraints_.push_back(b_constraint); + } + } +} + +void BufferState::Intersection(const BufferState& b, Analyzer* analyzer) { + // For a constraint to be in the output, it must be present in both + // inputs. + + std::vector new_constraints; + for (const auto& ai : constraints_) { + for (const auto& bi : b.constraints_) { + if (ai.buffer.same_as(bi.buffer)) { + PrimExpr predicate = SimplifyAsAndOfOrs(ai.predicate && bi.predicate, analyzer); + if (!is_zero(predicate)) { + With context(analyzer, predicate); + PrimExpr known_value_a = ai.value; + PrimExpr known_value_b = bi.value; + + bool is_consistent = analyzer->CanProveEqual(known_value_a, known_value_b); + if (is_consistent) { + new_constraints.push_back({ai.buffer, predicate, known_value_a}); + } + } + } + } + } + + constraints_ = std::move(new_constraints); +} + +class BufferRegionCollector : public ExprVisitor { + public: + struct Region { + PrimExpr region_predicate; + std::unordered_map> known_values; + }; + + static std::vector Collect(const Map>& axis_var_lookup, + const std::vector& knowns, + const std::vector>& exprs, + Analyzer* analyzer) { + BufferRegionCollector collector(axis_var_lookup, knowns, analyzer); + for (const auto& expr : exprs) { + if (expr) { + collector(expr.value()); + } + } + + return collector.regions_; + } + + private: + using Parent = ExprVisitor; + + BufferRegionCollector(const Map>& axis_var_lookup, + const std::vector& knowns, Analyzer* analyzer) + : analyzer_(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) { + regions_.push_back(Region{Bool(true), {}}); + } + + using Parent::VisitExpr_; + + void VisitExpr_(const BufferLoadNode* op) override { + // Helper struct for the known values of this BufferLoad + struct Known { + PrimExpr predicate; + Optional value; + }; + + std::vector new_regions; + + PrimExpr unknown_region = Bool(true); + + for (const BufferTouch& constraint : knowns_) { + if (!op->buffer.same_as(constraint.buffer)) { + // This is a different buffer, so continue searching. + continue; + } + + auto axis_vars = axis_var_lookup_.at(op->buffer); + PrimExpr touch_predicate = + SubstituteParamValues(axis_vars, op->indices, constraint.predicate).value(); + touch_predicate = SimplifyAsAndOfOrs(touch_predicate, analyzer_); + + if (!is_zero(touch_predicate)) { + Optional known_value = + SubstituteParamValues(axis_vars, op->indices, constraint.value); + new_regions.push_back(Known{touch_predicate, known_value}); + + unknown_region = unknown_region && !touch_predicate; + unknown_region = SimplifyAsAndOfOrs(unknown_region, analyzer_); + } + } + + if (new_regions.size()) { + Analyzer local_analyzer; + + if (!is_zero(unknown_region)) { + new_regions.insert(new_regions.begin(), Known{unknown_region, NullOpt}); + } + + std::vector updated_regions; + for (const auto& prev_region : regions_) { + for (const auto& new_region : new_regions) { + PrimExpr intersection = + SimplifyAsAndOfOrs(prev_region.region_predicate && new_region.predicate, analyzer_); + + if (!is_zero(intersection)) { + Region merged{intersection, prev_region.known_values}; + merged.known_values[op] = new_region.value; + updated_regions.push_back(std::move(merged)); + } + } + } + regions_ = updated_regions; + } + } + + Analyzer* analyzer_; + std::vector regions_; + const Map>& axis_var_lookup_; + const std::vector& knowns_; +}; + +class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { + public: + static PrimExpr Apply( + const std::unordered_map>& known_values, + PrimExpr expr, Analyzer* analyzer) { + BufferRegionValueReplacer mutator(known_values, analyzer); + PrimExpr result = mutator(expr); + // Simplification must occur after the substitution, as known + // values may provide enable simplifications. Also, cannot track + // whether a BufferLoad was + result = analyzer->Simplify(result); + return result; + } + + private: + using Parent = IRMutatorWithAnalyzer; + + BufferRegionValueReplacer( + const std::unordered_map>& known_values, + Analyzer* analyzer) + : Parent(analyzer), known_values_(known_values) {} + + using Parent::VisitExpr_; + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + auto it = known_values_.find(op); + if (it != known_values_.end() && it->second) { + return it->second.value(); + } else { + return GetRef(op); + } + } + + const std::unordered_map>& known_values_; +}; + +void BufferState::ApplyTouches(const Map>& axis_var_lookup, + const std::vector& touch_points, Analyzer* analyzer) { + std::vector new_knowns; + Map keep_prior_known_at; + + for (auto& touch : touch_points) { + if (touch.touch_type == BufferTouch::AccessType::Read) { + continue; + } + + PrimExpr known_value = touch.value; + + PrimExpr predicate = touch.predicate && touch.AfterLoopIteration(); + auto regions = BufferRegionCollector::Collect(axis_var_lookup, constraints_, + {predicate, touch.value}, analyzer); + + for (const auto& region : regions) { + PrimExpr updated_predicate = BufferRegionValueReplacer::Apply( + region.known_values, region.region_predicate && predicate, analyzer); + + updated_predicate = SimplifyAsAndOfOrs(updated_predicate, analyzer); + PrimExpr updated_value = + BufferRegionValueReplacer::Apply(region.known_values, known_value, analyzer); + + if (!is_zero(updated_predicate)) { + if (auto it = keep_prior_known_at.find(touch.buffer); it != keep_prior_known_at.end()) { + keep_prior_known_at.Set(touch.buffer, (*it).second && !updated_predicate); + } else { + keep_prior_known_at.Set(touch.buffer, !updated_predicate); + } + + if (!HasBufferLoad(updated_value)) { + BufferTouch new_constraint{touch.buffer, updated_predicate, updated_value}; + new_knowns.push_back(new_constraint); + } + } + } + } + + if (keep_prior_known_at.size()) { + for (auto& constraint : constraints_) { + if (auto it = keep_prior_known_at.find(constraint.buffer); it != keep_prior_known_at.end()) { + constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate && (*it).second, analyzer); + } + } + } + + if (new_knowns.size()) { + std::vector used(new_knowns.size(), false); + + for (auto& constraint : constraints_) { + PrimExpr expand_known_at = Bool(false); + + PrimExpr prev_value = constraint.value; + + for (size_t i = 0; i < new_knowns.size(); i++) { + if (new_knowns[i].buffer.same_as(constraint.buffer)) { + Optional overwritten_with = new_knowns[i].value; + if (overwritten_with && analyzer->CanProveEqual(prev_value, overwritten_with.value())) { + expand_known_at = + SimplifyAsAndOfOrs(expand_known_at || new_knowns[i].predicate, analyzer); + used[i] = true; + } + } + } + + if (!is_zero(expand_known_at)) { + constraint.predicate = + SimplifyAsAndOfOrs(constraint.predicate || expand_known_at, analyzer); + } + } + + for (size_t i = 0; i < new_knowns.size(); i++) { + if (!used[i]) { + constraints_.push_back(new_knowns[i]); + } + } + } + + constraints_.erase( + std::remove_if(constraints_.begin(), constraints_.end(), + [&](const auto& constraint) { return is_zero(constraint.predicate); }), + constraints_.end()); +} + +void BufferState::BackpropUnusedIndices(const Map>& axis_var_lookup, + const std::vector& touch_points, + Analyzer* analyzer) { + std::vector new_knowns; + Map keep_prior_known_at; + + Map regions_written; + Map regions_read; + for (auto it = touch_points.rbegin(); it != touch_points.rend(); it++) { + const auto& touch = *it; + + Map* to_update{nullptr}; + if (touch.touch_type == BufferTouch::AccessType::Write) { + to_update = ®ions_written; + + } else if (touch.touch_type == BufferTouch::AccessType::Read) { + to_update = ®ions_read; + } else { + continue; + } + + PrimExpr prev = to_update->Get(touch.buffer).value_or(Bool(false)); + PrimExpr new_predicate = touch.predicate && touch.BeforeLoopIteration(); + to_update->Set(touch.buffer, prev || new_predicate); + } + + auto update_map = [&](auto& map) { + Map new_map; + for (auto [buffer, predicate] : map) { + new_map.Set(buffer, SimplifyAsAndOfOrs(predicate, analyzer)); + } + map = std::move(new_map); + }; + update_map(regions_written); + update_map(regions_read); + + // If buffer is already in used, widen the predicate + for (auto& prev_unused : constraints_) { + if (auto opt_predicate = regions_written.Get(prev_unused.buffer)) { + PrimExpr new_predicate = prev_unused.predicate || opt_predicate.value(); + prev_unused.predicate = SimplifyAsAndOfOrs(new_predicate, analyzer); + regions_written.erase(prev_unused.buffer); + } + } + + // Otherwise, add new "touch" to represent the unused values + for (auto [buffer, predicate] : regions_written) { + constraints_.push_back( + BufferTouch{buffer, predicate, tir::Call(buffer->dtype, builtin::undef(), {})}); + } + + // If buffer is read out, narrow the predicate + for (auto& prev_unused : constraints_) { + if (auto opt_pred = regions_read.Get(prev_unused.buffer)) { + PrimExpr predicate = opt_pred.value(); + prev_unused.predicate = SimplifyAsAndOfOrs(prev_unused.predicate && !predicate, analyzer); + } + } + + // Clean-up and remove any empty constraints + constraints_.erase( + std::remove_if(constraints_.begin(), constraints_.end(), + [](const auto& constraint) { return is_zero(constraint.predicate); }), + constraints_.end()); +} + +void BufferState::RemoveFreeParameters(const Map& free_predicate_parameters, + Analyzer* analyzer) { + for (auto& known : constraints_) { + known.predicate = NarrowPredicateExpression(known.predicate, free_predicate_parameters); + known.predicate = SimplifyAsAndOfOrs(known.predicate, analyzer); + } +} + +bool BufferState::IsEquivalentTo(const BufferState& other, Analyzer* analyzer) const { + if (constraints_.size() != other.constraints_.size()) { + return false; + } + + for (size_t i = 0; i < constraints_.size(); i++) { + if (!constraints_[i].IsEquivalentTo(other.constraints_[i], analyzer)) { + return false; + } + } + + return true; +} + +Optional> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const { + if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { + return (*it).second; + } else { + return NullOpt; + } +} + +Array ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array& indices) { + if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { + return (*it).second; + } + + Array vars; + for (size_t i = 0; i < indices.size(); i++) { + std::stringstream ss; + ss << buf->name << "_axis_" << i; + vars.push_back(Var(ss.str(), indices[i].dtype().element_of())); + } + + axis_var_lookup_.Set(buf, vars); + return vars; +} + +void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) { + // Values to visit when searching. Using a std::set to + // preferentially visit nodes near the start of the control flow. + std::set to_visit; + + // Map from a block's index + std::unordered_map visit_count_lookup; + + // Initiatize the locations to search from, propagating values + // forward from all locations that have a known value. + for (size_t i = 0; i < control_flow_.size(); i++) { + bool has_known_value = false; + for (const auto& touch : control_flow_[i].touch_points) { + if (!HasBufferLoad(touch.value)) { + has_known_value = true; + break; + } + } + + if (has_known_value) { + to_visit.insert(i); + } + } + + Analyzer analyzer; + analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( + arith::RewriteSimplifier::kTransitivelyProveInequalities | + arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); + + analyzer.Bind(iterator_ranges_); + analyzer.Bind(free_predicate_parameters_); + + while (to_visit.size()) { + size_t visiting = *to_visit.begin(); + to_visit.erase(visiting); + + size_t num_previous_visits = visit_count_lookup[visiting]++; + + ControlFlowBlock& block = control_flow_[visiting]; + + // Step 1: Collect known values provided from each predecessor + block.known_at_block_start = [&]() -> BufferState { + if (num_previous_visits >= max_revisits) { + return BufferState(); + } + + // Validate internal constraint. This should be true by + // construction, as ControlFlowGraphBuilder only builds graphs + // that have two or fewer predecessors. + ICHECK_LE(block.predecessors.size(), 2) + << "InternalError: Each block should have at most two predecessors. " + << "Graph constructed in ControlFlowGraphBuilder did not satisfy this constraint."; + + std::vector states; + for (const auto& pred : block.predecessors) { + const auto& pred_block = control_flow_[pred.index]; + BufferState state = pred_block.known_at_block_end; + state.Substitute(pred.var_remap, &analyzer); + states.push_back(state); + } + + if (std::all_of(block.predecessors.begin(), block.predecessors.end(), + [&](const auto& pred) { return visit_count_lookup[pred.index] == 0; })) { + // Predecessors, if any, are unvisited. + return {}; + } else if (block.predecessors.size() == 1) { + // Block has only a single predecessor + return states[0]; + } + + const auto& pred_a = block.predecessors[0]; + const auto& pred_b = block.predecessors[1]; + + auto& priors_a = states[0]; + auto& priors_b = states[1]; + + // During the first visit of a block, predecessor blocks may be + // unvisited, even though we preferentially visit earlier blocks + // first. (e.g. During the first visit of the start of a For + // loop, the end of the For loop has not yet been visited.) If + // this is the case, assume the best-case scenario that all + // knowns are consistent, and rely on a later visit to + // resolve/remove any conflicts. + if (visit_count_lookup[pred_a.index] == 0) { + return priors_b; + } else if (visit_count_lookup[pred_b.index] == 0) { + return priors_a; + } + + if (pred_a.post_condition && pred_b.post_condition) { + // The predicate can identify which predecessor block applies + // (e.g. i==0 for the first loop iteration, i>0 for remaining + // loop iterations). Therefore, we can use all buffer + // constraints, conditional on having come from the + // predecessor that provides it. + priors_a.AddCondition(pred_a.post_condition.value()); + priors_b.AddCondition(pred_b.post_condition.value()); + priors_a.Union(priors_b, &analyzer); + return priors_a; + } else { + // We don't know which predecessor applies. Therefore, the + // only buffer constraints that can be used are those that + // appear in both predecessors. + priors_a.Intersection(priors_b, &analyzer); + return priors_a; + } + }(); + + // Step 2: Collect knowns provided as a result of executing this block + auto post_state = [&]() { + if (num_previous_visits >= max_revisits) { + return BufferState(); + } + auto post_state = block.known_at_block_start; + post_state.ApplyTouches(axis_var_lookup_, block.touch_points, &analyzer); + post_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer); + return post_state; + }(); + + // Step 3: If any changes are made to the post knowns since the + // previous time we visited this block, mark the successor block + // as needing to be visited. + if (num_previous_visits == 0 || + !post_state.IsEquivalentTo(block.known_at_block_end, &analyzer)) { + block.known_at_block_end = std::move(post_state); + for (const auto& successor : block.successors) { + to_visit.insert(successor.index); + } + } + } +} + +void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) { + // Values to visit when searching. Using a std::set to + // preferentially visit nodes near the end of the control flow. + std::set to_visit; + + // Map from a block's index + std::unordered_map visit_count_lookup; + + // Initiatize the locations to search from, propagating values + // backward from anywhere that performs a write. + for (size_t i = 0; i < control_flow_.size(); i++) { + const auto& touch_points = control_flow_[i].touch_points; + bool performs_write = std::any_of( + touch_points.begin(), touch_points.end(), + [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; }); + if (performs_write) { + to_visit.insert(i); + } + } + + Analyzer analyzer; + analyzer.rewrite_simplify.SetEnabledExtensions( + arith::RewriteSimplifier::kTransitivelyProveInequalities); + + analyzer.Bind(iterator_ranges_); + analyzer.Bind(free_predicate_parameters_); + + while (to_visit.size()) { + size_t visiting = *to_visit.rbegin(); + to_visit.erase(visiting); + + size_t num_previous_visits = visit_count_lookup[visiting]++; + + ControlFlowBlock& block = control_flow_[visiting]; + + // Step 1: Collect known unused indices provided by each successor + block.unused_at_block_end = [&]() -> BufferState { + if (num_previous_visits >= max_revisits) { + return BufferState(); + } + ICHECK_LE(block.successors.size(), 2) + << "Each block should have at most two successors, but block " << visiting + << " breaks this requirement"; + + std::vector states; + for (const auto& successor : block.successors) { + const auto& successor_block = control_flow_[successor.index]; + BufferState state = successor_block.unused_at_block_start; + state.Substitute(successor.var_remap, &analyzer); + states.push_back(state); + } + + if (std::all_of(block.successors.begin(), block.successors.end(), [&](const auto& successor) { + return visit_count_lookup[successor.index] == 0; + })) { + // Successors, if any, are unvisited. + return {}; + } else if (block.successors.size() == 1) { + // Block has only a single successor + return states[0]; + } + + const auto& successor_a = block.successors[0]; + const auto& successor_b = block.successors[1]; + + auto& post_a = states[0]; + auto& post_b = states[1]; + + // During the first visit of a block, successor blocks may be + // unvisited, even though we preferentially visit later blocks + // first. (e.g. During the first visit of the end of a For + // loop, the start of the For loop has not yet been visited.) + // If this is the case, assume the best-case scenario that all + // knowns are consistent, and rely on a later visit to + // resolve/remove any conflicts. + if (visit_count_lookup[successor_a.index] == 0) { + return post_b; + } else if (visit_count_lookup[successor_b.index] == 0) { + return post_a; + } + + if (successor_a.post_condition && successor_b.post_condition) { + // The predicate can identify which successor block applies + // (e.g. i==n-1 for the last loop iteration, i= max_revisits) { + return BufferState(); + } + auto prior_state = block.unused_at_block_end; + prior_state.BackpropUnusedIndices(axis_var_lookup_, block.touch_points, &analyzer); + prior_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer); + return prior_state; + }(); + + // Step 3: If any changes are made to the post knowns since the + // previous time we visited this block, mark the successor block + // as needing to be visited. + if (num_previous_visits == 0 || + !unused_at_block_start.IsEquivalentTo(block.unused_at_block_start, &analyzer)) { + block.unused_at_block_start = std::move(unused_at_block_start); + for (const auto& pred : block.predecessors) { + to_visit.insert(pred.index); + } + } + } +} + +bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, + const Stmt& context) const { + Optional> index_variables = GetIndexVariables(store->buffer); + if (!index_variables) { + return false; + } + + auto it = control_flow_lookup_.find(context.get()); + ICHECK(it != control_flow_lookup_.end()) + << "Context " << PrettyPrint(context) << " did not occur within analyzed statement"; + const auto& context_block = control_flow_[it->second]; + + auto [store_touch, free_params] = context_block.MakeBufferTouch( + store->buffer, index_variables.value(), store->indices, BufferTouch::AccessType::Write, + BufferLoad(store->buffer, store->indices)); + + Analyzer local_analyzer; + local_analyzer.Bind(free_predicate_parameters_); + local_analyzer.Bind(iterator_ranges_); + local_analyzer.Bind(free_params); + local_analyzer.rewrite_simplify.SetEnabledExtensions( + RewriteSimplifier::kTransitivelyProveInequalities); + + PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration(); + + predicate = SimplifyAsAndOfOrs(predicate, &local_analyzer); + + for (const auto& unused : context_block.unused_at_block_end.constraints_) { + if (store_touch.buffer.same_as(unused.buffer)) { + PrimExpr difference = SimplifyAsAndOfOrs(predicate && !unused.predicate, &local_analyzer); + if (is_zero(difference)) { + return true; + } + } + } + return false; +} + +PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& context, + Analyzer* analyzer) const { + size_t context_index = [&]() { + auto it = control_flow_lookup_.find(context.get()); + ICHECK(it != control_flow_lookup_.end()) + << "Context did not occur in the Stmt provided to BufferTouchPattern's constructor"; + return it->second; + }(); + + PrimExpr constraint = Bool(true); + for (const auto& known : non_buffer_assumptions_) { + constraint = constraint && known; + } + With constraint_context(analyzer, constraint); + + expr = control_flow_[context_index].known_at_block_start.SubstituteKnownBufferValues( + std::move(expr), axis_var_lookup_, analyzer); + + expr = analyzer->Simplify(std::move(expr)); + return expr; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h new file mode 100644 index 000000000000..aa9023ba29dd --- /dev/null +++ b/src/tir/analysis/control_flow_graph.h @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file control_flow_graph.h + * \brief Utility for extracting and interacting with buffer touch points + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifndef TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_ +#define TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_ + +namespace tvm { +namespace tir { + +/*! \brief Represents an interaction with a buffer */ +struct BufferTouch { + enum class AccessType { + /*! \brief Buffer access occurs in BufferLoad */ + Read, + + /*! \brief Buffer access occurs in BufferStore */ + Write, + + /*! \brief Buffer access occurs in tir::builtin::assume() */ + Assume, + }; + + BufferTouch(Buffer buffer, PrimExpr predicate, PrimExpr value) + : buffer(buffer), + predicate(predicate), + value(value), + loop_var_expressions({}), + touch_type(AccessType::Assume) {} + + BufferTouch(Buffer buffer, PrimExpr predicate, PrimExpr value, + std::vector> loop_var_expressions, AccessType touch_type) + : buffer(buffer), + predicate(predicate), + value(value), + loop_var_expressions(loop_var_expressions), + touch_type(touch_type) {} + + /*! \brief The buffer being touched */ + Buffer buffer; + + /*! \brief A predicate that is true when this touch applies + * + * May be in terms of axis variables to indicate touches that impact + * only a portion of a buffer. + */ + PrimExpr predicate; + + /*! \brief The value in this buffer after the touch + * + * May be in terms of axis variables to indicate a known + * non-constant value. May be in terms of a BufferLoad to indicate + * an unknown value. + */ + PrimExpr value; + + /*! \brief Active loops during the buffer touch + * + * The vector contains one entry for each loop that contains the + * buffer touch. The `Var` item in each entry is the loop variable + * itself. The `PrimExpr` item is an expression for the loop + * variable in terms of the buffer axis variables in + * `ControlFlowGraph::axis_var_lookup_`. + * + * Used to construct boolean expressions indicating whether the loop + * iteration that performs this touch has been reached. + */ + std::vector> loop_var_expressions; + + /*! \brief How the buffer was interacted with + * + * When used as a constraint (e.g. in BufferState), should use + * Assume. + */ + AccessType touch_type{AccessType::Assume}; + + /*! \brief Generate a boolean expression that is true for indices + * accessed by this touch during this iteration or a previous + * loop iteration. + * + * Used during forward propagation, to track known values that were + * written in the current loop iteration, or in a preceding loop + * iteration. + */ + PrimExpr BeforeLoopIteration() const; + + /*! \brief Generate a boolean expression that is true for indices + * accessed by this touch during this loop iteration. + * + * Used during speculative no-op insertion checks, to specify which + * indices must be later overwritten for a store to have no impact + * on final results. + */ + PrimExpr AtLoopIteration() const; + + /*! \brief Generate a boolean expression that is true for indices + * accessed by this touch during this loop iteration or a + * subsequent loop iteration. + * + * Used during backward propagation, to track indices that that are + * overwritten in the current loop iteration or in a later loop + * iteration. + */ + PrimExpr AfterLoopIteration() const; + + /* \brief Checks if this touch affects a subset of indices of another + * + * Returns true if the indices accessed by this touch are a subset + * of predicate is true can be proven to be a subset of the other + * subset. Returns false if it cannot be proven to be a subset of + * ther other subset. + */ + bool IsSubsetOf(const BufferTouch& other, arith::Analyzer* analyzer) const; + + /* \brief Checks if this touch affects distinct indices from another + * + * Returns true if it can be proven that the two predicates cannot + * be simultaneously true. Returns false if it cannot be proven + * that the two predicates are distinct. + */ + bool IsDistinctFrom(const BufferTouch& other, arith::Analyzer* analyzer) const; + + /* \brief Checks if this touch affects distinct indices from another + * + * Returns true if it can be proven that the two predicates cannot + * be simultaneously true. Returns false if it cannot be proven + * that the two predicates are distinct. + */ + bool IsEquivalentTo(const BufferTouch& other, arith::Analyzer* analyzer) const; + + friend std::ostream& operator<<(std::ostream& os, const BufferTouch& expr); +}; + +/*! \brief Represents the known state of buffers at a specific point */ +class BufferState { + public: + /*! Default constructor + * + * Initialize the buffer state with no known information. + */ + BufferState() {} + + /*! \brief Replace BufferLoad instances with known values + * + * \param expr The expression to be updated. + * + * \param axis_var_lookup A map from buffer to the variables + * representing positions along the buffer's axes. + * + * \param analyzer The analyzer to use when validating a + * constraint's predicate. + * + * \returns The modified expression. If no substitutions are made, + * the original expression is returned. + */ + PrimExpr SubstituteKnownBufferValues(PrimExpr expr, + const Map>& axis_var_lookup, + arith::Analyzer* analyzer) const; + + /*! \brief Apply a condition to all known constraints + * + * For example, when propagating pre-loop constraints into the body + * of a loop, add a condition that the loop iterator is zero. + * + * \param condition The condition to apply + */ + void AddCondition(const PrimExpr& condition); + + /*! \brief Perform a variable substitution for all constraints + * + * For example, when propagating constraints from the end of a loop + * to the beginning, replace `i` with `i-1`. + * + * \param var_remap The variable remapping to apply. + */ + void Substitute(const Map& var_remap, arith::Analyzer* analyzer); + + /*! \brief Simplify the predicate of all constraints + * + * \param analyzer The analyzer with which to simplify + */ + void Simplify(arith::Analyzer* analyzer); + + /*! \brief Update the known buffer values based on buffer touches + * + * For any Write or Assume touches, update the known values. For + * any Read touches, ignore. Used to determine known values at the + * end of a control flow block, given the known values at the start. + * + * \param axis_var_lookup A map from buffer to the variables + * representing positions along the buffer's axes. + * + * \param touch_points The buffer touch points to apply + * + * \param analyzer The analyzer to use for simplifications + */ + void ApplyTouches(const Map>& axis_var_lookup, + const std::vector& touch_points, arith::Analyzer* analyzer); + + /*! \brief Update unused buffer locations based on buffer touches + * + * For any Write, mark the written-to indices as unused. (That is, + * immediately prior to assigning `buf[i] = expr`, the value stored + * at `buf[i]` is irrelevant.) For any Read, mark the read-from + * indices as used. This method is used to determine unused buffer + * indices at the start of a control flow block, given the unused + * buffer indices values at the end. + * + * \param axis_var_lookup A map from buffer to the variables + * representing positions along the buffer's axes. + * + * \param touch_points The buffer touch points to apply + * + * \param analyzer The analyzer to use for simplifications + */ + void BackpropUnusedIndices(const Map>& axis_var_lookup, + const std::vector& touch_points, + arith::Analyzer* analyzer); + + /*! \brief Remove free parameters from the constraints + * + * \param free_predicate_parameters + * + * \param analyzer The analyzer with which to simplify after removal + */ + void RemoveFreeParameters(const Map& free_predicate_parameters, + arith::Analyzer* analyzer); + + /*! \brief Check if two buffer states are equivalent + * + * \param other + * + * \param analyzer The analyzer used to check equality of PrimExpr + * + * \return True if the two states are provably equivalent, false otherwise. + */ + bool IsEquivalentTo(const BufferState& other, arith::Analyzer* analyzer) const; + + /* \brief Add known values provided by another state + * + * \param other The state with which to merge constraints + * + * \param analyzer The analyzer with which to simplify the result + */ + void Union(const BufferState& other, arith::Analyzer* analyzer); + + /* \brief Remove all known values not consistent with another state + * + * \param other The state with which to merge constraints + * + * \param analyzer The analyzer with which to simplify the result + */ + void Intersection(const BufferState& other, arith::Analyzer* analyzer); + + friend std::ostream& operator<<(std::ostream& os, const BufferState&); + + private: + friend class ControlFlowGraph; + /*! \brief The known constraints */ + std::vector constraints_; +}; + +/*! \brief Represents the flow of control through a `tir::Stmt` + * + * This class contains an internal representation of the possible + * control flow that may occur during execution of a `tir::Stmt`. It + * consists of a collection of ControlFlowBlock objects, each of which + * represents a subset of operations performed during execution, along + * with edges that represent allowed transitions between + * `ControlFlowBlock`. + * + * In addition, the following restrictions are used. + * + * 1. Each block may have at most two predecessors, and at most two + * successors. + * + * 2. Within each block, values stored in a buffer do not change. + * That is, encountering a `BufferStore` node requires creating a + * new block. + * + * For example, consider the following PrimFunc + * + * ```python + * @T.prim_func + * def func(T.Buffer[16, "float32"]): + * for i in T.serial(16): + * if i < 8: + * B[i] = i + * else: + * B[i] = i-8 + * ``` + * + * The control flow graph would have eight control blocks. + * + * 1. function_entry, from the start of the function through the + * evaluation of the loop's extent. + * + * Predecessors: n/a + * Successors: loop_start + * + * 2. loop_start, after entering the body of the loop, through the + * evaluation of the conditional `i < 8` + * + * Predecessors: function_entry, after_conditional + * Successors: then_clause_start, else_clause_start + * + * 3. then_clause_start, after entering the then_clause of `i < 8`, + * through evaluation of the value `i`. + * + * Predecessors: loop_start + * Successors: then_clause_end + * + * 4. then_clause_end, after storing to `B[i]` prior to exiting the + * then_clause. + * + * Predecessors: then_clause_start + * Successors: after_conditional + * + * 5. else_clause_start, after entering the else_clause of `i < 8`, + * through evaluation of the value `i-8`. + * + * Predecessors: loop_start + * Successors: else_clause_end + * + * 6. else_clause_end, after storing to `B[i]` prior to exiting the + * else_clause. + * + * Predecessors: else_clause_start + * Successors: after_conditional + * + * 7. after_conditional, after the end of the if/then/else, before the + * end of the loop body + * + * Predecessors: then_clause_end, else_clause_end + * Successors: loop_start, after_loop + * + * 8. after_loop, after the loop + * + * Predecessors: after_conditional + * Successors: n/a + * + * + * By identifying `BufferStore` nodes whose value does not depend on + * values stored in input buffers (e.g. initializing `buf[i] = 0.0`), + * or whose values are provided using `builtin::assume()` + * (e.g. `T.assume(buf[i] == 0.0)`), the value stored in a buffer at + * those indices may be known for a given control block. These known + * values can then be propagated forward to successor blocks, to be + * used in context-dependent simplifications. + * + * In addition to the allowed transitions between control-flow + * blocks, each block also tracks the buffer touch points; which + * indices are read from a buffer, which values are written to which + * indices of a buffer, and assumptions are provided using + * `builtin::assume()`; that occur during the control-flow block. + * + * Note: The current implementation only tracks the values of + * buffers that are constrained to a specific value, and does not + * track inequalities that may partially constrain buffer values. + * That is, entering a scoped context with a data-dependent equality + * condition (e.g. `if buf[i] == value`) is tracked, but entering a + * scoped context with a data-dependent inequality condition + * (e.g. `if buf[i] > value`) is not tracked. + */ +class ControlFlowGraph { + public: + /* \brief Extract the touch pattern from a TIR statement + */ + explicit ControlFlowGraph(const Stmt& stmt, size_t max_revisits = 5); + + /* \brief Check if a write is overwritten without impacting final results + * + * \param store The store to be examined + * + * \param context The context in which the buffer store occurs, used + * to identify the control-flow block in which the store occurs. In + * most cases, this will be the same object as the `store` itself. + * + * \param analyzer The analyzer to be used for simplifications + * + * \return True if the specified store can be proven to be + * overwritten without contributing to any later statements. + * Returns false otherwise. + */ + bool IsOverwrittenWithoutEffect(const BufferStore& store, const Stmt& context) const; + + /* \brief Simplify the expression, assuming it occurs within the given context + * + * \param expr The expression to be simplified. Does not need to + * have occurred within the statement used to construct this + * BufferTouchPattern. + * + * \param context The statement where this expression occurred, or + * is to be inserted. Must occur within the statement used to + * construct this BufferTouchPattern. + * + * \param analyzer The analyzer to be used for simplifications + * + * \returns The simplified statement + */ + PrimExpr SimplifyInContext(PrimExpr expr, const Stmt& context, arith::Analyzer* analyzer) const; + + /*! \brief Remove the specified BufferStore from the control-flow + * graph + * + * Removing the specified store, which may reflow known values. + * This is necessary when simplifying sequential stores of the same + * value. Otherwise, the first could be removed as a no-op because + * it is overwritten by the second, and the second could be removed + * as a no-op because it is the same value as the first. + * + * \param store The store to remove + */ + void RemoveStore(const tir::BufferStore& store); + + friend std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern); + + private: + /*! \brief Return index variables representing locations within a + * buffer. + * + * For a given buffer, will always return the same set of variables. + * + * \param buf The buffer being accessed + * + * \param indices The indices at which the buffer is being accessed. + * These are used to set the dtype of the buffer axis variables. + * + * \returns Variables representing a position along the buffer's axis. + */ + Array GetIndexVariables(const Buffer& buf, const Array& indices); + + /*! \brief Return index variables representing locations within a + * buffer, if they have been generated before. + * + * For a given buffer, will always return the same set of variables. + * + * \param buf The buffer being accessed + * + * \returns Variables representing a position along the buffer's axis. + */ + Optional> GetIndexVariables(const Buffer& buf) const; + + /*! \brief Propagate known values from known BufferStore/assume + * subsequent control flow blocks + */ + void ForwardPropagateKnownValues(size_t max_revisits); + + /*! \brief Propagate overwritten/unused indices to preceding control + * flow blocks + */ + void BackwardPropagateUnusedValues(size_t max_revisits); + + struct ControlFlowEdge { + /* \brief The source block of the control flow edge + * + * Lookup index into `control_flow_` + */ + size_t index; + + /*! \brief Variable remaps + * + * e.g. Replacing loop iterator `i` with `i-1` when following an + * edge from the end of a loop to the beginning of the loop. + */ + Map var_remap; + + /*! \brief Condition that must to true after following this edge + * + * This is applied after variable remapping. For example, `i > + * loop_min` when following the an edge from the end of a loop to + * the beginning of the loop. + */ + Optional post_condition; + }; + friend std::ostream& operator<<(std::ostream& os, const ControlFlowEdge& edge); + + struct ControlFlowBlock { + struct LoopEntry { + Var loop_var; + PrimExpr loop_min; + PrimExpr loop_max; + Range loop_range; + }; + + /*! \brief Loop iterators that are active during this block */ + std::vector active_loop_iterators; + + /*! \brief Loop-dependent Let bindings that may appear within the block */ + Map let_bindings_using_loop; + + /*! \brief Predicate that must be true to have reached this block */ + PrimExpr scope_predicate{Bool(true)}; + + /*! \brief All known values prior to executing the block */ + BufferState known_at_block_start; + + /*! \brief All known values after executing the block */ + BufferState known_at_block_end; + + /*! \brief Indices whose value at the start of the block is known to be unused */ + BufferState unused_at_block_start; + + /*! \brief Indices whose value at the end of the block is known to be unused */ + BufferState unused_at_block_end; + + /* \brief Buffer touches that occur within the block + * + * All buffer touches within a block can be treated as occurring + * simultaneously. + */ + std::vector touch_points; + + /* \brief The blocks that occur after this block + * + * Lookup index into `control_flow_` + */ + std::vector successors; + + /* \brief The blocks that occur before this block */ + std::vector predecessors; + + /* \brief Construct a BufferTouch instance within this + * ControlFlowBlock + * + * \param graph The mutable ControlFlowGraph that owns the buffer + * touch. Any free parameters used in the BufferTouch's predicate + * will be tracked by the ControlFlowGraph. + * + * \param buf The Buffer being accessed + * + * \param indices The indices at which the buffer is accessed, in + * terms of the loop variables. + * + * \param touch_type The type of touch being generated + * + * \param known_expr_value The value being written to the buffer + * + * \returns The newly generated BufferTouch + */ + BufferTouch MakeBufferTouch(ControlFlowGraph* graph, const Buffer& buf, + const Array& indices, BufferTouch::AccessType touch_type, + PrimExpr known_value_expr) const; + + /* \brief Construct a BufferTouch instance as if it occurred in + * this ControlFlowBlock + * + * Used when speculative checking if a BufferStore could be + * inserted. + * + * \param buf The Buffer being accessed + * + * \param index_variables The variables representing location + * within a buffer, with one variable for each axis of the buffer. + * + * \param indices The indices at which the buffer is accessed, in + * terms of the loop variables. + * + * \param touch_type The type of touch being generated + * + * \param known_expr_value The value being written to the buffer + * + * \returns The newly generated BufferTouch, and a map specifying + * all free parameters that may occur in the BufferTouch's + * predicate. + */ + std::pair> MakeBufferTouch(const Buffer& buf, + Array index_variables, + Array indices, + BufferTouch::AccessType touch_type, + PrimExpr known_value_expr) const; + }; + friend std::ostream& operator<<(std::ostream& os, const ControlFlowBlock& pattern); + + /* \brief The control flow that occurs within the analyzed statement */ + std::vector control_flow_; + + /* \brief A lookup into control_flow_ + * + * A map to look up the control flow block that contains the + * statement. + */ + std::unordered_map control_flow_lookup_; + + /*! \brief A map from free parameters to their range + * + * A BufferStore/BufferLoad has indices in terms of loop iterators, + * while the internal BufferTouch must have predicate in terms of + * the buffer's axes. While converting to the internal BufferTouch, + * reduction axes show up as free parameters. Tracking the range of + * the free parameters allows them to be removed later, by requiring + * a predicate to be true for all values of the free parameters. + */ + Map free_predicate_parameters_; + + /*! \brief Ranges of iterators found in the analyzed statement */ + Map iterator_ranges_; + + /* \brief A map from buffer to the variables representing positions + * along the buffer's axes. + * + * This is stored here, rather than as part of the BufferState or + * BufferTouch, to ensure that all access of a buffer use the same + * variables to represent the buffer's axes, reducing the amount of + * variable substitution required. + */ + Map> axis_var_lookup_; + + /* \brief Assumptions that do not depend on buffer values + * + * These may be collected as part of the handling of `builtin::assume()`, and do not depend on any + * buffer. Since TIR only allows mutable values as part of buffers, these assumptions may be used + * anywhere the + */ + std::vector non_buffer_assumptions_; + + friend class ControlFlowGraphBuilder; +}; + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_ diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 1dbf9e688027..49d3a9ceaef5 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -29,7 +29,10 @@ #include #include +#include + #include "../../arith/ir_mutator_with_analyzer.h" +#include "../../tir/analysis/control_flow_graph.h" namespace tvm { namespace arith { @@ -38,6 +41,8 @@ using namespace tir; struct SimplifyConfigNode : public tvm::AttrsNode { bool transitively_prove_inequalities; + bool propagate_knowns_to_prove_conditional; + bool propagate_knowns_to_simplify_expressions; bool convert_boolean_to_and_of_ors; bool apply_constraints_to_boolean_branches; @@ -47,6 +52,17 @@ struct SimplifyConfigNode : public tvm::AttrsNode { "If true, simplify conditionals with transitive combinations of scoped constraints") .set_default(false); + TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional) + .describe( + "If true, known buffer values are propagated and used to statically prove conditionals") + .set_default(false); + + TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions) + .describe( + "If true, known buffer values are propagated and used to replace BufferLoad wherever " + "possible") + .set_default(false); + TVM_ATTR_FIELD(convert_boolean_to_and_of_ors) .describe("If true, simplify conditionals into an AND of ORs") .set_default(false); @@ -85,16 +101,46 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: - explicit StmtSimplifier(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} + static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional config_opt = NullOpt) { + auto config = config_opt.value_or(AttrsWithDefaultValues()); + analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); + + std::optional touch_pattern = std::nullopt; + if (config->propagate_knowns_to_prove_conditional || + config->propagate_knowns_to_simplify_expressions) { + touch_pattern = ControlFlowGraph(stmt); + } + StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern)); + return simplifier(std::move(stmt)); + } + + private: + explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config, + std::optional touch_pattern) + : IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {} using Parent = IRMutatorWithAnalyzer; using Parent::VisitStmt; using Parent::VisitStmt_; - PrimExpr VisitExpr(const PrimExpr& expr) final { return analyzer_->Simplify(expr); } + PrimExpr VisitExpr(const PrimExpr& expr) final { + if (config_->propagate_knowns_to_simplify_expressions) { + return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_); + } else { + return analyzer_->Simplify(expr); + } + } Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } + Stmt VisitStmt(const Stmt& stmt) override { + Optional cache = this->current_stmt_; + this->current_stmt_ = stmt; + Stmt output = Parent::VisitStmt(stmt); + this->current_stmt_ = std::move(cache); + return output; + } + Stmt VisitStmt_(const ForNode* op) final { analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); With ctx1(analyzer_, op->loop_var >= op->min); @@ -111,7 +157,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return SideEffect(op->value) <= CallEffectKind::kPure; } - Stmt VisitStmt_(const LetStmtNode* op) { + Stmt VisitStmt_(const LetStmtNode* op) override { PrimExpr value = this->VisitExpr(op->value); if (CanInlineLetStmt(op)) { // it is fine to discard the let binding @@ -134,26 +180,24 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } } - Stmt VisitStmt_(const IfThenElseNode* op) { - PrimExpr cond = analyzer_->Simplify(Substitute(op->condition, non_inlined_bindings_)); - if (const int64_t* as_int = as_const_int(cond)) { - if (*as_int) { + Stmt VisitStmt_(const IfThenElseNode* op) override { + if (Optional cond = ProveCondition(op->condition)) { + if (cond.value()->value) { return this->VisitStmt(op->then_case); } else if (op->else_case) { return this->VisitStmt(op->else_case.value()); } else { return Evaluate(0); } + } else { + return Parent::VisitStmt_(op); } - return Parent::VisitStmt_(op); } - PrimExpr VisitExpr_(const CallNode* op) { + PrimExpr VisitExpr_(const CallNode* op) override { if (op->op.same_as(builtin::if_then_else())) { - PrimExpr cond = this->VisitExpr(op->args[0]); - cond = analyzer_->Simplify(Substitute(std::move(cond), non_inlined_bindings_)); - if (const int64_t* as_int = as_const_int(cond)) { - if (*as_int) { + if (Optional cond = ProveCondition(op->args[0])) { + if (cond.value()->value) { return this->VisitExpr(op->args[1]); } else { return this->VisitExpr(op->args[2]); @@ -196,23 +240,50 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return true; } + /* \brief Internal utility for checking conditionals + * + * Uses more aggressive optimization, such as performing additional + * inlining and tracking known buffer values. + */ + Optional ProveCondition(PrimExpr condition) const { + condition = Substitute(condition, non_inlined_bindings_); + if (config_->propagate_knowns_to_prove_conditional) { + ICHECK(touch_pattern_.has_value()); + condition = touch_pattern_->SimplifyInContext(condition, current_stmt_.value(), analyzer_); + } else { + condition = analyzer_->Simplify(condition); + } + if (const int64_t* as_int = as_const_int(condition)) { + return Bool(*as_int); + } else { + return NullOpt; + } + } + + SimplifyConfig config_; + std::optional touch_pattern_; + Map non_inlined_bindings_; + Optional current_stmt_{NullOpt}; }; } // namespace arith namespace tir { + +Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) { + return arith::StmtSimplifier::Apply(stmt, analyzer); +} + namespace transform { Pass Simplify() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { arith::Analyzer analyzer; - auto cfg = ctx->GetConfig("tir.Simplify") - .value_or(AttrsWithDefaultValues()); - analyzer.rewrite_simplify.SetEnabledExtensions(cfg->GetEnabledExtensions()); + auto cfg = ctx->GetConfig("tir.Simplify"); auto* n = f.CopyOnWrite(); - n->body = arith::StmtSimplifier(&analyzer).Simplify(std::move(n->body)); + n->body = arith::StmtSimplifier::Apply(std::move(n->body), &analyzer, cfg); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 4477e1d9c713..4199cb9a56f7 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -16,7 +16,7 @@ # under the License. import pytest import tvm -from tvm import te +from tvm import te, tir class RewriteChecker: @@ -873,6 +873,65 @@ def test_cmp_simplify(): ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4) + (-2), y)) # End DivMod Rules + # merging flm/fld into known value + ck.verify(tir.all(fld(x, 8) == 3, flm(x, 8) == 4), x == 28) + ck.verify(tir.all(flm(x, 8) == 4, fld(x, 8) == 3), x == 28) + ck.verify(tir.all(fld(x, 8) == -3, flm(x, 8) == 4), x == -20) + ck.verify(tir.all(flm(x, 8) == 4, fld(x, 8) == -3), x == -20) + + # Rewrite based on definition of integer division + ck.verify(tir.all(tvm.runtime.convert(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)) + ck.verify(tir.all(x - y * 5 < 5, tvm.runtime.convert(0) <= x - y * 5), y == fld(x, 5)) + + # Narrow upper bound using floormod + ck.verify(tir.all(x < 20, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)) + ck.verify(tir.all(x < 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)) + ck.verify(tir.all(x <= 19, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)) + ck.verify(tir.all(x <= 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)) + ck.verify(tir.all(x < -20, flm(x, 5) < 2), tir.all(x < -23, flm(x, 5) < 2)) + ck.verify(tir.all(x < 18 - 40, flm(x, 5) < 2), tir.all(x < 17 - 40, flm(x, 5) < 2)) + ck.verify(tir.all(x <= -21, flm(x, 5) < 2), tir.all(x < -23, flm(x, 5) < 2)) + ck.verify(tir.all(x <= -22, flm(x, 5) < 2), tir.all(x < -23, flm(x, 5) < 2)) + # No change if the floormod cannot help narrow the upper bound + ck.verify(tir.all(x < 16, flm(x, 5) < 2), tir.all(x < 16, flm(x, 5) < 2)) + ck.verify(tir.all(x <= 15, flm(x, 5) < 2), tir.all(x <= 15, flm(x, 5) < 2)) + + # Merge a known floordiv and an upper bound of floormod into a value range + ck.verify( + tir.all(fld(x, 10) == 5, flm(x, 10) < 7), + tir.all(tvm.runtime.convert(50) <= x, x < 57), + ) + ck.verify( + tir.all(fld(x, 10) == 5, flm(x, 10) <= 7), + tir.all(tvm.runtime.convert(50) <= x, x <= 57), + ) + ck.verify( + tir.all(fld(x, 10) == -5, flm(x, 10) < 7), + tir.all(tvm.runtime.convert(-50) <= x, x < -43), + ) + ck.verify( + tir.all(fld(x, 10) == -5, flm(x, 10) <= 7), + tir.all(tvm.runtime.convert(-50) <= x, x <= -43), + ) + + # Merge a known floordiv and an lower bound of floormod into a value range + ck.verify( + tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) < flm(x, 10)), + tir.all(tvm.runtime.convert(57) < x, x < 60), + ) + ck.verify( + tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) <= flm(x, 10)), + tir.all(tvm.runtime.convert(57) <= x, x < 60), + ) + ck.verify( + tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) < flm(x, 10)), + tir.all(tvm.runtime.convert(-43) < x, x < -40), + ) + ck.verify( + tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) <= flm(x, 10)), + tir.all(tvm.runtime.convert(-43) <= x, x < -40), + ) + ck.verify(tvm.te.min(x, 11) < 10, x < 10) ck.verify(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool")) ck.verify(tvm.te.max(8, x) > 10, tvm.tir.LT(10, x)) diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 8d9c76c6b20d..fd98b715a4bc 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -140,6 +140,8 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): transitively_prove_inequalities = False convert_boolean_to_and_of_ors = False apply_constraints_to_boolean_branches = False + propagate_knowns_to_prove_conditional = False + propagate_knowns_to_simplify_expressions = False def transform(self): def inner(mod): @@ -148,6 +150,8 @@ def inner(mod): "transitively_prove_inequalities": self.transitively_prove_inequalities, "convert_boolean_to_and_of_ors": self.convert_boolean_to_and_of_ors, "apply_constraints_to_boolean_branches": self.apply_constraints_to_boolean_branches, + "propagate_knowns_to_prove_conditional": self.propagate_knowns_to_prove_conditional, + "propagate_knowns_to_simplify_expressions": self.propagate_knowns_to_simplify_expressions, } } with tvm.transform.PassContext(config=config): @@ -777,7 +781,7 @@ def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): A[0] = (i == 0 or j == 10 or k == 20) and (j == 10 or k != 30 or i == 0) def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): - A[0] = i == 0 or j == 10 or k == 20 + A[0] = j == 10 or k == 20 or i == 0 class TestRewriteAsAndOfOrUsingSimplificationAcrossAnd(BaseBeforeAfter): @@ -794,7 +798,7 @@ def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): A[0] = (k == 20) and ((i == 0 or j == 10) and (k != 30)) def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): - A[0] = (k == 20) and (i == 0 or j == 10) + A[0] = (i == 0 or j == 10) and (k == 20) class TestRewriteAsAndOfOrUsingSimplificationWithinOr(BaseBeforeAfter): @@ -815,7 +819,7 @@ def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): A[0] = (i == 20) or (j == 0) or (i != 30) def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): - A[0] = (i != 30) or (j == 0) + A[0] = (j == 0) or (i != 30) class TestConditionalFloorMod(BaseBeforeAfter): @@ -1049,5 +1053,640 @@ def func(A: T.Buffer[1, "bool"]): return func +class TestProvableConditionWithOffset(BaseBeforeAfter): + """Use scoped-constraint to prove inequalities""" + + transitively_prove_inequalities = False + + def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32): + if i < j: + A[0] = i < j + 1 + + def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32): + if i < j: + A[0] = True + + +class TestAlteredBufferContents(BaseBeforeAfter): + """Propagation of data-dependent conditionals. + + A literal constraint must not be propagated if the values + referenced may change. TIR requires single assignment of + variables, so Var objects may be assumed constant, but BufferLoad + may not. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[(1,), "int32"], n: T.int32): + if A[0] == n: + A[0] = A[0] + 1 + # If the simplifier incorrectly uses the invalidated + # A[0]==n condition required to reach this point, then it + # will incorrectly simplify to the then-case. If the + # simplifier correctly determines that A[0] now contains + # n+1, then it will correctly simplify to the else-case. + if A[0] == n: + A[0] = 5 + else: + A[0] = 10 + + def expected(A: T.Buffer[(1,), "int32"], n: T.int32): + if A[0] == n: + A[0] = A[0] + 1 + A[0] = 10 + + +class TestPossiblyAlteredBufferContents(BaseBeforeAfter): + """No simplification of data-dependent conditionals. + + Like TestAlteredBufferContents, but the `m==0` conditional + prevents the value of `A[0]` from being known at the point of the + inner conditional, either as `A[0] == n` from the outer + conditional or as `A[0] == n+1` from the write statement. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[(1,), "int32"], n: T.int32, m: T.int32): + if A[0] == n: + if m == 0: + A[0] = A[0] + 1 + + if A[0] == n: + A[0] = 5 + else: + A[0] = 10 + + expected = before + + +class TestSimplifyInputAssumption(BaseBeforeAfter): + """A T.assume annotation may be used to simplify""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[1, "int32"], n: T.int32): + T.evaluate(T.assume(n == 0)) + if n == 0: + A[0] = 42 + + def expected(A: T.Buffer[1, "int32"], n: T.int32): + T.evaluate(T.assume(n == 0)) + A[0] = 42 + + +class TestSimplifyInputAssumption(BaseBeforeAfter): + """A T.assume annotation may be used to simplify""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[1, "int32"], n: T.int32): + T.evaluate(T.assume(n == 0)) + if n == 0: + A[0] = 42 + + def expected(A: T.Buffer[1, "int32"], n: T.int32): + T.evaluate(T.assume(n == 0)) + A[0] = 42 + + +class TestNoSimplifyFromScopedInputAssumption(BaseBeforeAfter): + """A T.assume inside a scope may not apply outside that scope""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[1, "int32"], n: T.int32, m: T.int32): + if m == 0: + T.evaluate(T.assume(n == 0)) + + if n == 0: + A[0] = 42 + + expected = before + + +class TestSimplifyConditionalUsingBufferValue(BaseBeforeAfter): + """Simplify a conditional using the known value in the buffer""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[1, "int32"]): + A[0] = 0 + + if A[0] == 0: + A[0] = 42 + + def expected(A: T.Buffer[1, "int32"]): + A[0] = 0 + A[0] = 42 + + +class TestKeepExpressionSimplifyUsingBufferValue(BaseBeforeAfter): + """Do not simplify expressions in general using known values in the buffer + + For now, because this is equivalent to inlining, preventing this + usage from occurring. Known buffer values may be used to prove + conditionals, but should not be used for other simplifications. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[1, "int32"], B: T.Buffer[1, "int32"]): + A[0] = 0 + B[0] = A[0] + + expected = before + + +class TestSimplifyConditionalInLoopUsingBufferValue(BaseBeforeAfter): + """Simplify a conditional using the known value in the buffer + + Like TestSimplifyConditionalUsingBufferValue, but the value used + to simplify is set in a previous loop. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = i + + for j in T.serial(16): + if A[j] == j: + B[j] = 42 + else: + B[j] = 100 + + def expected(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = i + + for j in T.serial(16): + B[j] = 42 + + +class TestSimplifyUsingBufferAssumption(BaseBeforeAfter): + """A T.assume may apply to a buffer's contents""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[1, "int32"]): + T.evaluate(T.assume(A[0] == 0)) + + if A[0] == 0: + A[0] = 42 + + def expected(A: T.Buffer[1, "int32"]): + T.evaluate(T.assume(A[0] == 0)) + A[0] = 42 + + +class TestSimplifyUsingBufferAssumptionInLoop(BaseBeforeAfter): + """An assumption about buffer contents may apply to a range""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + T.evaluate(T.assume(A[i] == i)) + + for i in T.serial(16): + if A[i] < 100: + A[i] = 0 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + T.evaluate(T.assume(A[i] == i)) + + for i in T.serial(16): + A[i] = 0 + + +class TestSimplifyUsingPartiallyKnownBufferConditional(BaseBeforeAfter): + """An assumption about buffer contents may apply to only part of a buffer""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if 14 <= i: + T.evaluate(T.assume(A[i] == 0)) + + for i in T.serial(16): + if 14 <= i: + if A[i] == 0: + A[i] = 42 + + else: + if A[i] == 0: + A[i] = 100 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if 14 <= i: + T.evaluate(T.assume(A[i] == 0)) + + for i in T.serial(16): + if 14 <= i: + A[i] = 42 + + else: + if A[i] == 0: + A[i] = 100 + + +class TestSimplifyUsingPartiallyKnownBufferExpression(BaseBeforeAfter): + """An assumption about buffer contents may apply to only part of a buffer + + Like TestSimplifyUsingPartiallyKnownBufferConditional, but the + conditional is expressed as part of T.assume, instead of in the + control flow. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + T.evaluate(T.assume(i < 14 or A[i] == 0)) + + for i in T.serial(16): + if 14 <= i: + if A[i] == 0: + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + T.evaluate(T.assume(i < 14 or A[i] == 0)) + + for i in T.serial(16): + if 14 <= i: + A[i] = 42 + + +class TestNoSimplificationIfPredicateNotMet(BaseBeforeAfter): + """Assumptions about buffer contents must apply to all cases to be used + + Like TestSimplifyUsingPartialBufferAssumptionInLoop, but the + predicate in the second loop does not match the predicate in the + first loop. Therefore, the `T.assume` refers to a different set + of indices. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if 14 <= i: + T.evaluate(T.assume(A[i] == 0)) + + for i in T.serial(16): + if i < 14: + if A[i] == 0: + A[i] = 42 + + expected = before + + +class TestNoSimplifyUsingInvalidatedScopedConstraint(BaseBeforeAfter): + """A write may not be used for proofs outside its conditional""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + if i == 0: + A[i] = 0 + + if A[i] == 0: + A[i] = 42 + + expected = before + + +class TestNoSimplifyUsingOverwrittenValue(BaseBeforeAfter): + """A write that may have been overwritten may not be treated as known + + The appearance of "A[i] = 5" must prevent the earlier constraint + from being used for simplification. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + T.evaluate(T.assume(A[i] == 0)) + + for i in T.serial(16): + if i == 0: + A[i] = 5 + + if A[i] == 0: + A[i] = 42 + + expected = before + + +class TestNoSimplifyUsingLoopDependentBufferValue(BaseBeforeAfter): + """Do not simplify assuming reads are invariant + + If a buffer's value changes across loop iterations, the buffer's + value before the loop should not be used to simplify conditionals + within the loop. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]): + B[0] = 0 + for i in T.serial(16): + if B[0] < 10: + B[0] = A[i] * 2 + B[0] + else: + B[0] = A[i] + B[0] + + expected = before + + +class TestSimplifyPriorToOverwrittenValue(BaseBeforeAfter): + """A known value may be used until it is overwritten + + Like TestNoSimplifyUsingOverwrittenValue, but the use of the + known `A[i]` value occurs before it is overwritten. + + Like TestNoSimplifyUsingLoopDependentBufferValue, but the loop + iterations are all independent. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + T.evaluate(T.assume(A[i] == 0)) + + for i in T.serial(16): + if A[i] == 0: + A[i] = 17 + + if i == 0: + A[i] = 5 + + if A[i] == 0: + A[i] = 42 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + T.evaluate(T.assume(A[i] == 0)) + + for i in T.serial(16): + A[i] = 17 + + if i == 0: + A[i] = 5 + + if A[i] == 0: + A[i] = 42 + + +class TestSimplifyElementWiseUsingPreLoopBufferValue(BaseBeforeAfter): + """Allow data-Do not simplify assuming reads are invariant + + If an element-wise loop reads and overwrites a buffer value, the + pre-loop buffer value may be used to simplify conditions that + occur prior to the write. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]): + for i in T.serial(16): + B[i] = 0 + + for i in T.serial(16): + if B[i] < 10: + B[i] = A[i] * 2 + B[i] + else: + B[i] = A[i] + B[i] + + def expected(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]): + for i in T.serial(16): + B[i] = 0 + + for i in T.serial(16): + B[i] = A[i] * 2 + B[i] + + +class TestSimplifyNonConditional(BaseBeforeAfter): + """Propagate a known value to later expressions.""" + + propagate_knowns_to_simplify_expressions = True + + def before(A: T.Buffer[1, "int32"]): + A[0] = 0 + A[0] = A[0] + 1 + + def expected(A: T.Buffer[1, "int32"]): + A[0] = 0 + A[0] = 1 + + +class TestSuppressSimplifyNonConditional(BaseBeforeAfter): + """Propagate a known value to later expressions. + + Like TestSimplifyNonConditional, but with data-propagation turned off. + """ + + propagate_knowns_to_simplify_expressions = False + + def before(A: T.Buffer[1, "int32"]): + A[0] = 0 + A[0] = A[0] + 1 + + expected = before + + +class TestSimplifyUsingTransitiveKnownBufferValue(BaseBeforeAfter): + """Propagate known buffer values + + If a known value of a buffer depends on another known value, it + can be tracked backwards through both. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[1, "int32"]): + T.evaluate(T.assume(A[0] == 0)) + + A[0] = A[0] + 1 + A[0] = A[0] + 1 + A[0] = A[0] + 1 + + if A[0] == 3: + A[0] = 42 + + def expected(A: T.Buffer[1, "int32"]): + T.evaluate(T.assume(A[0] == 0)) + + A[0] = A[0] + 1 + A[0] = A[0] + 1 + A[0] = A[0] + 1 + + A[0] = 42 + + +class TestSimplifyRampIndexBroadcastValue(BaseBeforeAfter): + """Simplifications involving buffer loads with ramp indices""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[4, "int32"]): + A[T.ramp(0, 1, 4)] = T.broadcast(0, 4) + + if A[0] == 0: + A[0] = 42 + + if A[1] == 0: + A[1] = 60 + + def expected(A: T.Buffer[4, "int32"]): + A[T.ramp(0, 1, 4)] = T.broadcast(0, 4) + + A[0] = 42 + A[1] = 60 + + +class TestSimplifyRampIndexRampValue(BaseBeforeAfter): + """Simplifications involving buffer loads with ramp indices""" + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[4, "int32"]): + A[T.ramp(0, 1, 4)] = T.ramp(11, 1, 4) + + if A[0] == 11: + A[0] = 42 + + if A[1] == 12: + A[1] = 60 + + def expected(A: T.Buffer[4, "int32"]): + A[T.ramp(0, 1, 4)] = T.ramp(11, 1, 4) + + A[0] = 42 + A[1] = 60 + + +class TestSimplifyUsingPartiallyProvenBufferValueGather(BaseBeforeAfter): + """Propagate known buffer values in part of buffer. + + Even if a constraint can't be solved for all values in an + assignment, it may be provable in part of a buffer. Here, the + known 0 values in the padding of A produces known 0 values in the + padding of B. + """ + + transitively_prove_inequalities = True + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]): + # A has non-zero values only in the range 3 <= i < 17 + for i in T.serial(24): + T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0)) + + # After convoluting with F, B has non-zero values only in the + # range 3 <= i < 19. + for i in T.serial(24): + B[i] = 0 + for f in T.serial(3): + if 0 <= i - f: + B[i] = B[i] + A[i - f] * F[f] + + # Which means that this loop is unnecessary. It would be + # removed entirely in tir.transform.RemoveNoOp, but here we + # want to test that the simplification works as intended. + for i in T.serial(24): + if i < 3 or 19 <= i: + if B[i] != 0: + B[i] = 0 + + def expected(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]): + for i in T.serial(24): + T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0)) + + for i in T.serial(24): + B[i] = 0 + for f in T.serial(3): + if 0 <= i - f: + B[i] = B[i] + A[i - f] * F[f] + + for i in T.serial(24): + if i < 3 or 19 <= i: + T.evaluate(0) + + +class TestSimplifyUsingPartiallyProvenBufferValueScatter(BaseBeforeAfter): + """Propagate known buffer values in part of buffer. + + Like TestSimplifyUsingPartiallyProvenBufferValueGather, but the + compute loop is over the input buffer A, rather than the output + buffer B. + """ + + propagate_knowns_to_prove_conditional = True + + def before(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]): + # A has non-zero values only in the range 3 <= i < 17 + for i in T.serial(24): + T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0)) + + for i in T.serial(24): + B[i] = 0 + + # After convoluting with F, B has non-zero values only in the + # range 3 <= i < 19. + for i in T.serial(24): + for f in T.serial(3): + if i + f >= 0 and i + f < 24: + B[i + f] = B[i + f] + A[i] * F[f] + + # Which means that this loop is unnecessary. It actually gets + # removed in tir.transform.RemoveNoOp, but here we want to + # test that the simplification works as intended. + for i in T.serial(24): + if i < 3 or 19 <= i: + if B[i] != 0: + B[i] = 0 + + def expected(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]): + for i in T.serial(24): + T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0)) + + for i in T.serial(24): + B[i] = 0 + + for i in T.serial(24): + for f in T.serial(3): + if i + f < 24: + B[i + f] = B[i + f] + A[i] * F[f] + + for i in T.serial(24): + if i < 3 or 19 <= i: + T.evaluate(0) + + +class TestSimplifyBufferStore(BaseBeforeAfter): + """Simplification using prior known""" + + propagate_knowns_to_simplify_expressions = True + + def before(A: T.Buffer[1, "int32"]): + A[0] = 5 + A[0] = A[0] + 7 + + def expected(A: T.Buffer[1, "int32"]): + A[0] = 5 + A[0] = 12 + + if __name__ == "__main__": tvm.testing.main()