diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 91a37dc35ed6..438dccff0bdb 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -76,12 +76,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto node = Downcast(StmtExprMutator::VisitStmt_(op)); if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) { - const AllocateNode* repl = it->second.as(); + Buffer buf = Downcast(it->second); auto write_ptr = node.CopyOnWrite(); - write_ptr->buffer_var = repl->buffer_var; - write_ptr->dtype = repl->dtype; - write_ptr->extents = repl->extents; - write_ptr->condition = repl->condition; + write_ptr->buffer_var = buf->data; + write_ptr->dtype = buf->dtype; + write_ptr->extents = buf->shape; + write_ptr->condition = const_true(buf->dtype.lanes()); + + if (buf.scope() == "shared") { + // Use volatile access to shared buffer. + write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1, write_ptr->body); + } } return std::move(node); } @@ -344,15 +349,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // 4. Load staging buffer. // Second round of allreduce. for (size_t i = 0; i < size; ++i) { - values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{reduce_index}); + values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], + /*indices=*/{group_index * n_warps + reduce_index}); } if (n_warps < warp_size_) { - mask = mask & (((1 << n_warps) - 1) << group_index); + mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps)); } std::tie(reduce_results, local_bufs) = MakeWarpAllreduce( values, types, combiner, reduce_index, n_warps, group_index, mask, - /*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps), - &seq); + /*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq); new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end()); // 5. Create shared memory buffer(s) of `group_extent` elements, storing @@ -365,9 +370,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { /*shape=*/{make_const(reduce_index->dtype, group_extent)}, /*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared"); write_result.push_back( - BufferStore(broadcast_shared_buf, reduce_results[i], {zero_index})); + BufferStore(broadcast_shared_buf, reduce_results[i], {group_index})); // Update `reduce_results`, pointing to the value loaded from the shared memory buffer. - reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index}); + reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index}); } seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result))); seq.push_back(SyncThread("shared")); @@ -382,7 +387,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { load_remap_[buffers[i]->data.get()] = reduce_results[i]; auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0)); - alloc_remap_[buffers[i]->data.get()] = node; + alloc_remap_[buffers[i]->data.get()] = buf; var_remap_[buffers[i]->data.get()] = buf->data; buf_remap_[buffers[i].get()] = buf; } @@ -400,7 +405,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // previous iteration on the same buffer. seq.emplace_back(SyncThread("shared")); for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx), "shared"); + shared_bufs[idx] = decl_buffer({IntImm(group_index->dtype, group_extent * reduce_extent)}, + types[idx], "red_buf" + std::to_string(idx), "shared"); seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], {BufIndex(reduce_index, group_index, reduce_extent)})); } @@ -414,9 +420,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)}); ICHECK_EQ(load->dtype, types[idx]); load_remap_[buffers[idx]->data.get()] = load; - alloc_remap_[buffers[idx]->data.get()] = - Allocate(shared_bufs[idx]->data, types[idx], - {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); + alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx]; var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data; buf_remap_[buffers[idx].get()] = shared_bufs[idx]; } @@ -772,7 +776,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The load remap std::unordered_map load_remap_; // Allocate remap - std::unordered_map alloc_remap_; + std::unordered_map alloc_remap_; // BufferVar remap std::unordered_map var_remap_; // Buffer remap diff --git a/tests/python/unittest/test_subwarp_reduction_cuda.py b/tests/python/unittest/test_allreduce_cuda.py similarity index 94% rename from tests/python/unittest/test_subwarp_reduction_cuda.py rename to tests/python/unittest/test_allreduce_cuda.py index 7a7b1b06bac9..e9a8ef81cf1b 100644 --- a/tests/python/unittest/test_subwarp_reduction_cuda.py +++ b/tests/python/unittest/test_allreduce_cuda.py @@ -48,7 +48,7 @@ def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) @tvm.testing.requires_gpu @tvm.testing.requires_cuda -def test_cuda_subwarp_reduction(): +def test_allreduce_cuda(): def check_sum(d1: int, d2: int, d3: int): _, _, _d1, _d2, _d3 = reduce.params mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3}) @@ -95,10 +95,12 @@ def check_max(d1: int, d2: int, d3: int): for d1 in range(1, 5): for d2 in range(1, 5): - for d3 in range(2, 33): + for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]: + if d1 * d2 * d3 > 1024: + continue check_sum(d1, d2, d3) check_max(d1, d2, d3) if __name__ == "__main__": - test_cuda_subwarp_reduction() + test_allreduce_cuda() diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py index 1fb8aea66ea8..9d53b1f9dfb5 100644 --- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py @@ -387,6 +387,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 128) red_result = T.allocate([1], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) red_result_1 = T.Buffer((1,), data=red_result, scope="shared") with T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), @@ -463,6 +464,7 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_x = T.launch_thread("threadIdx.x", 1024) red_result = T.allocate([1], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) red_result_1 = T.Buffer((1,), data=red_result, scope="shared") with T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), @@ -550,6 +552,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 4) red_result = T.allocate([4], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) threadIdx_x = T.launch_thread("threadIdx.x", 128) red_result_1 = T.Buffer((4,), data=red_result, scope="shared") with T.attr( @@ -585,11 +588,11 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0] T.tvm_storage_sync("shared") red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") - if threadIdx_x < 16: - red_buf0_3[0] = red_buf_staging_1[threadIdx_x] + if threadIdx_x < 4: + red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x] mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") mask_3[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y)) + T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4)) ) t0_3 = T.Buffer((1,), data=t0, scope="local") t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) @@ -597,11 +600,11 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) red_buf0_3[0] = red_buf0_3[0] + t0_3[0] if threadIdx_x == 0: - red_result_1[0] = red_buf0_3[0] + red_result_1[threadIdx_y] = red_buf0_3[0] T.tvm_storage_sync("shared") if threadIdx_x == 0: B_1 = T.Buffer((4,), data=B.data) - B_1[threadIdx_y] = red_result_1[0] + B_1[threadIdx_y] = red_result_1[threadIdx_y] class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare): @@ -636,6 +639,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): threadIdx_y = T.launch_thread("threadIdx.y", 2) in_thread_B = T.allocate([1], "float32", "local") red_result = T.allocate([2], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) threadIdx_x = T.launch_thread("threadIdx.x", 512) in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local") in_thread_B_1[0] = T.float32(0) @@ -675,11 +679,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0] T.tvm_storage_sync("shared") red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") - if threadIdx_x < 32: - red_buf0_3[0] = red_buf_staging_1[threadIdx_x] + if threadIdx_x < 16: + red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 16 + threadIdx_x] mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") mask_3[0] = T.bitwise_and( - T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y)) + T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16)) ) t0_3 = T.Buffer((1,), data=t0, scope="local") t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32) @@ -691,11 +695,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) red_buf0_3[0] = red_buf0_3[0] + t0_3[0] if threadIdx_x == 0: - red_result_1[0] = red_buf0_3[0] + red_result_1[threadIdx_y] = red_buf0_3[0] T.tvm_storage_sync("shared") if threadIdx_x == 0: B_1 = T.Buffer((2,), data=B.data) - B_1[threadIdx_y] = red_result_1[0] + B_1[threadIdx_y] = red_result_1[threadIdx_y] if __name__ == "__main__":