Skip to content

Commit

Permalink
[TIR] Allreduce broadcast result to each thread in multi-warp case
Browse files Browse the repository at this point in the history
PR apache#15327 introduces the warp-level primitive support in multi-warp
allreduce. However, due to the specialty of the two-stage
shuffle-down reduction implementation of the allreduce in multi-warp
scenarios, PR apache#15327 did not broadcast the allreduce result to each
reduction thread. This behavior does not align with the semantics
of allreduce and is not ideal for many use cases. Therefore, this
PR completes the implementation by inserting a stage of writing the
reduction results to shared memory, so that each reduction thread
across all the reduction warps can access the reduction results.

This shared memory write-back stage will only be inserted in
multi-warp allreduce cases. In single-warp allreduce, a `shfl_sync`
is used to broadcast the reduction results across reduction threads.
Since in multi-warp settings we cannot leverage warp-level primitives
to broadcast the value, we can only make use of shared memory.

The numerical correctness are verified locally.
  • Loading branch information
MasterJH5574 committed Jul 21, 2023
1 parent 03fecba commit cd1bcc4
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 75 deletions.
92 changes: 37 additions & 55 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,6 @@
namespace tvm {
namespace tir {

class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope {
public:
explicit UpdatePointerStorageScopeAllReduce(
const std::unordered_map<const VarNode*, String>& new_storage_scopes)
: UpdatePointerStorageScope(new_storage_scopes) {}

Stmt VisitStmt_(const AllocateNode* op) final {
auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
auto new_scope = GetPtrStorageScope(remapped);
if (new_scope != GetPtrStorageScope(op->buffer_var)) {
Stmt body = StmtExprMutator::VisitStmt(op->body);
if (new_scope == "shared") {
// use volatile access to shared buffer.
body = AttrStmt(remapped, attr::volatile_scope, 1, body);
}
return Allocate(remapped, op->dtype, op->extents, op->condition, body, op->annotations);
}
return StmtExprMutator::VisitStmt_(op);
}
};

class ThreadAllreduceBuilder final : public StmtExprMutator {
public:
explicit ThreadAllreduceBuilder(const TargetNode* target)
Expand Down Expand Up @@ -98,11 +77,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {

if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
if (warp_allocs_.count(repl)) {
new_storage_scopes_[repl->buffer_var.get()] = "local";
} else {
new_storage_scopes_[repl->buffer_var.get()] = "shared";
}
auto write_ptr = node.CopyOnWrite();
write_ptr->buffer_var = repl->buffer_var;
write_ptr->dtype = repl->dtype;
Expand Down Expand Up @@ -161,8 +135,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
return std::move(store);
}

std::unordered_map<const VarNode*, String> new_storage_scopes_;

private:
// Thread entry
struct ThreadEntry {
Expand Down Expand Up @@ -310,6 +282,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// In the second stage we use the first 16 lanes of the first warp to reduce
// the remaining elements, and this reduction can also be optimized by
// shuffle_down warp-level primitives.
PrimExpr zero_index = make_const(reduce_index->dtype, 0);
if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
std::vector<PrimExpr> reduce_results;
DataType mask_dtype = DataType::UInt(32);
Expand All @@ -322,6 +295,18 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}
std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);

// Broadcast the reduction result from lane 0 to all other lanes.
// This avoids to emit predicated stores, as all threads are
// uniformly writing the same result.
for (int i = 0; i < size; ++i) {
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
PrimExpr val = BufferLoad(buf, {zero_index});
ICHECK_EQ(val->dtype, types[i]);
PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(), val,
reduce_extent * group_index);
seq.push_back(BufferStore(buf, splat, {zero_index}));
}
} else {
int n_warps = reduce_extent / warp_size_;
std::vector<Buffer> local_bufs;
Expand Down Expand Up @@ -352,7 +337,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
/*value=*/reduce_results[i],
/*indices=*/{group_index * n_warps + floordiv(reduce_index, warp_size_)}));
}
PrimExpr cond = floormod(reduce_index, warp_size_) == make_const(reduce_index->dtype, 0);
PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index;
seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
seq.push_back(SyncThread("shared"));

