Skip to content

Commit

Permalink
【Operator Mechanism】Pr68945 _C_ops.c_concat in dynamic graph bug fix (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Function-Samuel authored Jan 26, 2025
1 parent 7b72b3f commit 39e87e1
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 16 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ void PirInterpreter::UpdateNcclOpNum() {
"pd_op.barrier_grad",
"pd_op.alltoall_grad",
"pd_op.global_gather_grad",
"pd_op.c_concat_grad",
"pd_op.distributed_fused_lamb_grad",
"pd_op.margin_cross_entropy_grad",
"pd_op.sync_batch_norm_grad",
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,13 @@ def source_include(header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/store/store_utils.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#endif
Expand Down
64 changes: 64 additions & 0 deletions paddle/phi/api/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,46 @@
}}
"""

NCCL_COMMCONTEXT_INIT = """
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
const auto & comm_context_manager_ = phi::distributed::CommContextManager::GetInstance();
if (nranks > 1 && !comm_context_manager_.Has(std::to_string(ring_id))) {{
auto store = phi::distributed::CreateOrGetGlobalTCPStore();
phi::distributed::CommContextManager::CreateNCCLCommContext(
store, std::to_string(ring_id), rank, nranks);
}}
#endif
"""

SET_NCCL_COMMCONTEXT = """
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
const auto & comm_context_manager = phi::distributed::CommContextManager::GetInstance();
phi::distributed::NCCLCommContext* comm_context = nullptr;
if (comm_context_manager.Has(std::to_string(ring_id))) {{
comm_context = static_cast<phi::distributed::NCCLCommContext *>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(
comm_context,
nullptr,
common::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id(%d) attr.",
std::to_string(ring_id)));
if (!comm_context->GetDevContext() || !comm_context->GetDevContext()->GetCommContext())
{{
auto kernel_res = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{}", {{kernel_backend, kernel_layout, kernel_data_type}}, true);
if (FLAGS_low_precision_op_list) {{
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{}", kernel_data_type);
}}
Backend act_kernel_backend = kernel_res.has_fallback_cpu ? Backend::CPU : kernel_backend;
auto* dev_context = GetDeviceContextByBackend(act_kernel_backend);
dev_context->SetCommContext(comm_context);
}}
}}
#endif
"""

# 1. InferSPMD
SINGLE_DIST_META_IN_TEMPLATE = """
auto meta_dist_input_{name} = MakeDistMetaTensor(*{name}.impl());"""
Expand Down Expand Up @@ -861,6 +901,24 @@ def process_data_type_args(args_item):
input_args=input_args, mesh=mesh, kernel_code=kernel_select_code
)

# Current initialization only consider the case where the parameters of op contain ring_id, nranks and rank.
# Other cases will be addressed in the future.
if 'ring_id' in self.attrs['names']:
if (
'rank' in self.attrs['names']
and 'nranks' in self.attrs['names']
):
if_condition_code = (
if_condition_code
+ '\n'
+ self.generate_nccl_commcontext_init_code()
)
if_condition_code = (
if_condition_code
+ '\n'
+ self.generate_set_nccl_commcontext_code()
)

return kernel_key_item_init + if_condition_code

def generate_specialized_infer_spmd_code(self) -> str:
Expand Down Expand Up @@ -1322,6 +1380,12 @@ def generate_kernel_selection_code(self) -> str:
self.api, self.kernel['func'][0], self.kernel['func'][0]
)

def generate_nccl_commcontext_init_code(self) -> str:
return NCCL_COMMCONTEXT_INIT.format(self.kernel['func'][0])

def generate_set_nccl_commcontext_code(self) -> str:
return SET_NCCL_COMMCONTEXT.format(self.kernel['func'][0], self.api)

def generate_reshard_input_code(self) -> str:
input_reshard_code = ""
if self.generate_infer_spmd is True:
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/generator/dist_bw_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,13 @@ def source_include(header_file_path, fw_header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/store/store_utils.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#endif
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,12 @@
data_type : out_grad
no_need_buffer : input

- backward_op : c_concat_grad
forward : c_concat (Tensor x, int rank, int nranks, int ring_id, bool use_calc_stream, bool use_model_parallel) -> Tensor(out)
args : (Tensor out_grad, int rank = 0, int nranks = 1, int ring_id = 0, bool use_model_parallel = true)
output : Tensor(x_grad)
invoke: c_split(out_grad, rank, nranks, ring_id, use_model_parallel)

- backward_op : cast_grad
forward : cast (Tensor x, DataType dtype) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@
drop_empty_grad : [input_grad]

- op : c_concat
backward: c_concat_grad
inputs :
x : X
outputs :
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@
param : [x, nranks]
kernel :
func : c_concat
traits : paddle::dialect::ForwardOnlyTrait
backward: c_concat_grad

- op : c_identity
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
Expand Down
17 changes: 2 additions & 15 deletions python/paddle/distributed/fleet/layers/mpu/mp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LayerHelper,
_create_tensor,
in_dynamic_mode,
in_dynamic_or_pir_mode,
in_pir_mode,
)
from paddle.nn import Layer
Expand Down Expand Up @@ -139,21 +140,7 @@ def _c_concat(tensor, group=None):
rank = group.rank
nranks = group.nranks

if in_dynamic_mode():
return _legacy_C_ops.c_concat(
tensor,
'ring_id',
ring_id,
'use_calc_stream',
True,
'rank',
rank,
'nranks',
nranks,
'use_model_parallel',
True,
)
elif in_pir_mode():
if in_dynamic_or_pir_mode():
return _C_ops.c_concat(tensor, rank, nranks, ring_id, True, True)
else:
op_type = 'c_concat'
Expand Down

0 comments on commit 39e87e1

Please sign in to comment.