Skip to content

Commit

Permalink
Merge branch 'dev_refactor_xccl_primitive' of github.com:Oneflow-Inc/…
Browse files Browse the repository at this point in the history
…oneflow into dev_refactor_xccl_primitive
  • Loading branch information
Flowingsun007 committed Jan 25, 2025
2 parents 510524a + 269dd3e commit c56c527
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CpuRecvImpl final : public Recv {
}

void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,
CclComm ccl_comm) const override {
const CclComm& ccl_comm) const override {
Launch(stream, out, elem_cnt, src);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Recv : public CollectiveCommunication {
virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0;

virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,
CclComm ccl_comm) const = 0;
const CclComm& ccl_comm) const = 0;
};

inline bool IsRecvRegistered(DeviceType device_type) {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/kernels/eager_nccl_s2s_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class EagerNcclOpKernelCache final : public user_op::OpKernelCache {
~EagerNcclOpKernelCache() override = default;

Symbol<ParallelDesc> parallel_desc() const { return parallel_desc_; }
ccl::CclComm ccl_comm() const { return ccl_comm_; }
const ccl::CclComm& ccl_comm() const { return ccl_comm_; }

private:
void Init(user_op::KernelCacheContext* ctx) {
Expand Down
4 changes: 2 additions & 2 deletions oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState {
}
~NcclLogical2DSameDim0KernelCommState() override = default;

ccl::CclComm ccl_comm() {
const ccl::CclComm& ccl_comm() const {
if (!is_init_) { Init(); }
return ccl_comm_;
}
Expand Down Expand Up @@ -407,7 +407,7 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState
}
~NcclLogical2DSameDim1KernelCommState() = default;

ccl::CclComm ccl_comm() {
const ccl::CclComm& ccl_comm() {
if (!is_init_) {
const Shape& hierarchy = *parallel_desc_.hierarchy();
CHECK_EQ(hierarchy.NumAxes(), 2);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/kernels/nccl_logical_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState {
}
~NcclLogicalKernelCommState() override = default;

ccl::CclComm ccl_comm() {
const ccl::CclComm& ccl_comm() {
if (!is_init_) {
EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ccl_comm_ = comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc_, stream_name_);
Expand Down

0 comments on commit c56c527

Please sign in to comment.