Skip to content

Commit

Permalink
refactor GetCclCommForParallelDesc series functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Flowingsun007 committed Jan 25, 2025
1 parent 5636f1b commit 510524a
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 97 deletions.
11 changes: 7 additions & 4 deletions oneflow/core/job/eager_ccl_comm_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ class EagerCclCommMgr {
virtual void CreateCommFromPlan(const Plan& plan) = 0;
virtual bool IsAsyncLaunchCclLogicalKernel() const = 0;
virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0;
virtual ccl::CclComm GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) = 0;
virtual ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) = 0;
virtual ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) = 0;
virtual ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc,
const std::string& stream_name) = 0;
virtual ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc,
const std::string& stream_name,
const int64_t this_parallel_id,
const std::string& comm_key) = 0;

template<typename T>
T* As() {
Expand Down
76 changes: 72 additions & 4 deletions oneflow/core/job/eager_nccl_comm_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,84 @@ ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName(
return comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) {
ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) {
std::set<std::pair<int64_t, int64_t>> device_set;
FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) {
int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}

ncclComm_t comm = GetCommForDevice(device_set);
std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);
ccl::CclComm ccl_comm(ncclCommAdapter);
return ccl_comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) {
ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescAndStreamName(
const ParallelDesc& parallel_desc, const std::string& stream_name) {
std::set<std::pair<int64_t, int64_t>> device_set;
FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) {
int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}

ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name);
std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);
ccl::CclComm ccl_comm(ncclCommAdapter);
return ccl_comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescNdHierarchy(
const ParallelDesc& parallel_desc, const std::string& stream_name,
const int64_t this_parallel_id, const std::string& comm_key) {
std::set<std::pair<int64_t, int64_t>> device_set;
const Shape& hierarchy = *parallel_desc.hierarchy();
CHECK_LE(hierarchy.NumAxes(), 2);

// 1D
if (hierarchy.NumAxes() == 1) {
// 1D hierarchy
for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {
int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
} else if (hierarchy.NumAxes() == 2) {
// 2D hierarchy
CHECK(comm_key == "SameDim0" || comm_key == "SameDim1");
if (comm_key == "SameDim0") {
const int64_t num_groups = hierarchy.At(0);
const int64_t group_size = hierarchy.At(1);
CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num());
const int64_t this_group_begin_parallel_id = this_parallel_id / group_size * group_size;
CHECK_EQ(this_group_begin_parallel_id % group_size, 0);
CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc.parallel_num());
for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {
const int64_t parallel_id = this_group_begin_parallel_id + id_in_group;
const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
} else if (comm_key == "SameDim1") {
const int64_t group_size = hierarchy.At(0);
const int64_t num_groups = hierarchy.At(1);
CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num());
const int64_t this_group_begin_parallel_id = this_parallel_id % num_groups;
CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups,
parallel_desc.parallel_num());
for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {
const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups);
const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
} else {
UNIMPLEMENTED();
}
}

ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name);
std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);
ccl::CclComm ccl_comm(ncclCommAdapter);
Expand Down
12 changes: 7 additions & 5 deletions oneflow/core/job/eager_nccl_comm_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ class EagerNcclCommMgr final : public EagerCclCommMgr {
ncclComm_t GetCommForDevice(const std::set<std::pair<int64_t, int64_t>>& device_set);
ncclComm_t GetCommForDeviceAndStreamName(const std::set<std::pair<int64_t, int64_t>>& device_set,
const std::string& stream_name);
ccl::CclComm GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) override;
ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set,
const std::string& stream_name) override;
ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) override;
ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc,
const std::string& stream_name) override;
ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc,
const std::string& stream_name,
const int64_t this_parallel_id,
const std::string& comm_key) override;

void CreateCommFromPlan(const Plan& plan) override;
bool IsAsyncLaunchCclLogicalKernel() const override { return async_launch_nccl_logical_kernel_; }
Expand Down
9 changes: 2 additions & 7 deletions oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,9 @@ class NcclSendRecvBoxingKernel final : public Kernel {

void Init() const {
ParallelDesc parallel_desc(parallel_conf_);
std::set<std::pair<int64_t, int64_t>> device_set;
for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {
int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ccl::CclComm ccl_comm = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_);
ccl::CclComm ccl_comm =
comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc, stream_name_);
ccl_comm_.reset(new Comm(ccl_comm));
}

