Skip to content

Commit

Permalink
[TIR] Output DeclBuffer in FlattenBuffer
Browse files Browse the repository at this point in the history
If a flattened buffer is produced for use in `BufferLoad` and
`BufferStore` statements, generate a `DeclBuffer`.

This is a subset of the changes made in
apache#14778, broken out for ease of
testing and review.
Lunderberg committed Apr 4, 2024
1 parent cd08356 commit 629bb32
Showing 2 changed files with 62 additions and 70 deletions.
51 changes: 40 additions & 11 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
@@ -41,13 +41,29 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
static PrimFunc Flatten(PrimFunc func) {
arith::Analyzer ana;
auto pass = BufferFlattener(&ana);
auto writer = func.CopyOnWrite();
pass.MarkBufferMapShapes(func);
writer->body = pass.VisitStmt(func->body);
auto body = pass.VisitStmt(func->body);

// The buffers in func->buffer_map are deliberately left
// unflattened, as they are used for validation of user-provided
// arguments. The flattened buffers used in the updated
// function body alias the argument buffers.
for (size_t i = func->params.size(); i > 0; i--) {
auto handle = func->params[i - 1];
if (auto opt = func->buffer_map.Get(handle)) {
auto old_buf = opt.value();
if (pass.buffers_used_.count(old_buf)) {
auto new_buf = pass.GetFlattenedBuffer(old_buf);
if (!old_buf.same_as(new_buf)) {
body = DeclBuffer(new_buf, std::move(body));
}
}
}
}

if (!body.same_as(func->body)) {
func.CopyOnWrite()->body = std::move(body);
}
return func;
}

@@ -153,11 +169,14 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
}

Stmt VisitStmt_(const DeclBufferNode* op) final {
// TODO(rfc-70): Update the DeclBuffer node instead of
// stripping it out. Stripping it out in the current
// implementation as not all lowering passes support
// DeclBuffer.
return VisitStmt(op->body);
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));

auto new_buf = GetFlattenedBuffer(node->buffer);
if (!node->buffer.same_as(new_buf)) {
node.CopyOnWrite()->buffer = new_buf;
}

return std::move(node);
}

Buffer GetFlattenedBuffer(Buffer buf) {
@@ -166,16 +185,23 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
return it->second;
}
auto flattened = buf.GetFlattenedBuffer();
auto writer = flattened.CopyOnWrite();

// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (flattened->dtype == DataType::Bool()) {
writer->dtype = DataType::Int(8);
flattened.CopyOnWrite()->dtype = DataType::Int(8);
}
// canonicalize shape
for (size_t i = 0; i < flattened->shape.size(); ++i) {
writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i]));
bool shape_is_changed = false;
Array<PrimExpr> new_shape;
for (const auto& dim : flattened->shape) {
auto new_dim = analyzer_->canonical_simplify(dim);
shape_is_changed = shape_is_changed || !StructuralEqual()(dim, new_dim);
new_shape.push_back(new_dim);
}

if (shape_is_changed) {
flattened.CopyOnWrite()->shape = std::move(new_shape);
}

buffer_remap_[buf] = flattened;
@@ -226,6 +252,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
template <typename Node>
Node VisitBufferAccess(Node node) {
ICHECK(node->buffer.defined());
buffers_used_.insert(node->buffer);
auto flattened_indices = GetSimplifiedElemOffset(node->buffer, node->indices);
Buffer flattened_buffer = GetFlattenedBuffer(node->buffer);

@@ -264,6 +291,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;

std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffers_used_;

/*! \brief The updated external buffer map. */
Map<Var, Buffer> updated_extern_buffer_map_;
};
81 changes: 22 additions & 59 deletions tests/python/tir-transform/test_tir_transform_flatten_buffer.py
Original file line number Diff line number Diff line change
@@ -41,42 +41,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")):
C[i, j] = B_new[0, j] * 2.0

def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")):
A = T.Buffer(256, dtype="float32", data=input_A.data)
C = T.Buffer(256, dtype="float32", data=input_C.data)
A = T.decl_buffer(256, dtype="float32", data=input_A.data)
C = T.decl_buffer(256, dtype="float32", data=input_C.data)
for i in T.serial(0, 16):
B_new_data = T.allocate([16], "float32", scope="global")
B_new = T.Buffer([16], "float32", scope="global", data=B_new_data)
for j in T.serial(0, 16):
B_new[j] = A[((i * 16) + j)] + 1.0
for j in T.serial(0, 16):
C[((i * 16) + j)] = B_new[j] * 2.0


class TestElementwiseWithoutDeclBuffer(BaseCompare):
"""2-d buffers are flattened to 1-d
Like TestElementwise, but the TIR doesn't have the DeclBuffer
node. The T.Buffer declaration applies only during the
parsing the TVMScript, and doesn't occur in the TIR itself. In
this case, the allocation should be assumed to be targeting flat
memory, and should be flattened to a 1-d allocation.
"""

def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")):
for i in T.serial(0, 16):
B_new_data = T.allocate([1, 16], "float32", "global")
B_new = T.Buffer([1, 16], "float32", data=B_new_data)
for j in T.serial(0, 16):
B_new[0, j] = A[i, j] + 1.0
for j in T.serial(0, 16):
C[i, j] = B_new[0, j] * 2.0

def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")):
A = T.Buffer(256, dtype="float32", data=input_A.data)
C = T.Buffer(256, dtype="float32", data=input_C.data)
for i in T.serial(0, 16):
B_new_data = T.allocate([16], "float32", "global")
B_new = T.Buffer(16, "float32", data=B_new_data)
B_new = T.decl_buffer(16, "float32", scope="global")
for j in T.serial(0, 16):
B_new[j] = A[((i * 16) + j)] + 1.0
for j in T.serial(0, 16):
@@ -101,8 +69,8 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")):
C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0

