Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[phi] Update graph_send_recv OP #40509

Merged
merged 11 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions paddle/fluid/operators/graph_send_recv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class GraphSendRecvGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}

Expand Down Expand Up @@ -68,6 +68,12 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker {
"tensors of Dst_index.")
.SetDefault("SUM")
.InEnum({"SUM", "MEAN", "MIN", "MAX"});
AddAttr<int64_t>(
"out_size",
"(int64_t, default 0)"
"Define the first dimension of Output tensor."
"If set default 0, then the shape of Out is the same with X.")
.SetDefault(0);
AddComment(R"DOC(
Graph Learning Send_Recv combine operator.

Expand All @@ -93,14 +99,14 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetType("graph_send_recv_grad");
op->SetInput("Src_index", this->Input("Src_index"));
op->SetInput("Dst_index", this->Input("Dst_index"));
op->SetInput("X", this->Input("X"));

if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") {
op->SetInput("Dst_count", this->Output("Dst_count"));
}

if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" ||
BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") {
op->SetInput("X", this->Input("X"));
op->SetInput("Out", this->Output("Out"));
}

Expand Down
17 changes: 15 additions & 2 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
Expand Down Expand Up @@ -187,11 +188,23 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
"Src_index and Dst_index should have the same shape."));

auto dims = x.dims();
out->set_dims(dims);
if (out_size <= 0) {
out->set_dims(dims);
} else {
std::vector<int64_t> dims_ = phi::vectorize(dims);
if (dims_.size() > 0) {
dims_[0] = out_size;
}
out->set_dims(phi::make_ddim(dims_));
}
out->set_dtype(x.dtype());

if (pool_type == "MEAN") {
dst_count->set_dims({dims[0]});
if (out_size <= 0) {
dst_count->set_dims({dims[0]});
} else {
dst_count->set_dims({out_size});
}
dst_count->set_dtype(DataType::INT32);
}
}
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
MetaTensor* out,
MetaTensor* dst_count);

Expand Down
34 changes: 13 additions & 21 deletions paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
namespace phi {

template <typename T, typename IndexT, typename Functor>
void GraphSendRecvCpuGradLoop(const int& input_size,
const int& index_size,
void GraphSendRecvCpuGradLoop(const int& index_size,
const IndexT* s_index,
const IndexT* d_index,
const DenseTensor& src,
const DenseTensor& input,
DenseTensor* dst,
const std::string& pool_type,
const int* dst_count = nullptr,
const DenseTensor* input = nullptr,
const DenseTensor* output = nullptr) {
if (pool_type == "SUM") {
Functor functor;
Expand All @@ -55,7 +54,7 @@ void GraphSendRecvCpuGradLoop(const int& input_size,
for (int i = 0; i < index_size; ++i) {
const IndexT& forward_src_idx = d_index[i];
const IndexT& forward_dst_idx = s_index[i];
auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1);
auto input_slice = input.Slice(forward_src_idx, forward_src_idx + 1);
auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1);
auto eigen_input = phi::EigenVector<T>::Flatten(input_slice);
auto eigen_output = phi::EigenVector<T>::Flatten(output_slice);
Expand All @@ -73,18 +72,18 @@ template <typename Context, typename T, typename IndexT>
void GraphSendRecvGradOpKernelLaunchHelper(
const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr,
const DenseTensor* x = nullptr,
const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0];

ctx.template Alloc<T>(x_grad);
T* p_output = x_grad->data<T>();
const auto& src_dims = out_grad.dims();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
const size_t& memset_bytes = memset_size * sizeof(T);
Expand All @@ -97,37 +96,30 @@ void GraphSendRecvGradOpKernelLaunchHelper(

if (pool_type == "SUM") {
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, d_index, s_index, out_grad, x_grad, pool_type);
index_size, d_index, s_index, out_grad, x, x_grad, pool_type);
} else if (pool_type == "MEAN") {
const int* s_count = dst_count->data<int>();
// Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0],
index_size,
d_index,
s_index,
out_grad,
x_grad,
pool_type,
s_count);
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
index_size, d_index, s_index, out_grad, x, x_grad, pool_type, s_count);
} else if (pool_type == "MIN" || pool_type == "MAX") {
// Functor not used here.
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(src_dims[0],
index_size,
GraphSendRecvCpuGradLoop<T, IndexT, GraphSendRecvMinFunctor<T>>(index_size,
d_index,
s_index,
out_grad,
x,
x_grad,
pool_type,
nullptr,
x,
out);
}
}

template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index,
const DenseTensor& dst_index,
Expand All @@ -139,23 +131,23 @@ void GraphSendRecvGradKernel(const Context& ctx,
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int32_t>(
ctx,
out_grad,
x,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvGradOpKernelLaunchHelper<Context, T, int64_t>(
ctx,
out_grad,
x,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
}
}
Expand Down
17 changes: 14 additions & 3 deletions paddle/phi/kernels/cpu/graph_send_recv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];
Expand All @@ -91,7 +92,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
T* p_output = out->data<T>();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i];
if (out_size <= 0) {
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
} else {
memset_size = out_size;
for (int i = 1; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
}
const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes);

