diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 6fff2a23ccfe..c443c878e71b 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -201,8 +201,9 @@ class IterMapRewriter : public ExprMutator { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } - IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min, - const PrimExpr& predicate_induced_max) { + IterSumExpr RewriteIterConstraint(const PrimExpr& expr, + const Optional& predicate_induced_min, + const Optional& predicate_induced_max) { return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min, predicate_induced_max); } @@ -494,16 +495,17 @@ class IterMapRewriter : public ExprMutator { * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined. * \return The Normalized expression. */ - IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min, - PrimExpr predicate_induced_max) { + IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional predicate_induced_min, + Optional predicate_induced_max) { // normalize to zero base PrimExpr base = expr->base; if (!is_zero(base)) { expr.CopyOnWrite()->base = 0; - if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base; - if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base; + if (predicate_induced_min.defined()) + predicate_induced_min = predicate_induced_min.value() - base; + if (predicate_induced_max.defined()) + predicate_induced_max = predicate_induced_max.value() - base; } - if (expr->args.size() < 1) return expr; Optional opt = TryFuseIters(expr); ICHECK(!opt.defined() || opt.value()->args.size() == 1); // scale should be 1 @@ -522,10 +524,10 @@ class IterMapRewriter : public ExprMutator { PrimExpr iter_min = mark_offset; PrimExpr iter_max = iter_min + mark->extent; if (predicate_induced_min.defined()) { - iter_min = max(predicate_induced_min, iter_min); + iter_min = max(predicate_induced_min.value(), iter_min); } if (predicate_induced_max.defined()) { - iter_max = min(predicate_induced_max, iter_max); + iter_max = min(predicate_induced_max.value(), iter_max); } if (!is_zero(iter_min)) { // structured form's offset should be updated @@ -536,7 +538,6 @@ class IterMapRewriter : public ExprMutator { } mark.CopyOnWrite()->extent = iter_max - iter_min; sum_fuse_map_[flattened_form] = {mark, iter_min}; - // we need to note down the flattened form of constrained iterators // to check the validity of constraints, see also CheckConstraints() constrained_iters_flattened_.push_back(flattened_form); @@ -771,14 +772,15 @@ class IterMapRewriter : public ExprMutator { struct IterConstraint { // The expr of the iter PrimExpr iter; - // The expr of the lower_bound - PrimExpr lower_bound; - // The expr of the upper_bound - PrimExpr upper_bound; + // The expr of the lower_bound, maybe undefined + Optional lower_bound; + // The expr of the upper_bound, maybe undefined + Optional upper_bound; // The size of the iter, which is the number of nodes size_t expr_size = 0; - IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size) + IterConstraint(PrimExpr iter, Optional lower_bound, Optional upper_bound, + size_t size) : iter(std::move(iter)), lower_bound(std::move(lower_bound)), upper_bound(std::move(upper_bound)), @@ -788,11 +790,12 @@ struct IterConstraint { /*! * \brief Split the predicate into `(a < b) && (c < d) && ...` * \param pred The predicate to be split. + * \param input_iters The input iterators. + * \param result The result of predicate split. * \return A list of IterConstraint, empty if the split failed. */ -std::vector MatchBoundConstraints(PrimExpr pred, - const Map& input_iters) { - std::vector result; +bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, + std::vector* result) { arith::PVar lhs, rhs, rest; for (;;) { // try extract comparisions @@ -821,78 +824,94 @@ std::vector MatchBoundConstraints(PrimExpr pred, is_equal = true; is_finish = true; } else { - return std::vector(); + return false; } PrimExpr lhs_expr = lhs.Eval(); PrimExpr rhs_expr = rhs.Eval(); // we only accept predicate of integers if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) && (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) { - return std::vector(); + return false; } // determine iter and bound, if we can not distinguish them simply, // try divide (lhs - rhs) into itervar aware and itervar free parts auto f_use_itervar = [&input_iters](const VarNode* v) { - return input_iters.count(GetRef(v)); + return input_iters->count(GetRef(v)); }; bool bound_at_left; - if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) { - bound_at_left = true; - } else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) { - bound_at_left = false; - } else { - bound_at_left = false; // accumulate bound to rhs - PrimExpr sum_parts = lhs_expr - rhs_expr; - lhs_expr = 0; - rhs_expr = 0; - std::function f_extract = - [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) { - if (const AddNode* add = part.as()) { - f_extract(add->a, sign); - f_extract(add->b, sign); - } else if (const SubNode* sub = part.as()) { - f_extract(sub->a, sign); - f_extract(sub->b, !sign); - } else if (UsesVar(part, f_use_itervar)) { - lhs_expr = sign ? lhs_expr + part : lhs_expr - part; - } else { - rhs_expr = sign ? rhs_expr - part : rhs_expr + part; - } - }; - f_extract(sum_parts, true); - arith::Analyzer analyzer; - lhs_expr = analyzer.Simplify(lhs_expr); - rhs_expr = analyzer.Simplify(rhs_expr); - } - PrimExpr lower_bound, upper_bound, iter; - if (is_greater) { - if (bound_at_left) { - // bound > iter - upper_bound = is_equal ? lhs_expr + 1 : lhs_expr; - iter = rhs_expr; + if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) { + // At least it uses one input iter + if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) { + bound_at_left = true; + } else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) { + bound_at_left = false; } else { - // iter > bound - lower_bound = is_equal ? rhs_expr : rhs_expr + 1; - iter = lhs_expr; + bound_at_left = false; // accumulate bound to rhs + PrimExpr sum_parts = lhs_expr - rhs_expr; + lhs_expr = 0; + rhs_expr = 0; + std::function f_extract = + [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) { + if (const AddNode* add = part.as()) { + f_extract(add->a, sign); + f_extract(add->b, sign); + } else if (const SubNode* sub = part.as()) { + f_extract(sub->a, sign); + f_extract(sub->b, !sign); + } else if (UsesVar(part, f_use_itervar)) { + lhs_expr = sign ? lhs_expr + part : lhs_expr - part; + } else { + rhs_expr = sign ? rhs_expr - part : rhs_expr + part; + } + }; + f_extract(sum_parts, true); + arith::Analyzer analyzer; + lhs_expr = analyzer.Simplify(lhs_expr); + rhs_expr = analyzer.Simplify(rhs_expr); } - } else { - if (bound_at_left) { - // bound < iter - lower_bound = is_equal ? lhs_expr : lhs_expr + 1; - iter = rhs_expr; + Optional lower_bound = NullOpt, upper_bound = NullOpt; + PrimExpr iter; + if (is_greater) { + if (bound_at_left) { + // bound > iter / bound >= iter + upper_bound = is_equal ? lhs_expr + 1 : lhs_expr; + iter = rhs_expr; + } else { + // iter > bound / iter >= bound + lower_bound = is_equal ? rhs_expr : rhs_expr + 1; + iter = lhs_expr; + } } else { - // iter < bound - upper_bound = is_equal ? rhs_expr + 1 : rhs_expr; - iter = lhs_expr; + if (bound_at_left) { + // bound < iter / bound <= iter + lower_bound = is_equal ? lhs_expr : lhs_expr + 1; + iter = rhs_expr; + } else { + // iter < bound / iter <= bound + upper_bound = is_equal ? rhs_expr + 1 : rhs_expr; + iter = lhs_expr; + } + } + // If it is a predicate for a single input iter + if (const auto* var_ptr = iter.as()) { + auto it = input_iters->find(GetRef(var_ptr)); + if (it != input_iters->end()) { + PrimExpr iter_min = (*it).second->min; + PrimExpr iter_max = (*it).second->min + (*it).second->extent; + if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value()); + if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value()); + input_iters->Set(GetRef(var_ptr), Range(iter_min, iter_max)); + } + } else { + result->emplace_back(iter, lower_bound, upper_bound, 0); } } - result.emplace_back(iter, lower_bound, upper_bound, 0); if (is_finish) { break; } pred = rest.Eval(); } - return result; + return true; } bool IterRangeSanityCheck(const Map& iter_ranges) { @@ -912,13 +931,14 @@ Array DetectIterMap(const Array& indices, const Map(); - std::vector constraints = MatchBoundConstraints(predicate, input_iters); - if (!is_one(predicate) && constraints.empty()) { + Map constrained_input_iters = input_iters; + std::vector constraints; + if (!is_one(predicate) && + !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Fail to collect constraints from iteration predicate: " << predicate); return Array(); } - // We have to make sure when we visit an iterator, all the constraints related with its successors // in the iter var graph has been visited, where the expression of this iterator will contain the // expression of its successor, so we sort them by their sizes. @@ -930,10 +950,11 @@ Array DetectIterMap(const Array& indices, const Map(); } if (!rewriter.CheckConstraints()) { @@ -945,7 +966,10 @@ Array DetectIterMap(const Array& indices, const Map results; for (PrimExpr value : indices) { results.push_back(rewriter.Rewrite(value)); - if (rewriter.unresolved_count() != 0) return Array(); + if (rewriter.unresolved_count() != 0) { + diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Affine mapping detection failed"); + return Array(); + } } // Step1: IterIndependenceChecker checks if the iterator are independent. if (!rewriter.CheckMapping(results, require_bijective)) { @@ -1306,7 +1330,8 @@ class IterMapToExprNormalizer : public ExprMutator { } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) { return floordiv(source, expr->lower_factor) * expr->scale; } else { - return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale; + return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) * + expr->scale; } } diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index ac176b2623a3..99f90b9be90e 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -196,6 +196,18 @@ class ModularSetAnalyzer::Impl : public ExprFunctorb); + if (b.is_const()) { + int64_t c2 = b.base; + ICHECK(c2 != 0) << "MathError: the divisor is 0"; + Entry a = VisitExpr(op->a); + int64_t coeff = ZeroAwareGCD(a.coeff, c2); + return Entry(coeff, a.base % c2); + } + return Everything(); + } + Entry VisitExpr_(const MinNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 9f4cdc49fc96..e11bd024bb22 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -192,6 +192,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x); // floor div TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x); + TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), + c2.Eval()->value > 0); // canonicalization rule // will try rewrite again after canonicalization. @@ -785,6 +787,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -794,6 +801,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 7b9ac488b8b9..fa2a4469b8c9 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -413,7 +413,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, for (int i = 0; i < n; i++) { const PrimExpr& factor = factors[i]; Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); - substitute_value = substitute_value * factor + var; + if (!is_one(factor)) substitute_value = substitute_value * factor + var; analyzer.Bind(var, Range::FromMinExtent(0, factor)); new_loop_vars.emplace_back(std::move(var)); } @@ -505,11 +505,14 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); Array substitute_value; substitute_value.resize(loops.size()); - PrimExpr tot = fused_var; - for (int i = static_cast(loops.size()) - 1; i >= 0; i--) { - substitute_value.Set(i, floormod(tot, loops[i]->extent)); - tot = floordiv(tot, loops[i]->extent); - } + PrimExpr lower = 1; + for (int i = static_cast(loops.size()) - 1; i > 0; i--) { + substitute_value.Set(i, is_one(loops[i]->extent) + ? 0 + : floordiv(floormod(fused_var, lower * loops[i]->extent), lower)); + lower = lower * loops[i]->extent; + } + substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower)); Stmt new_stmt = loops.back()->body; Map opaque_block_reuse; auto f_substitute = [&](const Var& v) -> Optional { @@ -534,6 +537,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); return self->stmt2ref.at(new_stmt.get()); } + /*! * \brief Collect an array of loop srefs into a set * \param self The schedule state diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index e6738543a6aa..e741ee88a63e 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -105,7 +105,10 @@ def test_mod(): ck.verify( flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, - (x * 4 - 8 * fld(x * 4, 8), x * 4 - 8 * fld(x * 4, 8) + 3), + ( + z * 8 + x * 4 - 8 * fld(z * 8 + x * 4, 8), + z * 8 + x * 4 + 3 - 8 * fld(z * 8 + x * 4, 8), + ), ) ck1 = IntSetChecker() ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2)) diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 2de30eff3f5c..cb8bbd1063c9 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -633,7 +633,7 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[0][0], j0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) tvm.ir.assert_structural_equal(res[1][0], 0) - tvm.ir.assert_structural_equal(res[1][1], floormod(floordiv(l0[0] * 6 + l1[0], 3), 2)) + tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3)) tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py index 4a4cd6a31ef1..0acd2f4f5f77 100644 --- a/tests/python/unittest/test_arith_modular_set.py +++ b/tests/python/unittest/test_arith_modular_set.py @@ -50,6 +50,14 @@ def test_mul(): assert m.base == 2 +def test_floormod(): + analyzer = tvm.arith.Analyzer() + x, y = te.var("x"), te.var("y") + m = analyzer.modular_set(tvm.tir.floormod(x * 128 + y * 4, 256)) + assert m.coeff == 4 + assert m.base == 0 + + def test_div_shift(): analyzer = tvm.arith.Analyzer() x, y = te.var("x"), te.var("y") @@ -175,6 +183,7 @@ def test_let(): test_add_sub() test_mul() test_div_shift() + test_floormod() test_min_max_select() test_mix_index() test_constraint_scope() diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 662038e129f7..c820d003668c 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -80,6 +80,10 @@ def test_vector_simplify(): ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5)) + ck.verify( + fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)), + tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4) + ) ck.verify( fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), @@ -277,6 +281,7 @@ def test_add_index_simplify(): flm = tvm.te.floormod ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)) ck.verify(fld(x, 8) * 8 + flm(x, 8), x) + ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)) def test_sub_index_simplify(): diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py index 4ab2741da181..9b39ad1bff3e 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -99,7 +99,7 @@ def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> with T.block("C"): b = T.axis.S(1, 0) i, j = T.axis.remap("RR", [i1, i2]) - T.where(i0_fused_0 * 32 + i0_fused_1 < 1) + T.where(i0_fused_1 < 1) with T.init(): C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] @@ -107,7 +107,7 @@ def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): with T.block("D"): b = T.axis.S(1, 0) - T.where(i0_fused_0 * 32 + i0_fused_1 < 1) + T.where(i0_fused_1 < 1) D[b] = T.sqrt(C[b], dtype="float32") diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 84ececebbcba..fd2115bddbed 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -66,7 +66,7 @@ def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None: for i_j_k_fused in T.serial(0, (n * 16384)): with T.block("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, n * 128)) - vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, n), 128)) + vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, n*128), n)) vk = T.axis.S(n, T.floormod(i_j_k_fused, n)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -164,7 +164,7 @@ def elementwise_fused(a: T.handle, b: T.handle) -> None: for fused in T.serial(0, 2097152): with T.block("B"): vi = T.axis.S(128, T.floordiv(fused, 16384)) - vj = T.axis.S(128, T.floormod(T.floordiv(fused, 128), 128)) + vj = T.axis.S(128, T.floordiv(T.floormod(fused, 16384), 128)) vk = T.axis.S(128, T.floormod(fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -205,7 +205,7 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): with T.block("B"): - T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128) + T.where((i0 * 2 + i1) * 3 + i2 < 128 and j1 < 128 and k0 * 43 + k1 < 128) vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) vj = T.axis.S(128, j1) vk = T.axis.S(128, k0 * 43 + k1) @@ -223,8 +223,8 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: T.reads( [ A[ - T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128), - T.floormod(T.floordiv(i_j_k_fused, 128), 128), + T.floordiv(i_j_k_fused, 16384), + T.floordiv(T.floormod(i_j_k_fused, 16384), 128), T.floormod(i_j_k_fused, 128), ] ] @@ -232,15 +232,15 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: T.writes( [ B[ - T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128), - T.floormod(T.floordiv(i_j_k_fused, 128), 128), + T.floordiv(i_j_k_fused, 16384), + T.floordiv(T.floormod(i_j_k_fused, 16384), 128), T.floormod(i_j_k_fused, 128), ] ] ) with T.block("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, 16384)) - vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) + vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, 16384), 128)) vk = T.axis.S(128, T.floormod(i_j_k_fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -343,7 +343,7 @@ def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: with T.block("B"): vi = T.axis.S( 127, - i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1), + i * 32 + T.floordiv(j_k_fused, 128), ) vj = T.axis.S(128, T.floormod(j_k_fused, 128)) T.reads([A[vi, vj]])