Skip to content

Commit

Permalink
refactor ccl::AllToAll
Browse files Browse the repository at this point in the history
  • Loading branch information
Flowingsun007 committed Jan 26, 2025
1 parent 0dc6cbc commit 65e3046
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 23 deletions.
5 changes: 3 additions & 2 deletions oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const {
}
}

if (this->has_input() && this->has_output()) {
if (this->has_input() || this->has_output()) {
std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(
ctx->stream()->device_type(), data_type, data_type, parallel_num);
void* send_buf = reinterpret_cast<void*>(buf_ptr);
void* recv_buf = reinterpret_cast<void*>(buf_ptr + recv_offset);
all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(),
recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm);
recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm,
this->has_input(), this->has_output());
}

if (!this->has_output()) { return; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,31 +63,37 @@ class CudaAllToAll final : public AllToAll {

void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets,
void* recv, const void* recv_counts, const void* recv_offsets,
const ccl::CclComm& ccl_comm) const override {
const ccl::CclComm& ccl_comm, const bool has_input,
const bool has_output) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
int64_t* send_counts_ptr = static_cast<int64_t*>(const_cast<void*>(send_counts));
int64_t* recv_counts_ptr = static_cast<int64_t*>(const_cast<void*>(recv_counts));
int64_t* send_offsets_ptr = static_cast<int64_t*>(const_cast<void*>(send_offsets));
int64_t* recv_offsets_ptr = static_cast<int64_t*>(const_cast<void*>(recv_offsets));
OF_NCCL_CHECK(ncclGroupStart());
for (int64_t i = 0; i < this->rank_count_; ++i) {
uint64_t send_offset = static_cast<uint64_t>(send_offsets_ptr[i]);
uint64_t send_count = static_cast<uint64_t>(send_counts_ptr[i]);
char* send_ptr = static_cast<char*>(send) + send_offset;
if (send_count > 0) {
OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}

uint64_t recv_offset = static_cast<uint64_t>(recv_offsets_ptr[i]);
uint64_t recv_count = static_cast<uint64_t>(recv_counts_ptr[i]);
char* recv_ptr = static_cast<char*>(recv) + recv_offset;
if (recv_count > 0) {
OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
if (has_input || has_output) {
OF_NCCL_CHECK(ncclGroupStart());
for (int64_t i = 0; i < this->rank_count_; ++i) {
if (has_input) {
const uint64_t send_count = static_cast<uint64_t>(send_counts_ptr[i]);
if (send_count > 0) {
uint64_t send_offset = static_cast<uint64_t>(send_offsets_ptr[i]);
char* send_ptr = static_cast<char*>(send) + send_offset;
OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}
}
if (has_output) {
const uint64_t recv_count = static_cast<uint64_t>(recv_counts_ptr[i]);
if (recv_count > 0) {
uint64_t recv_offset = static_cast<uint64_t>(recv_offsets_ptr[i]);
char* recv_ptr = static_cast<char*>(recv) + recv_offset;
OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}
}
}
OF_NCCL_CHECK(ncclGroupEnd());
}
OF_NCCL_CHECK(ncclGroupEnd());
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class AllToAll : public CollectiveCommunication {
// for unbalanced all to all(e.g. nccl all2all using send/recv; hccl HcclAlltoAllV)
virtual void Launch(ep::Stream* stream, void* send, const void* send_counts,
const void* send_offsets, void* recv, const void* recv_counts,
const void* recv_offsets, const ccl::CclComm& ccl_comm) const = 0;
const void* recv_offsets, const ccl::CclComm& ccl_comm, const bool has_input,
const bool has_output) const = 0;
};

inline bool IsAllToAllRegistered(DeviceType device_type) {
Expand Down
4 changes: 2 additions & 2 deletions oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,14 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O
in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr());
}
}
const int64_t parallel_id = ctx->parallel_ctx().parallel_id();

std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(
ctx->stream()->device_type(), data_type, data_type, parallel_num);
void* send_buf = reinterpret_cast<void*>(buf_ptr);
void* recv_buf = reinterpret_cast<void*>(buf_ptr + recv_offset);
all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), recv_buf,
recv_elem_cnts.data(), recv_offsets.data(), ccl_comm);
recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, /*has_input=*/true,
/*has_output=*/true);

const std::vector<std::shared_ptr<TensorSliceCopier>>& out_tensor_slice_copier_vec =
kernel_state->out_tensor_slice_copier_vec();
Expand Down

0 comments on commit 65e3046

Please sign in to comment.