diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index eb69c188abf32..8fcecb4cb4292 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -276,14 +276,13 @@ class IterSumExpr : public IterMapExpr { * \param predicate The predicate constraints on the input iterators * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. - * \param diag_ctx Diagnostic context. * * \return The detected pattern if a match exists, * otherwise return an empty array. */ Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, DiagnosticContext diag_ctx); + arith::Analyzer* analyzer); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices * @@ -335,7 +334,6 @@ Map InverseAffineIterMap(const Array& iter_map, * \param predicate The predicate constraints on the input iterators * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. - * \param diag_ctx Diagnostic context. * * \return The result list has length len(bindings) + 1 [0, len(bindings)): The iter map matching result. The inner list is of length 2. @@ -347,8 +345,7 @@ Map InverseAffineIterMap(const Array& iter_map, Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer, - DiagnosticContext diag_ctx); + bool require_bijective, arith::Analyzer* analyzer); /*! * \brief Given an IterMapExpr, transform it to normal PrimExpr. diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 9f5ef644d2bb4..a3fa879afa270 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -867,10 +867,9 @@ Optional> EstimateRegionLowerBound(const Array& region, for (const Range& range : region) { affine_indices.push_back(range->min); } - DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); iter_sum_exprs = DetectIterMap( /*indices=*/affine_indices, /*input_iters=*/var_dom, - /*predicate=*/predicate, /*require_bijective=*/false, analyzer, diag_ctx); + /*predicate=*/predicate, /*require_bijective=*/false, analyzer); } if (iter_sum_exprs.empty()) { return NullOpt; diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index b5e8e66468692..7694300ce043d 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -173,9 +173,8 @@ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; - explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, - DiagnosticContext diag_ctx) - : analyzer_(analyzer), diag_ctx_(diag_ctx) { + explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters) + : analyzer_(analyzer) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; @@ -236,8 +235,6 @@ class IterMapRewriter : public ExprMutator { collector.Collect(bindings); for (const IterMark& mark : collector.visited_) { if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) { - diag_ctx_.Emit(Diagnostic::Error(mark->source->span) - << "Fail to normalize iter mark splits: " << mark); return false; } } @@ -245,9 +242,6 @@ class IterMapRewriter : public ExprMutator { // all input marks must be visited for (const IterMark& mark : input_marks_) { if (collector.visited_.count(mark) == 0) { - diag_ctx_.Emit(Diagnostic::Error(mark->source->span) - << "The mapping is not bijective because input iter mark " << mark - << " is not covered, "); return false; } } @@ -297,7 +291,7 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr(const PrimExpr& input_expr) final { auto expr = ExprMutator::VisitExpr(input_expr); if (expr->IsInstance()) { - Fail(Diagnostic::Error(input_expr->span)); + unresolved_count_++; } return expr; } @@ -347,13 +341,6 @@ class IterMapRewriter : public ExprMutator { } }; - void Fail(const Diagnostic& diagnostic) { - unresolved_count_++; - if (diag_ctx_.defined()) { - diag_ctx_.Emit(diagnostic); - } - } - // Internal analyzer Analyzer* analyzer_; // Counter to keep track of unresolved cases. @@ -390,8 +377,6 @@ class IterMapRewriter : public ExprMutator { std::unordered_map flattened_map_; // The flattened forms of constrained iters std::vector constrained_iters_flattened_; - // Diagnostic context - DiagnosticContext diag_ctx_; /*! * \brief Look for a split in splits that is not used such that its lower_factor is smallest. @@ -448,10 +433,6 @@ class IterMapRewriter : public ExprMutator { if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective if (require_bijective) { - diag_ctx_.Emit( - Diagnostic::Error(mark->source->span) - << "Do not allow incomplete split in bijective checking, expected_lower_factor=" - << expected_lower_factor); return Array(); } // look for the next split skipping this lower factor @@ -461,10 +442,6 @@ class IterMapRewriter : public ExprMutator { j = SearchSkipLowerFactor(splits, used, expected_lower_factor); // split not found if (j == splits.size()) { - diag_ctx_.Emit(Diagnostic::Error(mark->source->span) - << "Fail to find split skipping the lower factor in bijective-free " - "checking, expected_lower_factor=" - << expected_lower_factor); return Array(); } } @@ -480,9 +457,6 @@ class IterMapRewriter : public ExprMutator { // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) || (!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) { - diag_ctx_.Emit(Diagnostic::Error(mark->source->span) - << "Mark extent of " << mark - << " is not compatible with expected_lower_factor=" << expected_lower_factor); return Array(); } return Array(iters.rbegin(), iters.rend()); @@ -545,9 +519,7 @@ class IterMapRewriter : public ExprMutator { expr.CopyOnWrite()->base = base + iter_min; return expr; } - Fail(Diagnostic::Error(expr->span) - << "Fail to normalize " << expr << " with predicate bound [" << predicate_induced_min - << ", " << predicate_induced_max << ")"); + unresolved_count_++; return expr; } @@ -563,7 +535,7 @@ class IterMapRewriter : public ExprMutator { if (opt.defined()) { return opt.value(); } else { - Fail(Diagnostic::Error(expr->span) << "Fail to normalize iter sum with offset: " << expr); + unresolved_count_++; return expr; } } @@ -611,8 +583,6 @@ class IterMapRewriter : public ExprMutator { } } if (!base_scale) { - diag_ctx_.Emit(Diagnostic::Error(expr->span) - << "Fuse iters failed, can not find a valid base scale"); return NullOpt; } // check if it can be remapped into a fused pattern. @@ -625,8 +595,6 @@ class IterMapRewriter : public ExprMutator { if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break; } if (j == expr->args.size()) { - diag_ctx_.Emit(Diagnostic::Error(expr->span) - << "Fuse iters failed, can not find expected scale " << expected_scale); return NullOpt; } // look for the longest constrained iter started from expr->args[j] @@ -659,9 +627,6 @@ class IterMapRewriter : public ExprMutator { } } if (k == expr->args.size()) { - diag_ctx_.Emit(Diagnostic::Error(expr->span) - << "Fuse iters failed, can not find flattened iter match constraint " - << constraint_to_match.value()); return NullOpt; } visited[k] = true; @@ -701,8 +666,6 @@ class IterMapRewriter : public ExprMutator { // old iter if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) { // the extra offset is not consistent with old - diag_ctx_.Emit(Diagnostic::Error(expr->span) - << "Fuse iters failed, the extra offset is not consistent with old"); return NullOpt; } return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())}, @@ -929,7 +892,7 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, DiagnosticContext diag_ctx) { + arith::Analyzer* analyzer) { // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. @@ -938,8 +901,6 @@ Array DetectIterMap(const Array& indices, const Map constraints; if (!is_one(predicate) && !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { - diag_ctx.Emit(Diagnostic::Error(predicate->span) - << "Fail to collect constraints from iteration predicate: " << predicate); return Array(); } // We have to make sure when we visit an iterator, all the constraints related with its successors @@ -953,7 +914,7 @@ Array DetectIterMap(const Array& indices, const Map DetectIterMap(const Array& indices, const Map(); } if (!rewriter.CheckConstraints()) { - diag_ctx.Emit(Diagnostic::Error(predicate->span) - << "Illegal iteration constraints: " << predicate); return Array(); } // Step0.1: rewrite indices @@ -970,13 +929,11 @@ Array DetectIterMap(const Array& indices, const Mapspan) << "Affine mapping detection failed"); return Array(); } } // Step1: IterIndependenceChecker checks if the iterator are independent. if (!rewriter.CheckMapping(results, require_bijective)) { - diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Iterators are not independent"); return Array(); } @@ -987,8 +944,7 @@ TVM_REGISTER_GLOBAL("arith.DetectIterMap") .set_body_typed([](const Array& indices, const Map& input_iters, const PrimExpr& input_pred, bool is_bijective) { arith::Analyzer ana; - DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); - return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, diag_ctx); + return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana); }); PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { @@ -1091,7 +1047,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance() && b->IsInstance()) { // cannot multiply two iterators, mark as unresolved. - Fail(Diagnostic::Error(op->span) << "Cannot multiply two iterators: " << GetRef(op)); + unresolved_count_++; return GetRef(op); } @@ -1127,9 +1083,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); } else { // mark as unresolved. - Fail(Diagnostic::Error(orig->span) - << "Can not prove floordiv rhs " << rhs << " divisible by lhs scale " << lhs->scale - << ", lhs=" << lhs); + unresolved_count_++; return orig; } } @@ -1151,8 +1105,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, return std::move(lhs); } else { // mark as unresolved. - Fail(Diagnostic::Error(orig->span) - << "Can not prove floordiv lhs extent " << lhs->extent << " divisible by rhs " << rhs); + unresolved_count_++; return orig; } } @@ -1180,7 +1133,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (b->IsInstance()) { // cannot divide an iterator, mark as unresolved. - Fail(Diagnostic::Error(op->span) << "Cannot divide an iterator: " << GetRef(op)); + unresolved_count_++; return GetRef(op); } @@ -1189,15 +1142,13 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (Optional opt = TryFuseIters(ret)) { IterSumExpr sum = opt.value(); if (!is_zero(sum->base)) { - Fail(Diagnostic::Error(op->span) - << "Fuse IterSumExpr " << ret - << " failed, cannot floordiv an IterSumExpr with nonzero base"); + unresolved_count_++; return GetRef(op); } ICHECK_EQ(sum->args.size(), 1U); return SplitFloorDivConst(sum->args[0], b, GetRef(op)); } else { - Fail(Diagnostic::Error(op->span) << "Fuse IterSumExpr " << ret << " failed"); + unresolved_count_++; return GetRef(op); } } else { @@ -1221,8 +1172,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, rhs = floordiv(rhs, lhs->scale); } else { // mark as unresolved. - Fail(Diagnostic::Error(orig->span) << "Can not prove floormod rhs " << rhs - << " divisible by " << lhs->scale << ", lhs=" << lhs); + unresolved_count_++; return orig; } } @@ -1236,8 +1186,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, return std::move(lhs); } else { // mark as unresolved. - Fail(Diagnostic::Error(orig->span) - << "Can not prove floormod lhs extent " << lhs->extent << " divisible by rhs " << rhs); + unresolved_count_++; return orig; } } @@ -1265,7 +1214,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (b->IsInstance()) { // cannot mod an iterator, mark as unresolved. - Fail(Diagnostic::Error(op->span) << "Cannot mod an iterator: " << GetRef(op)); + unresolved_count_++; return GetRef(op); } @@ -1274,14 +1223,12 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (Optional opt = TryFuseIters(ret)) { IterSumExpr sum = opt.value(); if (!is_zero(sum->base)) { - Fail(Diagnostic::Error(op->span) - << "Fuse IterSumExpr " << ret - << " failed, cannot floormod an IterSumExpr with nonzero base"); + unresolved_count_++; return GetRef(op); } return SplitFloorModConst(sum->args[0], b, GetRef(op)); } else { - Fail(Diagnostic::Error(op->span) << "Fail to fuse iters of " << ret); + unresolved_count_++; return GetRef(op); } } else { @@ -1356,9 +1303,8 @@ Array IterMapSimplify(const Array& indices, const Map rewrite = - DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer, diag_ctx); + DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); if (rewrite.empty()) { return indices; } @@ -1385,9 +1331,8 @@ Array IterMapSimplify(const Array& indices, const Map& sub_iters, - DiagnosticContext diag_ctx) - : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters), diag_ctx_(diag_ctx) {} + const std::unordered_set& sub_iters) + : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {} size_t unresolved_count() const { return unresolved_count_; } @@ -1449,8 +1394,8 @@ class SubspaceDivider { } else if (expr->args.size() == 1) { // arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base) if (!is_one(expr->args[0]->scale)) { - return Fail(Diagnostic::Error(expr->span) - << "Expect split scale be 1, got " << expr->args[0]->scale); + unresolved_count_++; + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); } DivisionResult res = DivideIterSplitExpr(expr->args[0]); if (!is_zero(expr->base)) res = AddBase(res, expr->base); @@ -1469,9 +1414,10 @@ class SubspaceDivider { DivisionResult arg_division = DivideIterSplitExpr(arg); IterSplitExpr new_arg; if (arg_division.IsInner()) { - if (!inner) - return Fail(Diagnostic::Error(expr->span) - << "Current division is inner but outer division exists for previous args"); + if (!inner) { + unresolved_count_++; + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + } new_arg = arg_division.GetInnerAsSplit(); inner_args.push_back(new_arg); inner = true; @@ -1480,13 +1426,15 @@ class SubspaceDivider { outer_args.push_back(new_arg); inner = false; } else { - return Fail(Diagnostic::Error(expr->span) - << "Division of " << arg << " is neither inner nor outer"); + unresolved_count_++; + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); } extent *= new_arg->extent; } - if (!scale_is_one) - return Fail(Diagnostic::Error(expr->span) << "Expect all iter sum arg's scale be 1"); + if (!scale_is_one) { + unresolved_count_++; + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + } bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent); const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0); const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base); @@ -1505,8 +1453,8 @@ class SubspaceDivider { inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent); return DivisionResult::Inner(inner_source, mark_extent); } else { - return Fail(Diagnostic::Error(expr->span) - << "Either inner or outer args should exists if need predicate: " << expr); + unresolved_count_++; + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); } } return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent); @@ -1516,14 +1464,6 @@ class SubspaceDivider { PrimExpr GetInnerPreds() const { return inner_preds_; } private: - DivisionResult Fail(const Diagnostic& diagnostic) { - unresolved_count_++; - if (diag_ctx_.defined()) { - diag_ctx_.Emit(diagnostic); - } - return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); - } - DivisionResult AddBase(DivisionResult division, PrimExpr base) { DivisionResult res = division; if (const auto* op = division.inner.as()) { @@ -1599,10 +1539,10 @@ class SubspaceDivider { if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break; } - if (j == splits.size()) - return Fail(Diagnostic::Error(expr->span) - << "Can not find expected lower factor " << expected_lower_factor - << " in splits of " << expr->source); + if (j == splits.size()) { + unresolved_count_++; + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + } used[j] = true; if (!encountered_boundary) { inner_iters.push_back(splits[j]); @@ -1613,9 +1553,10 @@ class SubspaceDivider { if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent)) encountered_boundary = true; } - if (!encountered_boundary) - return Fail(Diagnostic::Error(expr->span) - << "Can not find inner/outer boundary of " << expr); + if (!encountered_boundary) { + unresolved_count_++; + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + } for (const IterSplitExpr& inner_iter : inner_iters) { IterSplitExpr new_iter = inner_iter; new_iter.CopyOnWrite()->source = inner_mark; @@ -1629,8 +1570,8 @@ class SubspaceDivider { split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent)); } } else { - return Fail(Diagnostic::Error(expr->span) - << "Source expr to divide is neither var nor IterSumExpr"); + unresolved_count_++; + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); } return split_map_.at(expr); } @@ -1646,18 +1587,15 @@ class SubspaceDivider { std::unordered_map split_map_; // predicate of outer space and inner space; PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)}; - // diagnostic context - DiagnosticContext diag_ctx_; }; Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer, - DiagnosticContext diag_ctx) { + bool require_bijective, arith::Analyzer* analyzer) { if (!IterRangeSanityCheck(input_iters)) return Array>(); const Array& maps = - DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer, diag_ctx); + DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer); if (maps.empty()) return {}; std::unordered_set inner_iter_set; @@ -1667,7 +1605,7 @@ Array> SubspaceDivide(const Array& bindings, IterMarkSplitCollector collector; collector.Collect(maps); - SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set, diag_ctx); + SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set); std::vector> results; for (const IterSumExpr& expr : maps) { @@ -1687,9 +1625,7 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide") const Array& sub_iters, const PrimExpr& predicate, bool require_bijective) { arith::Analyzer ana; - DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); - return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana, - diag_ctx); + return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); }); class InverseAffineIterMapTransformer { diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 1e54f54a066f1..3f8f84f649d4d 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -76,9 +76,7 @@ IndexMap IndexMap::Inverse(Array initial_ranges) const { // Unpack the output indices into linear combinations of the initial // indices. arith::Analyzer analyzer; - auto diagnostics = DiagnosticContext::Default(IRModule()); - auto iter_map = - DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer, diagnostics); + auto iter_map = DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer); CHECK(iter_map.size()) << "Index transformation was not bijective."; // Determine expressions for the input variables, in terms of the diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 868cabeed08c0..f3aa250ec86be 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -526,14 +526,12 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va if (loop_var_ranges.empty()) { return true; } - DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); Array results = arith::DetectIterMap( /*indices=*/realize->iter_values, /*input_iters=*/loop_var_ranges, /*predicate=*/realize->predicate, /*require_bijective=*/false, - /*analyzer=*/analyzer, - /*diag_ctx*/ diag_ctx); + /*analyzer=*/analyzer); if (results.empty()) { return false; } diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index 144b3a55a4675..993557f8be2f8 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -77,9 +77,8 @@ class SplitExprCollector { const PrimExpr& predicate, // bool require_bijective, // arith::Analyzer* analyzer) { - DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); Array iter_sum_exprs = arith::DetectIterMap( - {analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer, diag_ctx); + {analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer); if (iter_sum_exprs.empty()) { return {}; } diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 6daea391b918e..331d098347b02 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -257,12 +257,11 @@ Array> CheckSubspaceDivisible(const IRModule& mod, const LoopSubspaceCollector& collector, arith::Analyzer* analyzer) { const Block& block = block_realize->block; - DiagnosticContext diag_ctx(DiagnosticContext::Default(mod)); Array> division = arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain, collector.inner_loop_vars, block_realize->predicate, - /*require_bijective=*/false, analyzer, diag_ctx); + /*require_bijective=*/false, analyzer); if (division.empty()) { // If we can't do perfect subspace division, check if it is a trivial case of subspace division.