From 71cd6e2af61b964ef786d1af643306398fa5c546 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 11 Sep 2023 05:12:52 +0000 Subject: [PATCH 1/2] update matmul spmd rule name --- paddle/phi/infermeta/spmd_rules/matmul.cc | 18 +++++++++--------- paddle/phi/infermeta/spmd_rules/matmul.h | 14 +++++++------- paddle/phi/infermeta/spmd_rules/rules.h | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/matmul.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc index 088f9ab16363ad..a29f23b88038cd 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -114,10 +114,10 @@ void FillMatmulOperandNotation(const int x_ndim, ////////////////// InferMeta(Contains SPMD) Functions ////////////////// -SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, - const DistMetaTensor& y, - bool trans_x, - bool trans_y) { +SpmdInfo MatmulInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + bool trans_x, + bool trans_y) { // Step0: verify input args based on matmul logic auto x_shape = phi::vectorize(x.dims()); auto y_shape = phi::vectorize(y.dims()); @@ -221,11 +221,11 @@ SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}}; } -SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x, - const DistMetaTensor& y, - const DistMetaTensor& out, - bool trans_x, - bool trans_y) { +SpmdInfo MatmulInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out, + bool trans_x, + bool trans_y) { auto out_shape = phi::vectorize(out.dims()); int out_ndim = out_shape.size(); diff --git a/paddle/phi/infermeta/spmd_rules/matmul.h b/paddle/phi/infermeta/spmd_rules/matmul.h index 64cfba26a7445c..6bb36f4bd3d34b 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.h +++ b/paddle/phi/infermeta/spmd_rules/matmul.h @@ -22,16 +22,16 @@ limitations under the License. */ namespace phi { namespace distributed { -SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, +SpmdInfo MatmulInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + bool trans_x, + bool trans_y); + +SpmdInfo MatmulInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& y, + const DistMetaTensor& out, bool trans_x, bool trans_y); -SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x, - const DistMetaTensor& y, - const DistMetaTensor& out, - bool trans_x, - bool trans_y); - } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 5ec2f212ec65bf..7d4087a24b7007 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -40,8 +40,8 @@ namespace distributed { // matmul rule PD_REGISTER_SPMD_RULE(matmul, - PD_INFER_SPMD(phi::distributed::MatmulSpmdInferForward), - PD_INFER_SPMD(phi::distributed::MatmulSpmdInferBackward)); + PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), + PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); } // namespace distributed } // namespace phi From 29078465a23279ec18d036d4a01289c28b86d49c Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 11 Sep 2023 07:03:13 +0000 Subject: [PATCH 2/2] update matmul yaml config --- paddle/phi/api/yaml/legacy_ops.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 4c151374c68936..d81f1358133d82 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -651,7 +651,7 @@ output : Tensor infer_meta : func : MatmulInferMeta - spmd_rule : MatmulSpmdInferForward + spmd_rule : MatmulInferSpmd kernel : func : matmul backward : matmul_grad