Expand Down Expand Up @@ -129,15 +139,16 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
}
}

Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ template <typename Context, typename T, typename IndexT>
void GraphSendRecvGradOpCUDAKernelLaunchHelper(
const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
DenseTensor* x_grad,
const DenseTensor* dst_count = nullptr,
const DenseTensor* x = nullptr,
const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0];

ctx.template Alloc<T>(x_grad);
T* p_output = x_grad->data<T>();

const auto& src_dims = out_grad.dims();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
Expand Down Expand Up @@ -86,7 +86,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
ManipulateMeanGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src, d_index, s_index, p_output, index_size, slice_size, s_count);
} else if (pool_type == "MAX" || pool_type == "MIN") {
const T* ptr_input = x->data<T>();
const T* ptr_input = x.data<T>();
const T* ptr_output = out->data<T>();
ManipulateMinMaxGradCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src,
Expand All @@ -103,7 +103,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper(
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index,
const DenseTensor& dst_index,
Expand All @@ -115,23 +115,23 @@ void GraphSendRecvGradKernel(const Context& ctx,
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx,
out_grad,
x,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvGradOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx,
out_grad,
x,
src_index,
dst_index,
pool_type,
x_grad,
dst_count.get_ptr(),
x.get_ptr(),
out.get_ptr());
}
}
Expand Down
26 changes: 22 additions & 4 deletions paddle/phi/kernels/gpu/graph_send_recv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,23 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];
ctx.template Alloc<T>(out);
T* p_output = out->data<T>();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有个疑问,这里预置shape的目的是什么了?是否要对设置的out_size与实际的size进行检查了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

预置shape的目的是希望可以动态减小输出的实际形状。比如我输入X维度为(100, 128),而实际要进行消息传递的结果只有节点编号最大为60,那我可以把输出shape设置为(61, 128),其他结果则动态去掉。

这块设置检查我是想在python端处理。或者也可以在python端函数输入一个flag来确定是否要动态压缩,然后out_size我们直接自己设置为max(dst_index) + 1.

for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
if (out_size <= 0) {
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
} else {
memset_size = out_size;
for (int i = 1; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
}
const size_t& memset_bytes = memset_size * sizeof(T);
if (pool_type == "SUM" || pool_type == "MEAN") {
Expand Down Expand Up @@ -100,6 +108,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);

if (out_size > 0) {
input_size = out_size;
}
int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_max =
grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx;
Expand All @@ -114,6 +125,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
IndexT>><<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);

if (out_size > 0) {
input_size = out_size;
}
int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_min =
grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx;
Expand All @@ -130,6 +144,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,

ctx.template Alloc<int32_t>(dst_count);
int32_t* p_dst_count = dst_count->data<int32_t>();
if (out_size > 0) {
input_size = out_size;
}

#ifdef PADDLE_WITH_HIP
hipMemset(p_dst_count, 0, input_size * sizeof(int));
Expand All @@ -155,15 +172,16 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out, dst_count);
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/graph_send_recv_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace phi {
template <typename T, typename Context>
void GraphSendRecvGradKernel(const Context& ctx,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x,
const DenseTensor& x,
paddle::optional<const DenseTensor&> out,
const DenseTensor& src_index,
const DenseTensor& dst_index,
Expand Down
Loading