Skip to content

Commit

Permalink
[Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule …
Browse files Browse the repository at this point in the history
…primitive (#16660)

* [Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule primitive

When inserting a `cache_read` / `cache_write` stage, the `tir.AllocateConst` statement would be duplicated if its body was not a `tir.SeqStmt` node (e.g. `tir.For`), leading to compilation failures. This happened because `tir.AllocateConst` and `tir.DeclBuffer` statements are always re-attached to the statement's body after the `cache_read` / `cache_write` stage is inserted in it, but the stage was being appended to the whole statement (which already contains the `tir.AllocateConst`) and not just its body, causing duplications.

This commit also adds a test where the first `cache_read` stage is inserted into a statement whose body is a `tir.For`, while the second stage is added to a body that is `tir.SeqStmt` to check for regressions.

* Improve PrimFunc readability

* Remove redundant `T.reads()`
  • Loading branch information
Anndrey24 authored Mar 7, 2024
1 parent e005f85 commit 657880c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,9 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) {
seq.insert(seq.begin() + pos, stage);
body = SeqStmt(seq);
} else if (pos == 0) {
body = SeqStmt({stage, stmt});
body = SeqStmt({stage, body});
} else if (pos == 1) {
body = SeqStmt({stmt, stage});
body = SeqStmt({body, stage});
} else {
LOG(FATAL) << "Cannot insert at position " << pos
<< ". When inserting adjacent to non-SeqStmt, "
Expand Down
40 changes: 40 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,46 @@ def test_cache_read_fail_invalid_storage_scope(use_block_name):
sch.cache_read(block_b, 0, "test_scope")


def test_cache_read_allocate_const():
@T.prim_func
def before(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
B_buf = T.decl_buffer((8), dtype="float32", data=B)
for i in range(8):
with T.block("C"):
vi = T.axis.spatial(8, i)
C[vi] = A[vi] + B_buf[vi]

@T.prim_func
def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
B_buf_global = T.alloc_buffer((8), dtype="float32")
A_global = T.alloc_buffer((8), dtype="float32")
B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
B_buf = T.decl_buffer((8), data=B)
for ax0 in range(8):
with T.block("A_global"):
v0 = T.axis.spatial(8, ax0)
A_global[v0] = A[v0]
for ax0 in range(8):
with T.block("B_buf_global"):
v0 = T.axis.spatial(8, ax0)
B_buf_global[v0] = B_buf[v0]
for i in range(8):
with T.block("C"):
vi = T.axis.spatial(8, i)
C[vi] = A_global[vi] + B_buf_global[vi]

sch = tir.Schedule(before)
block_c = sch.get_block("C")
sch.cache_read(block_c, 1, "global")
sch.cache_read(block_c, 0, "global")

after = sch.mod["main"]

assert_structural_equal_ignore_global_symbol(expected, after)
verify_trace_roundtrip(sch=sch, mod=before)


def test_inplace_cache_read():
sch = tvm.tir.Schedule(inplace_func, debug_mask="all")
block = sch.get_block("copy_in")
Expand Down

0 comments on commit 657880c

Please sign in to comment.