diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index cdbdb4b5424f5..378be84621ba2 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -616,7 +616,7 @@ def tvm_storage_sync(storage_scope): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_storage_sync", storage_scope) + return call_intrin("int32", "tir.tvm_storage_sync", storage_scope) def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 8cbcfbc78f785..52e38c7ba2d8e 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -181,22 +181,23 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, freduce_args.push_back(dummy_load); } + // Checks for the thread. + std::vector output_preds; + if (stage->store_predicate.defined()) { + output_preds.emplace_back(stage->store_predicate); + } + for (IterVar iv : stage->leaf_iter_vars) { if (iv->iter_type == kCommReduce) { auto it = stage->iter_var_attrs.find(iv); if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { IterVar tv = (*it).second->bind_thread; freduce_args.push_back(tv->var); + output_preds.push_back(tv->var == make_const(tv->var->dtype, 0)); } } } - // Checks for the thread. - std::vector output_preds; - if (stage->store_predicate.defined()) { - output_preds.emplace_back(stage->store_predicate); - } - // Apply the existing input predicate if any. output_preds.push_back(input_pred); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 97a34a6ede1fb..a42e985a11b8c 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -63,7 +63,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) : target_(target), - warp_size_(target->GetAttr("thread_warp_size", 1).value().IntValue()) {} + warp_size_(target->GetAttr("thread_warp_size", 1).value().IntValue()), + max_num_threads_(target->GetAttr("max_num_threads", -1).value().IntValue()) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -279,9 +280,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } std::vector seq; - std::vector shared_buffer_vars(size); - std::vector shared_bufs(size); - std::vector local_bufs; + std::vector new_alloc_bufs; // // This is an optimization. For small reduction sizes, it may be beneficial // for a single warp to performance the entire reduction. No trips to shared @@ -299,131 +298,87 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // broadcast results from lane 0 to all other lanes and store // the final reduction result to the proper location. // - if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) { - ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction"; - // - // This is the index to the reduction variable, one reduction - // variable per warp. Local scope seems easier to reason without - // relying on a pattern match pass to fix it later. - Array zero_indices = {0}; - - for (size_t idx = 0; idx < size; ++idx) { - Array shape = {1}; - - Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx)); - Var buffer_var = buffer->data; - - shared_buffer_vars[idx] = buffer_var; - shared_bufs[idx] = buffer; - - PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices)); - - // Uses a local variable to store the shuffled data. Later - // on, an allocation will be built for this local variable. - local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx))); - } - - // The mask for this reducer, as this reducer may sit inside - // a divergent control flow. Here it uses a variable to cache the current - // active channels. - // + // When the thread extent is multiple of warp size, we can use a two-stage + // warp-level reduction to optimize. This is implemented by applying the + // algorithm above twice. + // + // For example, suppose we want to use 512 threads to reduce 512 elements + // and the warp size is 32. In this case there are (512 / 32) = 16 warps. + // In the first stage, each of the 16 warps reduces 32 elements. So after + // the stage, we have 16 remaining elements to be reduced, one for each warp. + // We store the 16 elements in shared memory, and start the second stage. + // In the second stage we use the first 16 lanes of the first warp to reduce + // the remaining elements, and this reduction can also be optimized by + // shuffle_down warp-level primitives. + if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) { + std::vector reduce_results; DataType mask_dtype = DataType::UInt(32); - Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask"); - { - PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); - if (group_extent > 1) { - mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1) - << (reduce_extent * cast(mask_dtype, group_index))); + PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); + + if (reduce_extent <= warp_size_) { + if (group_extent > 1 && reduce_extent < warp_size_) { + mask = mask & + (((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index))); } - seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices)); - // Push the buffer description. Later this will have an - // allocation built for it. - local_bufs.push_back(mask_buffer); - } + std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce( + values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq); + } else { + int n_warps = reduce_extent / warp_size_; + std::vector local_bufs; - // Emit reductions within a warp. - int start_offset = 1; - while (start_offset * 2 < reduce_extent) { - start_offset *= 2; - } - for (int offset = start_offset; offset > 0; offset /= 2) { - // Load reduction values, no synchronization needed. - Array a, b; + // 1. Create the staging buffer in shared memory. + std::vector staging_shared_bufs; + staging_shared_bufs.reserve(size); for (size_t i = 0; i < size; ++i) { - Buffer shared_buf = shared_bufs[i]; - BufferLoad val(shared_buf, zero_indices); - ICHECK_EQ(val->dtype, types[i]); - a.push_back(val); - - // __shfl_*sync calls shall not appear in if_then_else expressions - // as this is causing extra divergency. E.g. - // - // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0); - // - // behaves differently from - // - // int t = __shfl_sync(mask, v1, 0); - // v1 = (v2 < v3) ? v3 : t; - // - // The former may cause dead lock as there is a divergent - // branch with a warp sync call inside. - // - PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset); - Buffer local_buf = local_bufs[i]; - Stmt s = BufferStore(local_buf, other, zero_indices); - seq.push_back(s); - - BufferLoad load = BufferLoad(local_buf, zero_indices); - ICHECK_EQ(load->dtype, types[i]); - b.push_back(load); + Buffer staging_shared_buf = decl_buffer( + /*shape=*/{make_const(reduce_index->dtype, n_warps * group_extent)}, + /*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging", /*storage_scope=*/"shared"); + staging_shared_bufs.push_back(staging_shared_buf); + new_alloc_bufs.push_back(staging_shared_buf); } - // Do reductions. - Array ret = (*combiner)(a, b); + // 2. First round of allreduce. + std::tie(reduce_results, local_bufs) = MakeWarpAllreduce( + values, types, combiner, reduce_index, warp_size_, group_index, mask, NullOpt, &seq); + new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end()); - // Store the reduction result to itself. - std::vector stores(size); + // 3. Write allreduce results to staging buffer. + std::vector write_staging_buf; + write_staging_buf.reserve(size); for (size_t i = 0; i < size; ++i) { - Buffer buf = shared_bufs[i]; - stores[i] = BufferStore(buf, ret[i], zero_indices); + new_alloc_bufs.push_back(Downcast(reduce_results[i])->buffer); + write_staging_buf.push_back(BufferStore( + /*buffer=*/staging_shared_bufs[i], + /*value=*/reduce_results[i], + /*indices=*/{group_index * n_warps + floordiv(reduce_index, warp_size_)})); } + PrimExpr cond = floormod(reduce_index, warp_size_) == make_const(reduce_index->dtype, 0); + seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf))); + seq.push_back(SyncThread("shared")); - // During the sub-warp reduction, values from inactive threads could be read, - // which is an undefined behavior according to the cuda document. - // - // In practice, the return value are usually 0, which does no harm to sum reduction. - // However, the result can be incorrect in max or prod reduction. - // Therefore an additional range check has to be performed to ensure the correctness. - if (offset * 2 > reduce_extent) { - PrimExpr cond = reduce_index + offset < reduce_extent; - seq.push_back(IfThenElse(cond, SeqStmt::Flatten(stores))); - } else { - seq.push_back(SeqStmt::Flatten(stores)); + // 4. Load staging buffer. + // Second round of allreduce. + for (size_t i = 0; i < size; ++i) { + values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{reduce_index}); } + if (n_warps < warp_size_) { + mask = mask & (((1 << n_warps) - 1) << group_index); + } + std::tie(reduce_results, local_bufs) = MakeWarpAllreduce( + values, types, combiner, reduce_index, n_warps, group_index, mask, + /*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps), + &seq); + new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end()); } - // Broadcast the reduction result from lane 0 to all other lanes. - // This avoids to emit predicated stores, as all threads are - // uniformly writing the same result. - // - for (size_t i = 0; i < size; ++i) { - Buffer buf = shared_bufs[i]; - PrimExpr val = BufferLoad(buf, zero_indices); - ICHECK_EQ(val->dtype, types[i]); - PrimExpr splat = - WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index); - seq.push_back(BufferStore(buf, splat, zero_indices)); - } - - // Update existing allocations. + // Write back allreduce results and update existing allocations. for (size_t i = 0; i < size; ++i) { ICHECK(!load_remap_.count(buffers[i]->data.get())); PrimExpr pred = const_true(types[i].lanes()); - Buffer buf = shared_bufs[i]; - PrimExpr val = BufferLoad(buf, zero_indices); - ICHECK_EQ(val->dtype, types[i]); - load_remap_[buffers[i]->data.get()] = val; + Buffer buf = Downcast(reduce_results[i])->buffer; + ICHECK_EQ(reduce_results[i]->dtype, types[i]); + load_remap_[buffers[i]->data.get()] = reduce_results[i]; + Array extents{PrimExpr(1)}; auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0)); alloc_remap_[buffers[i]->data.get()] = node; @@ -432,6 +387,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { warp_allocs_.insert(node.get()); } } else { + std::vector shared_bufs(size); if (reduce_extent == 1) { // special case, no reduction is needed. std::vector stores; @@ -444,12 +400,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // previous iteration on the same buffer. seq.emplace_back(SyncThread("shared")); for (size_t idx = 0; idx < size; ++idx) { - Buffer buffer = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx)); - - shared_bufs[idx] = buffer; - shared_buffer_vars[idx] = buffer->data; - - PrimExpr pred = const_true(types[idx].lanes()); + shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx)); seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], {BufIndex(reduce_index, group_index, reduce_extent)})); } @@ -473,14 +424,146 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Fix all local allocations as all statements are built. Stmt body = SeqStmt::Flatten(seq); - for (Buffer buf : local_bufs) { + for (Buffer buf : new_alloc_bufs) { body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body); - new_storage_scopes_[buf->data.get()] = "local"; + if (buf.scope() != "shared") { + new_storage_scopes_[buf->data.get()] = "local"; + } } return body; } + std::pair, std::vector> MakeWarpAllreduce( + std::vector src_values, // + std::vector dtypes, // + const CommReducerNode* combiner, // + PrimExpr reduce_index, int reduce_extent, // + PrimExpr group_index, // + PrimExpr mask, Optional predicate, // + std::vector* seq) { + int n_buffers = src_values.size(); + + std::vector shared_bufs; + std::vector local_bufs; + shared_bufs.reserve(n_buffers); + + // This is the index to the reduction variable, one reduction + // variable per warp. Local scope seems easier to reason without + // relying on a pattern match pass to fix it later. + Array zero_indices = {0}; + Array shape = {1}; + + std::vector load_values; + load_values.reserve(n_buffers); + for (int idx = 0; idx < n_buffers; ++idx) { + shared_bufs.push_back(decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx))); + load_values.push_back(BufferStore(shared_bufs[idx], src_values[idx], zero_indices)); + + // Uses a local variable to store the shuffled data. Later + // on, an allocation will be built for this local variable. + local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx))); + } + + if (predicate.defined()) { + seq->push_back(IfThenElse(predicate.value(), SeqStmt::Flatten(load_values))); + } else { + seq->insert(seq->end(), load_values.begin(), load_values.end()); + } + + // The mask for this reducer, as this reducer may sit inside + // a divergent control flow. Here it uses a variable to cache the current + // active channels. + Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask"); + { + seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices)); + // Push the buffer description. Later this will have an + // allocation built for it. + local_bufs.push_back(mask_buffer); + } + + // Emit reductions within a warp. + int start_offset = 1; + while (start_offset * 2 < reduce_extent) { + start_offset *= 2; + } + for (int offset = start_offset; offset > 0; offset /= 2) { + // Load reduction values, no synchronization needed. + Array a, b; + for (int i = 0; i < n_buffers; ++i) { + Buffer shared_buf = shared_bufs[i]; + BufferLoad val(shared_buf, zero_indices); + ICHECK_EQ(val->dtype, dtypes[i]); + a.push_back(val); + + // __shfl_*sync calls shall not appear in if_then_else expressions + // as this is causing extra divergency. E.g. + // + // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0); + // + // behaves differently from + // + // int t = __shfl_sync(mask, v1, 0); + // v1 = (v2 < v3) ? v3 : t; + // + // The former may cause dead lock as there is a divergent + // branch with a warp sync call inside. + PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset); + Buffer local_buf = local_bufs[i]; + Stmt s = BufferStore(local_buf, other, zero_indices); + seq->push_back(s); + + BufferLoad load = BufferLoad(local_buf, zero_indices); + ICHECK_EQ(load->dtype, dtypes[i]); + b.push_back(load); + } + + // Do reductions. + Array ret = (*combiner)(a, b); + + // Store the reduction result to itself. + std::vector stores; + stores.reserve(n_buffers); + for (int i = 0; i < n_buffers; ++i) { + Buffer buf = shared_bufs[i]; + stores.push_back(BufferStore(buf, ret[i], zero_indices)); + } + + // During the sub-warp reduction, values from inactive threads could be read, + // which is an undefined behavior according to the cuda document. + // + // In practice, the return value are usually 0, which does no harm to sum reduction. + // However, the result can be incorrect in max or prod reduction. + // Therefore an additional range check has to be performed to ensure the correctness. + if (offset * 2 > reduce_extent) { + PrimExpr cond = reduce_index + offset < reduce_extent; + seq->push_back(IfThenElse(cond, SeqStmt::Flatten(stores))); + } else { + seq->push_back(SeqStmt::Flatten(stores)); + } + } + + // Broadcast the reduction result from lane 0 to all other lanes. + // This avoids to emit predicated stores, as all threads are + // uniformly writing the same result. + for (int i = 0; i < n_buffers; ++i) { + Buffer buf = shared_bufs[i]; + PrimExpr val = BufferLoad(buf, zero_indices); + ICHECK_EQ(val->dtype, dtypes[i]); + PrimExpr splat = + WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index); + seq->push_back(BufferStore(buf, splat, zero_indices)); + } + + std::vector reduce_results; + reduce_results.reserve(n_buffers); + for (int i = 0; i < n_buffers; ++i) { + reduce_results.push_back(BufferLoad(shared_bufs[i], zero_indices)); + } + + return {reduce_results, local_bufs}; + } + // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector& types, const Array& shared_bufs, PrimExpr reduce_index, @@ -637,8 +720,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // // Note: The ROCm backend will only have warp reductions for now. // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda). - bool is_warp_reduction(const std::vector& types, int group_extent, int reduce_extent, - int contiguous_reduce_extent) const { + bool IsWarpReduction(const std::vector& types, int group_extent, int reduce_extent, + int contiguous_reduce_extent) const { // Only cuda target supports warp reductions. if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false; @@ -676,8 +759,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (reduce_extent == 1) { return false; // no need to warp reduce } else { - if (warp_size_ % reduce_extent == 0) { - return true; // warp size is multiple of reduce extent + bool is_subwarp_reduction = warp_size_ % reduce_extent == 0; + bool is_multiwarp_reduction = max_num_threads_ != -1 && + max_num_threads_ <= warp_size_ * warp_size_ && + reduce_extent % warp_size_ == 0; + if (is_subwarp_reduction || is_multiwarp_reduction) { + return true; } else { return group_extent == 1 && reduce_extent <= warp_size_; } @@ -690,6 +777,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The warp size of the device. int warp_size_{1}; + // The maximum number of threads of the device. "-1" denotes unknown. + int max_num_threads_{-1}; // surrounding scope of thread extent. std::vector thread_extents_; diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py index c9e6136ca8d7a..f354dfe9cafcb 100644 --- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py @@ -235,7 +235,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): B[i] = reduce[0] -class TestMultiGroupMask(BaseCompare): +class TestMultiGroupReduction(BaseCompare): @T.prim_func def before(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) @@ -278,10 +278,7 @@ def expected(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): red_buf0_1[0] = A_1[threadIdx_y * 32 + threadIdx_x] mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") - mask_1[0] = T.bitwise_and( - T.tvm_warp_activemask(), - T.shift_left(T.uint32(4294967295), T.uint32(32) * T.Cast("uint32", threadIdx_y)), - ) + mask_1[0] = T.tvm_warp_activemask() t0_1 = T.Buffer((1,), data=t0, scope="local") t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) @@ -300,5 +297,394 @@ def expected(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): B_1[threadIdx_y] = red_buf0_1[0] +class TestMultiGroupMask1(BaseCompare): + @T.prim_func + def before(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 32) + cross_thread_B = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 8) + cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + A_1 = T.Buffer((256,), data=A.data) + T.tvm_thread_allreduce( + T.uint32(1), + A_1[threadIdx_y * 8 + threadIdx_x], + T.bool(True), + cross_thread_B_1[0], + threadIdx_x, + ) + if threadIdx_x == 0: + B_1 = T.Buffer((32,), data=B.data) + B_1[threadIdx_y] = cross_thread_B_1[0] + + @T.prim_func + def expected(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 32) + red_buf0 = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 8) + red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + A_1 = T.Buffer((256,), data=A.data) + red_buf0_1[0] = A_1[threadIdx_y * 8 + threadIdx_x] + mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_1[0] = T.bitwise_and( + T.tvm_warp_activemask(), + T.shift_left(T.uint32(255), T.uint32(8) * T.Cast("uint32", threadIdx_y)), + ) + t0_1 = T.Buffer((1,), data=t0, scope="local") + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 8 * threadIdx_y, 32, 32) + if threadIdx_x == 0: + B_1 = T.Buffer((32,), data=B.data) + B_1[threadIdx_y] = red_buf0_1[0] + + +class TestMultiWarpReduce1(BaseCompare): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + cross_thread_B = T.allocate([1], "float32", "local") + cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + A_1 = T.Buffer((16384,), data=A.data) + T.tvm_thread_allreduce( + T.uint32(1), + A_1[i * 128 + threadIdx_x], + T.bool(True), + cross_thread_B_1[0], + threadIdx_x, + ) + if threadIdx_x == 0: + B_1 = T.Buffer((128,), data=B.data) + B_1[i] = cross_thread_B_1[0] + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + red_buf0 = T.allocate([1], "float32", "local") + red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + red_buf0_1 = T.allocate([1], "float32", "local") + mask_1 = T.allocate([1], "uint32", "local") + t0_1 = T.allocate([1], "float32", "local") + red_buf_staging = T.allocate([4], "float32", "shared") + red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + A_1 = T.Buffer((16384,), data=A.data) + red_buf0_2[0] = A_1[i * 128 + threadIdx_x] + mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local") + mask_2[0] = T.tvm_warp_activemask() + t0_2 = T.Buffer((1,), data=t0_1, scope="local") + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 0, 32, 32) + red_buf_staging_1 = T.Buffer((4,), data=red_buf_staging, scope="shared") + if threadIdx_x % 32 == 0: + red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0] + T.tvm_storage_sync("shared") + if threadIdx_x < 4: + red_buf0_3[0] = red_buf_staging_1[threadIdx_x] + mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_3[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(15)) + t0_3 = T.Buffer((1,), data=t0, scope="local") + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 0, 32, 32) + if threadIdx_x == 0: + B_1 = T.Buffer((128,), data=B.data) + B_1[i] = red_buf0_3[0] + + +class TestMultiWarpReduce2(BaseCompare): + @T.prim_func + def before(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_x = T.launch_thread("threadIdx.x", 1024) + cross_thread_B = T.allocate([1], "float32", "local") + cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + A_1 = T.Buffer((1024,), data=A.data) + T.tvm_thread_allreduce( + T.uint32(1), A_1[threadIdx_x], T.bool(True), cross_thread_B_1[0], threadIdx_x + ) + if threadIdx_x == 0: + B_1 = T.Buffer((1,), data=B.data) + B_1[0] = cross_thread_B_1[0] + + @T.prim_func + def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_x = T.launch_thread("threadIdx.x", 1024) + red_buf0 = T.allocate([1], "float32", "local") + red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + red_buf0_1 = T.allocate([1], "float32", "local") + mask_1 = T.allocate([1], "uint32", "local") + t0_1 = T.allocate([1], "float32", "local") + red_buf_staging = T.allocate([32], "float32", "shared") + red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + A_1 = T.Buffer((1024,), data=A.data) + red_buf0_2[0] = A_1[threadIdx_x] + mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local") + mask_2[0] = T.tvm_warp_activemask() + t0_2 = T.Buffer((1,), data=t0_1, scope="local") + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 0, 32, 32) + red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared") + if threadIdx_x % 32 == 0: + red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0] + T.tvm_storage_sync("shared") + if threadIdx_x < 32: + red_buf0_3[0] = red_buf_staging_1[threadIdx_x] + mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_3[0] = T.tvm_warp_activemask() + t0_3 = T.Buffer((1,), data=t0, scope="local") + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 16, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 4, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 0, 32, 32) + if threadIdx_x == 0: + B_1 = T.Buffer((1,), data=B.data) + B_1[0] = red_buf0_3[0] + + +class TestMultiGroupMultiWarpReduction(BaseCompare): + @T.prim_func + def before(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 4) + cross_thread_B = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 128) + cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + A_1 = T.Buffer((512,), data=A.data) + T.tvm_thread_allreduce( + T.uint32(1), + A_1[threadIdx_y * 128 + threadIdx_x], + T.bool(True), + cross_thread_B_1[0], + threadIdx_x, + ) + if threadIdx_x == 0: + B_1 = T.Buffer((4,), data=B.data) + B_1[threadIdx_y] = cross_thread_B_1[0] + + @T.prim_func + def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 4) + red_buf0 = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 128) + red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + red_buf0_1 = T.allocate([1], "float32", "local") + mask_1 = T.allocate([1], "uint32", "local") + t0_1 = T.allocate([1], "float32", "local") + red_buf_staging = T.allocate([16], "float32", "shared") + red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + A_1 = T.Buffer((512,), data=A.data) + red_buf0_2[0] = A_1[threadIdx_y * 128 + threadIdx_x] + mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local") + mask_2[0] = T.tvm_warp_activemask() + t0_2 = T.Buffer((1,), data=t0_1, scope="local") + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * threadIdx_y, 32, 32) + red_buf_staging_1 = T.Buffer((16,), data=red_buf_staging, scope="shared") + if threadIdx_x % 32 == 0: + red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0] + T.tvm_storage_sync("shared") + if threadIdx_x < 16: + red_buf0_3[0] = red_buf_staging_1[threadIdx_x] + mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_3[0] = T.bitwise_and( + T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y)) + ) + t0_3 = T.Buffer((1,), data=t0, scope="local") + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 4 * threadIdx_y, 32, 32) + if threadIdx_x == 0: + B_1 = T.Buffer((4,), data=B.data) + B_1[threadIdx_y] = red_buf0_3[0] + + +class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare): + @T.prim_func + def before(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 2) + in_thread_B = T.allocate([1], "float32", "local") + cross_thread_B = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 512) + in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local") + in_thread_B_1[0] = T.float32(0) + if threadIdx_x < 70: + A_1 = T.Buffer((140,), data=A.data) + in_thread_B_1[0] = in_thread_B_1[0] + A_1[threadIdx_y * 70 + threadIdx_x] + cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), in_thread_B_1[0], T.bool(True), cross_thread_B_1[0], threadIdx_x + ) + if threadIdx_x == 0: + B_1 = T.Buffer((2,), data=B.data) + B_1[threadIdx_y] = cross_thread_B_1[0] + + @T.prim_func + def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + threadIdx_y = T.launch_thread("threadIdx.y", 2) + in_thread_B = T.allocate([1], "float32", "local") + red_buf0 = T.allocate([1], "float32", "local") + threadIdx_x = T.launch_thread("threadIdx.x", 512) + in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local") + in_thread_B_1[0] = T.float32(0) + if threadIdx_x < 70: + A_1 = T.Buffer((140,), data=A.data) + in_thread_B_1[0] = in_thread_B_1[0] + A_1[threadIdx_y * 70 + threadIdx_x] + red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + red_buf0_1 = T.allocate([1], "float32", "local") + mask_1 = T.allocate([1], "uint32", "local") + t0_1 = T.allocate([1], "float32", "local") + red_buf_staging = T.allocate([32], "float32", "shared") + red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + red_buf0_2[0] = in_thread_B_1[0] + mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local") + mask_2[0] = T.tvm_warp_activemask() + t0_2 = T.Buffer((1,), data=t0_1, scope="local") + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * threadIdx_y, 32, 32) + red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared") + if threadIdx_x % 32 == 0: + red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0] + T.tvm_storage_sync("shared") + if threadIdx_x < 32: + red_buf0_3[0] = red_buf_staging_1[threadIdx_x] + mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_3[0] = T.bitwise_and( + T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y)) + ) + t0_3 = T.Buffer((1,), data=t0, scope="local") + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 4, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 16 * threadIdx_y, 32, 32) + if threadIdx_x == 0: + B_1 = T.Buffer((2,), data=B.data) + B_1[threadIdx_y] = red_buf0_3[0] + + if __name__ == "__main__": tvm.testing.main()