From 65e30465662be65a3d32670dc124cedb52b02872 Mon Sep 17 00:00:00 2001 From: luyang Date: Sun, 26 Jan 2025 14:13:47 +0000 Subject: [PATCH] refactor ccl::AllToAll --- .../kernel/nccl_send_recv_boxing_kernel.cpp | 5 ++- .../cuda/cuda_all_to_all.cpp | 42 +++++++++++-------- .../include/all_to_all.h | 3 +- .../kernels/nccl_logical_send_recv_kernel.cpp | 4 +- 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index 63694047bbe..825c43a36a6 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -119,13 +119,14 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { } } - if (this->has_input() && this->has_output()) { + if (this->has_input() || this->has_output()) { std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), data_type, data_type, parallel_num); void* send_buf = reinterpret_cast(buf_ptr); void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), - recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm); + recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, + this->has_input(), this->has_output()); } if (!this->has_output()) { return; } diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp index c32363d5c4d..313a14f00e0 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp @@ -63,31 +63,37 @@ class CudaAllToAll final : public AllToAll { void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, void* recv, const void* recv_counts, const void* recv_offsets, - const ccl::CclComm& ccl_comm) const override { + const ccl::CclComm& ccl_comm, const bool has_input, + const bool has_output) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); int64_t* send_counts_ptr = static_cast(const_cast(send_counts)); int64_t* recv_counts_ptr = static_cast(const_cast(recv_counts)); int64_t* send_offsets_ptr = static_cast(const_cast(send_offsets)); int64_t* recv_offsets_ptr = static_cast(const_cast(recv_offsets)); - OF_NCCL_CHECK(ncclGroupStart()); - for (int64_t i = 0; i < this->rank_count_; ++i) { - uint64_t send_offset = static_cast(send_offsets_ptr[i]); - uint64_t send_count = static_cast(send_counts_ptr[i]); - char* send_ptr = static_cast(send) + send_offset; - if (send_count > 0) { - OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, - stream->As()->cuda_stream())); - } - - uint64_t recv_offset = static_cast(recv_offsets_ptr[i]); - uint64_t recv_count = static_cast(recv_counts_ptr[i]); - char* recv_ptr = static_cast(recv) + recv_offset; - if (recv_count > 0) { - OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, - stream->As()->cuda_stream())); + if (has_input || has_output) { + OF_NCCL_CHECK(ncclGroupStart()); + for (int64_t i = 0; i < this->rank_count_; ++i) { + if (has_input) { + const uint64_t send_count = static_cast(send_counts_ptr[i]); + if (send_count > 0) { + uint64_t send_offset = static_cast(send_offsets_ptr[i]); + char* send_ptr = static_cast(send) + send_offset; + OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + } + if (has_output) { + const uint64_t recv_count = static_cast(recv_counts_ptr[i]); + if (recv_count > 0) { + uint64_t recv_offset = static_cast(recv_offsets_ptr[i]); + char* recv_ptr = static_cast(recv) + recv_offset; + OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + } } + OF_NCCL_CHECK(ncclGroupEnd()); } - OF_NCCL_CHECK(ncclGroupEnd()); } private: diff --git a/oneflow/user/kernels/collective_communication/include/all_to_all.h b/oneflow/user/kernels/collective_communication/include/all_to_all.h index af7bd3fb3b4..81c35ce80ab 100644 --- a/oneflow/user/kernels/collective_communication/include/all_to_all.h +++ b/oneflow/user/kernels/collective_communication/include/all_to_all.h @@ -37,7 +37,8 @@ class AllToAll : public CollectiveCommunication { // for unbalanced all to all(e.g. nccl all2all using send/recv; hccl HcclAlltoAllV) virtual void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, void* recv, const void* recv_counts, - const void* recv_offsets, const ccl::CclComm& ccl_comm) const = 0; + const void* recv_offsets, const ccl::CclComm& ccl_comm, const bool has_input, + const bool has_output) const = 0; }; inline bool IsAllToAllRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index d8e6b396cb2..fa06e22fd44 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -193,14 +193,14 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr()); } } - const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), data_type, data_type, parallel_num); void* send_buf = reinterpret_cast(buf_ptr); void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), recv_buf, - recv_elem_cnts.data(), recv_offsets.data(), ccl_comm); + recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, /*has_input=*/true, + /*has_output=*/true); const std::vector>& out_tensor_slice_copier_vec = kernel_state->out_tensor_slice_copier_vec();