From ac57b013b259ea97948245933432b0a5eed3d707 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 17 Feb 2023 22:24:51 -0800 Subject: [PATCH] [TIR] Update block flags and simplify predicate in Reverse-Compute-Inline (#14030) * Add simplification after substitution to make the predicate simpler to arithmetic analysis. * Update block flags. Since reverse-compute-inline may introduce predicates and the result block may not have affine binding. --- src/tir/schedule/primitive/compute_inline.cc | 10 ++- .../test_tir_schedule_compute_inline.py | 74 +++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index d21149437f08..99286c91b344 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -608,7 +608,7 @@ class ReverseComputeInliner : public BaseInliner { /*indices=*/buffer_load_indices_, /*input_iters=*/consumer_iter_doms, /*predicate=*/true, - /*check_level=*/arith::IterMapLevel::Bijective, + /*check_level=*/arith::IterMapLevel::NoCheck, /*analyzer=*/&analyzer_, /*simplify_trivial_iterators=*/false); buffer_load_iter_map_ = res->indices; @@ -651,6 +651,7 @@ class ReverseComputeInliner : public BaseInliner { // Substitute the producer block iters with the its bindings since the predicate in BlockRealize // should not contain the block iters predicate = Substitute(predicate, subst_map); + predicate = analyzer_.Simplify(predicate); return predicate; } @@ -865,6 +866,13 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block return; } self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); + // Step 8. Update the cached flags + arith::Analyzer analyzer; + BlockInfo& block_info = self->block_info[producer_block_sref]; + block_info.affine_binding = IsAffineBinding( + /*realize=*/GetBlockRealize(self, producer_block_sref), + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(producer_block_sref->parent)), + /*analyzer=*/&analyzer); } bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) { diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index ee5e85e4f05b..a4c7344909c5 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -611,6 +611,60 @@ def elementwise_overcomputed_producer_reverse_inlined( C[vi, vj] = A[vi, vj] * 2.0 + 1.0 +@T.prim_func +def elementwise_overcomputed_producer_simplify_predicate( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") +) -> None: + B = T.alloc_buffer((128, 128)) + for i in T.grid(16384): + with T.block("B"): + vi = T.axis.spatial(128, i // 128) + vj = T.axis.spatial(128, i % 128) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(127, 127): + with T.block("C"): + cvi, cvj = T.axis.remap("SS", [i, j]) + C[cvi, cvj] = B[cvi, cvj] + 1.0 + + +@T.prim_func +def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") +) -> None: + for i in T.grid(16384): + with T.block("B"): + vi = T.axis.spatial(128, i // 128) + vj = T.axis.spatial(128, i % 128) + T.where(i < 16255 and i % 128 < 127) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +@T.prim_func +def elementwise_overcomputed_producer_injective_load( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") +) -> None: + B = T.alloc_buffer((8, 8, 16, 16)) + for i0, j0, i1, j1 in T.grid(8, 8, 16, 16): + with T.block("B"): + vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1]) + B[vi, vj, vm, vn] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + for i, j in T.grid(127, 127): + with T.block("C"): + cvi, cvj = T.axis.remap("SS", [i, j]) + C[cvi, cvj] = B[cvi // 16, cvj // 16, cvi % 16, cvj % 16] + 1.0 + + +@T.prim_func +def elementwise_overcomputed_producer_injective_load_reverse_inlined( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") +) -> None: + for i0, j0, i1, j1 in T.grid(8, 8, 16, 16): + with T.block("B"): + vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1]) + T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127) + C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0 + + @T.prim_func def elementwise_producer_not_cover_consumer( A: T.Buffer((128, 128), "float32"), D: T.Buffer((256, 128), "float32") @@ -1025,6 +1079,26 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name): ) +def test_reverse_compute_inline_overcomputed_producer_simplify_predicate(use_block_name): + """Test reverse compute inline overcomputed producer where the predicate should be simplified""" + sch = tir.Schedule(elementwise_overcomputed_producer_simplify_predicate, debug_mask="all") + compute = "C" if use_block_name else sch.get_block("C") + sch.reverse_compute_inline(compute) + tvm.ir.assert_structural_equal( + elementwise_overcomputed_producer_simplify_predicate_reverse_inlined, sch.mod["main"] + ) + + +def test_reverse_compute_inline_overcomputed_producer_injective_load(use_block_name): + """Test reverse compute inline overcomputed producer with injective buffer load""" + sch = tir.Schedule(elementwise_overcomputed_producer_injective_load, debug_mask="all") + compute = "C" if use_block_name else sch.get_block("C") + sch.reverse_compute_inline(compute) + tvm.ir.assert_structural_equal( + elementwise_overcomputed_producer_injective_load_reverse_inlined, sch.mod["main"] + ) + + def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name): """Test reverse compute inline failure when the inlined block iter domains are not covered by its producer