From d59c1a83c3536500685befe968f38d6982d3d388 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 25 Jul 2023 02:40:48 -0400 Subject: [PATCH] [BugFix][TIR] Fix multi-grouped multi-warp allreduce PR #15327 and #15373 introduced multi-warp allreduce implementation. At the time of the introduction, I tested the correctness numerically via the workload of "taking a matrix of ones as input, computing the summation over each row". Both PR passed this numerical tess, while I didn't realize that this test is not complete and cannot guarantee the correctness. The previous implementation has bug which can be tested by turning the input matrix from ones to random floating-point numbers. This will expose the issues of the previous implementation. Therefore, this PR fixes the issues, and add the numerical tests for multi-warp allreduce into `test_allreduce_cuda.py`. By reducing some of the redundant tests in that file, we hope this can reduce the testing time a bit while still guarantee the correctness. Sorry for not testing the implementation completely before. --- src/tir/transforms/lower_thread_allreduce.cc | 38 ++++++++++--------- ...duction_cuda.py => test_allreduce_cuda.py} | 4 +- ...t_tir_transform_lower_thread_all_reduce.py | 24 +++++++----- 3 files changed, 38 insertions(+), 28 deletions(-) rename tests/python/unittest/{test_subwarp_reduction_cuda.py => test_allreduce_cuda.py} (96%) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 91a37dc35ed63..438dccff0bdbe 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -76,12 +76,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto node = Downcast(StmtExprMutator::VisitStmt_(op)); if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) { - const AllocateNode* repl = it->second.as(); + Buffer buf = Downcast(it->second); auto write_ptr = node.CopyOnWrite(); - write_ptr->buffer_var = repl->buffer_var; - write_ptr->dtype = repl->dtype; - write_ptr->extents = repl->extents; - write_ptr->condition = repl->condition; + write_ptr->buffer_var = buf->data; + write_ptr->dtype = buf->dtype; + write_ptr->extents = buf->shape; + write_ptr->condition = const_true(buf->dtype.lanes()); + + if (buf.scope() == "shared") { + // Use volatile access to shared buffer. + write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1, write_ptr->body); + } } return std::move(node); } @@ -344,15 +349,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // 4. Load staging buffer. // Second round of allreduce. for (size_t i = 0; i < size; ++i) { - values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{reduce_index}); + values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], + /*indices=*/{group_index * n_warps + reduce_index}); } if (n_warps < warp_size_) { - mask = mask & (((1 << n_warps) - 1) << group_index); + mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps)); } std::tie(reduce_results, local_bufs) = MakeWarpAllreduce( values, types, combiner, reduce_index, n_warps, group_index, mask, - /*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps), - &seq); + /*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq); new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end()); // 5. Create shared memory buffer(s) of `group_extent` elements, storing @@ -365,9 +370,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { /*shape=*/{make_const(reduce_index->dtype, group_extent)}, /*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared"); write_result.push_back( - BufferStore(broadcast_shared_buf, reduce_results[i], {zero_index})); + BufferStore(broadcast_shared_buf, reduce_results[i], {group_index})); // Update `reduce_results`, pointing to the value loaded from the shared memory buffer. - reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index}); + reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index}); } seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result))); seq.push_back(SyncThread("shared")); @@ -382,7 +387,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { load_remap_[buffers[i]->data.get()] = reduce_results[i]; auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0)); - alloc_remap_[buffers[i]->data.get()] = node; + alloc_remap_[buffers[i]->data.get()] = buf; var_remap_[buffers[i]->data.get()] = buf->data; buf_remap_[buffers[i].get()] = buf; } @@ -400,7 +405,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // previous iteration on the same buffer. seq.emplace_back(SyncThread("shared")); for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx), "shared"); + shared_bufs[idx] = decl_buffer({IntImm(group_index->dtype, group_extent * reduce_extent)}, + types[idx], "red_buf" + std::to_string(idx), "shared"); seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], {BufIndex(reduce_index, group_index, reduce_extent)})); } @@ -414,9 +420,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)}); ICHECK_EQ(load->dtype, types[idx]); load_remap_[buffers[idx]->data.get()] = load; - alloc_remap_[buffers[idx]->data.get()] = - Allocate(shared_bufs[idx]->data, types[idx], - {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); + alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx]; var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data; buf_remap_[buffers[idx].get()] = shared_bufs[idx]; } @@ -772,7 +776,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The load remap std::unordered_map load_remap_; // Allocate remap - std::unordered_map alloc_remap_; + std::unordered_map alloc_remap_; // BufferVar remap std::unordered_map var_remap_; // Buffer remap diff --git a/tests/python/unittest/test_subwarp_reduction_cuda.py b/tests/python/unittest/test_allreduce_cuda.py similarity index 96% rename from tests/python/unittest/test_subwarp_reduction_cuda.py rename to tests/python/unittest/test_allreduce_cuda.py index 7a7b1b06bac94..a24d97a84d99d 100644 --- a/tests/python/unittest/test_subwarp_reduction_cuda.py +++ b/tests/python/unittest/test_allreduce_cuda.py @@ -95,7 +95,9 @@ def check_max(d1: int, d2: int, d3: int): for d1 in range(1, 5): for d2 in range(1, 5): - for d3 in range(2, 33): + for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]: + if d1 * d2 * d3 > 1024: + continue check_sum(d1, d2, d3) check_max(d1, d2, d3) diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py index 1fb8aea66ea8a..9d53b1f9dfb54 100644 --- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py @@ -387,6 +387,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 128) red_result = T.allocate([1], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) red_result_1 = T.Buffer((1,), data=red_result, scope="shared") with T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), @@ -463,6 +464,7 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_x = T.launch_thread("threadIdx.x", 1024) red_result = T.allocate([1], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) red_result_1 = T.Buffer((1,), data=red_result, scope="shared") with T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), @@ -550,6 +552,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 4) red_result = T.allocate([4], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) threadIdx_x = T.launch_thread("threadIdx.x", 128) red_result_1 = T.Buffer((4,), data=red_result, scope="shared") with T.attr( @@ -585,11 +588,11 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0] T.tvm_storage_sync("shared") red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") - if threadIdx_x < 16: - red_buf0_3[0] = red_buf_staging_1[threadIdx_x] + if threadIdx_x < 4: + red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x] mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") mask_3[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y)) + T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4)) ) t0_3 = T.Buffer((1,), data=t0, scope="local") t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) @@ -597,11 +600,11 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) red_buf0_3[0] = red_buf0_3[0] + t0_3[0] if threadIdx_x == 0: - red_result_1[0] = red_buf0_3[0] + red_result_1[threadIdx_y] = red_buf0_3[0] T.tvm_storage_sync("shared") if threadIdx_x == 0: B_1 = T.Buffer((4,), data=B.data) - B_1[threadIdx_y] = red_result_1[0] + B_1[threadIdx_y] = red_result_1[threadIdx_y] class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare): @@ -636,6 +639,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): threadIdx_y = T.launch_thread("threadIdx.y", 2) in_thread_B = T.allocate([1], "float32", "local") red_result = T.allocate([2], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) threadIdx_x = T.launch_thread("threadIdx.x", 512) in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local") in_thread_B_1[0] = T.float32(0) @@ -675,11 +679,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0] T.tvm_storage_sync("shared") red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") - if threadIdx_x < 32: - red_buf0_3[0] = red_buf_staging_1[threadIdx_x] + if threadIdx_x < 16: + red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 16 + threadIdx_x] mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") mask_3[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y)) + T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16)) ) t0_3 = T.Buffer((1,), data=t0, scope="local") t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32) @@ -691,11 +695,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) red_buf0_3[0] = red_buf0_3[0] + t0_3[0] if threadIdx_x == 0: - red_result_1[0] = red_buf0_3[0] + red_result_1[threadIdx_y] = red_buf0_3[0] T.tvm_storage_sync("shared") if threadIdx_x == 0: B_1 = T.Buffer((2,), data=B.data) - B_1[threadIdx_y] = red_result_1[0] + B_1[threadIdx_y] = red_result_1[threadIdx_y] if __name__ == "__main__":