def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")):
A = T.Buffer(256, dtype="float32", data=input_A.data)
C = T.Buffer(256, dtype="float32", data=input_C.data)
A = T.decl_buffer(256, dtype="float32", data=input_A.data)
C = T.decl_buffer(256, dtype="float32", data=input_C.data)

i0 = T.env_thread("blockIdx.x")
i1 = T.env_thread("threadIdx.x")
@@ -111,8 +79,7 @@ def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16),
T.launch_thread(i0, 4)
T.launch_thread(i1, 2)
T.launch_thread(i2, 2)
B_data = T.allocate([16], "float32", scope="local")
B = T.Buffer([16], "float32", scope="local", data=B_data)
B = T.decl_buffer(16, "float32", scope="local")
for j in range(0, 16):
B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
for j in range(0, 16):
@@ -136,12 +103,11 @@ def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
input_A = T.match_buffer(a, (n, m), "float32")
input_C = T.match_buffer(c, (n, m), "float32")
A = T.Buffer(n * m, "float32", data=input_A.data)
C = T.Buffer(n * m, "float32", data=input_C.data)
A = T.decl_buffer(n * m, "float32", data=input_A.data)
C = T.decl_buffer(n * m, "float32", data=input_C.data)

for i in range(0, n):
B_data = T.allocate([m], "float32", scope="global")
B = T.Buffer([m], "float32", scope="global", data=B_data)
B = T.decl_buffer(m, "float32", scope="global")
for j in range(0, m):
B[j] = A[i * m + j] + 1.0
for j in range(0, m):
@@ -161,8 +127,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None:
def expected(a: T.handle, b: T.handle, n: T.int32) -> None:
input_A = T.match_buffer(a, (32, n, n), "float32")
input_B = T.match_buffer(b, (32, n, n), "float32")
A = T.Buffer(n * n * 32, "float32", data=input_A.data)
B = T.Buffer(n * n * 32, "float32", data=input_B.data)
A = T.decl_buffer(n * n * 32, "float32", data=input_A.data)
B = T.decl_buffer(n * n * 32, "float32", data=input_B.data)

for i in range(0, n * n * 32):
B[i] = A[i]
@@ -185,8 +151,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None:
def expected(a: T.handle, b: T.handle, n: T.int32) -> None:
input_A = T.match_buffer(a, (32, n, n), "float32")
input_B = T.match_buffer(b, (32, n, n), "float32")
A = T.Buffer(n * n * 32, "float32", data=input_A.data)
B = T.Buffer(n * n * 32, "float32", data=input_B.data)
A = T.decl_buffer(n * n * 32, "float32", data=input_A.data)
B = T.decl_buffer(n * n * 32, "float32", data=input_B.data)

for bx, tx in T.grid((n * n + 1) // 2, 64):
if bx * 64 + tx < n * n * 32:
@@ -205,14 +171,12 @@ def before(A: T.Buffer((4, 32), "float32"), D: T.Buffer((4, 32), "float32")):
D[i, j] = C[i, j] * 2.0

def expected(input_A: T.Buffer((4, 32), "float32"), input_D: T.Buffer((4, 32), "float32")):
A = T.Buffer(128, "float32", data=input_A.data)
D = T.Buffer(128, "float32", data=input_D.data)
A = T.decl_buffer(128, "float32", data=input_A.data)
D = T.decl_buffer(128, "float32", data=input_D.data)

for i, j in T.grid(4, 32):
B_data = T.allocate([128], "float32", scope="global")
B = T.Buffer([128], "float32", scope="global", data=B_data)
C_data = T.allocate([128], "float32", scope="global")
C = T.Buffer([128], "float32", scope="global", data=C_data)
B = T.decl_buffer(128, "float32", scope="global")
C = T.decl_buffer(128, "float32", scope="global")
B[i * 32 + j] = A[i * 32 + j] + 1.0
C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
D[i * 32 + j] = C[i * 32 + j] * 2.0
@@ -231,11 +195,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")):
C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0

def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")):
A = T.Buffer(256, dtype="float32", data=input_A.data)
C = T.Buffer(256, dtype="float32", data=input_C.data)
A = T.decl_buffer(256, dtype="float32", data=input_A.data)
C = T.decl_buffer(256, dtype="float32", data=input_C.data)
for i0 in T.serial(0, 4):
B_new_data = T.allocate([68], "float32", scope="global")
B_new = T.Buffer([68], "float32", scope="global", data=B_new_data)
B_new = T.decl_buffer(68, "float32", scope="global")
for i1 in T.serial(0, 4):
for j in T.serial(0, 16):
B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
@@ -252,8 +215,8 @@ def before(A: T.Buffer(10, "bool"), B: T.Buffer(10, "bool")) -> None:
B[i0] = A[i0]

def expected(input_A: T.Buffer(10, "bool"), input_B: T.Buffer(10, "bool")) -> None:
A = T.Buffer(10, dtype="int8", data=input_A.data)
B = T.Buffer(10, dtype="int8", data=input_B.data)
A = T.decl_buffer(10, dtype="int8", data=input_A.data)
B = T.decl_buffer(10, dtype="int8", data=input_B.data)
# body
for i0 in T.serial(10):
B[i0] = T.cast(T.cast(A[i0], "bool"), "int8")

0 comments on commit 629bb32

Please sign in to comment.