From aeb5fbab0648feceb031939597d707c2567e4e36 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 14 Mar 2022 06:57:00 +0000 Subject: [PATCH 01/11] add out_size shape for graph_send_recv --- paddle/fluid/operators/graph_send_recv_op.cc | 10 +++- paddle/phi/infermeta/ternary.cc | 18 ++++++- .../cpu/graph_send_recv_grad_kernel.cc | 34 +++++-------- .../phi/kernels/cpu/graph_send_recv_kernel.cc | 17 +++++-- .../gpu/graph_send_recv_grad_kernel.cu | 12 ++--- .../phi/kernels/gpu/graph_send_recv_kernel.cu | 26 ++++++++-- .../phi/kernels/graph_send_recv_grad_kernel.h | 2 +- paddle/phi/kernels/graph_send_recv_kernel.h | 1 + paddle/phi/ops/compat/graph_send_recv_sig.cc | 11 +++++ .../unittests/test_graph_send_recv_op.py | 28 +++++++++++ .../incubate/operators/graph_send_recv.py | 48 ++++++++++++++++--- 11 files changed, 162 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index f7c006dbcb1a9a..b714d3b1d96b11 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -38,7 +38,7 @@ class GraphSendRecvGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - auto in_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto in_dims = ctx->GetInputDim("X"); ctx->SetOutputDim(framework::GradVarName("X"), in_dims); } @@ -68,6 +68,12 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { "tensors of Dst_index.") .SetDefault("SUM") .InEnum({"SUM", "MEAN", "MIN", "MAX"}); + AddAttr( + "out_size", + "(int64_t, default -1)" + "Define the first dimension of Output tensor." + "If set default -1, then the shape of Out is the same with X.") + .SetDefault(-1); AddComment(R"DOC( Graph Learning Send_Recv combine operator. @@ -93,6 +99,7 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker { op->SetType("graph_send_recv_grad"); op->SetInput("Src_index", this->Input("Src_index")); op->SetInput("Dst_index", this->Input("Dst_index")); + op->SetInput("X", this->Input("X")); if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") { op->SetInput("Dst_count", this->Output("Dst_count")); @@ -100,7 +107,6 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker { if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" || BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") { - op->SetInput("X", this->Input("X")); op->SetInput("Out", this->Output("Out")); } diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 235cfe368c1921..aaca3944617e78 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -145,6 +145,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, const std::string& pool_type, + const int64_t& out_size, MetaTensor* out, MetaTensor* dst_count) { auto src_index_dims = src_index.dims(); @@ -187,11 +188,24 @@ void GraphSendRecvInferMeta(const MetaTensor& x, "Src_index and Dst_index should have the same shape.")); auto dims = x.dims(); - out->set_dims(dims); + if (out_size <= -1) { + out->set_dims(dims); + } else { + // std::vector dims_ = phi::vectorize(dims); + // if (dims_.size() > 0) { + // dims_[0] = out_size; + //} + // out->set_dims(phi::make_ddim(dims_)); + out->set_dims({out_size, dims[1]}); + } out->set_dtype(x.dtype()); if (pool_type == "MEAN") { - dst_count->set_dims({dims[0]}); + if (out_size <= -1) { + dst_count->set_dims({dims[0]}); + } else { + dst_count->set_dims({out_size}); + } dst_count->set_dtype(DataType::INT32); } } diff --git a/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc index 8538461b1b83b8..6a83cee1ae40d1 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc @@ -23,15 +23,14 @@ namespace phi { template -void GraphSendRecvCpuGradLoop(const int& input_size, - const int& index_size, +void GraphSendRecvCpuGradLoop(const int& index_size, const IndexT* s_index, const IndexT* d_index, const DenseTensor& src, + const DenseTensor& input, DenseTensor* dst, const std::string& pool_type, const int* dst_count = nullptr, - const DenseTensor* input = nullptr, const DenseTensor* output = nullptr) { if (pool_type == "SUM") { Functor functor; @@ -55,7 +54,7 @@ void GraphSendRecvCpuGradLoop(const int& input_size, for (int i = 0; i < index_size; ++i) { const IndexT& forward_src_idx = d_index[i]; const IndexT& forward_dst_idx = s_index[i]; - auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1); + auto input_slice = input.Slice(forward_src_idx, forward_src_idx + 1); auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1); auto eigen_input = phi::EigenVector::Flatten(input_slice); auto eigen_output = phi::EigenVector::Flatten(output_slice); @@ -73,18 +72,18 @@ template void GraphSendRecvGradOpKernelLaunchHelper( const Context& ctx, const DenseTensor& out_grad, + const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, DenseTensor* x_grad, const DenseTensor* dst_count = nullptr, - const DenseTensor* x = nullptr, const DenseTensor* out = nullptr) { const int& index_size = dst_index.dims()[0]; ctx.template Alloc(x_grad); T* p_output = x_grad->data(); - const auto& src_dims = out_grad.dims(); + const auto& src_dims = x.dims(); int64_t memset_size = 1; for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; const size_t& memset_bytes = memset_size * sizeof(T); @@ -97,29 +96,22 @@ void GraphSendRecvGradOpKernelLaunchHelper( if (pool_type == "SUM") { GraphSendRecvCpuGradLoop>( - src_dims[0], index_size, d_index, s_index, out_grad, x_grad, pool_type); + index_size, d_index, s_index, out_grad, x, x_grad, pool_type); } else if (pool_type == "MEAN") { const int* s_count = dst_count->data(); // Functor not used here. - GraphSendRecvCpuGradLoop>(src_dims[0], - index_size, - d_index, - s_index, - out_grad, - x_grad, - pool_type, - s_count); + GraphSendRecvCpuGradLoop>( + index_size, d_index, s_index, out_grad, x, x_grad, pool_type, s_count); } else if (pool_type == "MIN" || pool_type == "MAX") { // Functor not used here. - GraphSendRecvCpuGradLoop>(src_dims[0], - index_size, + GraphSendRecvCpuGradLoop>(index_size, d_index, s_index, out_grad, + x, x_grad, pool_type, nullptr, - x, out); } } @@ -127,7 +119,7 @@ void GraphSendRecvGradOpKernelLaunchHelper( template void GraphSendRecvGradKernel(const Context& ctx, const DenseTensor& out_grad, - paddle::optional x, + const DenseTensor& x, paddle::optional out, const DenseTensor& src_index, const DenseTensor& dst_index, @@ -139,23 +131,23 @@ void GraphSendRecvGradKernel(const Context& ctx, GraphSendRecvGradOpKernelLaunchHelper( ctx, out_grad, + x, src_index, dst_index, pool_type, x_grad, dst_count.get_ptr(), - x.get_ptr(), out.get_ptr()); } else if (index_type == phi::DataType::INT64) { GraphSendRecvGradOpKernelLaunchHelper( ctx, out_grad, + x, src_index, dst_index, pool_type, x_grad, dst_count.get_ptr(), - x.get_ptr(), out.get_ptr()); } } diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc index fecbd4b1d7aa05..a4d2ac1c9607e0 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -83,6 +83,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + const int64_t& out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { const int& index_size = src_index.dims()[0]; @@ -91,7 +92,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, T* p_output = out->data(); const auto& src_dims = x.dims(); int64_t memset_size = 1; - for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; + if (out_size <= -1) { + for (int i = 0; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + } else { + memset_size = out_size; + for (int i = 1; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + } const size_t& memset_bytes = memset_size * sizeof(T); memset(p_output, 0, memset_bytes); @@ -129,15 +139,16 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + const int64_t& out_size, DenseTensor* out, DenseTensor* dst_count) { auto index_type = src_index.dtype(); if (index_type == phi::DataType::INT32) { GraphSendRecvOpKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out, dst_count); + ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); } else if (index_type == phi::DataType::INT64) { GraphSendRecvOpKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out, dst_count); + ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); } } diff --git a/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu index 75692966b4662c..8bd3337280d759 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu @@ -28,19 +28,19 @@ template void GraphSendRecvGradOpCUDAKernelLaunchHelper( const Context& ctx, const DenseTensor& out_grad, + const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, DenseTensor* x_grad, const DenseTensor* dst_count = nullptr, - const DenseTensor* x = nullptr, const DenseTensor* out = nullptr) { const int& index_size = dst_index.dims()[0]; ctx.template Alloc(x_grad); T* p_output = x_grad->data(); - const auto& src_dims = out_grad.dims(); + const auto& src_dims = x.dims(); int64_t memset_size = 1; for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; @@ -86,7 +86,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( ManipulateMeanGradCUDAKernel<<>>( p_src, d_index, s_index, p_output, index_size, slice_size, s_count); } else if (pool_type == "MAX" || pool_type == "MIN") { - const T* ptr_input = x->data(); + const T* ptr_input = x.data(); const T* ptr_output = out->data(); ManipulateMinMaxGradCUDAKernel<<>>( p_src, @@ -103,7 +103,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( template void GraphSendRecvGradKernel(const Context& ctx, const DenseTensor& out_grad, - paddle::optional x, + const DenseTensor& x, paddle::optional out, const DenseTensor& src_index, const DenseTensor& dst_index, @@ -115,23 +115,23 @@ void GraphSendRecvGradKernel(const Context& ctx, GraphSendRecvGradOpCUDAKernelLaunchHelper( ctx, out_grad, + x, src_index, dst_index, pool_type, x_grad, dst_count.get_ptr(), - x.get_ptr(), out.get_ptr()); } else if (index_type == phi::DataType::INT64) { GraphSendRecvGradOpCUDAKernelLaunchHelper( ctx, out_grad, + x, src_index, dst_index, pool_type, x_grad, dst_count.get_ptr(), - x.get_ptr(), out.get_ptr()); } } diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index fab306f831a6f4..87e32da4019dc4 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -32,6 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + const int64_t& out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { const int& index_size = src_index.dims()[0]; @@ -39,8 +40,15 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, T* p_output = out->data(); const auto& src_dims = x.dims(); int64_t memset_size = 1; - for (int i = 0; i < src_dims.size(); ++i) { - memset_size *= src_dims[i]; + if (out_size <= -1) { + for (int i = 0; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + } else { + memset_size = out_size; + for (int i = 1; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } } const size_t& memset_bytes = memset_size * sizeof(T); if (pool_type == "SUM" || pool_type == "MEAN") { @@ -100,6 +108,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, IndexT>><<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); + if (out_size > -1) { + input_size = out_size; + } int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block; int64_t grid_max = grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; @@ -114,6 +125,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, IndexT>><<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); + if (out_size > -1) { + input_size = out_size; + } int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block; int64_t grid_min = grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; @@ -130,6 +144,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ctx.template Alloc(dst_count); int32_t* p_dst_count = dst_count->data(); + if (out_size > -1) { + input_size = out_size; + } #ifdef PADDLE_WITH_HIP hipMemset(p_dst_count, 0, input_size * sizeof(int)); @@ -155,15 +172,16 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + const int64_t& out_size, DenseTensor* out, DenseTensor* dst_count) { auto index_type = src_index.dtype(); if (index_type == phi::DataType::INT32) { GraphSendRecvOpCUDAKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out, dst_count); + ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); } else if (index_type == phi::DataType::INT64) { GraphSendRecvOpCUDAKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out, dst_count); + ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); } } diff --git a/paddle/phi/kernels/graph_send_recv_grad_kernel.h b/paddle/phi/kernels/graph_send_recv_grad_kernel.h index d163e6e278a075..3694c8f1e6c990 100644 --- a/paddle/phi/kernels/graph_send_recv_grad_kernel.h +++ b/paddle/phi/kernels/graph_send_recv_grad_kernel.h @@ -23,7 +23,7 @@ namespace phi { template void GraphSendRecvGradKernel(const Context& ctx, const DenseTensor& out_grad, - paddle::optional x, + const DenseTensor& x, paddle::optional out, const DenseTensor& src_index, const DenseTensor& dst_index, diff --git a/paddle/phi/kernels/graph_send_recv_kernel.h b/paddle/phi/kernels/graph_send_recv_kernel.h index 95dbdc4443ad00..84faa18f97864b 100644 --- a/paddle/phi/kernels/graph_send_recv_kernel.h +++ b/paddle/phi/kernels/graph_send_recv_kernel.h @@ -25,6 +25,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, + const int64_t& out_size, DenseTensor* out, DenseTensor* dst_count); diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc index dacb8b25a89f9c..fa4da0704c9871 100644 --- a/paddle/phi/ops/compat/graph_send_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -16,6 +16,14 @@ limitations under the License. */ namespace phi { +KernelSignature GraphSendRecvOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("graph_send_recv", + {"X", "Src_index", "Dst_index"}, + {"pool_type", "out_size"}, + {"Out", "Dst_count"}); +} + KernelSignature GraphSendRecvGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( @@ -27,5 +35,8 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( } // namespace phi +PD_REGISTER_ARG_MAPPING_FN(graph_send_recv, + phi::GraphSendRecvOpArgumentMapping); + PD_REGISTER_ARG_MAPPING_FN(graph_send_recv_grad, phi::GraphSendRecvGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index 68b354775d13e6..3cf3e737d2cfc9 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -304,6 +304,34 @@ def test_int32_input(self): "two value is\ {}\n{}, check diff!".format(np_res, ret_res)) + def test_set_outsize(self): + device = paddle.CPUPlace() + with paddle.fluid.dygraph.guard(device): + x = paddle.to_tensor( + np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), dtype="float32") + src_index = paddle.to_tensor(np.array([0, 0, 1]), dtype="int32") + dst_index = paddle.to_tensor(np.array([0, 1, 1]), dtype="int32") + res = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "sum") + out_size = paddle.max(dst_index) + 1 + res_set_outsize = paddle.incubate.graph_send_recv( + x, src_index, dst_index, "sum", out_size) + + np_res = np.array([[0, 2, 3], [1, 6, 8]], dtype="float32") + np_res_set_outsize = np.array( + [[0, 2, 3], [1, 6, 8], [0, 0, 0]], dtype="float32") + self.assertTrue( + np.allclose( + np_res, res, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, res)) + self.assertTrue( + np.allclose( + np_res_set_outsize, res_set_outsize, atol=1e-6), + "two value is\ + {}\n{}, check diff!" + .format(np_res_set_outsize, res_set_outsize)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 45810621e42076..6a4d05b81c08cb 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -19,7 +19,12 @@ from paddle import _C_ops -def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): +def graph_send_recv(x, + src_index, + dst_index, + pool_type="sum", + out_size=None, + name=None): r""" Graph Learning Send_Recv combine operator. @@ -27,7 +32,7 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor - in different pooling types, like sum, mean, max, or min. + in different pooling types, like sum, mean, max, or min. Besides, we can set `out_size` to get necessary output shape. .. code-block:: text @@ -43,6 +48,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): pool_type = "sum" + out_size = None + Then: Out = [[0, 2, 3], @@ -56,6 +63,9 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): The available data type is int32, int64. pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. Default value is `sum`. + out_size (int64): We can set `out_size` to get necessary output shape. If not set, then this + attribute will not be used. Default value is None, and if set, then it + should be `max(dst_index) + 1`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -72,9 +82,24 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") src_index = indexes[:, 0] dst_index = indexes[:, 1] - out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum") + out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum", out_size=None) # Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]] + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out_size = paddle.max(dst_index) + 1 + out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum", out_size=out_size) + # Outputs: [[0., 2., 3.], [[2., 8., 10.]]] + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum", out_size=None) + # Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]] + """ if pool_type not in ["sum", "mean", "max", "min"]: @@ -82,9 +107,17 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): "pool_type should be `sum`, `mean`, `max` or `min`, but received %s" % pool_type) + # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. + if in_dygraph_mode(): - out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, 'pool_type', - pool_type.upper()) + if out_size is None: + out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, + 'pool_type', + pool_type.upper(), 'out_size', -1) + else: + out, tmp = _C_ops.graph_send_recv( + x, src_index, dst_index, 'pool_type', + pool_type.upper(), 'out_size', out_size) return out check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), @@ -105,5 +138,8 @@ def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): "Dst_index": dst_index}, outputs={"Out": out, "Dst_count": dst_count}, - attrs={"pool_type": pool_type.upper()}) + attrs={ + "pool_type": pool_type.upper(), + "out_size": -1 if out_size is None else out_size + }) return out From 87c8dfc573a08b7c6d69e7379cd565385b0afb18 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 15 Mar 2022 02:28:33 +0000 Subject: [PATCH 02/11] fix bug in register kernel: no const int& support --- paddle/phi/infermeta/ternary.cc | 13 ++++++------- paddle/phi/kernels/cpu/graph_send_recv_kernel.cc | 4 ++-- paddle/phi/kernels/gpu/graph_send_recv_kernel.cu | 4 ++-- paddle/phi/kernels/graph_send_recv_kernel.h | 2 +- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index aaca3944617e78..e60b5e8cc581e9 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -145,7 +145,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, const std::string& pool_type, - const int64_t& out_size, + int64_t out_size, MetaTensor* out, MetaTensor* dst_count) { auto src_index_dims = src_index.dims(); @@ -191,12 +191,11 @@ void GraphSendRecvInferMeta(const MetaTensor& x, if (out_size <= -1) { out->set_dims(dims); } else { - // std::vector dims_ = phi::vectorize(dims); - // if (dims_.size() > 0) { - // dims_[0] = out_size; - //} - // out->set_dims(phi::make_ddim(dims_)); - out->set_dims({out_size, dims[1]}); + std::vector dims_ = phi::vectorize(dims); + if (dims_.size() > 0) { + dims_[0] = out_size; + } + out->set_dims(phi::make_ddim(dims_)); } out->set_dtype(x.dtype()); diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc index a4d2ac1c9607e0..152c46398f83a3 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -83,7 +83,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, - const int64_t& out_size, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { const int& index_size = src_index.dims()[0]; @@ -139,7 +139,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, - const int64_t& out_size, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count) { auto index_type = src_index.dtype(); diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index 87e32da4019dc4..8d7f3fddad2f8c 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -32,7 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, - const int64_t& out_size, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { const int& index_size = src_index.dims()[0]; @@ -172,7 +172,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, - const int64_t& out_size, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count) { auto index_type = src_index.dtype(); diff --git a/paddle/phi/kernels/graph_send_recv_kernel.h b/paddle/phi/kernels/graph_send_recv_kernel.h index 84faa18f97864b..51768fbc18f019 100644 --- a/paddle/phi/kernels/graph_send_recv_kernel.h +++ b/paddle/phi/kernels/graph_send_recv_kernel.h @@ -25,7 +25,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, - const int64_t& out_size, + int64_t out_size, DenseTensor* out, DenseTensor* dst_count); From 989800e81ecbb8b6839c34c271489b05452d3cb2 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 15 Mar 2022 02:51:15 +0000 Subject: [PATCH 03/11] add out_size in infermeta --- paddle/phi/infermeta/ternary.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 209a07db18b5c7..aefbb264b9db65 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -49,6 +49,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, const std::string& pool_type, + int64_t out_size, MetaTensor* out, MetaTensor* dst_count); From 5829fde623357795c1d1004ea018a586dd2ea1d3 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 15 Mar 2022 03:32:34 +0000 Subject: [PATCH 04/11] change unittest --- .../paddle/fluid/tests/unittests/test_graph_send_recv_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index 3cf3e737d2cfc9..e5876697a4d618 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -304,9 +304,8 @@ def test_int32_input(self): "two value is\ {}\n{}, check diff!".format(np_res, ret_res)) - def test_set_outsize(self): - device = paddle.CPUPlace() - with paddle.fluid.dygraph.guard(device): + def test_set_outsize_gpu(self): + if paddle.fluid.core.is_compiled_with_cuda(): x = paddle.to_tensor( np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), dtype="float32") src_index = paddle.to_tensor(np.array([0, 0, 1]), dtype="int32") @@ -320,6 +319,7 @@ def test_set_outsize(self): np_res = np.array([[0, 2, 3], [1, 6, 8]], dtype="float32") np_res_set_outsize = np.array( [[0, 2, 3], [1, 6, 8], [0, 0, 0]], dtype="float32") + self.assertTrue( np.allclose( np_res, res, atol=1e-6), From bbc8d0fd313efa0392b61c58f3cb9fad606eaf19 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 15 Mar 2022 04:16:25 +0000 Subject: [PATCH 05/11] fix unittest --- .../paddle/fluid/tests/unittests/test_graph_send_recv_op.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index e5876697a4d618..30f943e3248e90 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -316,9 +316,10 @@ def test_set_outsize_gpu(self): res_set_outsize = paddle.incubate.graph_send_recv( x, src_index, dst_index, "sum", out_size) - np_res = np.array([[0, 2, 3], [1, 6, 8]], dtype="float32") - np_res_set_outsize = np.array( + np_res = np.array( [[0, 2, 3], [1, 6, 8], [0, 0, 0]], dtype="float32") + np_res_set_outsize = np.array( + [[0, 2, 3], [1, 6, 8]], dtype="float32") self.assertTrue( np.allclose( From 9ce535333a87f63b65f66c32e951696e361cdd2f Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 17 Mar 2022 13:14:27 +0000 Subject: [PATCH 06/11] fix out_size default value --- paddle/fluid/operators/graph_send_recv_op.cc | 6 +++--- .../incubate/operators/graph_send_recv.py | 17 ++++++++--------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index b714d3b1d96b11..f67dea7402864b 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -70,10 +70,10 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { .InEnum({"SUM", "MEAN", "MIN", "MAX"}); AddAttr( "out_size", - "(int64_t, default -1)" + "(int64_t, default 0)" "Define the first dimension of Output tensor." - "If set default -1, then the shape of Out is the same with X.") - .SetDefault(-1); + "If set default 0, then the shape of Out is the same with X.") + .SetDefault(0); AddComment(R"DOC( Graph Learning Send_Recv combine operator. diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 6a4d05b81c08cb..eef705c998479c 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -63,9 +63,9 @@ def graph_send_recv(x, The available data type is int32, int64. pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. Default value is `sum`. - out_size (int64): We can set `out_size` to get necessary output shape. If not set, then this - attribute will not be used. Default value is None, and if set, then it - should be `max(dst_index) + 1`. + out_size (int64|None): We can set `out_size` to get necessary output shape. If not set, then this + attribute will not be used. If set, then we will use the following rule to set + output shape: max(out_size, max(dst_index) + 1). name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -82,7 +82,7 @@ def graph_send_recv(x, indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") src_index = indexes[:, 0] dst_index = indexes[:, 1] - out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum", out_size=None) + out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum") # Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]] x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") @@ -97,7 +97,7 @@ def graph_send_recv(x, indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") src_index = indexes[:, 0] dst_index = indexes[:, 1] - out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum", out_size=None) + out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum") # Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]] """ @@ -110,10 +110,9 @@ def graph_send_recv(x, # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. if in_dygraph_mode(): - if out_size is None: + if out_size is None or out_size <= 0: out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, - 'pool_type', - pool_type.upper(), 'out_size', -1) + 'pool_type', pool_type.upper()) else: out, tmp = _C_ops.graph_send_recv( x, src_index, dst_index, 'pool_type', @@ -140,6 +139,6 @@ def graph_send_recv(x, "Dst_count": dst_count}, attrs={ "pool_type": pool_type.upper(), - "out_size": -1 if out_size is None else out_size + "out_size": 0 if out_size is None or out_size <= 0 else out_size }) return out From 0fe8fc7cedad943f46bd7b7a78617599bf37f20c Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 17 Mar 2022 13:18:30 +0000 Subject: [PATCH 07/11] fix doc --- python/paddle/incubate/operators/graph_send_recv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index eef705c998479c..05f6a80a442f28 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -64,8 +64,8 @@ def graph_send_recv(x, pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. Default value is `sum`. out_size (int64|None): We can set `out_size` to get necessary output shape. If not set, then this - attribute will not be used. If set, then we will use the following rule to set - output shape: max(out_size, max(dst_index) + 1). + attribute will not be used. If set, it should be equal with or larger than + max(dst_index) + 1. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. From b69a56b42aacc74652085589fa8d9e52a78ef4b6 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 17 Mar 2022 13:29:43 +0000 Subject: [PATCH 08/11] delete arg mapping --- paddle/phi/ops/compat/graph_send_recv_sig.cc | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc index fa4da0704c9871..dacb8b25a89f9c 100644 --- a/paddle/phi/ops/compat/graph_send_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -16,14 +16,6 @@ limitations under the License. */ namespace phi { -KernelSignature GraphSendRecvOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("graph_send_recv", - {"X", "Src_index", "Dst_index"}, - {"pool_type", "out_size"}, - {"Out", "Dst_count"}); -} - KernelSignature GraphSendRecvGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( @@ -35,8 +27,5 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( } // namespace phi -PD_REGISTER_ARG_MAPPING_FN(graph_send_recv, - phi::GraphSendRecvOpArgumentMapping); - PD_REGISTER_ARG_MAPPING_FN(graph_send_recv_grad, phi::GraphSendRecvGradOpArgumentMapping); From 4c494a27261ef87fea9d6258faa307e93a53a0d2 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 18 Mar 2022 02:28:31 +0000 Subject: [PATCH 09/11] add sig --- paddle/phi/ops/compat/graph_send_recv_sig.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc index dacb8b25a89f9c..fa4da0704c9871 100644 --- a/paddle/phi/ops/compat/graph_send_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -16,6 +16,14 @@ limitations under the License. */ namespace phi { +KernelSignature GraphSendRecvOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("graph_send_recv", + {"X", "Src_index", "Dst_index"}, + {"pool_type", "out_size"}, + {"Out", "Dst_count"}); +} + KernelSignature GraphSendRecvGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( @@ -27,5 +35,8 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( } // namespace phi +PD_REGISTER_ARG_MAPPING_FN(graph_send_recv, + phi::GraphSendRecvOpArgumentMapping); + PD_REGISTER_ARG_MAPPING_FN(graph_send_recv_grad, phi::GraphSendRecvGradOpArgumentMapping); From a52c9ae353ab95c3ac065c66ed3032542457fff9 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 18 Mar 2022 02:37:17 +0000 Subject: [PATCH 10/11] move -1 to 0 --- paddle/phi/kernels/cpu/graph_send_recv_kernel.cc | 2 +- paddle/phi/kernels/gpu/graph_send_recv_kernel.cu | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc index 152c46398f83a3..8f71ba12cc4fa2 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -92,7 +92,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, T* p_output = out->data(); const auto& src_dims = x.dims(); int64_t memset_size = 1; - if (out_size <= -1) { + if (out_size <= 0) { for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; } diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index 8d7f3fddad2f8c..2826c071d6ec3e 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -40,7 +40,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, T* p_output = out->data(); const auto& src_dims = x.dims(); int64_t memset_size = 1; - if (out_size <= -1) { + if (out_size <= 0) { for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; } @@ -108,7 +108,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, IndexT>><<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); - if (out_size > -1) { + if (out_size > 0) { input_size = out_size; } int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block; @@ -125,7 +125,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, IndexT>><<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); - if (out_size > -1) { + if (out_size > 0) { input_size = out_size; } int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block; @@ -144,7 +144,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ctx.template Alloc(dst_count); int32_t* p_dst_count = dst_count->data(); - if (out_size > -1) { + if (out_size > 0) { input_size = out_size; } From 7250bf0e7957568bc68a86aae7b8d82e234067cd Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 18 Mar 2022 03:50:00 +0000 Subject: [PATCH 11/11] move -1 to 0 --- paddle/phi/infermeta/ternary.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index e60b5e8cc581e9..205431c2ec8731 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -188,7 +188,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, "Src_index and Dst_index should have the same shape.")); auto dims = x.dims(); - if (out_size <= -1) { + if (out_size <= 0) { out->set_dims(dims); } else { std::vector dims_ = phi::vectorize(dims); @@ -200,7 +200,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); if (pool_type == "MEAN") { - if (out_size <= -1) { + if (out_size <= 0) { dst_count->set_dims({dims[0]}); } else { dst_count->set_dims({out_size});