Skip to content

Commit

Permalink
[BugFix][TIR] Fix multi-grouped multi-warp allreduce
Browse files Browse the repository at this point in the history
PR #15327 and #15373 introduced multi-warp allreduce implementation.
At the time of the introduction, I tested the correctness numerically
via the workload of "taking a matrix of ones as input, computing the
summation over each row". Both PR passed this numerical tess, while
I didn't realize that this test is not complete and cannot guarantee
the correctness.

The previous implementation has bug which can be tested by turning
the input matrix from ones to random floating-point numbers. This will
expose the issues of the previous implementation.

Therefore, this PR fixes the issues, and add the numerical tests
for multi-warp allreduce into `test_allreduce_cuda.py`. By reducing
some of the redundant tests in that file, we hope this can reduce the
testing time a bit while still guarantee the correctness.

Sorry for not testing the implementation completely before.
  • Loading branch information
MasterJH5574 committed Jul 25, 2023
1 parent d6407be commit d59c1a8
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 28 deletions.
38 changes: 21 additions & 17 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));

if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
Buffer buf = Downcast<Buffer>(it->second);
auto write_ptr = node.CopyOnWrite();
write_ptr->buffer_var = repl->buffer_var;
write_ptr->dtype = repl->dtype;
write_ptr->extents = repl->extents;
write_ptr->condition = repl->condition;
write_ptr->buffer_var = buf->data;
write_ptr->dtype = buf->dtype;
write_ptr->extents = buf->shape;
write_ptr->condition = const_true(buf->dtype.lanes());

if (buf.scope() == "shared") {
// Use volatile access to shared buffer.
write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1, write_ptr->body);
}
}
return std::move(node);
}
Expand Down Expand Up @@ -344,15 +349,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// 4. Load staging buffer.
// Second round of allreduce.
for (size_t i = 0; i < size; ++i) {
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{reduce_index});
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
/*indices=*/{group_index * n_warps + reduce_index});
}
if (n_warps < warp_size_) {
mask = mask & (((1 << n_warps) - 1) << group_index);
mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps));
}
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, n_warps, group_index, mask,
/*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps),
&seq);
/*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());

// 5. Create shared memory buffer(s) of `group_extent` elements, storing
Expand All @@ -365,9 +370,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared");
write_result.push_back(
BufferStore(broadcast_shared_buf, reduce_results[i], {zero_index}));
BufferStore(broadcast_shared_buf, reduce_results[i], {group_index}));
// Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index});
reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
}
seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result)));
seq.push_back(SyncThread("shared"));
Expand All @@ -382,7 +387,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
load_remap_[buffers[i]->data.get()] = reduce_results[i];

auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = node;
alloc_remap_[buffers[i]->data.get()] = buf;
var_remap_[buffers[i]->data.get()] = buf->data;
buf_remap_[buffers[i].get()] = buf;
}
Expand All @@ -400,7 +405,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx), "shared");
shared_bufs[idx] = decl_buffer({IntImm(group_index->dtype, group_extent * reduce_extent)},
types[idx], "red_buf" + std::to_string(idx), "shared");
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
{BufIndex(reduce_index, group_index, reduce_extent)}));
}
Expand All @@ -414,9 +420,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
{BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)});
ICHECK_EQ(load->dtype, types[idx]);
load_remap_[buffers[idx]->data.get()] = load;
alloc_remap_[buffers[idx]->data.get()] =
Allocate(shared_bufs[idx]->data, types[idx],
{PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
buf_remap_[buffers[idx].get()] = shared_bufs[idx];
}
Expand Down Expand Up @@ -772,7 +776,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// The load remap
std::unordered_map<const VarNode*, PrimExpr> load_remap_;
// Allocate remap
std::unordered_map<const VarNode*, Stmt> alloc_remap_;
std::unordered_map<const VarNode*, Buffer> alloc_remap_;
// BufferVar remap
std::unordered_map<const VarNode*, Var> var_remap_;
// Buffer remap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def check_max(d1: int, d2: int, d3: int):

for d1 in range(1, 5):
for d2 in range(1, 5):
for d3 in range(2, 33):
for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]:
if d1 * d2 * d3 > 1024:
continue
check_sum(d1, d2, d3)
check_max(d1, d2, d3)

Expand Down
24 changes: 14 additions & 10 deletions tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
for i in range(128):
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result = T.allocate([1], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
Expand Down Expand Up @@ -463,6 +464,7 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
red_result = T.allocate([1], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
Expand Down Expand Up @@ -550,6 +552,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_y = T.launch_thread("threadIdx.y", 4)
red_result = T.allocate([4], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result_1 = T.Buffer((4,), data=red_result, scope="shared")
with T.attr(
Expand Down Expand Up @@ -585,23 +588,23 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 16:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
if threadIdx_x < 4:
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_3[0] = T.bitwise_and(
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y))
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4))
)
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((4,), data=B.data)
B_1[threadIdx_y] = red_result_1[0]
B_1[threadIdx_y] = red_result_1[threadIdx_y]


class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
Expand Down Expand Up @@ -636,6 +639,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
threadIdx_y = T.launch_thread("threadIdx.y", 2)
in_thread_B = T.allocate([1], "float32", "local")
red_result = T.allocate([2], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 512)
in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
in_thread_B_1[0] = T.float32(0)
Expand Down Expand Up @@ -675,11 +679,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 32:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
if threadIdx_x < 16:
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 16 + threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_3[0] = T.bitwise_and(
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y))
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16))
)
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32)
Expand All @@ -691,11 +695,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = red_result_1[0]
B_1[threadIdx_y] = red_result_1[threadIdx_y]


if __name__ == "__main__":
Expand Down

0 comments on commit d59c1a8

Please sign in to comment.