Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementation of broadcast sub backward by reduce #37754

Merged
merged 4 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_sub_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

Expand All @@ -30,12 +32,69 @@ static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout,
int col = blockIdx.x * blockDim.x + threadIdx.x;

while (col < size) {
dx[col] = dout[col];
if (dx != nullptr) {
dx[col] = dout[col];
}
dy[col] = -dout[col];
col += blockDim.x * gridDim.x;
}
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
auto* dout_data = dout->data<T>();
// dx
if (dx != nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout->dims()) {
if (dx_data != dout_data) {
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(*dout, dx, reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout->dims()) {
if (dy_data != dout_data) {
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
auto size = dy->numel();
dim3 grid_size = dim3(
(size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<T><<<
grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, nullptr,
dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSub>(*dout, dy, reduce_dims, stream);
}
}
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
Expand Down
35 changes: 28 additions & 7 deletions paddle/fluid/operators/elementwise/elementwise_sub_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ struct SubGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; }
};

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");

ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
Expand All @@ -79,13 +94,21 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
Expand All @@ -108,15 +131,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
// skip out
auto* out = dout;
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(),
SubGradDY<T>());
default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
dy);
}
}
};
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/operators/kernel_primitives/functor_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ struct DivideFunctor {
Tx n_inv;
};

/**
* @brief Default inverse functor
*/
template <typename Tx, typename Ty = Tx>
struct InverseFunctor {
HOSTDEVICE inline InverseFunctor() {}

HOSTDEVICE explicit inline InverseFunctor(int n) {}

HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(-x);
}
};

/**
* @brief Default unary square functor
*/
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_functor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ struct CustomSum {
}
};

template <typename Tx, typename Ty = Tx>
struct CustomSub {
using Transformer = kps::InverseFunctor<Tx>;

inline Ty initial() { return static_cast<Ty>(0.0f); }

__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b + a;
}
};

template <typename Tx, typename Ty = Tx>
struct CustomMean {
using Transformer = kps::DivideFunctor<Tx>;
Expand Down