Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix][TIR] Fix multi-grouped multi-warp allreduce #15399

Merged
merged 1 commit into from
Jul 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
@@ -76,12 +76,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));

if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
Buffer buf = Downcast<Buffer>(it->second);
auto write_ptr = node.CopyOnWrite();
write_ptr->buffer_var = repl->buffer_var;
write_ptr->dtype = repl->dtype;
write_ptr->extents = repl->extents;
write_ptr->condition = repl->condition;
write_ptr->buffer_var = buf->data;
write_ptr->dtype = buf->dtype;
write_ptr->extents = buf->shape;
write_ptr->condition = const_true(buf->dtype.lanes());

if (buf.scope() == "shared") {
// Use volatile access to shared buffer.
write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1, write_ptr->body);
}
}
return std::move(node);
}
@@ -344,15 +349,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// 4. Load staging buffer.
// Second round of allreduce.
for (size_t i = 0; i < size; ++i) {
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{reduce_index});
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
/*indices=*/{group_index * n_warps + reduce_index});
}
if (n_warps < warp_size_) {
mask = mask & (((1 << n_warps) - 1) << group_index);
mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps));
}
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, n_warps, group_index, mask,
/*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps),
&seq);
/*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());

// 5. Create shared memory buffer(s) of `group_extent` elements, storing
@@ -365,9 +370,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared");
write_result.push_back(
BufferStore(broadcast_shared_buf, reduce_results[i], {zero_index}));
BufferStore(broadcast_shared_buf, reduce_results[i], {group_index}));
// Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index});
reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
}
seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result)));
seq.push_back(SyncThread("shared"));
@@ -382,7 +387,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
load_remap_[buffers[i]->data.get()] = reduce_results[i];

auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = node;
alloc_remap_[buffers[i]->data.get()] = buf;
var_remap_[buffers[i]->data.get()] = buf->data;
buf_remap_[buffers[i].get()] = buf;
}
@@ -400,7 +405,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx), "shared");
shared_bufs[idx] = decl_buffer({IntImm(group_index->dtype, group_extent * reduce_extent)},
types[idx], "red_buf" + std::to_string(idx), "shared");
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
{BufIndex(reduce_index, group_index, reduce_extent)}));
}
@@ -414,9 +420,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
{BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)});
ICHECK_EQ(load->dtype, types[idx]);
load_remap_[buffers[idx]->data.get()] = load;
alloc_remap_[buffers[idx]->data.get()] =
Allocate(shared_bufs[idx]->data, types[idx],
{PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
buf_remap_[buffers[idx].get()] = shared_bufs[idx];
}
@@ -772,7 +776,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// The load remap
std::unordered_map<const VarNode*, PrimExpr> load_remap_;
// Allocate remap
std::unordered_map<const VarNode*, Stmt> alloc_remap_;
std::unordered_map<const VarNode*, Buffer> alloc_remap_;
// BufferVar remap
std::unordered_map<const VarNode*, Var> var_remap_;
// Buffer remap
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32)

@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_subwarp_reduction():
def test_allreduce_cuda():
def check_sum(d1: int, d2: int, d3: int):
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
@@ -95,10 +95,12 @@ def check_max(d1: int, d2: int, d3: int):

for d1 in range(1, 5):
for d2 in range(1, 5):
for d3 in range(2, 33):
for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]:
if d1 * d2 * d3 > 1024:
continue
check_sum(d1, d2, d3)
check_max(d1, d2, d3)


if __name__ == "__main__":
test_cuda_subwarp_reduction()
test_allreduce_cuda()
24 changes: 14 additions & 10 deletions tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
Original file line number Diff line number Diff line change
@@ -387,6 +387,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
for i in range(128):
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result = T.allocate([1], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
@@ -463,6 +464,7 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
red_result = T.allocate([1], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
@@ -550,6 +552,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_y = T.launch_thread("threadIdx.y", 4)
red_result = T.allocate([4], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result_1 = T.Buffer((4,), data=red_result, scope="shared")
with T.attr(
@@ -585,23 +588,23 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 16:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
if threadIdx_x < 4:
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_3[0] = T.bitwise_and(
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y))
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4))
)
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((4,), data=B.data)
B_1[threadIdx_y] = red_result_1[0]
B_1[threadIdx_y] = red_result_1[threadIdx_y]


class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
@@ -636,6 +639,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
threadIdx_y = T.launch_thread("threadIdx.y", 2)
in_thread_B = T.allocate([1], "float32", "local")
red_result = T.allocate([2], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 512)
in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
in_thread_B_1[0] = T.float32(0)
@@ -675,11 +679,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 32:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
if threadIdx_x < 16:
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 16 + threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
mask_3[0] = T.bitwise_and(
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y))
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16))
)
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32)
@@ -691,11 +695,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = red_result_1[0]
B_1[threadIdx_y] = red_result_1[threadIdx_y]


if __name__ == "__main__":