Expand All @@ -369,6 +354,23 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
/*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps),
&seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());

// 5. Create shared memory buffer(s) of `group_extent` elements, storing
// the allreduce results so each thread can access.
std::vector<Stmt> write_result;
write_result.reserve(size);
for (size_t i = 0; i < size; ++i) {
new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
Buffer broadcast_shared_buf = decl_buffer(
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared");
write_result.push_back(
BufferStore(broadcast_shared_buf, reduce_results[i], {zero_index}));
// Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index});
}
seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result)));
seq.push_back(SyncThread("shared"));
}

// Write back allreduce results and update existing allocations.
Expand All @@ -379,12 +381,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
ICHECK_EQ(reduce_results[i]->dtype, types[i]);
load_remap_[buffers[i]->data.get()] = reduce_results[i];

Array<PrimExpr> extents{PrimExpr(1)};
auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = node;
var_remap_[buffers[i]->data.get()] = buf->data;
buf_remap_[buffers[i].get()] = buf;
warp_allocs_.insert(node.get());
}
} else {
std::vector<Buffer> shared_bufs(size);
Expand Down Expand Up @@ -426,9 +426,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
Stmt body = SeqStmt::Flatten(seq);
for (Buffer buf : new_alloc_bufs) {
body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
if (buf.scope() != "shared") {
new_storage_scopes_[buf->data.get()] = "local";
}
}

return body;
Expand Down Expand Up @@ -457,12 +454,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
std::vector<Stmt> load_values;
load_values.reserve(n_buffers);
for (int idx = 0; idx < n_buffers; ++idx) {
shared_bufs.push_back(decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx)));
shared_bufs.push_back(
decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx), "local"));
load_values.push_back(BufferStore(shared_bufs[idx], src_values[idx], zero_indices));

// Uses a local variable to store the shuffled data. Later
// on, an allocation will be built for this local variable.
local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx)));
local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx), "local"));
}

if (predicate.defined()) {
Expand All @@ -474,7 +472,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// The mask for this reducer, as this reducer may sit inside
// a divergent control flow. Here it uses a variable to cache the current
// active channels.
Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask");
Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
{
seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices));
// Push the buffer description. Later this will have an
Expand Down Expand Up @@ -543,18 +541,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}
}

// Broadcast the reduction result from lane 0 to all other lanes.
// This avoids to emit predicated stores, as all threads are
// uniformly writing the same result.
for (int i = 0; i < n_buffers; ++i) {
Buffer buf = shared_bufs[i];
PrimExpr val = BufferLoad(buf, zero_indices);
ICHECK_EQ(val->dtype, dtypes[i]);
PrimExpr splat =
WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
seq->push_back(BufferStore(buf, splat, zero_indices));
}

std::vector<PrimExpr> reduce_results;
reduce_results.reserve(n_buffers);
for (int i = 0; i < n_buffers; ++i) {
Expand Down Expand Up @@ -791,8 +777,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
std::unordered_map<const VarNode*, Var> var_remap_;
// Buffer remap
std::unordered_map<const BufferNode*, Buffer> buf_remap_;
// Allocate from warp reductions
std::unordered_set<const void*> warp_allocs_;
// Internal analyzer
arith::Analyzer analyzer_;
};
Expand All @@ -806,9 +790,7 @@ Pass LowerThreadAllreduce() {
ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute";
const TargetNode* target_node = target.as<TargetNode>();
ThreadAllreduceBuilder thread_all_reduce(target_node);
auto reduce_body = thread_all_reduce(n->body);
n->body =
UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body);
n->body = thread_all_reduce(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
Expand Down
52 changes: 32 additions & 20 deletions tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,14 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
T.func_attr({"target": T.target("cuda", host="llvm")})
for i in range(128):
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_buf0 = T.allocate([1], "float32", "local")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
red_result = T.allocate([1], "float32", "shared")
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
red_buf0 = T.allocate([1], "float32", "local")
mask = T.allocate([1], "uint32", "local")
t0 = T.allocate([1], "float32", "local")
red_buf0_1 = T.allocate([1], "float32", "local")
Expand All @@ -415,11 +416,11 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 0, 32, 32)
red_buf_staging_1 = T.Buffer((4,), data=red_buf_staging, scope="shared")
if threadIdx_x % 32 == 0:
red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 4:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
Expand All @@ -429,10 +430,12 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 0, 32, 32)
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((128,), data=B.data)
B_1[i] = red_buf0_3[0]
B_1[i] = red_result_1[0]


class TestMultiWarpReduce2(BaseCompare):
Expand All @@ -459,13 +462,14 @@ def before(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
red_buf0 = T.allocate([1], "float32", "local")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
red_result = T.allocate([1], "float32", "shared")
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
red_buf0 = T.allocate([1], "float32", "local")
mask = T.allocate([1], "uint32", "local")
t0 = T.allocate([1], "float32", "local")
red_buf0_1 = T.allocate([1], "float32", "local")
Expand All @@ -488,11 +492,11 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 0, 32, 32)
red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared")
if threadIdx_x % 32 == 0:
red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 32:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
Expand All @@ -508,10 +512,12 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 0, 32, 32)
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((1,), data=B.data)
B_1[0] = red_buf0_3[0]
B_1[0] = red_result_1[0]


class TestMultiGroupMultiWarpReduction(BaseCompare):
Expand Down Expand Up @@ -543,14 +549,15 @@ def before(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_y = T.launch_thread("threadIdx.y", 4)
red_buf0 = T.allocate([1], "float32", "local")
red_result = T.allocate([4], "float32", "shared")
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
red_result_1 = T.Buffer((4,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
red_buf0 = T.allocate([1], "float32", "local")
mask = T.allocate([1], "uint32", "local")
t0 = T.allocate([1], "float32", "local")
red_buf0_1 = T.allocate([1], "float32", "local")
Expand All @@ -573,11 +580,11 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * threadIdx_y, 32, 32)
red_buf_staging_1 = T.Buffer((16,), data=red_buf_staging, scope="shared")
if threadIdx_x % 32 == 0:
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 16:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
Expand All @@ -589,10 +596,12 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 4 * threadIdx_y, 32, 32)
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((4,), data=B.data)
B_1[threadIdx_y] = red_buf0_3[0]
B_1[threadIdx_y] = red_result_1[0]


class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
Expand Down Expand Up @@ -626,19 +635,20 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
threadIdx_y = T.launch_thread("threadIdx.y", 2)
in_thread_B = T.allocate([1], "float32", "local")
red_buf0 = T.allocate([1], "float32", "local")
red_result = T.allocate([2], "float32", "shared")
threadIdx_x = T.launch_thread("threadIdx.x", 512)
in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
in_thread_B_1[0] = T.float32(0)
if threadIdx_x < 70:
A_1 = T.Buffer((140,), data=A.data)
in_thread_B_1[0] = in_thread_B_1[0] + A_1[threadIdx_y * 70 + threadIdx_x]
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
red_result_1 = T.Buffer((2,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
red_buf0 = T.allocate([1], "float32", "local")
mask = T.allocate([1], "uint32", "local")
t0 = T.allocate([1], "float32", "local")
red_buf0_1 = T.allocate([1], "float32", "local")
Expand All @@ -660,11 +670,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * threadIdx_y, 32, 32)
red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared")
if threadIdx_x % 32 == 0:
red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 32:
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
Expand All @@ -680,10 +690,12 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 16 * threadIdx_y, 32, 32)
if threadIdx_x == 0:
red_result_1[0] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = red_buf0_3[0]
B_1[threadIdx_y] = red_result_1[0]


if __name__ == "__main__":
Expand Down

0 comments on commit cd1bcc4

Please sign in to comment.