From df002d2b3735e7a3e5f1e7435bf2efa5d7fe62e8 Mon Sep 17 00:00:00 2001 From: baoxinqi Date: Fri, 10 Dec 2021 14:18:35 +0800 Subject: [PATCH 1/7] Enable freevars, iter lowerbound and diagnostics in affine utility --- include/tvm/arith/iter_affine_map.h | 21 +- python/tvm/arith/iter_affine_map.py | 7 +- src/arith/int_set.cc | 10 +- src/arith/iter_affine_map.cc | 349 ++++++++++++++---- src/tir/schedule/analysis/analysis.cc | 4 +- tests/python/unittest/test_arith_intset.py | 47 +++ .../unittest/test_arith_iter_affine_map.py | 103 +++++- .../test_tir_schedule_state_cached_flags.py | 49 +++ 8 files changed, 480 insertions(+), 110 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 6c72cbeafdd4..889da78835f9 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -49,6 +49,7 @@ #define TVM_ARITH_ITER_AFFINE_MAP_H_ #include +#include #include #include @@ -82,7 +83,7 @@ class IterMapExpr : public PrimExpr { }; /*! - * \brief Mark the source as an iterator in [0, extent). + * \brief Mark the source as an iterator in [min, extent). * * IterMark is used to mark source expression as a valid * iterator to make future analysis easy. @@ -94,6 +95,10 @@ class IterMarkNode : public Object { * a IterSumExpr or a Var. */ PrimExpr source; + /*! + * \brief The min of the iteration. + */ + PrimExpr min; /*! * \brief The extent of the iteration. */ @@ -102,17 +107,19 @@ class IterMarkNode : public Object { // overrides void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source", &source); + v->Visit("min", &min); v->Visit("extent", &extent); } bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return equal(source, other->source) && equal(extent, other->extent); + return equal(source, other->source) && equal(extent, other->extent) && equal(min, other->min); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce->MarkGraphNode(); hash_reduce(source); + hash_reduce(min); hash_reduce(extent); } @@ -131,9 +138,10 @@ class IterMark : public ObjectRef { /*! * \brief constructor. * \param source The source expression. + * \param min The min of the iterator. * \param extent The extent of the iterator. */ - TVM_DLL IterMark(PrimExpr source, PrimExpr extent); + TVM_DLL IterMark(PrimExpr source, PrimExpr min, PrimExpr extent); TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMarkNode); @@ -281,7 +289,7 @@ class IterSumExpr : public IterMapExpr { */ Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer); + arith::Analyzer* analyzer, DiagnosticContext diag_ctx); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices * @@ -289,6 +297,7 @@ Array DetectIterMap(const Array& indices, const Map InverseAffineIterMap(const Array& iter_map, * \param predicate The predicate constraints on the input iterators * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. + * \param diag_ctx Diagnostic context. * * \return The result list has length len(bindings) + 1 [0, len(bindings)): The iter map matching result. The inner list is of length 2. @@ -344,7 +354,8 @@ Map InverseAffineIterMap(const Array& iter_map, Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer); + bool require_bijective, arith::Analyzer* analyzer, + DiagnosticContext diag_ctx); } // namespace arith } // namespace tvm diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 85513ecae5c4..74f843b8278c 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -34,12 +34,15 @@ class IterMark(Object): source : PrimExpr. The source expression. + min_value : PrimExpr + The min of the iterator. + extent : PrimExpr The extent of the iterator. """ - def __init__(self, source, extent): - self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent) + def __init__(self, source, min_value, extent): + self.__init_handle_by_constructor__(_ffi_api.IterMark, source, min_value, extent) @tvm._ffi.register_object("arith.IterSplitExpr") diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index e620e3bdcdec..ba1563fda137 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -835,12 +835,13 @@ Optional> EstimateRegionLowerBound(const Array& region, for (const Range& range : region) { affine_indices.push_back(range->min); } + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); iter_sum_exprs = DetectIterMap( /*indices=*/affine_indices, /*input_iters=*/var_dom, - /*predicate=*/predicate, /*require_bijective=*/false, analyzer); - } - if (iter_sum_exprs.empty()) { - return NullOpt; + /*predicate=*/predicate, /*require_bijective=*/false, analyzer, diag_ctx); + if (iter_sum_exprs.empty()) { + return NullOpt; + } } ICHECK_EQ(iter_sum_exprs.size(), ndim); Array result; @@ -857,6 +858,7 @@ Optional> EstimateRegionLowerBound(const Array& region, if (!analyzer->CanProve(range->extent >= split->scale)) { return NullOpt; } + const PrimExpr& base = sum_expr->base; // IterSplitExpr: (source // lower_factor) % extent * scale // where `(source // lower_factor) % extent` is within [0, extent - 1] diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 34c35cecdce4..bee5b9ad9530 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -36,23 +36,26 @@ namespace arith { using namespace tir; -IterMark::IterMark(PrimExpr source, PrimExpr extent) { +IterMark::IterMark(PrimExpr source, PrimExpr min, PrimExpr extent) { auto n = make_object(); n->source = std::move(source); + n->min = std::move(min); n->extent = std::move(extent); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) { - return IterMark(source, extent); -}); +TVM_REGISTER_GLOBAL("arith.IterMark") + .set_body_typed([](PrimExpr source, PrimExpr min, PrimExpr extent) { + return IterMark(source, min, extent); + }); TVM_REGISTER_NODE_TYPE(IterMarkNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "IterMark(" << op->source << ", extent=" << op->extent << ")"; + p->stream << "IterMark(" << op->source << ", min=" << op->min << ", extent=" << op->extent + << ")"; }); IterSplitExpr::IterSplitExpr(IterMark source) { @@ -72,6 +75,9 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { n->dtype = source->source->dtype; n->source = std::move(source); n->extent = n->source->extent; + if (is_zero(n->source->min)) { + n->extent = n->extent + n->source->min; + } n->lower_factor = one; n->scale = std::move(scale); data_ = std::move(n); @@ -165,19 +171,20 @@ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; - explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters) - : analyzer_(analyzer) { + explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, + DiagnosticContext diag_ctx) + : analyzer_(analyzer), diag_ctx_(diag_ctx) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; if (is_one(vrng->extent)) { var_map_[var] = IterSumExpr({}, vrng->min); } else if (is_zero(vrng->min)) { - IterMark mark(var, vrng->extent); + IterMark mark(var, 0, vrng->extent); var_map_[var] = IterSplitExpr(mark); input_marks_.push_back(mark); } else { - IterMark mark(var - vrng->min, vrng->extent); + IterMark mark(var - vrng->min, 0, vrng->extent); IterSumExpr sum_expr = ToIterSumExpr(IterSplitExpr(mark)); sum_expr.CopyOnWrite()->base = vrng->min; var_map_[var] = sum_expr; @@ -192,9 +199,10 @@ class IterMapRewriter : public ExprMutator { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } - IterSumExpr RewriteIterConstraint(const PrimExpr& expr, - const PrimExpr& predicate_induced_extent) { - return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_extent); + IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min, + const PrimExpr& predicate_induced_max) { + return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min, + predicate_induced_max); } /*! @@ -224,8 +232,9 @@ class IterMapRewriter : public ExprMutator { // The splits do not overlap with each other. collector.Collect(bindings); for (const IterMark& mark : collector.visited_) { - if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) + if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) { return false; + } } if (require_bijective) { // all input marks must be visited @@ -278,7 +287,7 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr(const PrimExpr& input_expr) final { auto expr = ExprMutator::VisitExpr(input_expr); if (expr->IsInstance()) { - ++unresolved_count_; + Fail(Diagnostic::Error(input_expr->span)); } return expr; } @@ -328,6 +337,13 @@ class IterMapRewriter : public ExprMutator { } }; + void Fail(const Diagnostic& diagnostic) { + unresolved_count_++; + if (diag_ctx_.defined()) { + diag_ctx_.Emit(diagnostic); + } + } + // Internal analyzer Analyzer* analyzer_; // Counter to keep track of unresolved cases. @@ -352,6 +368,8 @@ class IterMapRewriter : public ExprMutator { std::unordered_map flattened_map_; // The flattened forms of constrained iters std::vector constrained_iters_flattened_; + // Diagnostic context + DiagnosticContext diag_ctx_; /*! * \brief Look for a split in splits that is not used such that its lower_factor is smallest. @@ -433,14 +451,21 @@ class IterMapRewriter : public ExprMutator { } /*! - * \brief Normalize the left hand side of iter constraint(expr < predicate_induced_extent) - * \param expr The left hand side of iter constraint. - * \param predicate_induced_extent Extent from iter constraint. + * \brief Normalize the iter expression with constraint (min <= expr < max) + * \param expr The iter expression. + * \param predicate_induced_min Closed lower bound from iter constraint, maybe undefined. + * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined. * \return The Normalized expression. */ - IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, - const PrimExpr& predicate_induced_extent) { - // We are normalizing the left hand side of iter constraint(iter < predicate_induced_extent) + IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min, + PrimExpr predicate_induced_max) { + // remove base temporarily since `TryFuseIters` require zero base iter sum + PrimExpr base = expr->base; + if (!is_zero(base)) { + expr.CopyOnWrite()->base = 0; + if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base; + if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base; + } Optional opt = TryFuseIters(expr); // scale should be 1 if (opt.defined() && is_one(opt.value()->scale)) { @@ -453,16 +478,30 @@ class IterMapRewriter : public ExprMutator { auto it_mark = sum_fuse_map_.find(flattened_form); ICHECK(it_mark != sum_fuse_map_.end()); IterMark mark = it_mark->second; - mark.CopyOnWrite()->extent = min(predicate_induced_extent, mark->extent); + // update iter mark iter range to [0, mark->extent) ^ [pred_min, pred_max) + PrimExpr mark_min = mark->min; + PrimExpr mark_max = mark->min + mark->extent; + if (predicate_induced_min.defined()) { + mark_min = max(predicate_induced_min, mark_min); + } + if (predicate_induced_max.defined()) { + mark_max = min(predicate_induced_max, mark_max); + } + mark.CopyOnWrite()->min = mark_min; + mark.CopyOnWrite()->extent = mark_max - mark_min; + // update the bound of the lhs based on predicate_induced_extent sum_fuse_map_[flattened_form] = mark; // we need to note down the flattened form of constrained iterators // to check the validity of constraints, see also CheckConstraints() constrained_iters_flattened_.push_back(flattened_form); expr.CopyOnWrite()->args = Array({opt.value()}); + expr.CopyOnWrite()->base = base; return expr; } - ++unresolved_count_; + Fail(Diagnostic::Error(expr->span) + << "Fail to normalize " << expr << " with predicate bound [" << predicate_induced_min + << ", " << predicate_induced_max << ")"); return expr; } @@ -482,7 +521,7 @@ class IterMapRewriter : public ExprMutator { expr.CopyOnWrite()->args = Array({opt.value()}); return expr; } else { - ++unresolved_count_; + Fail(Diagnostic::Error(expr->span) << "Fail to normalize iter sum with offset: " << expr); return expr; } } @@ -534,6 +573,8 @@ class IterMapRewriter : public ExprMutator { // check if it can be remapped into a fused pattern. PrimExpr expected_scale = base_scale.value(); for (size_t i = 0; i < expr->args.size();) { + // check arg iter mark starts from zero + if (!is_zero(expr->args[i]->source->min)) return NullOpt; // find j such that expr->args[j] has expected scale size_t j = i == 0 ? base_index : 0; for (; j < expr->args.size(); ++j) { @@ -602,7 +643,7 @@ class IterMapRewriter : public ExprMutator { return IterSplitExpr(it->second, base_scale.value()); } else { // new iter, form a new mark - IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value())); + IterMark mark = IterMark(structured_form, 0, div(expected_scale, base_scale.value())); sum_fuse_map_[flattened_form] = mark; flattened_map_[structured_form] = flattened_form; return IterSplitExpr(mark, base_scale.value()); @@ -667,34 +708,126 @@ class IterMapRewriter : public ExprMutator { struct IterConstraint { // The expr of the iter PrimExpr iter; + // The expr of the lower_bound + PrimExpr lower_bound; // The expr of the upper_bound PrimExpr upper_bound; // The size of the iter, which is the number of nodes size_t expr_size = 0; - IterConstraint(PrimExpr iter, PrimExpr upper_bound, size_t size) - : iter(std::move(iter)), upper_bound(std::move(upper_bound)), expr_size(size) {} + IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size) + : iter(std::move(iter)), + lower_bound(std::move(lower_bound)), + upper_bound(std::move(upper_bound)), + expr_size(size) {} }; /*! * \brief Split the predicate into `(a < b) && (c < d) && ...` * \param pred The predicate to be split. - * \return A list of pairs, each element of which are lhs and rhs of the '<' sign, - * empty if the split failed. + * \return A list of IterConstraint, empty if the split failed. */ -std::vector MatchUpperBoundConstraints(PrimExpr pred) { +std::vector MatchBoundConstraints(PrimExpr pred, + const Map& input_iters) { std::vector result; arith::PVar lhs, rhs, rest; for (;;) { + // try extract comparisions + bool is_finish = false; + bool is_greater = false; + bool is_equal = false; if ((rest && (lhs < rhs)).Match(pred)) { - result.emplace_back(lhs.Eval(), rhs.Eval(), 0); - pred = rest.Eval(); + // pass } else if ((lhs < rhs).Match(pred)) { - result.emplace_back(lhs.Eval(), rhs.Eval(), 0); - break; + is_finish = true; + } else if ((rest && (lhs <= rhs)).Match(pred)) { + is_equal = true; + } else if ((lhs <= rhs).Match(pred)) { + is_equal = true; + is_finish = true; + } else if ((rest && (lhs > rhs)).Match(pred)) { + is_greater = true; + } else if ((lhs > rhs).Match(pred)) { + is_greater = true; + is_finish = true; + } else if ((rest && (lhs >= rhs)).Match(pred)) { + is_greater = true; + is_equal = true; + } else if ((lhs >= rhs).Match(pred)) { + is_greater = true; + is_equal = true; + is_finish = true; } else { return std::vector(); } + PrimExpr lhs_expr = lhs.Eval(); + PrimExpr rhs_expr = rhs.Eval(); + // we only accept predicate of integers + if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) && + (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) { + return std::vector(); + } + // determine iter and bound, if we can not distinguish them simply, + // try divide (lhs - rhs) into itervar aware and itervar free parts + auto f_use_itervar = [&input_iters](const VarNode* v) { + return input_iters.count(GetRef(v)); + }; + bool bound_at_left; + if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) { + bound_at_left = true; + } else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) { + bound_at_left = false; + } else { + bound_at_left = false; // accumulate bound to rhs + PrimExpr sum_parts = lhs_expr - rhs_expr; + lhs_expr = 0; + rhs_expr = 0; + std::function f_extract = + [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) { + if (const AddNode* add = part.as()) { + f_extract(add->a, sign); + f_extract(add->b, sign); + } else if (const SubNode* sub = part.as()) { + f_extract(sub->a, sign); + f_extract(sub->b, !sign); + } else if (UsesVar(part, f_use_itervar)) { + lhs_expr = sign ? lhs_expr + part : lhs_expr - part; + } else { + rhs_expr = sign ? rhs_expr - part : rhs_expr + part; + } + }; + f_extract(sum_parts, true); + arith::Analyzer analyzer; + lhs_expr = analyzer.Simplify(lhs_expr); + rhs_expr = analyzer.Simplify(rhs_expr); + } + PrimExpr lower_bound, upper_bound, iter; + if (is_greater) { + if (bound_at_left) { + // bound > iter + upper_bound = is_equal ? lhs_expr + 1 : lhs_expr; + iter = rhs_expr; + } else { + // iter > bound + lower_bound = is_equal ? rhs_expr : rhs_expr + 1; + iter = lhs_expr; + } + } else { + if (bound_at_left) { + // bound < iter + lower_bound = is_equal ? lhs_expr : lhs_expr + 1; + iter = rhs_expr; + } else { + // iter < bound + upper_bound = is_equal ? rhs_expr + 1 : rhs_expr; + iter = lhs_expr; + } + } + result.emplace_back(iter, lower_bound, upper_bound, 0); + if (is_finish) { + break; + } + pred = rest.Eval(); } return result; } @@ -711,14 +844,17 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer) { + arith::Analyzer* analyzer, DiagnosticContext diag_ctx) { // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. - if (!IterRangeSanityCheck(input_iters)) return Array(); - std::vector constraints = MatchUpperBoundConstraints(predicate); - if (!is_one(predicate) && constraints.empty()) return Array(); + std::vector constraints = MatchBoundConstraints(predicate, input_iters); + if (!is_one(predicate) && constraints.empty()) { + diag_ctx.Emit(Diagnostic::Error(predicate->span) + << "Fail to collect constraints from iteration predicate: " << predicate); + return Array(); + } // We have to make sure when we visit an iterator, all the constraints related with its successors // in the iter var graph has been visited, where the expression of this iterator will contain the @@ -731,13 +867,17 @@ Array DetectIterMap(const Array& indices, const Map(); } - if (!rewriter.CheckConstraints()) return Array(); + if (!rewriter.CheckConstraints()) { + diag_ctx.Emit(Diagnostic::Error(predicate->span) + << "Illegal iteration constraints: " << predicate); + return Array(); + } // Step0.1: rewrite indices Array results; for (PrimExpr value : indices) { @@ -745,7 +885,10 @@ Array DetectIterMap(const Array& indices, const Map(); } // Step1: IterIndependenceChecker checks if the iterator are independent. - if (!rewriter.CheckMapping(results, require_bijective)) return Array(); + if (!rewriter.CheckMapping(results, require_bijective)) { + diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Iterators are not independent"); + return Array(); + } return results; } @@ -754,7 +897,8 @@ TVM_REGISTER_GLOBAL("arith.DetectIterMap") .set_body_typed([](const Array& indices, const Map& input_iters, const PrimExpr& input_pred, bool is_bijective) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana); + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, diag_ctx); }); PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { @@ -769,8 +913,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { return Parent::VisitExpr_(op); } - PrimExpr a = this->DirectMutate(op->a); - PrimExpr b = this->DirectMutate(op->b); + // skip analysis of irrelated sum parts + auto f_use_itervar = [this](const VarNode* v) { return var_map_.count(GetRef(v)); }; + PrimExpr a = op->a; + PrimExpr b = op->b; + if (!is_const_int(a) && UsesVar(a, f_use_itervar)) { + a = this->DirectMutate(op->a); + } + if (!is_const_int(b) && UsesVar(b, f_use_itervar)) { + b = this->DirectMutate(op->b); + } // const folding PrimExpr const_res = TryConstFold(a, b); @@ -858,7 +1010,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance() && b->IsInstance()) { // cannot multiply two iterators, mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Cannot multiply two iterators: " << GetRef(op)); return GetRef(op); } @@ -894,7 +1046,9 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); } else { // mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(orig->span) + << "Can not prove floordiv rhs " << rhs << " divisible by lhs scale " << lhs->scale + << ", lhs=" << lhs); return orig; } } @@ -916,7 +1070,8 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, return std::move(lhs); } else { // mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(orig->span) + << "Can not prove floordiv lhs extent " << lhs->extent << " divisible by rhs " << rhs); return orig; } } @@ -944,7 +1099,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (b->IsInstance()) { // cannot divide an iterator, mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Cannot divide an iterator: " << GetRef(op)); return GetRef(op); } @@ -953,7 +1108,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (Optional opt = TryFuseIters(ret)) { return SplitFloorDivConst(opt.value(), b, GetRef(op)); } else { - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Fuse IterSumExpr " << ret << " failed"); return GetRef(op); } } else { @@ -977,7 +1132,8 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, rhs = floordiv(rhs, lhs->scale); } else { // mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(orig->span) << "Can not prove floormod rhs " << rhs + << " divisible by " << lhs->scale << ", lhs=" << lhs); return orig; } } @@ -991,7 +1147,8 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, return std::move(lhs); } else { // mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(orig->span) + << "Can not prove floormod lhs extent " << lhs->extent << " divisible by rhs " << rhs); return orig; } } @@ -1019,7 +1176,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (b->IsInstance()) { // cannot mod an iterator, mark as unresolved. - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Cannot mod an iterator: " << GetRef(op)); return GetRef(op); } @@ -1028,7 +1185,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (Optional opt = TryFuseIters(ret)) { return SplitFloorModConst(opt.value(), b, GetRef(op)); } else { - ++unresolved_count_; + Fail(Diagnostic::Error(op->span) << "Fail to fuse iters of " << ret); return GetRef(op); } } else { @@ -1039,19 +1196,21 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { } /*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */ -class IterMapToExprNormalizer { +class IterMapToExprNormalizer : public ExprMutator { public: explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {} - PrimExpr Convert(const IterMapExpr& expr) { + PrimExpr Convert(const IterMapExpr& expr) { return VisitExpr(expr); } + + private: + /*! \brief Override VisitExpr for iter expr type processing */ + PrimExpr VisitExpr(const PrimExpr& expr) override { if (const auto* op = expr.as()) { return ConvertIterSplitExpr(GetRef(op)); } else if (const auto* op = expr.as()) { return ConvertIterSumExpr(GetRef(op)); } else { - ICHECK(expr.defined()); - LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey(); - return 0; + return ExprMutator::VisitExpr(expr); } } @@ -1071,7 +1230,7 @@ class IterMapToExprNormalizer { } else if (const auto* op = expr->source->source.as()) { source = ConvertIterSumExpr(GetRef(op)); } else { - LOG(FATAL) << "Unexpected source of IterSplitExpr"; + source = VisitExpr(expr->source->source); } if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) { return source * expr->scale; @@ -1100,8 +1259,9 @@ Array IterMapSimplify(const Array& indices, const Map rewrite = - DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer, diag_ctx); if (rewrite.empty()) { return indices; } @@ -1128,8 +1288,9 @@ Array IterMapSimplify(const Array& indices, const Map& sub_iters) - : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {} + const std::unordered_set& sub_iters, + DiagnosticContext diag_ctx) + : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters), diag_ctx_(diag_ctx) {} size_t unresolved_count() const { return unresolved_count_; } @@ -1175,7 +1336,7 @@ class SubspaceDivider { if (const auto* op = expr.as()) { return GetRef(op); } else if (const auto* op = expr.as()) { - return IterSplitExpr(IterMark(GetRef(op), extent)); + return IterSplitExpr(IterMark(GetRef(op), 0, extent)); } else { LOG(FATAL) << "Unknown IterMapExpr type"; return NullValue(); @@ -1190,7 +1351,10 @@ class SubspaceDivider { return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1); } else if (expr->args.size() == 1) { // arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base) - if (!is_one(expr->args[0]->scale)) return Fail(); + if (!is_one(expr->args[0]->scale)) { + return Fail(Diagnostic::Error(expr->span) + << "Expect split scale be 1, got " << expr->args[0]->scale); + } DivisionResult res = DivideIterSplitExpr(expr->args[0]); if (!is_zero(expr->base)) res = AddBase(res, expr->base); return res; @@ -1208,7 +1372,9 @@ class SubspaceDivider { DivisionResult arg_division = DivideIterSplitExpr(arg); IterSplitExpr new_arg; if (arg_division.IsInner()) { - if (!inner) return Fail(); + if (!inner) + return Fail(Diagnostic::Error(expr->span) + << "Current division is inner but outer division exists for previous args"); new_arg = arg_division.GetInnerAsSplit(); inner_args.push_back(new_arg); inner = true; @@ -1217,11 +1383,13 @@ class SubspaceDivider { outer_args.push_back(new_arg); inner = false; } else { - return Fail(); + return Fail(Diagnostic::Error(expr->span) + << "Division of " << arg << " is neither inner nor outer"); } extent *= new_arg->extent; } - if (!scale_is_one) return Fail(); + if (!scale_is_one) + return Fail(Diagnostic::Error(expr->span) << "Expect all iter sum arg's scale be 1"); bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent); const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0); const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base); @@ -1240,7 +1408,8 @@ class SubspaceDivider { inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent); return DivisionResult::Inner(inner_source, mark_extent); } else { - return Fail(); + return Fail(Diagnostic::Error(expr->span) + << "Either inner or outer args should exists if need predicate: " << expr); } } return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent); @@ -1250,8 +1419,11 @@ class SubspaceDivider { PrimExpr GetInnerPreds() const { return inner_preds_; } private: - DivisionResult Fail() { + DivisionResult Fail(const Diagnostic& diagnostic) { unresolved_count_++; + if (diag_ctx_.defined()) { + diag_ctx_.Emit(diagnostic); + } return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); } @@ -1276,7 +1448,7 @@ class SubspaceDivider { extent *= arg->extent; res.push_back(arg); } - return IterMark(IterSumExpr(Array(res.rbegin(), res.rend()), base), extent); + return IterMark(IterSumExpr(Array(res.rbegin(), res.rend()), base), 0, extent); } DivisionResult DivideIterSplitExpr(const IterSplitExpr& expr) { @@ -1317,8 +1489,10 @@ class SubspaceDivider { if (splits.size() == 1) { return mark_division; } - IterMark outer_mark(Downcast(mark_division.outer), mark_division.outer_extent); - IterMark inner_mark(Downcast(mark_division.inner), mark_division.inner_extent); + IterMark outer_mark(Downcast(mark_division.outer), 0, + mark_division.outer_extent); + IterMark inner_mark(Downcast(mark_division.inner), 0, + mark_division.inner_extent); bool encountered_boundary = mark_division.IsOuter(); std::vector used(splits.size(), false); std::vector inner_iters, outer_iters; @@ -1330,7 +1504,10 @@ class SubspaceDivider { if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break; } - if (j == splits.size()) return Fail(); + if (j == splits.size()) + return Fail(Diagnostic::Error(expr->span) + << "Can not find expected lower factor " << expected_lower_factor + << " in splits of " << expr->source); used[j] = true; if (!encountered_boundary) { inner_iters.push_back(splits[j]); @@ -1341,7 +1518,9 @@ class SubspaceDivider { if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent)) encountered_boundary = true; } - if (!encountered_boundary) return Fail(); + if (!encountered_boundary) + return Fail(Diagnostic::Error(expr->span) + << "Can not find inner/outer boundary of " << expr); for (const IterSplitExpr& inner_iter : inner_iters) { IterSplitExpr new_iter = inner_iter; new_iter.CopyOnWrite()->source = inner_mark; @@ -1355,7 +1534,8 @@ class SubspaceDivider { split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent)); } } else { - return Fail(); + return Fail(Diagnostic::Error(expr->span) + << "Source expr to divide is neither var nor IterSumExpr"); } return split_map_.at(expr); } @@ -1371,15 +1551,18 @@ class SubspaceDivider { std::unordered_map split_map_; // predicate of outer space and inner space; PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)}; + // diagnostic context + DiagnosticContext diag_ctx_; }; Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer) { + bool require_bijective, arith::Analyzer* analyzer, + DiagnosticContext diag_ctx) { if (!IterRangeSanityCheck(input_iters)) return Array>(); const Array& maps = - DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer); + DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer, diag_ctx); if (maps.empty()) return {}; std::unordered_set inner_iter_set; @@ -1389,18 +1572,18 @@ Array> SubspaceDivide(const Array& bindings, IterMarkSplitCollector collector; collector.Collect(maps); - SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set); + SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set, diag_ctx); std::vector> results; for (const IterSumExpr& expr : maps) { SubspaceDivider::DivisionResult res = subspace_divider.DivideIterSumExpr(expr, 0); if (subspace_divider.unresolved_count()) return {}; results.push_back( - {IterMark(res.outer, res.outer_extent), IterMark(res.inner, res.inner_extent)}); + {IterMark(res.outer, 0, res.outer_extent), IterMark(res.inner, 0, res.inner_extent)}); } - results.push_back({IterMark(IterSumExpr({}, 0), subspace_divider.GetOuterPreds()), - IterMark(IterSumExpr({}, 0), subspace_divider.GetInnerPreds())}); + results.push_back({IterMark(IterSumExpr({}, 0), 0, subspace_divider.GetOuterPreds()), + IterMark(IterSumExpr({}, 0), 0, subspace_divider.GetInnerPreds())}); return results; } @@ -1409,7 +1592,9 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide") const Array& sub_iters, const PrimExpr& predicate, bool require_bijective) { arith::Analyzer ana; - return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana, + diag_ctx); }); class InverseAffineIterMapTransformer { diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 6d744a66b498..0a7d57effd0d 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -415,12 +415,14 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va if (loop_var_ranges.empty()) { return true; } + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); Array results = arith::DetectIterMap( /*indices=*/realize->iter_values, /*input_iters=*/loop_var_ranges, /*predicate=*/realize->predicate, /*require_bijective=*/false, - /*analyzer=*/analyzer); + /*analyzer=*/analyzer, + /*diag_ctx*/ diag_ctx); if (results.empty()) { return false; } diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 618fd554c6a7..fa97b145e2d6 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.ir.base import structural_equal class IntSetChecker: @@ -233,6 +234,51 @@ def test_region_lower_bound_negative_scale(): assert int_set_1.max_value.value == 35 +def test_region_lower_bound_for_non_perfect_tile(): + h1 = tvm.tir.Var("h1", "int32") + h2 = tvm.tir.Var("h2", "int32") + h3 = tvm.tir.Var("h3", "int32") + # h1, h2 are bounded, h3 is free + var_dom = { + h2: tvm.ir.Range(begin=0, end=2), + h1: tvm.ir.Range(begin=0, end=5), + } + + def do_test_point_access(point, predicates, expect): + regions = tvm.arith.estimate_region_lower_bound( + region=[ + tvm.ir.Range.from_min_extent(min_value=point, extent=1), + ], + var_dom=var_dom, + predicate=tvm.tir.all(*predicates), + ) + if expect is None: # expect a failure + assert regions is None + else: + assert len(regions) == 1 + assert structural_equal(expect[0], regions[0].min_value) + assert structural_equal(expect[1], regions[0].max_value) + + # normal case of a non-uniform tiling + # h3 == 0: region is [1, 9] + # 0 < h3 <= 8: region is [h3 * 8, h3 * 8 + 9] + # h3 > 8: region is [h3 * 8, 223] + do_test_point_access( + point=h3 * 8 + h2 * 5 + h1, + predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224], + expect=( + tvm.tir.max(h3 * 8, 1), + tvm.tir.max(h3 * 8, 1) - tvm.tir.max(h3 * 8, 214) - tvm.tir.max(1 + h3 * -8, 0) + 223, + ), + ) + # shoud fail on incompatible predicates + do_test_point_access( + point=h3 * 8 + h2 * 5 + h1, + predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224], + expect=None, + ) + + def test_union_lower_bound(): neg_inf = tvm.arith.int_set.neg_inf() pos_inf = tvm.arith.int_set.pos_inf() @@ -257,4 +303,5 @@ def test_union_lower_bound(): test_region_lower_bound_split_predicate() test_region_lower_bound_multiple_variables() test_region_lower_bound_negative_scale() + test_region_lower_bound_for_non_perfect_tile() test_union_lower_bound() diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index c307034c04c9..9277d935e5a1 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -44,7 +44,7 @@ def var_dom(iters): return {var: tvm.ir.Range(0, ext) for var, ext in iters} -def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): +def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, min=0): """Check the sum expr have the right pattern.""" assert isinstance(sum_expr, tvm.arith.IterSumExpr) if extent == 1: @@ -53,6 +53,7 @@ def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): assert len(sum_expr.args) == 1 tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.min, min) tvm.testing.assert_prim_expr_equal(sum_expr.base, base) @@ -178,8 +179,8 @@ def test_compound(): assert_iter_sum_pattern(res[0], 18, 0) assert_iter_sum_pattern(res[1], 5, 0) # reconstruct the pattern manually - mx = tvm.arith.IterMark(x[0], 10) - my = tvm.arith.IterMark(y[0], 9) + mx = tvm.arith.IterMark(x[0], 0, 10) + my = tvm.arith.IterMark(y[0], 0, 9) xoscale = 3 xiscale = 1 @@ -190,7 +191,7 @@ def test_compound(): myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale) myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale) - mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18) + mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 0, 18) sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0) tvm.ir.assert_structural_equal(sz, res[0]) @@ -199,8 +200,41 @@ def test_predicate(): x = tvm.tir.Var("x", "int32"), 13 y = tvm.tir.Var("y", "int32"), 10 + # available contraints + # upper bound only res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 128) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 128, 0) + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] <= 127) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 128, 0) + # lower bound only + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 124, 0, min=6) + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 124, 0, min=6) + + # lower bound + upper bound + res = tvm.arith.detect_iter_map( + [x[0] * 10 + y[0]], + var_dom([x, y]), + tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128), + ) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 122, 0, min=6) + res = tvm.arith.detect_iter_map( + [x[0] * 10 + y[0]], + var_dom([x, y]), + tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127), + ) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 122, 0, min=6) + + # non-standard form of predicate + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0]) assert len(res) == 1 assert_iter_sum_pattern(res[0], 128, 0) @@ -541,12 +575,12 @@ def test_complex(): ) assert len(res) == 2 - n0_mark = tvm.arith.IterMark(n0[0], n0[1]) - n1_mark = tvm.arith.IterMark(n1[0], n1[1]) - l0_mark = tvm.arith.IterMark(l0[0], l0[1]) - l1_mark = tvm.arith.IterMark(l1[0], l1[1]) - m1_mark = tvm.arith.IterMark(m1[0], m1[1]) - l3_mark = tvm.arith.IterMark(l3[0], l3[1]) + n0_mark = tvm.arith.IterMark(n0[0], 0, n0[1]) + n1_mark = tvm.arith.IterMark(n1[0], 0, n1[1]) + l0_mark = tvm.arith.IterMark(l0[0], 0, l0[1]) + l1_mark = tvm.arith.IterMark(l1[0], 0, l1[1]) + m1_mark = tvm.arith.IterMark(m1[0], 0, m1[1]) + l3_mark = tvm.arith.IterMark(l3[0], 0, l3[1]) m0_expr = tvm.arith.IterSumExpr( [ @@ -555,12 +589,12 @@ def test_complex(): ], 0, ) - m0_mark = tvm.arith.IterMark(m0_expr, 6) + m0_mark = tvm.arith.IterMark(m0_expr, 0, 6) l2_expr = tvm.arith.IterSumExpr( [tvm.arith.IterSplitExpr(m0_mark, 1, 6, 3), tvm.arith.IterSplitExpr(m1_mark, 1, m1[1], 1)], 0, ) - l2_mark = tvm.arith.IterMark(l2_expr, 16) + l2_mark = tvm.arith.IterMark(l2_expr, 0, 16) k0_expr = tvm.arith.IterSplitExpr(l0_mark, 2, 2, 4) k1_expr = tvm.arith.IterSplitExpr(l1_mark, 2, 4, 1) k2_expr = tvm.arith.IterSplitExpr(l2_mark, 4, 4, 8) @@ -571,19 +605,19 @@ def test_complex(): k7_expr = tvm.arith.IterSplitExpr(l3_mark, 1, 4, 1) j0_expr = tvm.arith.IterSumExpr([k0_expr, k1_expr], 0) - j0_mark = tvm.arith.IterMark(j0_expr, 7) + j0_mark = tvm.arith.IterMark(j0_expr, 0, 7) i0_expr = tvm.arith.IterSumExpr( [tvm.arith.IterSplitExpr(j0_mark, 1, 7, 32), k2_expr, k3_expr], 0 ) j3_expr = tvm.arith.IterSumExpr([k6_expr, k7_expr], 0) - j3_mark = tvm.arith.IterMark(j3_expr, 15) + j3_mark = tvm.arith.IterMark(j3_expr, 0, 15) i1_expr = tvm.arith.IterSumExpr( [k4_expr, k5_expr, tvm.arith.IterSplitExpr(j3_mark, 1, 15, 1)], 0 ) - i0_mark = tvm.arith.IterMark(i0_expr, i0[1]) - i1_mark = tvm.arith.IterMark(i1_expr, i1[1]) + i0_mark = tvm.arith.IterMark(i0_expr, 0, i0[1]) + i1_mark = tvm.arith.IterMark(i1_expr, 0, i1[1]) i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0) i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0) @@ -651,6 +685,10 @@ def test_normalize_iter_map_to_expr(): ) tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) + # iter mark wrap a complex expr + split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x[0] * y[0] + 1, 0, 1024), 1, 1024, 1) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x[0] * y[0] + 1) + def test_inverse_affine_iter_map(): analyzer = tvm.arith.Analyzer() @@ -712,6 +750,38 @@ def test_inverse_affine_iter_map(): assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 +def test_free_variables(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + + # illegal iter if z is within dom + res = tvm.arith.detect_iter_map([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) + assert len(res) == 0 + + # iter is valid if z is free, even there are linear forms of z + res = tvm.arith.detect_iter_map( + [z * 19 + y * 3 + x], + var_dom( + [ + (x, 3), + (y, 3), + ] + ), + ) + assert_iter_sum_pattern(res[0], 9, z * 19) + res = tvm.arith.detect_iter_map( + [z * z + y * 3 + x], + var_dom( + [ + (x, 3), + (y, 3), + ] + ), + ) + assert_iter_sum_pattern(res[0], 9, z * z) + + if __name__ == "__main__": test_split() test_trivial() @@ -722,3 +792,4 @@ def test_inverse_affine_iter_map(): test_subspace_division() test_complex() test_inverse_affine_iter_map() + test_free_variables() diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index e3bd000c2e70..3a1dcee186fd 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -314,6 +314,39 @@ def warp_memory_negative(a: T.handle, c: T.handle) -> None: C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 +@T.prim_func +def non_uniform_tiling_cache(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [224, 224], dtype="float32") + Y = T.match_buffer(b, [224, 224], dtype="float32") + cache = T.alloc_buffer([224, 224], dtype="float32") + for hh_0, ww_0 in T.grid(28, 28): + for ax0 in T.serial(0, T.min(hh_0 * 8 + 8, 223) + 1 - T.max(hh_0 * 8 - 1, 0)): + for ax1 in T.serial(0, T.min(ww_0 * 8 + 8, 223) + 1 - T.max(ww_0 * 8 - 1, 0)): + with T.block("cache"): + h = T.axis.spatial(224, T.max(hh_0 * 8 - 1, 0) + ax0) + w = T.axis.spatial(224, T.max(ww_0 * 8 - 1, 0) + ax1) + cache[h, w] = X[h, w] + for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): + with T.block("compute"): + h = T.axis.spatial(224, hh_0 * 8 + hh_1) + w = T.axis.spatial(224, ww_0 * 8 + ww_1) + kh, kw = T.axis.remap("RR", [khh, kww]) + with T.init(): + Y[h, w] = 0.0 + Y[h, w] = T.max( + Y[h, w], + T.if_then_else( + T.likely(1 <= h + kh, dtype="bool") + and T.likely(h + kh < 225, dtype="bool") + and T.likely(1 <= w + kw, dtype="bool") + and T.likely(w + kw < 225, dtype="bool"), + cache[h + kh - 1, w + kw - 1], + 0.0, + dtype="float32", + ), + ) + + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -702,5 +735,21 @@ def test_warp_memory_negative(): # pylint: enable=protected-access +def test_non_uniform_tiling(): + s = tir.ScheduleState(non_uniform_tiling_cache, debug_mask="all") + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags( + affine_binding=False, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From d86997dcd6aa113cebc9630b3882b52985fefbb1 Mon Sep 17 00:00:00 2001 From: baoxinqi Date: Fri, 10 Dec 2021 17:21:18 +0800 Subject: [PATCH 2/7] fix lint issues and compare bug --- include/tvm/arith/iter_affine_map.h | 2 +- src/arith/iter_affine_map.cc | 43 ++++++++++++++++--- tests/python/unittest/test_arith_intset.py | 2 +- .../unittest/test_arith_iter_affine_map.py | 14 +++--- .../test_tir_schedule_state_cached_flags.py | 12 ++++-- 5 files changed, 55 insertions(+), 18 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 889da78835f9..39a0e9c658a3 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -283,6 +283,7 @@ class IterSumExpr : public IterMapExpr { * \param predicate The predicate constraints on the input iterators * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. + * \param diag_ctx Diagnostic context. * * \return The detected pattern if a match exists, * otherwise return an empty array. @@ -297,7 +298,6 @@ Array DetectIterMap(const Array& indices, const Mapdtype = source->source->dtype; n->source = std::move(source); n->extent = n->source->extent; - if (is_zero(n->source->min)) { + if (!is_zero(n->source->min)) { n->extent = n->extent + n->source->min; } n->lower_factor = one; @@ -233,13 +233,20 @@ class IterMapRewriter : public ExprMutator { collector.Collect(bindings); for (const IterMark& mark : collector.visited_) { if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) { + diag_ctx_.Emit(Diagnostic::Error(mark->source->span) + << "Fail to normalize iter mark splits: " << mark); return false; } } if (require_bijective) { // all input marks must be visited for (const IterMark& mark : input_marks_) { - if (collector.visited_.count(mark) == 0) return false; + if (collector.visited_.count(mark) == 0) { + diag_ctx_.Emit(Diagnostic::Error(mark->source->span) + << "The mapping is not bijective because input iter mark " << mark + << " is not covered, "); + return false; + } } } return true; @@ -425,26 +432,48 @@ class IterMapRewriter : public ExprMutator { } if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective - if (require_bijective) return Array(); + if (require_bijective) { + diag_ctx_.Emit( + Diagnostic::Error(mark->source->span) + << "Do not allow incomplete split in bijective checking, expected_lower_factor=" + << expected_lower_factor); + return Array(); + } // look for the next split skipping this lower factor // For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2] // It is valid to only have [y / 6, y % 2] if bijective is not required // We can skip (y / 2) % 6 j = SearchSkipLowerFactor(splits, used, expected_lower_factor); // split not found - if (j == splits.size()) return Array(); + if (j == splits.size()) { + diag_ctx_.Emit(Diagnostic::Error(mark->source->span) + << "Fail to find split skipping the lower factor in bijective-free " + "checking, expected_lower_factor=" + << expected_lower_factor); + return Array(); + } } used[j] = true; iters.push_back(splits[j]); expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; } + // Case 1. bijective is required. // We check the extent we calculate is consistent with the extent of the mark // Case 2. bijective is not required. - // We check the extent we calculate is a factor of the extent of the mark + // We check either + // (1) the extent we calculate is a factor of the extent of the mark // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. - if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) || - (!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) { + // (2) the extent we calculate is larger than the max of the mark + // For example, y \in [1, 8] [y / 18, y % 18] is valid. + if ((require_bijective && + !(analyzer_->CanProveEqual(expected_lower_factor, mark->extent) && is_zero(mark->min))) || + (!require_bijective && + !(CanProveDivisible(mark->extent, expected_lower_factor) || + analyzer_->CanProve(mark->min + mark->extent <= expected_lower_factor)))) { + diag_ctx_.Emit(Diagnostic::Error(mark->source->span) + << "Mark extent of " << mark + << " is not compatible with expected_lower_factor=" << expected_lower_factor); return Array(); } return Array(iters.rbegin(), iters.rend()); diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index fa97b145e2d6..6cde357c7ef0 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -271,7 +271,7 @@ def do_test_point_access(point, predicates, expect): tvm.tir.max(h3 * 8, 1) - tvm.tir.max(h3 * 8, 214) - tvm.tir.max(1 + h3 * -8, 0) + 223, ), ) - # shoud fail on incompatible predicates + # should fail on incompatible predicates do_test_point_access( point=h3 * 8 + h2 * 5 + h1, predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224], diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 9277d935e5a1..492b611b4d0c 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -44,7 +44,7 @@ def var_dom(iters): return {var: tvm.ir.Range(0, ext) for var, ext in iters} -def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, min=0): +def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, mark_min=0, mark_extent=None): """Check the sum expr have the right pattern.""" assert isinstance(sum_expr, tvm.arith.IterSumExpr) if extent == 1: @@ -53,7 +53,9 @@ def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, min=0): assert len(sum_expr.args) == 1 tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) - tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.min, min) + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.min, mark_min) + if mark_extent: + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.extent, mark_extent) tvm.testing.assert_prim_expr_equal(sum_expr.base, base) @@ -212,10 +214,10 @@ def test_predicate(): # lower bound only res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5) assert len(res) == 1 - assert_iter_sum_pattern(res[0], 124, 0, min=6) + assert_iter_sum_pattern(res[0], 130, 0, mark_min=6, mark_extent=124) res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6) assert len(res) == 1 - assert_iter_sum_pattern(res[0], 124, 0, min=6) + assert_iter_sum_pattern(res[0], 130, 0, mark_min=6, mark_extent=124) # lower bound + upper bound res = tvm.arith.detect_iter_map( @@ -224,14 +226,14 @@ def test_predicate(): tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128), ) assert len(res) == 1 - assert_iter_sum_pattern(res[0], 122, 0, min=6) + assert_iter_sum_pattern(res[0], 128, 0, mark_min=6, mark_extent=122) res = tvm.arith.detect_iter_map( [x[0] * 10 + y[0]], var_dom([x, y]), tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127), ) assert len(res) == 1 - assert_iter_sum_pattern(res[0], 122, 0, min=6) + assert_iter_sum_pattern(res[0], 128, 0, mark_min=6, mark_extent=122) # non-standard form of predicate res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0]) diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index 3a1dcee186fd..f06ca129490b 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -315,7 +315,7 @@ def warp_memory_negative(a: T.handle, c: T.handle) -> None: @T.prim_func -def non_uniform_tiling_cache(a: T.handle, b: T.handle) -> None: +def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: X = T.match_buffer(a, [224, 224], dtype="float32") Y = T.match_buffer(b, [224, 224], dtype="float32") cache = T.alloc_buffer([224, 224], dtype="float32") @@ -325,6 +325,12 @@ def non_uniform_tiling_cache(a: T.handle, b: T.handle) -> None: with T.block("cache"): h = T.axis.spatial(224, T.max(hh_0 * 8 - 1, 0) + ax0) w = T.axis.spatial(224, T.max(ww_0 * 8 - 1, 0) + ax1) + T.where( + 1 <= hh_0 * 8 + ax0 + and hh_0 * 8 + ax0 < 225 + and 1 <= ww_0 * 8 + ax1 + and ww_0 * 8 + ax1 < 225 + ) cache[h, w] = X[h, w] for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): with T.block("compute"): @@ -735,8 +741,8 @@ def test_warp_memory_negative(): # pylint: enable=protected-access -def test_non_uniform_tiling(): - s = tir.ScheduleState(non_uniform_tiling_cache, debug_mask="all") +def test_non_perfect_tiling_cache(): + s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags( affine_binding=False, From a4625a6a6c3a641decbfbb543f7fa29c98a4489c Mon Sep 17 00:00:00 2001 From: wrongtest Date: Fri, 17 Dec 2021 20:02:35 +0800 Subject: [PATCH 3/7] update to use iter shift instead of itermark min for lowerbound --- include/tvm/arith/iter_affine_map.h | 13 +- python/tvm/arith/iter_affine_map.py | 7 +- src/arith/int_set.cc | 6 +- src/arith/iter_affine_map.cc | 184 ++++++++++-------- tests/python/unittest/test_arith_intset.py | 15 +- .../unittest/test_arith_iter_affine_map.py | 45 ++--- .../test_tir_schedule_state_cached_flags.py | 8 +- 7 files changed, 145 insertions(+), 133 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 39a0e9c658a3..22b4cd580e18 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -83,7 +83,7 @@ class IterMapExpr : public PrimExpr { }; /*! - * \brief Mark the source as an iterator in [min, extent). + * \brief Mark the source as an iterator in [0, extent). * * IterMark is used to mark source expression as a valid * iterator to make future analysis easy. @@ -95,10 +95,6 @@ class IterMarkNode : public Object { * a IterSumExpr or a Var. */ PrimExpr source; - /*! - * \brief The min of the iteration. - */ - PrimExpr min; /*! * \brief The extent of the iteration. */ @@ -107,19 +103,17 @@ class IterMarkNode : public Object { // overrides void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source", &source); - v->Visit("min", &min); v->Visit("extent", &extent); } bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return equal(source, other->source) && equal(extent, other->extent) && equal(min, other->min); + return equal(source, other->source) && equal(extent, other->extent); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce->MarkGraphNode(); hash_reduce(source); - hash_reduce(min); hash_reduce(extent); } @@ -138,10 +132,9 @@ class IterMark : public ObjectRef { /*! * \brief constructor. * \param source The source expression. - * \param min The min of the iterator. * \param extent The extent of the iterator. */ - TVM_DLL IterMark(PrimExpr source, PrimExpr min, PrimExpr extent); + TVM_DLL IterMark(PrimExpr source, PrimExpr extent); TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMarkNode); diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 74f843b8278c..85513ecae5c4 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -34,15 +34,12 @@ class IterMark(Object): source : PrimExpr. The source expression. - min_value : PrimExpr - The min of the iterator. - extent : PrimExpr The extent of the iterator. """ - def __init__(self, source, min_value, extent): - self.__init_handle_by_constructor__(_ffi_api.IterMark, source, min_value, extent) + def __init__(self, source, extent): + self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent) @tvm._ffi.register_object("arith.IterSplitExpr") diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index ba1563fda137..55a1a5a1830e 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -839,9 +839,9 @@ Optional> EstimateRegionLowerBound(const Array& region, iter_sum_exprs = DetectIterMap( /*indices=*/affine_indices, /*input_iters=*/var_dom, /*predicate=*/predicate, /*require_bijective=*/false, analyzer, diag_ctx); - if (iter_sum_exprs.empty()) { - return NullOpt; - } + } + if (iter_sum_exprs.empty()) { + return NullOpt; } ICHECK_EQ(iter_sum_exprs.size(), ndim); Array result; diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 5743c040a8e0..2a1792b5428f 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -36,26 +36,23 @@ namespace arith { using namespace tir; -IterMark::IterMark(PrimExpr source, PrimExpr min, PrimExpr extent) { +IterMark::IterMark(PrimExpr source, PrimExpr extent) { auto n = make_object(); n->source = std::move(source); - n->min = std::move(min); n->extent = std::move(extent); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("arith.IterMark") - .set_body_typed([](PrimExpr source, PrimExpr min, PrimExpr extent) { - return IterMark(source, min, extent); - }); +TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) { + return IterMark(source, extent); +}); TVM_REGISTER_NODE_TYPE(IterMarkNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "IterMark(" << op->source << ", min=" << op->min << ", extent=" << op->extent - << ")"; + p->stream << "IterMark(" << op->source << ", extent=" << op->extent << ")"; }); IterSplitExpr::IterSplitExpr(IterMark source) { @@ -75,9 +72,6 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { n->dtype = source->source->dtype; n->source = std::move(source); n->extent = n->source->extent; - if (!is_zero(n->source->min)) { - n->extent = n->extent + n->source->min; - } n->lower_factor = one; n->scale = std::move(scale); data_ = std::move(n); @@ -166,6 +160,14 @@ class IterMarkSplitCollector { } }; +/*! \brief Record form of IterMark(x, extent) + offset */ +struct IterMarkWithOffset { + IterMark mark; + PrimExpr offset{0}; + IterMarkWithOffset() {} + IterMarkWithOffset(IterMark mark, PrimExpr offset) : mark(mark), offset(offset) {} +}; + /*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */ class IterMapRewriter : public ExprMutator { public: @@ -180,11 +182,11 @@ class IterMapRewriter : public ExprMutator { if (is_one(vrng->extent)) { var_map_[var] = IterSumExpr({}, vrng->min); } else if (is_zero(vrng->min)) { - IterMark mark(var, 0, vrng->extent); + IterMark mark(var, vrng->extent); var_map_[var] = IterSplitExpr(mark); input_marks_.push_back(mark); } else { - IterMark mark(var - vrng->min, 0, vrng->extent); + IterMark mark(var - vrng->min, vrng->extent); IterSumExpr sum_expr = ToIterSumExpr(IterSplitExpr(mark)); sum_expr.CopyOnWrite()->base = vrng->min; var_map_[var] = sum_expr; @@ -359,8 +361,8 @@ class IterMapRewriter : public ExprMutator { std::unordered_map var_map_; // input iter marks std::vector input_marks_; - // The map for sum that maps flattened form to IterMark with normal form and extent - // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) + // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly + // an extra offset) Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: j*2 + k < 9 // Then, flattened form = IterSum(IterSplit(i, scale=9), // IterSplit(j, scale=2), @@ -370,7 +372,18 @@ class IterMapRewriter : public ExprMutator { // IterSplit(k, scale=1)), // extent=9) // scale=1)) - std::unordered_map sum_fuse_map_; + // Example(2): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) + // predicate: 1 <= j*2 + k < 9 + // Then, flattened form = IterSum(IterSplit(i, scale=9), + // IterSplit(j, scale=2), + // IterSplit(k, scale=1)) + // normal form = IterSum(IterSplit(i, scale=9), + // IterSplit(IterMark(IterSum(IterSplit(j, scale=2), + // IterSplit(k, scale=1), base=-1), + // extent=9-1) + // scale=1), + // base=1) + std::unordered_map sum_fuse_map_; // The map for sum that maps normal form to flattened form std::unordered_map flattened_map_; // The flattened forms of constrained iters @@ -461,16 +474,10 @@ class IterMapRewriter : public ExprMutator { // Case 1. bijective is required. // We check the extent we calculate is consistent with the extent of the mark // Case 2. bijective is not required. - // We check either - // (1) the extent we calculate is a factor of the extent of the mark + // We check the extent we calculate is a factor of the extent of the mark // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. - // (2) the extent we calculate is larger than the max of the mark - // For example, y \in [1, 8] [y / 18, y % 18] is valid. - if ((require_bijective && - !(analyzer_->CanProveEqual(expected_lower_factor, mark->extent) && is_zero(mark->min))) || - (!require_bijective && - !(CanProveDivisible(mark->extent, expected_lower_factor) || - analyzer_->CanProve(mark->min + mark->extent <= expected_lower_factor)))) { + if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) || + (!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) { diag_ctx_.Emit(Diagnostic::Error(mark->source->span) << "Mark extent of " << mark << " is not compatible with expected_lower_factor=" << expected_lower_factor); @@ -495,10 +502,12 @@ class IterMapRewriter : public ExprMutator { if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base; if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base; } - Optional opt = TryFuseIters(expr); + Optional opt = TryFuseIters(expr); + ICHECK(!opt.defined() || opt.value()->args.size() == 1); // scale should be 1 - if (opt.defined() && is_one(opt.value()->scale)) { - IterSumExpr sum = Downcast(opt.value()->source->source); + if (opt.defined() && is_one(opt.value()->args[0]->scale)) { + IterSplitExpr fused_split = opt.value()->args[0]; + IterSumExpr sum = Downcast(fused_split->source->source); // get the flattened form auto it = flattened_map_.find(sum); ICHECK(it != flattened_map_.end()); @@ -506,26 +515,30 @@ class IterMapRewriter : public ExprMutator { // get the mark auto it_mark = sum_fuse_map_.find(flattened_form); ICHECK(it_mark != sum_fuse_map_.end()); - IterMark mark = it_mark->second; + IterMark mark = it_mark->second.mark; + PrimExpr mark_offset = it_mark->second.offset; // update iter mark iter range to [0, mark->extent) ^ [pred_min, pred_max) - PrimExpr mark_min = mark->min; - PrimExpr mark_max = mark->min + mark->extent; + PrimExpr mark_min = 0; + PrimExpr mark_max = mark->extent; if (predicate_induced_min.defined()) { mark_min = max(predicate_induced_min, mark_min); } if (predicate_induced_max.defined()) { mark_max = min(predicate_induced_max, mark_max); } - mark.CopyOnWrite()->min = mark_min; + // mark.CopyOnWrite()->min = mark_min; + mark.CopyOnWrite()->source = mark->source - mark_min; mark.CopyOnWrite()->extent = mark_max - mark_min; + mark_offset = mark_offset + mark_min; // update the bound of the lhs based on predicate_induced_extent - sum_fuse_map_[flattened_form] = mark; + sum_fuse_map_[flattened_form] = {mark, mark_offset}; + // we need to note down the flattened form of constrained iterators // to check the validity of constraints, see also CheckConstraints() constrained_iters_flattened_.push_back(flattened_form); - expr.CopyOnWrite()->args = Array({opt.value()}); - expr.CopyOnWrite()->base = base; + expr.CopyOnWrite()->args = Array({fused_split}); + expr.CopyOnWrite()->base = base + mark_min; return expr; } Fail(Diagnostic::Error(expr->span) @@ -542,13 +555,9 @@ class IterMapRewriter : public ExprMutator { IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { // We are normalizing a regular iter if (expr->args.size() <= 1) return expr; - PrimExpr base = expr->base; - expr.CopyOnWrite()->base = make_zero(expr->dtype); - Optional opt = TryFuseIters(expr); - expr.CopyOnWrite()->base = base; + Optional opt = TryFuseIters(expr); if (opt.defined()) { - expr.CopyOnWrite()->args = Array({opt.value()}); - return expr; + return opt.value(); } else { Fail(Diagnostic::Error(expr->span) << "Fail to normalize iter sum with offset: " << expr); return expr; @@ -572,17 +581,16 @@ class IterMapRewriter : public ExprMutator { } /*! - * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn - * = (x1*s1 + x2*s2 + ... + xn)*cn - * = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) - * = [IterSplit(IterMark(y), scale=cn)] - * return a corresponding IterSplitExpr if needed. + * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + base + * = (x1*s1 + x2*s2 + ... + xn)*cn + base + * = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) + base + * = [IterSplit(IterMark(y), scale=cn)] + base + * return a corresponding IterSumExpr with extra offset if needed. * Try to normalize IterSum into a fused IterMark * \param expr The input sum. - * \return The split with the fused IterMark if succeed. + * \return The sum with the fused IterMark and extra offset if succeed. */ - Optional TryFuseIters(IterSumExpr expr) { - if (!is_zero(expr->base)) return NullOpt; + Optional TryFuseIters(IterSumExpr expr) { // select the iterators in order std::vector visited(expr->args.size(), false); std::vector flattened_iters, grouped_iters; @@ -600,10 +608,9 @@ class IterMapRewriter : public ExprMutator { } if (!base_scale) return NullOpt; // check if it can be remapped into a fused pattern. + PrimExpr expected_extra_base = 0; PrimExpr expected_scale = base_scale.value(); for (size_t i = 0; i < expr->args.size();) { - // check arg iter mark starts from zero - if (!is_zero(expr->args[i]->source->min)) return NullOpt; // find j such that expr->args[j] has expected scale size_t j = i == 0 ? base_index : 0; for (; j < expr->args.size(); ++j) { @@ -645,9 +652,10 @@ class IterMapRewriter : public ExprMutator { } auto iter = sum_fuse_map_.find(constraint_to_match.value()); ICHECK(iter != sum_fuse_map_.end()); - IterMark iter_matched = iter->second; - grouped_iters.emplace_back(iter_matched, expected_scale); - expected_scale *= iter_matched->extent; + const IterMarkWithOffset& iter_matched = iter->second; + grouped_iters.emplace_back(iter_matched.mark, expected_scale); + expected_extra_base += iter_matched.offset * expected_scale; + expected_scale *= iter_matched.mark->extent; // move forward i += constraint_to_match.value()->args.size(); } else { @@ -664,18 +672,26 @@ class IterMapRewriter : public ExprMutator { IterSumExpr structured_form = expr, flattened_form = expr; flattened_form.CopyOnWrite()->args = Array(flattened_iters.rbegin(), flattened_iters.rend()); + flattened_form.CopyOnWrite()->base = 0; structured_form.CopyOnWrite()->args = Array(grouped_iters.rbegin(), grouped_iters.rend()); + structured_form.CopyOnWrite()->base = 0; auto it = sum_fuse_map_.find(flattened_form); if (it != sum_fuse_map_.end()) { // old iter - return IterSplitExpr(it->second, base_scale.value()); + if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) { + // the extra offset is not consistent with old + return NullOpt; + } + return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())}, + expr->base + expected_extra_base); } else { // new iter, form a new mark - IterMark mark = IterMark(structured_form, 0, div(expected_scale, base_scale.value())); - sum_fuse_map_[flattened_form] = mark; + IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value())); + sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0); flattened_map_[structured_form] = flattened_form; - return IterSplitExpr(mark, base_scale.value()); + return IterSumExpr({IterSplitExpr(mark, base_scale.value())}, + expr->base + expected_extra_base); } } @@ -941,17 +957,8 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { if (!IsIndexType(op->dtype)) { return Parent::VisitExpr_(op); } - - // skip analysis of irrelated sum parts - auto f_use_itervar = [this](const VarNode* v) { return var_map_.count(GetRef(v)); }; - PrimExpr a = op->a; - PrimExpr b = op->b; - if (!is_const_int(a) && UsesVar(a, f_use_itervar)) { - a = this->DirectMutate(op->a); - } - if (!is_const_int(b) && UsesVar(b, f_use_itervar)) { - b = this->DirectMutate(op->b); - } + PrimExpr a = this->DirectMutate(op->a); + PrimExpr b = this->DirectMutate(op->b); // const folding PrimExpr const_res = TryConstFold(a, b); @@ -1134,8 +1141,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (a->IsInstance()) { IterSumExpr ret = Downcast(a); - if (Optional opt = TryFuseIters(ret)) { - return SplitFloorDivConst(opt.value(), b, GetRef(op)); + if (Optional opt = TryFuseIters(ret)) { + IterSumExpr sum = opt.value(); + if (!is_zero(sum->base)) { + Fail(Diagnostic::Error(op->span) + << "Fuse IterSumExpr " << ret + << " failed, cannot floordiv an IterSumExpr with nonzero base"); + return GetRef(op); + } + ICHECK_EQ(sum->args.size(), 1U); + return SplitFloorDivConst(sum->args[0], b, GetRef(op)); } else { Fail(Diagnostic::Error(op->span) << "Fuse IterSumExpr " << ret << " failed"); return GetRef(op); @@ -1211,8 +1226,15 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (a->IsInstance()) { IterSumExpr ret = Downcast(a); - if (Optional opt = TryFuseIters(ret)) { - return SplitFloorModConst(opt.value(), b, GetRef(op)); + if (Optional opt = TryFuseIters(ret)) { + IterSumExpr sum = opt.value(); + if (!is_zero(sum->base)) { + Fail(Diagnostic::Error(op->span) + << "Fuse IterSumExpr " << ret + << " failed, cannot floormod an IterSumExpr with nonzero base"); + return GetRef(op); + } + return SplitFloorModConst(sum->args[0], b, GetRef(op)); } else { Fail(Diagnostic::Error(op->span) << "Fail to fuse iters of " << ret); return GetRef(op); @@ -1365,7 +1387,7 @@ class SubspaceDivider { if (const auto* op = expr.as()) { return GetRef(op); } else if (const auto* op = expr.as()) { - return IterSplitExpr(IterMark(GetRef(op), 0, extent)); + return IterSplitExpr(IterMark(GetRef(op), extent)); } else { LOG(FATAL) << "Unknown IterMapExpr type"; return NullValue(); @@ -1477,7 +1499,7 @@ class SubspaceDivider { extent *= arg->extent; res.push_back(arg); } - return IterMark(IterSumExpr(Array(res.rbegin(), res.rend()), base), 0, extent); + return IterMark(IterSumExpr(Array(res.rbegin(), res.rend()), base), extent); } DivisionResult DivideIterSplitExpr(const IterSplitExpr& expr) { @@ -1518,10 +1540,8 @@ class SubspaceDivider { if (splits.size() == 1) { return mark_division; } - IterMark outer_mark(Downcast(mark_division.outer), 0, - mark_division.outer_extent); - IterMark inner_mark(Downcast(mark_division.inner), 0, - mark_division.inner_extent); + IterMark outer_mark(Downcast(mark_division.outer), mark_division.outer_extent); + IterMark inner_mark(Downcast(mark_division.inner), mark_division.inner_extent); bool encountered_boundary = mark_division.IsOuter(); std::vector used(splits.size(), false); std::vector inner_iters, outer_iters; @@ -1608,11 +1628,11 @@ Array> SubspaceDivide(const Array& bindings, SubspaceDivider::DivisionResult res = subspace_divider.DivideIterSumExpr(expr, 0); if (subspace_divider.unresolved_count()) return {}; results.push_back( - {IterMark(res.outer, 0, res.outer_extent), IterMark(res.inner, 0, res.inner_extent)}); + {IterMark(res.outer, res.outer_extent), IterMark(res.inner, res.inner_extent)}); } - results.push_back({IterMark(IterSumExpr({}, 0), 0, subspace_divider.GetOuterPreds()), - IterMark(IterSumExpr({}, 0), 0, subspace_divider.GetInnerPreds())}); + results.push_back({IterMark(IterSumExpr({}, 0), subspace_divider.GetOuterPreds()), + IterMark(IterSumExpr({}, 0), subspace_divider.GetInnerPreds())}); return results; } diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 6cde357c7ef0..92a9e630eb08 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -243,6 +243,7 @@ def test_region_lower_bound_for_non_perfect_tile(): h2: tvm.ir.Range(begin=0, end=2), h1: tvm.ir.Range(begin=0, end=5), } + analyzer = tvm.arith.Analyzer() def do_test_point_access(point, predicates, expect): regions = tvm.arith.estimate_region_lower_bound( @@ -256,19 +257,23 @@ def do_test_point_access(point, predicates, expect): assert regions is None else: assert len(regions) == 1 - assert structural_equal(expect[0], regions[0].min_value) - assert structural_equal(expect[1], regions[0].max_value) + assert structural_equal( + analyzer.simplify(expect[0], 3), analyzer.simplify(regions[0].min_value, 3) + ) + assert structural_equal( + analyzer.simplify(expect[1], 3), analyzer.simplify(regions[0].max_value, 3) + ) # normal case of a non-uniform tiling # h3 == 0: region is [1, 9] - # 0 < h3 <= 8: region is [h3 * 8, h3 * 8 + 9] - # h3 > 8: region is [h3 * 8, 223] + # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9] + # h3 > 26: region is [h3 * 8, 223] do_test_point_access( point=h3 * 8 + h2 * 5 + h1, predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224], expect=( tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - tvm.tir.max(h3 * 8, 214) - tvm.tir.max(1 + h3 * -8, 0) + 223, + tvm.tir.max(h3 * 8, 1) - tvm.tir.max(h3 * 8, 214) - tvm.tir.max(1 - h3 * 8, 0) + 223, ), ) # should fail on incompatible predicates diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 492b611b4d0c..ecd744ae11ac 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -44,7 +44,7 @@ def var_dom(iters): return {var: tvm.ir.Range(0, ext) for var, ext in iters} -def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, mark_min=0, mark_extent=None): +def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, mark_extent=None): """Check the sum expr have the right pattern.""" assert isinstance(sum_expr, tvm.arith.IterSumExpr) if extent == 1: @@ -53,9 +53,6 @@ def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, mark_min=0, mark_ex assert len(sum_expr.args) == 1 tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) - tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.min, mark_min) - if mark_extent: - tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.extent, mark_extent) tvm.testing.assert_prim_expr_equal(sum_expr.base, base) @@ -181,8 +178,8 @@ def test_compound(): assert_iter_sum_pattern(res[0], 18, 0) assert_iter_sum_pattern(res[1], 5, 0) # reconstruct the pattern manually - mx = tvm.arith.IterMark(x[0], 0, 10) - my = tvm.arith.IterMark(y[0], 0, 9) + mx = tvm.arith.IterMark(x[0], 10) + my = tvm.arith.IterMark(y[0], 9) xoscale = 3 xiscale = 1 @@ -193,7 +190,7 @@ def test_compound(): myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale) myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale) - mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 0, 18) + mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18) sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0) tvm.ir.assert_structural_equal(sz, res[0]) @@ -214,10 +211,10 @@ def test_predicate(): # lower bound only res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5) assert len(res) == 1 - assert_iter_sum_pattern(res[0], 130, 0, mark_min=6, mark_extent=124) + assert_iter_sum_pattern(res[0], 124, 6) res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6) assert len(res) == 1 - assert_iter_sum_pattern(res[0], 130, 0, mark_min=6, mark_extent=124) + assert_iter_sum_pattern(res[0], 124, 6) # lower bound + upper bound res = tvm.arith.detect_iter_map( @@ -226,14 +223,14 @@ def test_predicate(): tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128), ) assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0, mark_min=6, mark_extent=122) + assert_iter_sum_pattern(res[0], 122, 6) res = tvm.arith.detect_iter_map( [x[0] * 10 + y[0]], var_dom([x, y]), tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127), ) assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0, mark_min=6, mark_extent=122) + assert_iter_sum_pattern(res[0], 122, 6) # non-standard form of predicate res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0]) @@ -577,12 +574,12 @@ def test_complex(): ) assert len(res) == 2 - n0_mark = tvm.arith.IterMark(n0[0], 0, n0[1]) - n1_mark = tvm.arith.IterMark(n1[0], 0, n1[1]) - l0_mark = tvm.arith.IterMark(l0[0], 0, l0[1]) - l1_mark = tvm.arith.IterMark(l1[0], 0, l1[1]) - m1_mark = tvm.arith.IterMark(m1[0], 0, m1[1]) - l3_mark = tvm.arith.IterMark(l3[0], 0, l3[1]) + n0_mark = tvm.arith.IterMark(n0[0], n0[1]) + n1_mark = tvm.arith.IterMark(n1[0], n1[1]) + l0_mark = tvm.arith.IterMark(l0[0], l0[1]) + l1_mark = tvm.arith.IterMark(l1[0], l1[1]) + m1_mark = tvm.arith.IterMark(m1[0], m1[1]) + l3_mark = tvm.arith.IterMark(l3[0], l3[1]) m0_expr = tvm.arith.IterSumExpr( [ @@ -591,12 +588,12 @@ def test_complex(): ], 0, ) - m0_mark = tvm.arith.IterMark(m0_expr, 0, 6) + m0_mark = tvm.arith.IterMark(m0_expr, 6) l2_expr = tvm.arith.IterSumExpr( [tvm.arith.IterSplitExpr(m0_mark, 1, 6, 3), tvm.arith.IterSplitExpr(m1_mark, 1, m1[1], 1)], 0, ) - l2_mark = tvm.arith.IterMark(l2_expr, 0, 16) + l2_mark = tvm.arith.IterMark(l2_expr, 16) k0_expr = tvm.arith.IterSplitExpr(l0_mark, 2, 2, 4) k1_expr = tvm.arith.IterSplitExpr(l1_mark, 2, 4, 1) k2_expr = tvm.arith.IterSplitExpr(l2_mark, 4, 4, 8) @@ -607,19 +604,19 @@ def test_complex(): k7_expr = tvm.arith.IterSplitExpr(l3_mark, 1, 4, 1) j0_expr = tvm.arith.IterSumExpr([k0_expr, k1_expr], 0) - j0_mark = tvm.arith.IterMark(j0_expr, 0, 7) + j0_mark = tvm.arith.IterMark(j0_expr, 7) i0_expr = tvm.arith.IterSumExpr( [tvm.arith.IterSplitExpr(j0_mark, 1, 7, 32), k2_expr, k3_expr], 0 ) j3_expr = tvm.arith.IterSumExpr([k6_expr, k7_expr], 0) - j3_mark = tvm.arith.IterMark(j3_expr, 0, 15) + j3_mark = tvm.arith.IterMark(j3_expr, 15) i1_expr = tvm.arith.IterSumExpr( [k4_expr, k5_expr, tvm.arith.IterSplitExpr(j3_mark, 1, 15, 1)], 0 ) - i0_mark = tvm.arith.IterMark(i0_expr, 0, i0[1]) - i1_mark = tvm.arith.IterMark(i1_expr, 0, i1[1]) + i0_mark = tvm.arith.IterMark(i0_expr, i0[1]) + i1_mark = tvm.arith.IterMark(i1_expr, i1[1]) i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0) i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0) @@ -688,7 +685,7 @@ def test_normalize_iter_map_to_expr(): tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) # iter mark wrap a complex expr - split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x[0] * y[0] + 1, 0, 1024), 1, 1024, 1) + split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x[0] * y[0] + 1, 1024), 1, 1024, 1) tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x[0] * y[0] + 1) diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index f06ca129490b..334723533d85 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -320,11 +320,11 @@ def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: Y = T.match_buffer(b, [224, 224], dtype="float32") cache = T.alloc_buffer([224, 224], dtype="float32") for hh_0, ww_0 in T.grid(28, 28): - for ax0 in T.serial(0, T.min(hh_0 * 8 + 8, 223) + 1 - T.max(hh_0 * 8 - 1, 0)): - for ax1 in T.serial(0, T.min(ww_0 * 8 + 8, 223) + 1 - T.max(ww_0 * 8 - 1, 0)): + for ax0 in T.serial(0, 10): + for ax1 in T.serial(0, 10): with T.block("cache"): - h = T.axis.spatial(224, T.max(hh_0 * 8 - 1, 0) + ax0) - w = T.axis.spatial(224, T.max(ww_0 * 8 - 1, 0) + ax1) + h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) + w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) T.where( 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 From 899e2eb2d48409d88b46952fe7e559c4e99039d8 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sun, 19 Dec 2021 18:33:52 +0800 Subject: [PATCH 4/7] add testcase of fused iters sum with multiple lowerbounds --- src/arith/iter_affine_map.cc | 8 ++++---- .../unittest/test_arith_iter_affine_map.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 2a1792b5428f..d9065f1905e6 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -781,21 +781,21 @@ std::vector MatchBoundConstraints(PrimExpr pred, bool is_finish = false; bool is_greater = false; bool is_equal = false; - if ((rest && (lhs < rhs)).Match(pred)) { + if ((rest && (lhs < rhs)).Match(pred) || ((lhs < rhs) && rest).Match(pred)) { // pass } else if ((lhs < rhs).Match(pred)) { is_finish = true; - } else if ((rest && (lhs <= rhs)).Match(pred)) { + } else if ((rest && (lhs <= rhs)).Match(pred) || ((lhs <= rhs) && rest).Match(pred)) { is_equal = true; } else if ((lhs <= rhs).Match(pred)) { is_equal = true; is_finish = true; - } else if ((rest && (lhs > rhs)).Match(pred)) { + } else if ((rest && (lhs > rhs)).Match(pred) || ((lhs > rhs) && rest).Match(pred)) { is_greater = true; } else if ((lhs > rhs).Match(pred)) { is_greater = true; is_finish = true; - } else if ((rest && (lhs >= rhs)).Match(pred)) { + } else if ((rest && (lhs >= rhs)).Match(pred) || ((lhs >= rhs) && rest).Match(pred)) { is_greater = true; is_equal = true; } else if ((lhs >= rhs).Match(pred)) { diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index ecd744ae11ac..dcd0e67f76ca 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -232,6 +232,23 @@ def test_predicate(): assert len(res) == 1 assert_iter_sum_pattern(res[0], 122, 6) + # lower bound on many fused iters + # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) + # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1) + # i1 * 60 in [60, 240), extent=180 (= scale of i0) + i0 = tvm.tir.Var("i0", "int32"), 3 + i1 = tvm.tir.Var("i1", "int32"), 4 + i2 = tvm.tir.Var("i2", "int32"), 3 + i3 = tvm.tir.Var("i3", "int32"), 2 + i4 = tvm.tir.Var("i4", "int32"), 3 + i5 = tvm.tir.Var("i5", "int32"), 6 + res = tvm.arith.detect_iter_map( + [i0[0] * 180 + i1[0] * 60 + i2[0] * 30 + i3[0] * 15 + i4[0] * 6 + i5[0]], + var_dom([i0, i1, i2, i3, i4, i5]), + tvm.tir.And(1 <= i1[0], tvm.tir.And(2 <= i2[0] * 2 + i3[0], 3 <= i4[0] * 6 + i5[0])), + ) + assert_iter_sum_pattern(res[0], 540, 93) + # non-standard form of predicate res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0]) assert len(res) == 1 From 79b72777e97a734507ef66d03462052f893ffcae Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 22 Dec 2021 11:04:45 +0800 Subject: [PATCH 5/7] add more affine check testcases, fix bug for single iter and duplicate constraints on iter --- src/arith/iter_affine_map.cc | 89 +++++++---- tests/python/unittest/test_arith_intset.py | 71 ++++++--- .../unittest/test_arith_iter_affine_map.py | 138 ++++++++++++++++-- .../test_tir_schedule_state_cached_flags.py | 2 +- 4 files changed, 244 insertions(+), 56 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index d9065f1905e6..1ec1c12d76bf 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -372,12 +372,12 @@ class IterMapRewriter : public ExprMutator { // IterSplit(k, scale=1)), // extent=9) // scale=1)) - // Example(2): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) + // Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: 1 <= j*2 + k < 9 - // Then, flattened form = IterSum(IterSplit(i, scale=9), + // Then, flattened form = IterSum(IterSplit(i, scale=8), // IterSplit(j, scale=2), // IterSplit(k, scale=1)) - // normal form = IterSum(IterSplit(i, scale=9), + // normal form = IterSum(IterSplit(i, scale=8), // IterSplit(IterMark(IterSum(IterSplit(j, scale=2), // IterSplit(k, scale=1), base=-1), // extent=9-1) @@ -495,7 +495,7 @@ class IterMapRewriter : public ExprMutator { */ IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min, PrimExpr predicate_induced_max) { - // remove base temporarily since `TryFuseIters` require zero base iter sum + // normalize to zero base PrimExpr base = expr->base; if (!is_zero(base)) { expr.CopyOnWrite()->base = 0; @@ -506,39 +506,40 @@ class IterMapRewriter : public ExprMutator { ICHECK(!opt.defined() || opt.value()->args.size() == 1); // scale should be 1 if (opt.defined() && is_one(opt.value()->args[0]->scale)) { - IterSplitExpr fused_split = opt.value()->args[0]; - IterSumExpr sum = Downcast(fused_split->source->source); + const IterSplitExpr split = opt.value()->args[0]; + IterSumExpr structured_form = Downcast(split->source->source); // get the flattened form - auto it = flattened_map_.find(sum); + auto it = flattened_map_.find(structured_form); ICHECK(it != flattened_map_.end()); IterSumExpr flattened_form = it->second; - // get the mark + // get the mark and offset of the structured_form auto it_mark = sum_fuse_map_.find(flattened_form); ICHECK(it_mark != sum_fuse_map_.end()); IterMark mark = it_mark->second.mark; PrimExpr mark_offset = it_mark->second.offset; - // update iter mark iter range to [0, mark->extent) ^ [pred_min, pred_max) - PrimExpr mark_min = 0; - PrimExpr mark_max = mark->extent; + PrimExpr iter_min = mark_offset; + PrimExpr iter_max = iter_min + mark->extent; if (predicate_induced_min.defined()) { - mark_min = max(predicate_induced_min, mark_min); + iter_min = max(predicate_induced_min, iter_min); } if (predicate_induced_max.defined()) { - mark_max = min(predicate_induced_max, mark_max); + iter_max = min(predicate_induced_max, iter_max); } - // mark.CopyOnWrite()->min = mark_min; - mark.CopyOnWrite()->source = mark->source - mark_min; - mark.CopyOnWrite()->extent = mark_max - mark_min; - mark_offset = mark_offset + mark_min; - - // update the bound of the lhs based on predicate_induced_extent - sum_fuse_map_[flattened_form] = {mark, mark_offset}; + if (!is_zero(iter_min)) { + // structured form's offset should be updated + flattened_map_.erase(structured_form); + structured_form.CopyOnWrite()->base = -iter_min; + mark.CopyOnWrite()->source = structured_form; + flattened_map_[structured_form] = flattened_form; + } + mark.CopyOnWrite()->extent = iter_max - iter_min; + sum_fuse_map_[flattened_form] = {mark, iter_min}; // we need to note down the flattened form of constrained iterators // to check the validity of constraints, see also CheckConstraints() constrained_iters_flattened_.push_back(flattened_form); - expr.CopyOnWrite()->args = Array({fused_split}); - expr.CopyOnWrite()->base = base + mark_min; + expr.CopyOnWrite()->args = Array({split}); + expr.CopyOnWrite()->base = base + iter_min; return expr; } Fail(Diagnostic::Error(expr->span) @@ -554,7 +555,7 @@ class IterMapRewriter : public ExprMutator { */ IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { // We are normalizing a regular iter - if (expr->args.size() <= 1) return expr; + if (expr->args.size() < 1) return expr; Optional opt = TryFuseIters(expr); if (opt.defined()) { return opt.value(); @@ -593,6 +594,7 @@ class IterMapRewriter : public ExprMutator { Optional TryFuseIters(IterSumExpr expr) { // select the iterators in order std::vector visited(expr->args.size(), false); + size_t num_visited = 0; std::vector flattened_iters, grouped_iters; // canonicalize the expression into two different forms: flattened form and structured form // step0. check if find the base scale first @@ -606,7 +608,11 @@ class IterMapRewriter : public ExprMutator { } } } - if (!base_scale) return NullOpt; + if (!base_scale) { + diag_ctx_.Emit(Diagnostic::Error(expr->span) + << "Fuse iters failed, can not find a valid base scale"); + return NullOpt; + } // check if it can be remapped into a fused pattern. PrimExpr expected_extra_base = 0; PrimExpr expected_scale = base_scale.value(); @@ -616,7 +622,11 @@ class IterMapRewriter : public ExprMutator { for (; j < expr->args.size(); ++j) { if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break; } - if (j == expr->args.size()) return NullOpt; + if (j == expr->args.size()) { + diag_ctx_.Emit(Diagnostic::Error(expr->span) + << "Fuse iters failed, can not find expected scale " << expected_scale); + return NullOpt; + } // look for the longest constrained iter started from expr->args[j] // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: j*2 + k < 9 @@ -637,6 +647,7 @@ class IterMapRewriter : public ExprMutator { // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate = j*2 + k < 9 // then j*2 + k matches the lower two splits of expr + bool match_constraint_suffix = false; for (auto it = constraint_to_match.value()->args.rbegin(); it != constraint_to_match.value()->args.rend(); ++it) { size_t k = 0; @@ -646,10 +657,33 @@ class IterMapRewriter : public ExprMutator { break; } } - if (k == expr->args.size()) return NullOpt; + if (k == expr->args.size()) { + if (i == 0 && num_visited == visited.size()) { + // if match failed because of iterations are used out instead of scale mismatch, + // and all used iters are visited during current match round, fallback to skip the + // constraint. Example: exprs = [i * 2 + j, k], i in [0, 3), j in [0, 2), k in [0, 4) + // predicate = i * 8 + j * 4 + k < 10 + ICHECK_EQ(flattened_iters.size(), num_visited); + for (size_t l = 0; l < flattened_iters.size(); ++l) { + grouped_iters.push_back(flattened_iters[l]); + expected_scale *= flattened_iters[l]->extent; + } + match_constraint_suffix = true; + break; + } + diag_ctx_.Emit(Diagnostic::Error(expr->span) + << "Fuse iters failed, can not find flattened iter match constraint " + << constraint_to_match.value()); + return NullOpt; + } visited[k] = true; + num_visited += 1; flattened_iters.push_back(expr->args[k]); } + if (match_constraint_suffix) { + // all iters are used to match the constraint, but only a suffix is matched. + break; + } auto iter = sum_fuse_map_.find(constraint_to_match.value()); ICHECK(iter != sum_fuse_map_.end()); const IterMarkWithOffset& iter_matched = iter->second; @@ -661,6 +695,7 @@ class IterMapRewriter : public ExprMutator { } else { // constraint_to_match not found, skip this iterator visited[j] = true; + num_visited += 1; flattened_iters.push_back(expr->args[j]); grouped_iters.push_back(expr->args[j]); expected_scale *= expr->args[j]->extent; @@ -681,6 +716,8 @@ class IterMapRewriter : public ExprMutator { // old iter if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) { // the extra offset is not consistent with old + diag_ctx_.Emit(Diagnostic::Error(expr->span) + << "Fuse iters failed, the extra offset is not consistent with old"); return NullOpt; } return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())}, diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 92a9e630eb08..b40f3c9f56ea 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm import tir from tvm.ir.base import structural_equal @@ -238,14 +239,9 @@ def test_region_lower_bound_for_non_perfect_tile(): h1 = tvm.tir.Var("h1", "int32") h2 = tvm.tir.Var("h2", "int32") h3 = tvm.tir.Var("h3", "int32") - # h1, h2 are bounded, h3 is free - var_dom = { - h2: tvm.ir.Range(begin=0, end=2), - h1: tvm.ir.Range(begin=0, end=5), - } analyzer = tvm.arith.Analyzer() - def do_test_point_access(point, predicates, expect): + def do_test_point_access(point, predicates, var_dom, expect): regions = tvm.arith.estimate_region_lower_bound( region=[ tvm.ir.Range.from_min_extent(min_value=point, extent=1), @@ -257,29 +253,68 @@ def do_test_point_access(point, predicates, expect): assert regions is None else: assert len(regions) == 1 - assert structural_equal( - analyzer.simplify(expect[0], 3), analyzer.simplify(regions[0].min_value, 3) - ) - assert structural_equal( - analyzer.simplify(expect[1], 3), analyzer.simplify(regions[0].max_value, 3) - ) - - # normal case of a non-uniform tiling + for binding, expect_min, expect_max in expect: + min_diff = expect_min - regions[0].min_value + assert analyzer.simplify(tir.stmt_functor.substitute(min_diff, binding), 3) == 0 + max_diff = expect_max - regions[0].max_value + assert analyzer.simplify(tir.stmt_functor.substitute(max_diff, binding), 3) == 0 + + # non-uniform tiling, single inner variable # h3 == 0: region is [1, 9] # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9] # h3 > 26: region is [h3 * 8, 223] + do_test_point_access( + point=h3 * 8 + h2, + predicates=[1 <= h3 * 8 + h2, h3 * 8 + h2 < 224], + var_dom={ + h2: tvm.ir.Range(begin=0, end=10), + }, + expect=[ + ( + {}, + tvm.tir.max(h3 * 8, 1), + tvm.tir.max(h3 * 8, 1) + - tvm.tir.max(h3 * 8, 214) + - tvm.tir.max(1 - h3 * 8, 0) + + 223, + ), + ({h3: 0}, 1, 9), + ({h3: 10}, h3 * 8, h3 * 8 + 9), + ({h3: 27}, h3 * 8, 223), + ], + ) + + # non-uniform tiling, two inner variables do_test_point_access( point=h3 * 8 + h2 * 5 + h1, predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224], - expect=( - tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - tvm.tir.max(h3 * 8, 214) - tvm.tir.max(1 - h3 * 8, 0) + 223, - ), + var_dom={ + h2: tvm.ir.Range(begin=0, end=2), + h1: tvm.ir.Range(begin=0, end=5), + }, + expect=[ + ( + {}, + tvm.tir.max(h3 * 8, 1), + tvm.tir.max(h3 * 8, 1) + - tvm.tir.max(h3 * 8, 214) + - tvm.tir.max(1 - h3 * 8, 0) + + 223, + ), + ({h3: 0}, 1, 9), + ({h3: 10}, h3 * 8, h3 * 8 + 9), + ({h3: 27}, h3 * 8, 223), + ], ) + # should fail on incompatible predicates do_test_point_access( point=h3 * 8 + h2 * 5 + h1, predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224], + var_dom={ + h2: tvm.ir.Range(begin=0, end=2), + h1: tvm.ir.Range(begin=0, end=5), + }, expect=None, ) diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index dcd0e67f76ca..c40dba01f124 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -44,7 +44,7 @@ def var_dom(iters): return {var: tvm.ir.Range(0, ext) for var, ext in iters} -def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, mark_extent=None): +def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): """Check the sum expr have the right pattern.""" assert isinstance(sum_expr, tvm.arith.IterSumExpr) if extent == 1: @@ -232,23 +232,139 @@ def test_predicate(): assert len(res) == 1 assert_iter_sum_pattern(res[0], 122, 6) - # lower bound on many fused iters + # constraint on one fused iter + i = tvm.tir.Var("i", "int32") + j = tvm.tir.Var("j", "int32") + k = tvm.tir.Var("k", "int32") + res = tvm.arith.detect_iter_map( + [i * 8 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9), + ) + assert_iter_sum_pattern(res[0], 88, 1) + + # constraint on single var + res = tvm.arith.detect_iter_map([i], var_dom([(i, 48)]), tvm.tir.all(i < 10)) + assert_iter_sum_pattern(res[0], 10, 0) + + # iterations are subparts of constraint, case 1 + res = tvm.arith.detect_iter_map( + [i, j, k], + var_dom([(i, 128), (j, 128), (k, 128)]), + tvm.tir.all(i * 16384 + j * 128 + k < 100), + ) + assert_iter_sum_pattern(res[0], 128, 0) + assert_iter_sum_pattern(res[1], 128, 0) + assert_iter_sum_pattern(res[2], 128, 0) + + # iterations are subparts of constraint, case 2 + res = tvm.arith.detect_iter_map( + [i * 128 + j, k], + var_dom([(i, 128), (j, 128), (k, 128)]), + tvm.tir.all(i * 16384 + j * 128 + k < 100), + ) + assert_iter_sum_pattern(res[0], 16384, 0) + assert_iter_sum_pattern(res[1], 128, 0) + + # constraint on nested fused iters + res = tvm.arith.detect_iter_map( + [i * 8 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25), + ) + assert_iter_sum_pattern(res[0], 22, 3) + + # duplicate constraint on one fused iter + res = tvm.arith.detect_iter_map( + [i * 6 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), + ) + assert_iter_sum_pattern(res[0], 66, 2) + + # duplicate constraint on nested fused iters + res = tvm.arith.detect_iter_map( + [i * 6 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all( + 1 <= j * 2 + k, + 2 <= j * 2 + k, + j * 2 + k < 8, + j * 2 + k < 9, + 3 <= i * 6 + j * 2 + k, + i * 6 + j * 2 + k < 25, + 1 <= i * 6 + j * 2 + k, + i * 6 + j * 2 + k < 18, + ), + ) + assert_iter_sum_pattern(res[0], 15, 3) + + # constraint on non-disjoint fused iters should fail + res = tvm.arith.detect_iter_map( + [i * 8 + j * 2 + k], + var_dom([(i, 11), (j, 5), (k, 2)]), + tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), + ) + assert len(res) == 0 + + # constraint on many disjoint fused iters, case 1 # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1) # i1 * 60 in [60, 240), extent=180 (= scale of i0) - i0 = tvm.tir.Var("i0", "int32"), 3 - i1 = tvm.tir.Var("i1", "int32"), 4 - i2 = tvm.tir.Var("i2", "int32"), 3 - i3 = tvm.tir.Var("i3", "int32"), 2 - i4 = tvm.tir.Var("i4", "int32"), 3 - i5 = tvm.tir.Var("i5", "int32"), 6 + i0 = tvm.tir.Var("i0", "int32") + i1 = tvm.tir.Var("i1", "int32") + i2 = tvm.tir.Var("i2", "int32") + i3 = tvm.tir.Var("i3", "int32") + i4 = tvm.tir.Var("i4", "int32") + i5 = tvm.tir.Var("i5", "int32") res = tvm.arith.detect_iter_map( - [i0[0] * 180 + i1[0] * 60 + i2[0] * 30 + i3[0] * 15 + i4[0] * 6 + i5[0]], - var_dom([i0, i1, i2, i3, i4, i5]), - tvm.tir.And(1 <= i1[0], tvm.tir.And(2 <= i2[0] * 2 + i3[0], 3 <= i4[0] * 6 + i5[0])), + [i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5], + var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]), + tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), ) assert_iter_sum_pattern(res[0], 540, 93) + # constraint on many disjoint fused iters, case 2 + res = tvm.arith.detect_iter_map( + [i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4], + var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]), + tvm.tir.all(3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10), + ) + assert_iter_sum_pattern(res[0], 135, 28) + + # constraint on split iters + res = tvm.arith.detect_iter_map( + [i % 16, i // 16], + var_dom([(i, 1024)]), + tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), + require_bijective=True, + ) + assert_iter_sum_pattern(res[0], 7, 3) + assert_iter_sum_pattern(res[1], 8, 4) + + # constraint on split iters, nested case 1 + res = tvm.arith.detect_iter_map( + [(i * 32 + j) % 16], + var_dom([(i, 5), (j, 32)]), + tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), + ) + assert_iter_sum_pattern(res[0], 7, 3) + + # constraint on split iters, nested case 2 + res = tvm.arith.detect_iter_map( + [(i * 32 + j) % 16], + var_dom([(i, 5), (j, 32)]), + tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + ) + assert len(res) == 0 + res = tvm.arith.detect_iter_map( + [(i * 32 + j - 1) % 16, (i * 32 + j - 1) // 16], + var_dom([(i, 5), (j, 32)]), + tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64), + ) + assert_iter_sum_pattern(res[0], 16, 0) + assert_iter_sum_pattern(res[1], 4, 0) + # non-standard form of predicate res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0]) assert len(res) == 1 diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index 334723533d85..d86af72fca93 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -751,7 +751,7 @@ def test_non_perfect_tiling_cache(): ) assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags( affine_binding=True, - region_cover=True, + region_cover=False, stage_pipeline=True, ) # pylint: enable=protected-access From d71bd245eb717218c3f6d3564eae7349d99360b0 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Thu, 23 Dec 2021 10:49:03 +0800 Subject: [PATCH 6/7] add a newline to comment --- src/arith/iter_affine_map.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 1ec1c12d76bf..b479ef010b16 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -362,7 +362,8 @@ class IterMapRewriter : public ExprMutator { // input iter marks std::vector input_marks_; // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly - // an extra offset) Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) + // an extra offset) + // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: j*2 + k < 9 // Then, flattened form = IterSum(IterSplit(i, scale=9), // IterSplit(j, scale=2), From 0337ffae04465576a39bbbe873c394092c3544f9 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Fri, 24 Dec 2021 13:38:23 +0800 Subject: [PATCH 7/7] forbidden predicate unmatch --- src/arith/iter_affine_map.cc | 21 ------------------- .../unittest/test_arith_iter_affine_map.py | 11 ++++------ .../unittest/test_tir_schedule_reorder.py | 5 ++--- .../unittest/test_tir_schedule_rfactor.py | 8 +++---- 4 files changed, 9 insertions(+), 36 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index b479ef010b16..c9d4b1edc3a0 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -595,7 +595,6 @@ class IterMapRewriter : public ExprMutator { Optional TryFuseIters(IterSumExpr expr) { // select the iterators in order std::vector visited(expr->args.size(), false); - size_t num_visited = 0; std::vector flattened_iters, grouped_iters; // canonicalize the expression into two different forms: flattened form and structured form // step0. check if find the base scale first @@ -648,7 +647,6 @@ class IterMapRewriter : public ExprMutator { // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate = j*2 + k < 9 // then j*2 + k matches the lower two splits of expr - bool match_constraint_suffix = false; for (auto it = constraint_to_match.value()->args.rbegin(); it != constraint_to_match.value()->args.rend(); ++it) { size_t k = 0; @@ -659,32 +657,14 @@ class IterMapRewriter : public ExprMutator { } } if (k == expr->args.size()) { - if (i == 0 && num_visited == visited.size()) { - // if match failed because of iterations are used out instead of scale mismatch, - // and all used iters are visited during current match round, fallback to skip the - // constraint. Example: exprs = [i * 2 + j, k], i in [0, 3), j in [0, 2), k in [0, 4) - // predicate = i * 8 + j * 4 + k < 10 - ICHECK_EQ(flattened_iters.size(), num_visited); - for (size_t l = 0; l < flattened_iters.size(); ++l) { - grouped_iters.push_back(flattened_iters[l]); - expected_scale *= flattened_iters[l]->extent; - } - match_constraint_suffix = true; - break; - } diag_ctx_.Emit(Diagnostic::Error(expr->span) << "Fuse iters failed, can not find flattened iter match constraint " << constraint_to_match.value()); return NullOpt; } visited[k] = true; - num_visited += 1; flattened_iters.push_back(expr->args[k]); } - if (match_constraint_suffix) { - // all iters are used to match the constraint, but only a suffix is matched. - break; - } auto iter = sum_fuse_map_.find(constraint_to_match.value()); ICHECK(iter != sum_fuse_map_.end()); const IterMarkWithOffset& iter_matched = iter->second; @@ -696,7 +676,6 @@ class IterMapRewriter : public ExprMutator { } else { // constraint_to_match not found, skip this iterator visited[j] = true; - num_visited += 1; flattened_iters.push_back(expr->args[j]); grouped_iters.push_back(expr->args[j]); expected_scale *= expr->args[j]->extent; diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index c40dba01f124..6b3c29592eb6 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -247,24 +247,21 @@ def test_predicate(): res = tvm.arith.detect_iter_map([i], var_dom([(i, 48)]), tvm.tir.all(i < 10)) assert_iter_sum_pattern(res[0], 10, 0) - # iterations are subparts of constraint, case 1 + # iterations are subparts of constraint, invalid, case 1 res = tvm.arith.detect_iter_map( [i, j, k], var_dom([(i, 128), (j, 128), (k, 128)]), tvm.tir.all(i * 16384 + j * 128 + k < 100), ) - assert_iter_sum_pattern(res[0], 128, 0) - assert_iter_sum_pattern(res[1], 128, 0) - assert_iter_sum_pattern(res[2], 128, 0) + assert len(res) == 0 - # iterations are subparts of constraint, case 2 + # iterations are subparts of constraint, invalid, case 2 res = tvm.arith.detect_iter_map( [i * 128 + j, k], var_dom([(i, 128), (j, 128), (k, 128)]), tvm.tir.all(i * 16384 + j * 128 + k < 100), ) - assert_iter_sum_pattern(res[0], 16384, 0) - assert_iter_sum_pattern(res[1], 128, 0) + assert len(res) == 0 # constraint on nested fused iters res = tvm.arith.detect_iter_map( diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 8267a369cf5d..fd2d82d1ff1f 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -217,9 +217,8 @@ def test_reorder_with_predicate(): sch = tir.Schedule(elementwise_predicate, debug_mask="all") block_b = sch.get_block("B") i, j, k, l = sch.get_loops(block_b) - sch.reorder(l, i) - tvm.ir.assert_structural_equal(elementwise_reordered_with_predicate, sch.mod["main"]) - verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(l, i) def test_reorder_fail_with_multi_appearance_loops(): diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index 35d4f5f4b76a..f5fc5a73d038 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -690,11 +690,9 @@ def test_reduction_rfactor_predicate(): # pylint: disable=invalid-name s = tir.Schedule(rowsum_predicate, debug_mask="all") B = s.get_block("B") _, ko, _ = s.get_loops(B) - rf_block = s.rfactor(ko, 1) - tvm.ir.assert_structural_equal(s.mod["main"], rowsum_predicate_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("B_rf"))) - assert s.get(B).same_as(s.get(s.get_block("B"))) - verify_trace_roundtrip(s, mod=rowsum_predicate) + # TODO: should be a tvm.tir.ScheduleError + with pytest.raises(tvm.TVMError): + rf_block = s.rfactor(ko, 1) def test_reduction_rfactor_with_annotation():