Expand Down
8 changes: 1 addition & 7 deletions oneflow/user/kernels/eager_nccl_s2s_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,10 @@ class EagerNcclOpKernelCache final : public user_op::OpKernelCache {
void Init(user_op::KernelCacheContext* ctx) {
const std::string& parallel_conf_txt = ctx->Attr<std::string>("parallel_conf");
ParallelConf parallel_conf;
std::set<std::pair<int64_t, int64_t>> device_set;
CHECK(TxtString2PbMessage(parallel_conf_txt, &parallel_conf));
parallel_desc_ = SymbolOf(ParallelDesc(parallel_conf));
FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_->parallel_num()) {
int64_t machine_id = CHECK_JUST(parallel_desc_->MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc_->DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ccl_comm_ = comm_mgr->GetCclCommForDevice(device_set);
ccl_comm_ = comm_mgr->GetCclCommForParallelDesc(parallel_conf);
}

Symbol<ParallelDesc> parallel_desc_;
Expand Down
31 changes: 4 additions & 27 deletions oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,12 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState {
private:
void Init() {
CHECK(!is_init_);
std::set<std::pair<int64_t, int64_t>> device_set;
const Shape& hierarchy = *parallel_desc_.hierarchy();
CHECK_EQ(hierarchy.NumAxes(), 2);
const int64_t num_groups = hierarchy.At(0);
const int64_t group_size = hierarchy.At(1);
CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num());
const int64_t this_group_begin_parallel_id = this_parallel_id_ / group_size * group_size;
CHECK_EQ(this_group_begin_parallel_id % group_size, 0);
CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc_.parallel_num());
for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {
const int64_t parallel_id = this_group_begin_parallel_id + id_in_group;
const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id));
const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_);
ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_,
this_parallel_id_, "SameDim0");
num_ranks_ = group_size;
is_init_ = true;
}
Expand Down Expand Up @@ -420,23 +409,11 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState

ccl::CclComm ccl_comm() {
if (!is_init_) {
std::set<std::pair<int64_t, int64_t>> device_set;
const Shape& hierarchy = *parallel_desc_.hierarchy();
CHECK_EQ(hierarchy.NumAxes(), 2);
const int64_t group_size = hierarchy.At(0);
const int64_t num_groups = hierarchy.At(1);
CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num());
const int64_t this_group_begin_parallel_id = this_parallel_id_ % num_groups;
CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups,
parallel_desc_.parallel_num());
for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {
const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups);
const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id));
const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_);
ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_,
this_parallel_id_, "SameDim1");
is_init_ = true;
}
return ccl_comm_;
Expand Down
31 changes: 2 additions & 29 deletions oneflow/user/kernels/nccl_logical_fusion_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,45 +127,17 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState {
private:
void InitComm() {
CHECK(!is_init_);
std::set<std::pair<int64_t, int64_t>> device_set;
const Shape& hierarchy = *parallel_desc_.hierarchy();

if (hierarchy.NumAxes() == 1) {
num_ranks_ = parallel_desc_.parallel_num();
FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_.parallel_num()) {
int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
} else if (hierarchy.NumAxes() == 2) {
CHECK(comm_key_ == "SameDim0" || comm_key_ == "SameDim1");
if (comm_key_ == "SameDim0") {
const int64_t num_groups = hierarchy.At(0);
const int64_t group_size = hierarchy.At(1);
CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num());
const int64_t this_group_begin_parallel_id = this_parallel_id_ / group_size * group_size;
CHECK_EQ(this_group_begin_parallel_id % group_size, 0);
CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc_.parallel_num());
for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {
const int64_t parallel_id = this_group_begin_parallel_id + id_in_group;
const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id));
const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
num_ranks_ = group_size;
} else if (comm_key_ == "SameDim1") {
const int64_t group_size = hierarchy.At(0);
const int64_t num_groups = hierarchy.At(1);
CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num());
const int64_t this_group_begin_parallel_id = this_parallel_id_ % num_groups;
CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups,
parallel_desc_.parallel_num());
for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {
const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups);
const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id));
const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
num_ranks_ = group_size;
} else {
UNIMPLEMENTED();
Expand All @@ -175,7 +147,8 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState {
}

EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_);
ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_,
this_parallel_id_, comm_key_);
is_init_ = true;
}

Expand Down
8 changes: 1 addition & 7 deletions oneflow/user/kernels/nccl_logical_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,8 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState {

ccl::CclComm ccl_comm() {
if (!is_init_) {
std::set<std::pair<int64_t, int64_t>> device_set;
FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_.parallel_num()) {
int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_);
ccl_comm_ = comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc_, stream_name_);
is_init_ = true;
}
return ccl_comm_;
Expand Down
9 changes: 2 additions & 7 deletions oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,9 @@ NcclLogicalSendRecvState::NcclLogicalSendRecvState(user_op::KernelInitContext* c
}

void NcclLogicalSendRecvState::InitComm() const {
std::set<std::pair<int64_t, int64_t>> device_set;
for (int64_t parallel_id = 0; parallel_id < parallel_desc_->parallel_num(); ++parallel_id) {
int64_t machine_id = CHECK_JUST(parallel_desc_->MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc_->DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ccl::CclComm ccl_comm = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_);
ccl::CclComm ccl_comm =
comm_mgr->GetCclCommForParallelDescAndStreamName(*parallel_desc_.get(), stream_name_);
ccl_comm_.reset(new Comm(ccl_comm));
}

Expand Down

0 comments on commit 510524a

Please sign in to comment.