From 2041d9c5117240d1e5d669378377e688b8ae1bdc Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sat, 25 Dec 2021 12:13:09 +0800 Subject: [PATCH] [TIR] Affine utility support iter lowerbound and diagnostics (#9699) * Enable freevars, iter lowerbound and diagnostics in affine utility * fix lint issues and compare bug * update to use iter shift instead of itermark min for lowerbound * add testcase of fused iters sum with multiple lowerbounds * add more affine check testcases, fix bug for single iter and duplicate constraints on iter * add a newline to comment * forbidden predicate unmatch Co-authored-by: baoxinqi --- include/tvm/arith/iter_affine_map.h | 8 +- src/arith/int_set.cc | 4 +- src/arith/iter_affine_map.cc | 471 ++++++++++++++---- src/tir/schedule/analysis/analysis.cc | 4 +- tests/python/unittest/test_arith_intset.py | 87 ++++ .../unittest/test_arith_iter_affine_map.py | 200 ++++++++ .../unittest/test_tir_schedule_reorder.py | 5 +- .../unittest/test_tir_schedule_rfactor.py | 8 +- .../test_tir_schedule_state_cached_flags.py | 55 ++ 9 files changed, 720 insertions(+), 122 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 6c72cbeafdd4..22b4cd580e18 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -49,6 +49,7 @@ #define TVM_ARITH_ITER_AFFINE_MAP_H_ #include +#include #include #include @@ -275,13 +276,14 @@ class IterSumExpr : public IterMapExpr { * \param predicate The predicate constraints on the input iterators * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. + * \param diag_ctx Diagnostic context. * * \return The detected pattern if a match exists, * otherwise return an empty array. */ Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer); + arith::Analyzer* analyzer, DiagnosticContext diag_ctx); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices * @@ -333,6 +335,7 @@ Map InverseAffineIterMap(const Array& iter_map, * \param predicate The predicate constraints on the input iterators * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. + * \param diag_ctx Diagnostic context. * * \return The result list has length len(bindings) + 1 [0, len(bindings)): The iter map matching result. The inner list is of length 2. @@ -344,7 +347,8 @@ Map InverseAffineIterMap(const Array& iter_map, Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer); + bool require_bijective, arith::Analyzer* analyzer, + DiagnosticContext diag_ctx); } // namespace arith } // namespace tvm diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index e620e3bdcdec..55a1a5a1830e 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -835,9 +835,10 @@ Optional> EstimateRegionLowerBound(const Array& region, for (const Range& range : region) { affine_indices.push_back(range->min); } + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); iter_sum_exprs = DetectIterMap( /*indices=*/affine_indices, /*input_iters=*/var_dom, - /*predicate=*/predicate, /*require_bijective=*/false, analyzer); + /*predicate=*/predicate, /*require_bijective=*/false, analyzer, diag_ctx); } if (iter_sum_exprs.empty()) { return NullOpt; @@ -857,6 +858,7 @@ Optional> EstimateRegionLowerBound(const Array& region, if (!analyzer->CanProve(range->extent >= split->scale)) { return NullOpt; } + const PrimExpr& base = sum_expr->base; // IterSplitExpr: (source // lower_factor) % extent * scale // where `(source // lower_factor) % extent` is within [0, extent - 1] diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 34c35cecdce4..c9d4b1edc3a0 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -160,13 +160,22 @@ class IterMarkSplitCollector { } }; +/*! \brief Record form of IterMark(x, extent) + offset */ +struct IterMarkWithOffset { + IterMark mark; + PrimExpr offset{0}; + IterMarkWithOffset() {} + IterMarkWithOffset(IterMark mark, PrimExpr offset) : mark(mark), offset(offset) {} +}; + /*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; - explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters) - : analyzer_(analyzer) { + explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, + DiagnosticContext diag_ctx) + : analyzer_(analyzer), diag_ctx_(diag_ctx) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; @@ -192,9 +201,10 @@ class IterMapRewriter : public ExprMutator { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } - IterSumExpr RewriteIterConstraint(const PrimExpr& expr, - const PrimExpr& predicate_induced_extent) { - return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_extent); + IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min, + const PrimExpr& predicate_induced_max) { + return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min, + predicate_induced_max); } /*! @@ -224,13 +234,21 @@ class IterMapRewriter : public ExprMutator { // The splits do not overlap with each other. collector.Collect(bindings); for (const IterMark& mark : collector.visited_) { - if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) + if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) { + diag_ctx_.Emit(Diagnostic::Error(mark->source->span) + << "Fail to normalize iter mark splits: " << mark); return false; + } } if (require_bijective) { // all input marks must be visited for (const IterMark& mark : input_marks_) { - if (collector.visited_.count(mark) == 0) return false; + if (collector.visited_.count(mark) == 0) { + diag_ctx_.Emit(Diagnostic::Error(mark->source->span) + << "The mapping is not bijective because input iter mark " << mark + << " is not covered, "); + return false; + } } } return true; @@ -278,7 +296,7 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr(const PrimExpr& input_expr) final { auto expr = ExprMutator::VisitExpr(input_expr); if (expr->IsInstance()) { - ++unresolved_count_; + Fail(Diagnostic::Error(input_expr->span)); } return expr; } @@ -328,6 +346,13 @@ class IterMapRewriter : public ExprMutator { } }; + void Fail(const Diagnostic& diagnostic) { + unresolved_count_++; + if (diag_ctx_.defined()) { + diag_ctx_.Emit(diagnostic); + } + } + // Internal analyzer Analyzer* analyzer_; // Counter to keep track of unresolved cases. @@ -336,8 +361,9 @@ class IterMapRewriter : public ExprMutator { std::unordered_map var_map_; // input iter marks std::vector input_marks_; - // The map for sum that maps flattened form to IterMark with normal form and extent - // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) + // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly + // an extra offset) + // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: j*2 + k < 9 // Then, flattened form = IterSum(IterSplit(i, scale=9), // IterSplit(j, scale=2), @@ -347,11 +373,24 @@ class IterMapRewriter : public ExprMutator { // IterSplit(k, scale=1)), // extent=9) // scale=1)) - std::unordered_map sum_fuse_map_; + // Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) + // predicate: 1 <= j*2 + k < 9 + // Then, flattened form = IterSum(IterSplit(i, scale=8), + // IterSplit(j, scale=2), + // IterSplit(k, scale=1)) + // normal form = IterSum(IterSplit(i, scale=8), + // IterSplit(IterMark(IterSum(IterSplit(j, scale=2), + // IterSplit(k, scale=1), base=-1), + // extent=9-1) + // scale=1), + // base=1) + std::unordered_map sum_fuse_map_; // The map for sum that maps normal form to flattened form std::unordered_map flattened_map_; // The flattened forms of constrained iters std::vector constrained_iters_flattened_; + // Diagnostic context + DiagnosticContext diag_ctx_; /*! * \brief Look for a split in splits that is not used such that its lower_factor is smallest. @@ -407,19 +446,32 @@ class IterMapRewriter : public ExprMutator { } if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective - if (require_bijective) return Array(); + if (require_bijective) { + diag_ctx_.Emit( + Diagnostic::Error(mark->source->span) + << "Do not allow incomplete split in bijective checking, expected_lower_factor=" + << expected_lower_factor); + return Array(); + } // look for the next split skipping this lower factor // For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2] // It is valid to only have [y / 6, y % 2] if bijective is not required // We can skip (y / 2) % 6 j = SearchSkipLowerFactor(splits, used, expected_lower_factor); // split not found - if (j == splits.size()) return Array(); + if (j == splits.size()) { + diag_ctx_.Emit(Diagnostic::Error(mark->source->span) + << "Fail to find split skipping the lower factor in bijective-free " + "checking, expected_lower_factor=" + << expected_lower_factor); + return Array(); + } } used[j] = true; iters.push_back(splits[j]); expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; } + // Case 1. bijective is required. // We check the extent we calculate is consistent with the extent of the mark // Case 2. bijective is not required. @@ -427,42 +479,73 @@ class IterMapRewriter : public ExprMutator { // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) || (!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) { + diag_ctx_.Emit(Diagnostic::Error(mark->source->span) + << "Mark extent of " << mark + << " is not compatible with expected_lower_factor=" << expected_lower_factor); return Array(); } return Array(iters.rbegin(), iters.rend()); } /*! - * \brief Normalize the left hand side of iter constraint(expr < predicate_induced_extent) - * \param expr The left hand side of iter constraint. - * \param predicate_induced_extent Extent from iter constraint. + * \brief Normalize the iter expression with constraint (min <= expr < max) + * \param expr The iter expression. + * \param predicate_induced_min Closed lower bound from iter constraint, maybe undefined. + * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined. * \return The Normalized expression. */ - IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, - const PrimExpr& predicate_induced_extent) { - // We are normalizing the left hand side of iter constraint(iter < predicate_induced_extent) - Optional opt = TryFuseIters(expr); + IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min, + PrimExpr predicate_induced_max) { + // normalize to zero base + PrimExpr base = expr->base; + if (!is_zero(base)) { + expr.CopyOnWrite()->base = 0; + if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base; + if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base; + } + Optional opt = TryFuseIters(expr); + ICHECK(!opt.defined() || opt.value()->args.size() == 1); // scale should be 1 - if (opt.defined() && is_one(opt.value()->scale)) { - IterSumExpr sum = Downcast(opt.value()->source->source); + if (opt.defined() && is_one(opt.value()->args[0]->scale)) { + const IterSplitExpr split = opt.value()->args[0]; + IterSumExpr structured_form = Downcast(split->source->source); // get the flattened form - auto it = flattened_map_.find(sum); + auto it = flattened_map_.find(structured_form); ICHECK(it != flattened_map_.end()); IterSumExpr flattened_form = it->second; - // get the mark + // get the mark and offset of the structured_form auto it_mark = sum_fuse_map_.find(flattened_form); ICHECK(it_mark != sum_fuse_map_.end()); - IterMark mark = it_mark->second; - mark.CopyOnWrite()->extent = min(predicate_induced_extent, mark->extent); - // update the bound of the lhs based on predicate_induced_extent - sum_fuse_map_[flattened_form] = mark; + IterMark mark = it_mark->second.mark; + PrimExpr mark_offset = it_mark->second.offset; + PrimExpr iter_min = mark_offset; + PrimExpr iter_max = iter_min + mark->extent; + if (predicate_induced_min.defined()) { + iter_min = max(predicate_induced_min, iter_min); + } + if (predicate_induced_max.defined()) { + iter_max = min(predicate_induced_max, iter_max); + } + if (!is_zero(iter_min)) { + // structured form's offset should be updated + flattened_map_.erase(structured_form); + structured_form.CopyOnWrite()->base = -iter_min; + mark.CopyOnWrite()->source = structured_form; + flattened_map_[structured_form] = flattened_form; + } + mark.CopyOnWrite()->extent = iter_max - iter_min; + sum_fuse_map_[flattened_form] = {mark, iter_min}; + // we need to note down the flattened form of constrained iterators // to check the validity of constraints, see also CheckConstraints() constrained_iters_flattened_.push_back(flattened_form); - expr.CopyOnWrite()->args = Array({opt.value()}); + expr.CopyOnWrite()->args = Array({split}); + expr.CopyOnWrite()->base = base + iter_min; return expr; } - ++unresolved_count_; + Fail(Diagnostic::Error(expr->span) + << "Fail to normalize " << expr << " with predicate bound [" << predicate_induced_min + << ", " << predicate_induced_max << ")"); return expr; } @@ -473,16 +556,12 @@ class IterMapRewriter : public ExprMutator { */ IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { // We are normalizing a regular iter - if (expr->args.size() <= 1) return expr; - PrimExpr base = expr->base; - expr.CopyOnWrite()->base = make_zero(expr->dtype); - Optional opt = TryFuseIters(expr); - expr.CopyOnWrite()->base = base; + if (expr->args.size() < 1) return expr; + Optional opt = TryFuseIters(expr); if (opt.defined()) { - expr.CopyOnWrite()->args = Array({opt.value()}); - return expr; + return opt.value(); } else { - ++unresolved_count_; + Fail(Diagnostic::Error(expr->span) << "Fail to normalize iter sum with offset: " << expr); return expr; } } @@ -504,17 +583,16 @@ class IterMapRewriter : public ExprMutator { } /*! - * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn - * = (x1*s1 + x2*s2 + ... + xn)*cn - * = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) - * = [IterSplit(IterMark(y), scale=cn)] - * return a corresponding IterSplitExpr if needed. + * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + base + * = (x1*s1 + x2*s2 + ... + xn)*cn + base + * = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) + base + * = [IterSplit(IterMark(y), scale=cn)] + base + * return a corresponding IterSumExpr with extra offset if needed. * Try to normalize IterSum into a fused IterMark * \param expr The input sum. - * \return The split with the fused IterMark if succeed. + * \return The sum with the fused IterMark and extra offset if succeed. */ - Optional TryFuseIters(IterSumExpr expr) { - if (!is_zero(expr->base)) return NullOpt; + Optional TryFuseIters(IterSumExpr expr) { // select the iterators in order std::vector visited(expr->args.size(), false); std::vector flattened_iters, grouped_iters; @@ -530,8 +608,13 @@ class IterMapRewriter : public ExprMutator { } } } - if (!base_scale) return NullOpt; + if (!base_scale) { + diag_ctx_.Emit(Diagnostic::Error(expr->span) + << "Fuse iters failed, can not find a valid base scale"); + return NullOpt; + } // check if it can be remapped into a fused pattern. + PrimExpr expected_extra_base = 0; PrimExpr expected_scale = base_scale.value(); for (size_t i = 0; i < expr->args.size();) { // find j such that expr->args[j] has expected scale @@ -539,7 +622,11 @@ class IterMapRewriter : public ExprMutator { for (; j < expr->args.size(); ++j) { if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break; } - if (j == expr->args.size()) return NullOpt; + if (j == expr->args.size()) { + diag_ctx_.Emit(Diagnostic::Error(expr->span) + << "Fuse iters failed, can not find expected scale " << expected_scale); + return NullOpt; + } // look for the longest constrained iter started from expr->args[j] // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: j*2 + k < 9 @@ -569,15 +656,21 @@ class IterMapRewriter : public ExprMutator { break; } } - if (k == expr->args.size()) return NullOpt; + if (k == expr->args.size()) { + diag_ctx_.Emit(Diagnostic::Error(expr->span) + << "Fuse iters failed, can not find flattened iter match constraint " + << constraint_to_match.value()); + return NullOpt; + } visited[k] = true; flattened_iters.push_back(expr->args[k]); } auto iter = sum_fuse_map_.find(constraint_to_match.value()); ICHECK(iter != sum_fuse_map_.end()); - IterMark iter_matched = iter->second; - grouped_iters.emplace_back(iter_matched, expected_scale); - expected_scale *= iter_matched->extent; + const IterMarkWithOffset& iter_matched = iter->second; + grouped_iters.emplace_back(iter_matched.mark, expected_scale); + expected_extra_base += iter_matched.offset * expected_scale; + expected_scale *= iter_matched.mark->extent; // move forward i += constraint_to_match.value()->args.size(); } else { @@ -594,18 +687,28 @@ class IterMapRewriter : public ExprMutator { IterSumExpr structured_form = expr, flattened_form = expr; flattened_form.CopyOnWrite()->args = Array(flattened_iters.rbegin(), flattened_iters.rend()); + flattened_form.CopyOnWrite()->base = 0; structured_form.CopyOnWrite()->args = Array(grouped_iters.rbegin(), grouped_iters.rend()); + structured_form.CopyOnWrite()->base = 0; auto it = sum_fuse_map_.find(flattened_form); if (it != sum_fuse_map_.end()) { // old iter - return IterSplitExpr(it->second, base_scale.value()); + if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) { + // the extra offset is not consistent with old + diag_ctx_.Emit(Diagnostic::Error(expr->span) + << "Fuse iters failed, the extra offset is not consistent with old"); + return NullOpt; + } + return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())}, + expr->base + expected_extra_base); } else { // new iter, form a new mark IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value())); - sum_fuse_map_[flattened_form] = mark; + sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0); flattened_map_[structured_form] = flattened_form; - return IterSplitExpr(mark, base_scale.value()); + return IterSumExpr({IterSplitExpr(mark, base_scale.value())}, + expr->base + expected_extra_base); } } @@ -667,34 +770,126 @@ class IterMapRewriter : public ExprMutator { struct IterConstraint { // The expr of the iter PrimExpr iter; + // The expr of the lower_bound + PrimExpr lower_bound; // The expr of the upper_bound PrimExpr upper_bound; // The size of the iter, which is the number of nodes size_t expr_size = 0; - IterConstraint(PrimExpr iter, PrimExpr upper_bound, size_t size) - : iter(std::move(iter)), upper_bound(std::move(upper_bound)), expr_size(size) {} + IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size) + : iter(std::move(iter)), + lower_bound(std::move(lower_bound)), + upper_bound(std::move(upper_bound)), + expr_size(size) {} }; /*! * \brief Split the predicate into `(a < b) && (c < d) && ...` * \param pred The predicate to be split. - * \return A list of pairs, each element of which are lhs and rhs of the '<' sign, - * empty if the split failed. + * \return A list of IterConstraint, empty if the split failed. */ -std::vector MatchUpperBoundConstraints(PrimExpr pred) { +std::vector MatchBoundConstraints(PrimExpr pred, + const Map& input_iters) { std::vector result; arith::PVar lhs, rhs, rest; for (;;) { - if ((rest && (lhs < rhs)).Match(pred)) { - result.emplace_back(lhs.Eval(), rhs.Eval(), 0); - pred = rest.Eval(); + // try extract comparisions + bool is_finish = false; + bool is_greater = false; + bool is_equal = false; + if ((rest && (lhs < rhs)).Match(pred) || ((lhs < rhs) && rest).Match(pred)) { + // pass } else if ((lhs < rhs).Match(pred)) { - result.emplace_back(lhs.Eval(), rhs.Eval(), 0); - break; + is_finish = true; + } else if ((rest && (lhs <= rhs)).Match(pred) || ((lhs <= rhs) && rest).Match(pred)) { + is_equal = true; + } else if ((lhs <= rhs).Match(pred)) { + is_equal = true; + is_finish = true; + } else if ((rest && (lhs > rhs)).Match(pred) || ((lhs > rhs) && rest).Match(pred)) { + is_greater = true; + } else if ((lhs > rhs).Match(pred)) { + is_greater = true; + is_finish = true; + } else if ((rest && (lhs >= rhs)).Match(pred) || ((lhs >= rhs) && rest).Match(pred)) { + is_greater = true; + is_equal = true; + } else if ((lhs >= rhs).Match(pred)) { + is_greater = true; + is_equal = true; + is_finish = true; } else { return std::vector(); } + PrimExpr lhs_expr = lhs.Eval(); + PrimExpr rhs_expr = rhs.Eval(); + // we only accept predicate of integers + if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) && + (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) { + return std::vector(); + } + // determine iter and bound, if we can not distinguish them simply, + // try divide (lhs - rhs) into itervar aware and itervar free parts + auto f_use_itervar = [&input_iters](const VarNode* v) { + return input_iters.count(GetRef(v)); + }; + bool bound_at_left; + if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) { + bound_at_left = true; + } else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) { + bound_at_left = false; + } else { + bound_at_left = false; // accumulate bound to rhs + PrimExpr sum_parts = lhs_expr - rhs_expr; + lhs_expr = 0; + rhs_expr = 0; + std::function f_extract = + [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) { + if (const AddNode* add = part.as()) { + f_extract(add->a, sign); + f_extract(add->b, sign); + } else if (const SubNode* sub = part.as()) { + f_extract(sub->a, sign); + f_extract(sub->b, !sign); + } else if (UsesVar(part, f_use_itervar)) { + lhs_expr = sign ? lhs_expr + part : lhs_expr - part; + } else { + rhs_expr = sign ? rhs_expr - part : rhs_expr + part; + } + }; + f_extract(sum_parts, true); + arith::Analyzer analyzer; + lhs_expr = analyzer.Simplify(lhs_expr); + rhs_expr = analyzer.Simplify(rhs_expr); + } + PrimExpr lower_bound, upper_bound, iter; + if (is_greater) { + if (bound_at_left) { + // bound > iter + upper_bound = is_equal ? lhs_expr + 1 : lhs_expr; + iter = rhs_expr; + } else { + // iter > bound + lower_bound = is_equal ? rhs_expr : rhs_expr + 1; + iter = lhs_expr; + } + } else { + if (bound_at_left) { + // bound < iter + lower_bound = is_equal ? lhs_expr : lhs_expr + 1; + iter = rhs_expr; + } else { + // iter < bound + upper_bound = is_equal ? rhs_expr + 1 : rhs_expr; + iter = lhs_expr; + } + } + result.emplace_back(iter, lower_bound, upper_bound, 0); + if (is_finish) { + break; + } + pred = rest.Eval(); } return result; } @@ -711,14 +906,17 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer) { + arith::Analyzer* analyzer, DiagnosticContext diag_ctx) { // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. - if (!IterRangeSanityCheck(input_iters)) return Array(); - std::vector constraints = MatchUpperBoundConstraints(predicate); - if (!is_one(predicate) && constraints.empty()) return Array(); + std::vector constraints = MatchBoundConstraints(predicate, input_iters); + if (!is_one(predicate) && constraints.empty()) { + diag_ctx.Emit(Diagnostic::Error(predicate->span) + << "Fail to collect constraints from iteration predicate: " << predicate); + return Array(); + } // We have to make sure when we visit an iterator, all the constraints related with its successors // in the iter var graph has been visited, where the expression of this iterator will contain the @@ -731,13 +929,17 @@ Array DetectIterMap(const Array& indices, const Map(); } - if (!rewriter.CheckConstraints()) return Array(); + if (!rewriter.CheckConstraints()) { + diag_ctx.Emit(Diagnostic::Error(predicate->span) + << "Illegal iteration constraints: " << predicate); + return Array(); + } // Step0.1: rewrite indices Array results; for (PrimExpr value : indices) { @@ -745,7 +947,10 @@ Array DetectIterMap(const Array& indices, const Map(); } // Step1: IterIndependenceChecker checks if the iterator are independent. - if (!rewriter.CheckMapping(results, require_bijective)) return Array(); + if (!rewriter.CheckMapping(results, require_bijective)) { + diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Iterators are not independent"); + return Array(); + } return results; } @@ -754,7 +959,8 @@ TVM_REGISTER_GLOBAL("arith.DetectIterMap") .set_body_typed([](const Array& indices, const Map& input_iters, const PrimExpr& input_pred, bool is_bijective) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana); + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, diag_ctx); }); PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { @@ -768,7 +974,6 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { if (!IsIndexType(op->dtype)) { return Parent::VisitExpr_(op); } - PrimExpr a = this->DirectMutate(op->a); PrimExpr b = this->DirectMutate(op->b); @@ -858,7 +1063,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance() && b->IsInstance()) { // cannot multiply two iterators, mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Cannot multiply two iterators: " << GetRef(op)); return GetRef(op); } @@ -894,7 +1099,9 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); } else { // mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(orig->span) + << "Can not prove floordiv rhs " << rhs << " divisible by lhs scale " << lhs->scale + << ", lhs=" << lhs); return orig; } } @@ -916,7 +1123,8 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, return std::move(lhs); } else { // mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(orig->span) + << "Can not prove floordiv lhs extent " << lhs->extent << " divisible by rhs " << rhs); return orig; } } @@ -944,16 +1152,24 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (b->IsInstance()) { // cannot divide an iterator, mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Cannot divide an iterator: " << GetRef(op)); return GetRef(op); } if (a->IsInstance()) { IterSumExpr ret = Downcast(a); - if (Optional opt = TryFuseIters(ret)) { - return SplitFloorDivConst(opt.value(), b, GetRef(op)); + if (Optional opt = TryFuseIters(ret)) { + IterSumExpr sum = opt.value(); + if (!is_zero(sum->base)) { + Fail(Diagnostic::Error(op->span) + << "Fuse IterSumExpr " << ret + << " failed, cannot floordiv an IterSumExpr with nonzero base"); + return GetRef(op); + } + ICHECK_EQ(sum->args.size(), 1U); + return SplitFloorDivConst(sum->args[0], b, GetRef(op)); } else { - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Fuse IterSumExpr " << ret << " failed"); return GetRef(op); } } else { @@ -977,7 +1193,8 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, rhs = floordiv(rhs, lhs->scale); } else { // mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(orig->span) << "Can not prove floormod rhs " << rhs + << " divisible by " << lhs->scale << ", lhs=" << lhs); return orig; } } @@ -991,7 +1208,8 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, return std::move(lhs); } else { // mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(orig->span) + << "Can not prove floormod lhs extent " << lhs->extent << " divisible by rhs " << rhs); return orig; } } @@ -1019,16 +1237,23 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (b->IsInstance()) { // cannot mod an iterator, mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Cannot mod an iterator: " << GetRef(op)); return GetRef(op); } if (a->IsInstance()) { IterSumExpr ret = Downcast(a); - if (Optional opt = TryFuseIters(ret)) { - return SplitFloorModConst(opt.value(), b, GetRef(op)); + if (Optional opt = TryFuseIters(ret)) { + IterSumExpr sum = opt.value(); + if (!is_zero(sum->base)) { + Fail(Diagnostic::Error(op->span) + << "Fuse IterSumExpr " << ret + << " failed, cannot floormod an IterSumExpr with nonzero base"); + return GetRef(op); + } + return SplitFloorModConst(sum->args[0], b, GetRef(op)); } else { - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Fail to fuse iters of " << ret); return GetRef(op); } } else { @@ -1039,19 +1264,21 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { } /*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */ -class IterMapToExprNormalizer { +class IterMapToExprNormalizer : public ExprMutator { public: explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {} - PrimExpr Convert(const IterMapExpr& expr) { + PrimExpr Convert(const IterMapExpr& expr) { return VisitExpr(expr); } + + private: + /*! \brief Override VisitExpr for iter expr type processing */ + PrimExpr VisitExpr(const PrimExpr& expr) override { if (const auto* op = expr.as()) { return ConvertIterSplitExpr(GetRef(op)); } else if (const auto* op = expr.as()) { return ConvertIterSumExpr(GetRef(op)); } else { - ICHECK(expr.defined()); - LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey(); - return 0; + return ExprMutator::VisitExpr(expr); } } @@ -1071,7 +1298,7 @@ class IterMapToExprNormalizer { } else if (const auto* op = expr->source->source.as()) { source = ConvertIterSumExpr(GetRef(op)); } else { - LOG(FATAL) << "Unexpected source of IterSplitExpr"; + source = VisitExpr(expr->source->source); } if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) { return source * expr->scale; @@ -1100,8 +1327,9 @@ Array IterMapSimplify(const Array& indices, const Map rewrite = - DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer, diag_ctx); if (rewrite.empty()) { return indices; } @@ -1128,8 +1356,9 @@ Array IterMapSimplify(const Array& indices, const Map& sub_iters) - : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {} + const std::unordered_set& sub_iters, + DiagnosticContext diag_ctx) + : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters), diag_ctx_(diag_ctx) {} size_t unresolved_count() const { return unresolved_count_; } @@ -1190,7 +1419,10 @@ class SubspaceDivider { return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1); } else if (expr->args.size() == 1) { // arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base) - if (!is_one(expr->args[0]->scale)) return Fail(); + if (!is_one(expr->args[0]->scale)) { + return Fail(Diagnostic::Error(expr->span) + << "Expect split scale be 1, got " << expr->args[0]->scale); + } DivisionResult res = DivideIterSplitExpr(expr->args[0]); if (!is_zero(expr->base)) res = AddBase(res, expr->base); return res; @@ -1208,7 +1440,9 @@ class SubspaceDivider { DivisionResult arg_division = DivideIterSplitExpr(arg); IterSplitExpr new_arg; if (arg_division.IsInner()) { - if (!inner) return Fail(); + if (!inner) + return Fail(Diagnostic::Error(expr->span) + << "Current division is inner but outer division exists for previous args"); new_arg = arg_division.GetInnerAsSplit(); inner_args.push_back(new_arg); inner = true; @@ -1217,11 +1451,13 @@ class SubspaceDivider { outer_args.push_back(new_arg); inner = false; } else { - return Fail(); + return Fail(Diagnostic::Error(expr->span) + << "Division of " << arg << " is neither inner nor outer"); } extent *= new_arg->extent; } - if (!scale_is_one) return Fail(); + if (!scale_is_one) + return Fail(Diagnostic::Error(expr->span) << "Expect all iter sum arg's scale be 1"); bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent); const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0); const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base); @@ -1240,7 +1476,8 @@ class SubspaceDivider { inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent); return DivisionResult::Inner(inner_source, mark_extent); } else { - return Fail(); + return Fail(Diagnostic::Error(expr->span) + << "Either inner or outer args should exists if need predicate: " << expr); } } return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent); @@ -1250,8 +1487,11 @@ class SubspaceDivider { PrimExpr GetInnerPreds() const { return inner_preds_; } private: - DivisionResult Fail() { + DivisionResult Fail(const Diagnostic& diagnostic) { unresolved_count_++; + if (diag_ctx_.defined()) { + diag_ctx_.Emit(diagnostic); + } return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); } @@ -1330,7 +1570,10 @@ class SubspaceDivider { if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break; } - if (j == splits.size()) return Fail(); + if (j == splits.size()) + return Fail(Diagnostic::Error(expr->span) + << "Can not find expected lower factor " << expected_lower_factor + << " in splits of " << expr->source); used[j] = true; if (!encountered_boundary) { inner_iters.push_back(splits[j]); @@ -1341,7 +1584,9 @@ class SubspaceDivider { if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent)) encountered_boundary = true; } - if (!encountered_boundary) return Fail(); + if (!encountered_boundary) + return Fail(Diagnostic::Error(expr->span) + << "Can not find inner/outer boundary of " << expr); for (const IterSplitExpr& inner_iter : inner_iters) { IterSplitExpr new_iter = inner_iter; new_iter.CopyOnWrite()->source = inner_mark; @@ -1355,7 +1600,8 @@ class SubspaceDivider { split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent)); } } else { - return Fail(); + return Fail(Diagnostic::Error(expr->span) + << "Source expr to divide is neither var nor IterSumExpr"); } return split_map_.at(expr); } @@ -1371,15 +1617,18 @@ class SubspaceDivider { std::unordered_map split_map_; // predicate of outer space and inner space; PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)}; + // diagnostic context + DiagnosticContext diag_ctx_; }; Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer) { + bool require_bijective, arith::Analyzer* analyzer, + DiagnosticContext diag_ctx) { if (!IterRangeSanityCheck(input_iters)) return Array>(); const Array& maps = - DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer); + DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer, diag_ctx); if (maps.empty()) return {}; std::unordered_set inner_iter_set; @@ -1389,7 +1638,7 @@ Array> SubspaceDivide(const Array& bindings, IterMarkSplitCollector collector; collector.Collect(maps); - SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set); + SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set, diag_ctx); std::vector> results; for (const IterSumExpr& expr : maps) { @@ -1409,7 +1658,9 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide") const Array& sub_iters, const PrimExpr& predicate, bool require_bijective) { arith::Analyzer ana; - return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana, + diag_ctx); }); class InverseAffineIterMapTransformer { diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 6d744a66b498..0a7d57effd0d 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -415,12 +415,14 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va if (loop_var_ranges.empty()) { return true; } + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); Array results = arith::DetectIterMap( /*indices=*/realize->iter_values, /*input_iters=*/loop_var_ranges, /*predicate=*/realize->predicate, /*require_bijective=*/false, - /*analyzer=*/analyzer); + /*analyzer=*/analyzer, + /*diag_ctx*/ diag_ctx); if (results.empty()) { return false; } diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 618fd554c6a7..b40f3c9f56ea 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -16,6 +16,8 @@ # under the License. import tvm from tvm import te +from tvm import tir +from tvm.ir.base import structural_equal class IntSetChecker: @@ -233,6 +235,90 @@ def test_region_lower_bound_negative_scale(): assert int_set_1.max_value.value == 35 +def test_region_lower_bound_for_non_perfect_tile(): + h1 = tvm.tir.Var("h1", "int32") + h2 = tvm.tir.Var("h2", "int32") + h3 = tvm.tir.Var("h3", "int32") + analyzer = tvm.arith.Analyzer() + + def do_test_point_access(point, predicates, var_dom, expect): + regions = tvm.arith.estimate_region_lower_bound( + region=[ + tvm.ir.Range.from_min_extent(min_value=point, extent=1), + ], + var_dom=var_dom, + predicate=tvm.tir.all(*predicates), + ) + if expect is None: # expect a failure + assert regions is None + else: + assert len(regions) == 1 + for binding, expect_min, expect_max in expect: + min_diff = expect_min - regions[0].min_value + assert analyzer.simplify(tir.stmt_functor.substitute(min_diff, binding), 3) == 0 + max_diff = expect_max - regions[0].max_value + assert analyzer.simplify(tir.stmt_functor.substitute(max_diff, binding), 3) == 0 + + # non-uniform tiling, single inner variable + # h3 == 0: region is [1, 9] + # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9] + # h3 > 26: region is [h3 * 8, 223] + do_test_point_access( + point=h3 * 8 + h2, + predicates=[1 <= h3 * 8 + h2, h3 * 8 + h2 < 224], + var_dom={ + h2: tvm.ir.Range(begin=0, end=10), + }, + expect=[ + ( + {}, + tvm.tir.max(h3 * 8, 1), + tvm.tir.max(h3 * 8, 1) + - tvm.tir.max(h3 * 8, 214) + - tvm.tir.max(1 - h3 * 8, 0) + + 223, + ), + ({h3: 0}, 1, 9), + ({h3: 10}, h3 * 8, h3 * 8 + 9), + ({h3: 27}, h3 * 8, 223), + ], + ) + + # non-uniform tiling, two inner variables + do_test_point_access( + point=h3 * 8 + h2 * 5 + h1, + predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224], + var_dom={ + h2: tvm.ir.Range(begin=0, end=2), + h1: tvm.ir.Range(begin=0, end=5), + }, + expect=[ + ( + {}, + tvm.tir.max(h3 * 8, 1), + tvm.tir.max(h3 * 8, 1) + - tvm.tir.max(h3 * 8, 214) + - tvm.tir.max(1 - h3 * 8, 0) + + 223, + ), + ({h3: 0}, 1, 9), + ({h3: 10}, h3 * 8, h3 * 8 + 9), + ({h3: 27}, h3 * 8, 223), + ], + ) + + # should fail on incompatible predicates + do_test_point_access( + point=h3 * 8 + h2 * 5 + h1, + predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224], + var_dom={ + h2: tvm.ir.Range(begin=0, end=2), + h1: tvm.ir.Range(begin=0, end=5), + }, + expect=None, + ) + + def test_union_lower_bound(): neg_inf = tvm.arith.int_set.neg_inf() pos_inf = tvm.arith.int_set.pos_inf() @@ -257,4 +343,5 @@ def test_union_lower_bound(): test_region_lower_bound_split_predicate() test_region_lower_bound_multiple_variables() test_region_lower_bound_negative_scale() + test_region_lower_bound_for_non_perfect_tile() test_union_lower_bound() diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index c307034c04c9..6b3c29592eb6 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -199,8 +199,171 @@ def test_predicate(): x = tvm.tir.Var("x", "int32"), 13 y = tvm.tir.Var("y", "int32"), 10 + # available contraints + # upper bound only res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 128) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 128, 0) + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] <= 127) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 128, 0) + # lower bound only + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 124, 6) + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 124, 6) + + # lower bound + upper bound + res = tvm.arith.detect_iter_map( + [x[0] * 10 + y[0]], + var_dom([x, y]), + tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128), + ) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 122, 6) + res = tvm.arith.detect_iter_map( + [x[0] * 10 + y[0]], + var_dom([x, y]), + tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127), + ) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 122, 6) + + # constraint on one fused iter + i = tvm.tir.Var("i", "int32") + j = tvm.tir.Var("j", "int32") + k = tvm.tir.Var("k", "int32") + res = tvm.arith.detect_iter_map( + [i * 8 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9), + ) + assert_iter_sum_pattern(res[0], 88, 1) + + # constraint on single var + res = tvm.arith.detect_iter_map([i], var_dom([(i, 48)]), tvm.tir.all(i < 10)) + assert_iter_sum_pattern(res[0], 10, 0) + + # iterations are subparts of constraint, invalid, case 1 + res = tvm.arith.detect_iter_map( + [i, j, k], + var_dom([(i, 128), (j, 128), (k, 128)]), + tvm.tir.all(i * 16384 + j * 128 + k < 100), + ) + assert len(res) == 0 + + # iterations are subparts of constraint, invalid, case 2 + res = tvm.arith.detect_iter_map( + [i * 128 + j, k], + var_dom([(i, 128), (j, 128), (k, 128)]), + tvm.tir.all(i * 16384 + j * 128 + k < 100), + ) + assert len(res) == 0 + + # constraint on nested fused iters + res = tvm.arith.detect_iter_map( + [i * 8 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25), + ) + assert_iter_sum_pattern(res[0], 22, 3) + + # duplicate constraint on one fused iter + res = tvm.arith.detect_iter_map( + [i * 6 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), + ) + assert_iter_sum_pattern(res[0], 66, 2) + + # duplicate constraint on nested fused iters + res = tvm.arith.detect_iter_map( + [i * 6 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all( + 1 <= j * 2 + k, + 2 <= j * 2 + k, + j * 2 + k < 8, + j * 2 + k < 9, + 3 <= i * 6 + j * 2 + k, + i * 6 + j * 2 + k < 25, + 1 <= i * 6 + j * 2 + k, + i * 6 + j * 2 + k < 18, + ), + ) + assert_iter_sum_pattern(res[0], 15, 3) + + # constraint on non-disjoint fused iters should fail + res = tvm.arith.detect_iter_map( + [i * 8 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), + ) + assert len(res) == 0 + + # constraint on many disjoint fused iters, case 1 + # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) + # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1) + # i1 * 60 in [60, 240), extent=180 (= scale of i0) + i0 = tvm.tir.Var("i0", "int32") + i1 = tvm.tir.Var("i1", "int32") + i2 = tvm.tir.Var("i2", "int32") + i3 = tvm.tir.Var("i3", "int32") + i4 = tvm.tir.Var("i4", "int32") + i5 = tvm.tir.Var("i5", "int32") + res = tvm.arith.detect_iter_map( + [i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5], + var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]), + tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), + ) + assert_iter_sum_pattern(res[0], 540, 93) + + # constraint on many disjoint fused iters, case 2 + res = tvm.arith.detect_iter_map( + [i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4], + var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]), + tvm.tir.all(3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10), + ) + assert_iter_sum_pattern(res[0], 135, 28) + + # constraint on split iters + res = tvm.arith.detect_iter_map( + [i % 16, i // 16], + var_dom([(i, 1024)]), + tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), + require_bijective=True, + ) + assert_iter_sum_pattern(res[0], 7, 3) + assert_iter_sum_pattern(res[1], 8, 4) + + # constraint on split iters, nested case 1 + res = tvm.arith.detect_iter_map( + [(i * 32 + j) % 16], + var_dom([(i, 5), (j, 32)]), + tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), + ) + assert_iter_sum_pattern(res[0], 7, 3) + + # constraint on split iters, nested case 2 + res = tvm.arith.detect_iter_map( + [(i * 32 + j) % 16], + var_dom([(i, 5), (j, 32)]), + tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + ) + assert len(res) == 0 + res = tvm.arith.detect_iter_map( + [(i * 32 + j - 1) % 16, (i * 32 + j - 1) // 16], + var_dom([(i, 5), (j, 32)]), + tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64), + ) + assert_iter_sum_pattern(res[0], 16, 0) + assert_iter_sum_pattern(res[1], 4, 0) + + # non-standard form of predicate + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0]) assert len(res) == 1 assert_iter_sum_pattern(res[0], 128, 0) @@ -651,6 +814,10 @@ def test_normalize_iter_map_to_expr(): ) tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) + # iter mark wrap a complex expr + split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x[0] * y[0] + 1, 1024), 1, 1024, 1) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x[0] * y[0] + 1) + def test_inverse_affine_iter_map(): analyzer = tvm.arith.Analyzer() @@ -712,6 +879,38 @@ def test_inverse_affine_iter_map(): assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 +def test_free_variables(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + + # illegal iter if z is within dom + res = tvm.arith.detect_iter_map([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) + assert len(res) == 0 + + # iter is valid if z is free, even there are linear forms of z + res = tvm.arith.detect_iter_map( + [z * 19 + y * 3 + x], + var_dom( + [ + (x, 3), + (y, 3), + ] + ), + ) + assert_iter_sum_pattern(res[0], 9, z * 19) + res = tvm.arith.detect_iter_map( + [z * z + y * 3 + x], + var_dom( + [ + (x, 3), + (y, 3), + ] + ), + ) + assert_iter_sum_pattern(res[0], 9, z * z) + + if __name__ == "__main__": test_split() test_trivial() @@ -722,3 +921,4 @@ def test_inverse_affine_iter_map(): test_subspace_division() test_complex() test_inverse_affine_iter_map() + test_free_variables() diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 8267a369cf5d..fd2d82d1ff1f 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -217,9 +217,8 @@ def test_reorder_with_predicate(): sch = tir.Schedule(elementwise_predicate, debug_mask="all") block_b = sch.get_block("B") i, j, k, l = sch.get_loops(block_b) - sch.reorder(l, i) - tvm.ir.assert_structural_equal(elementwise_reordered_with_predicate, sch.mod["main"]) - verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(l, i) def test_reorder_fail_with_multi_appearance_loops(): diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index 35d4f5f4b76a..f5fc5a73d038 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -690,11 +690,9 @@ def test_reduction_rfactor_predicate(): # pylint: disable=invalid-name s = tir.Schedule(rowsum_predicate, debug_mask="all") B = s.get_block("B") _, ko, _ = s.get_loops(B) - rf_block = s.rfactor(ko, 1) - tvm.ir.assert_structural_equal(s.mod["main"], rowsum_predicate_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("B_rf"))) - assert s.get(B).same_as(s.get(s.get_block("B"))) - verify_trace_roundtrip(s, mod=rowsum_predicate) + # TODO: should be a tvm.tir.ScheduleError + with pytest.raises(tvm.TVMError): + rf_block = s.rfactor(ko, 1) def test_reduction_rfactor_with_annotation(): diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index e3bd000c2e70..d86af72fca93 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -314,6 +314,45 @@ def warp_memory_negative(a: T.handle, c: T.handle) -> None: C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 +@T.prim_func +def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [224, 224], dtype="float32") + Y = T.match_buffer(b, [224, 224], dtype="float32") + cache = T.alloc_buffer([224, 224], dtype="float32") + for hh_0, ww_0 in T.grid(28, 28): + for ax0 in T.serial(0, 10): + for ax1 in T.serial(0, 10): + with T.block("cache"): + h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) + w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) + T.where( + 1 <= hh_0 * 8 + ax0 + and hh_0 * 8 + ax0 < 225 + and 1 <= ww_0 * 8 + ax1 + and ww_0 * 8 + ax1 < 225 + ) + cache[h, w] = X[h, w] + for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): + with T.block("compute"): + h = T.axis.spatial(224, hh_0 * 8 + hh_1) + w = T.axis.spatial(224, ww_0 * 8 + ww_1) + kh, kw = T.axis.remap("RR", [khh, kww]) + with T.init(): + Y[h, w] = 0.0 + Y[h, w] = T.max( + Y[h, w], + T.if_then_else( + T.likely(1 <= h + kh, dtype="bool") + and T.likely(h + kh < 225, dtype="bool") + and T.likely(1 <= w + kw, dtype="bool") + and T.likely(w + kw < 225, dtype="bool"), + cache[h + kh - 1, w + kw - 1], + 0.0, + dtype="float32", + ), + ) + + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -702,5 +741,21 @@ def test_warp_memory_negative(): # pylint: enable=protected-access +def test_non_perfect_tiling_cache(): + s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all") + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags( + affine_binding=False, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags( + affine_binding=True, + region_cover=False, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))