Skip to content

Commit

Permalink
[TIR] Update block flags and simplify predicate in Reverse-Compute-In…
Browse files Browse the repository at this point in the history
…line (#14030)

* Add simplification after substitution to make the predicate simpler to arithmetic analysis.
* Update block flags. Since reverse-compute-inline may introduce predicates and the result block may not have affine binding.
  • Loading branch information
vinx13 authored Feb 18, 2023
1 parent 6f232f9 commit ac57b01
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ class ReverseComputeInliner : public BaseInliner {
/*indices=*/buffer_load_indices_,
/*input_iters=*/consumer_iter_doms,
/*predicate=*/true,
/*check_level=*/arith::IterMapLevel::Bijective,
/*check_level=*/arith::IterMapLevel::NoCheck,
/*analyzer=*/&analyzer_,
/*simplify_trivial_iterators=*/false);
buffer_load_iter_map_ = res->indices;
Expand Down Expand Up @@ -651,6 +651,7 @@ class ReverseComputeInliner : public BaseInliner {
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
// should not contain the block iters
predicate = Substitute(predicate, subst_map);
predicate = analyzer_.Simplify(predicate);
return predicate;
}

Expand Down Expand Up @@ -865,6 +866,13 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
return;
}
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
// Step 8. Update the cached flags
arith::Analyzer analyzer;
BlockInfo& block_info = self->block_info[producer_block_sref];
block_info.affine_binding = IsAffineBinding(
/*realize=*/GetBlockRealize(self, producer_block_sref),
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(producer_block_sref->parent)),
/*analyzer=*/&analyzer);
}

bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) {
Expand Down
74 changes: 74 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,60 @@ def elementwise_overcomputed_producer_reverse_inlined(
C[vi, vj] = A[vi, vj] * 2.0 + 1.0


@T.prim_func
def elementwise_overcomputed_producer_simplify_predicate(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
) -> None:
B = T.alloc_buffer((128, 128))
for i in T.grid(16384):
with T.block("B"):
vi = T.axis.spatial(128, i // 128)
vj = T.axis.spatial(128, i % 128)
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(127, 127):
with T.block("C"):
cvi, cvj = T.axis.remap("SS", [i, j])
C[cvi, cvj] = B[cvi, cvj] + 1.0


@T.prim_func
def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
) -> None:
for i in T.grid(16384):
with T.block("B"):
vi = T.axis.spatial(128, i // 128)
vj = T.axis.spatial(128, i % 128)
T.where(i < 16255 and i % 128 < 127)
C[vi, vj] = A[vi, vj] * 2.0 + 1.0


@T.prim_func
def elementwise_overcomputed_producer_injective_load(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
) -> None:
B = T.alloc_buffer((8, 8, 16, 16))
for i0, j0, i1, j1 in T.grid(8, 8, 16, 16):
with T.block("B"):
vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1])
B[vi, vj, vm, vn] = A[vi * 16 + vm, vj * 16 + vn] * 2.0
for i, j in T.grid(127, 127):
with T.block("C"):
cvi, cvj = T.axis.remap("SS", [i, j])
C[cvi, cvj] = B[cvi // 16, cvj // 16, cvi % 16, cvj % 16] + 1.0


@T.prim_func
def elementwise_overcomputed_producer_injective_load_reverse_inlined(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
) -> None:
for i0, j0, i1, j1 in T.grid(8, 8, 16, 16):
with T.block("B"):
vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1])
T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127)
C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0


@T.prim_func
def elementwise_producer_not_cover_consumer(
A: T.Buffer((128, 128), "float32"), D: T.Buffer((256, 128), "float32")
Expand Down Expand Up @@ -1025,6 +1079,26 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name):
)


def test_reverse_compute_inline_overcomputed_producer_simplify_predicate(use_block_name):
"""Test reverse compute inline overcomputed producer where the predicate should be simplified"""
sch = tir.Schedule(elementwise_overcomputed_producer_simplify_predicate, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
sch.reverse_compute_inline(compute)
tvm.ir.assert_structural_equal(
elementwise_overcomputed_producer_simplify_predicate_reverse_inlined, sch.mod["main"]
)


def test_reverse_compute_inline_overcomputed_producer_injective_load(use_block_name):
"""Test reverse compute inline overcomputed producer with injective buffer load"""
sch = tir.Schedule(elementwise_overcomputed_producer_injective_load, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
sch.reverse_compute_inline(compute)
tvm.ir.assert_structural_equal(
elementwise_overcomputed_producer_injective_load_reverse_inlined, sch.mod["main"]
)


def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name):
"""Test reverse compute inline failure when the inlined block iter domains are not covered by
its producer
Expand Down

0 comments on commit ac57b01

Please sign in to comment.