diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc index 28d71ff93b49ca..6ba383ef4bfa1d 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc @@ -99,6 +99,8 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, SetValue(out, in.value()); SetDistProps(out, in.dims(), in_dist_attr); + VLOG(0) << "Same nd mesh, in_dist_attr: " << in_dist_attr; + VLOG(0) << "Same nd mesh, out_dist_attr: " << out_dist_attr; // 1. change all the partial status to replicated status if needed if (in_dist_attr.is_partial()) { // Copy in_dist_attr.partial_status to avoid overwriting the value of diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc index 4d827bb9d72f6c..1a3d4015a69192 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc @@ -49,7 +49,6 @@ void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, VLOG(3) << "Call RToPReshardFunction Eval"; const auto& out_process_mesh = out_dist_attr.process_mesh(); int64_t local_rank = GetCurRankCoordInMesh(out_process_mesh)[0]; - IntArray shape(in.dims().Get(), in.dims().size()); const auto& in_reduce_type = out_dist_attr.partial_status().at(0); if (local_rank != 0) { @@ -59,6 +58,7 @@ void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, dev_ctx, Assign, in.value(), GetMutableTensor(out)); } else { // reset the physical tensor to zero + IntArray shape(in.local_dims().Get(), in.local_dims().size()); RESHARD_FUNCTOR( dev_ctx, Full, in.dtype(), shape, 0, GetMutableTensor(out)); } diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h index 15ecef53d03433..efa872cf72292c 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h @@ -75,14 +75,14 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, #define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \ do { \ if (phi::CPUContext::classof(dev_ctx)) { \ - VLOG(4) << "Call `" << #fn_name << "` in Resharding on GPU."; \ + VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \ PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES( \ dtype, #fn_name, ([&] { \ fn_name(static_cast(*dev_ctx), \ __VA_ARGS__); \ })); \ } else if (phi::GPUContext::classof(dev_ctx)) { \ - VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \ + VLOG(4) << "Call `" << #fn_name << "` in Resharding on GPU."; \ PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES( \ dtype, #fn_name, ([&] { \ fn_name(static_cast(*dev_ctx), \ diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc index 5740e14ae833a4..695b38519c7ccf 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc @@ -66,7 +66,6 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx, const DistTensor& in, const TensorDistAttr& out_dist_attr, DistTensor* out) { - VLOG(3) << "Call SameStatusReshardFunction Eval"; const auto& in_dist_attr = in.dist_attr(); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.cc b/paddle/phi/infermeta/spmd_rules/elementwise.cc index 9ec18bdaf50ce5..c7ca31b821cc25 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.cc +++ b/paddle/phi/infermeta/spmd_rules/elementwise.cc @@ -224,8 +224,8 @@ SpmdInfo ElementwiseBinaryInferSpmd(const DistMetaTensor& x, out_dist_attr.set_dims_mapping(out_dims_mapping); // Step2.3: Update inputs' dims mapping with merged one. - TensorDistAttr x_dist_attr_dst(x_dist_attr_src); - TensorDistAttr y_dist_attr_dst(y_dist_attr_src); + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + TensorDistAttr y_dist_attr_dst = CopyTensorDistAttrForOutput(y_dist_attr_src); x_dist_attr_dst.set_dims_mapping( GetDimsMappingForAxes(x_axes, axis_to_dim_map)); y_dist_attr_dst.set_dims_mapping( diff --git a/paddle/phi/infermeta/spmd_rules/matmul.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc index 60c7acacf0478c..c50da0aec13397 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -310,6 +310,22 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, y.dist_attr())); }; + auto confirm_dist_attr_with_arg_same_fn = [&](const ArgDistAttr& x_dist_attr, + const ArgDistAttr& y_dist_attr, + const char* debug_msg) { + const auto& x_single_dist_attr = get_attr(x_dist_attr); + const auto& y_single_dist_attr = get_attr(y_dist_attr); + PADDLE_ENFORCE_EQ( + DistAttrsAreBasicallyEqual(x_single_dist_attr, y_single_dist_attr), + true, + phi::errors::Unavailable("The matmul grad infer spmd `%s` verify " + "error: left dist attr is %s, " + "right dist attr is %s.", + debug_msg, + x_single_dist_attr, + y_single_dist_attr)); + }; + // TODO(chenweihang): Now for the case where the forward input generates // an intermediate value through Reshard, because the intermediate value // is destroyed after the forward calculation is completed, the x and y @@ -343,6 +359,9 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, confirm_dist_attr_same_fn( dy_spmd_info.first[0], out_grad, "trans x&y: dy-out_grad"); confirm_dist_attr_same_fn(dy_spmd_info.first[1], x, "trans x&y: dy-x"); + return { + {dy_spmd_info.first[1], dx_spmd_info.first[0], dx_spmd_info.first[1]}, + {dx_spmd_info.second[0], dy_spmd_info.second[0]}}; } else { // X'Y: dX = YG', dY = XG dx_spmd_info = @@ -355,6 +374,9 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, confirm_dist_attr_same_fn(dy_spmd_info.first[0], x, "trans x: dy-x"); confirm_dist_attr_same_fn( dy_spmd_info.first[1], out_grad, "trans x: dy-out_grad"); + return { + {dy_spmd_info.first[0], dx_spmd_info.first[0], dx_spmd_info.first[1]}, + {dx_spmd_info.second[0], dy_spmd_info.second[0]}}; } } else { if (trans_y) { @@ -369,25 +391,25 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, confirm_dist_attr_same_fn( dy_spmd_info.first[0], out_grad, "trans y: dy-out_grad"); confirm_dist_attr_same_fn(dy_spmd_info.first[1], x, "trans y: dy-x"); + return { + {dy_spmd_info.first[1], dx_spmd_info.first[1], dx_spmd_info.first[0]}, + {dx_spmd_info.second[0], dy_spmd_info.second[0]}}; } else { // XY: dX = GY', dY = X'G dx_spmd_info = MatmulInferSpmd(out_grad, y, /*trans_x=*/false, /*trans_y=*/true); dy_spmd_info = MatmulInferSpmd(x, out_grad, /*trans_x=*/true, /*trans_y=*/false); - confirm_dist_attr_same_fn( - dx_spmd_info.first[0], out_grad, "no trans: dx-out_grad"); confirm_dist_attr_same_fn(dx_spmd_info.first[1], y, "no trans: dx-y"); confirm_dist_attr_same_fn(dy_spmd_info.first[0], x, "no trans: dy-x"); - confirm_dist_attr_same_fn( - dy_spmd_info.first[1], out_grad, "no trans: dy-out_grad"); + confirm_dist_attr_with_arg_same_fn(dx_spmd_info.first[0], + dy_spmd_info.first[1], + "no trans: dy-out_grad"); + return { + {dy_spmd_info.first[0], dx_spmd_info.first[1], dx_spmd_info.first[0]}, + {dx_spmd_info.second[0], dy_spmd_info.second[0]}}; } } - - // Here we assume that the input dist attr is unchanged after inference, - // and only return the gradient dist attr - return {{x.dist_attr(), y.dist_attr(), out_grad.dist_attr()}, - {dx_spmd_info.second[0], dy_spmd_info.second[0]}}; } } // namespace distributed diff --git a/test/auto_parallel/test_semi_auto_parallel_hybrid_strategy.py b/test/auto_parallel/test_semi_auto_parallel_hybrid_strategy.py index eefc47d6967163..ae4dbb2dea8cf2 100644 --- a/test/auto_parallel/test_semi_auto_parallel_hybrid_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_hybrid_strategy.py @@ -19,7 +19,11 @@ class TestSemiAutoParallelHybridStrategy(test_base.CommunicationTestDistBase): def setUp(self): - super().setUp(num_of_devices=2, timeout=120, nnode=2) + super().setUp( + num_of_devices=2, + timeout=120, + nnode=2, + ) self._default_envs = { "dtype": "float32", "seed": "2023",