Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a inner loop for index_select_grad_init() in index_select op when dealing with large-shape data #41563

Merged
merged 4 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions paddle/phi/kernels/funcs/gather.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/memory/memcpy.h"
// TODO(paddle-dev): move gpu_primitives.h to phi
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
Expand Down Expand Up @@ -110,11 +111,8 @@ void GPUGather(const phi::GPUContext& ctx,

int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
if (grid > maxGridDimX) {
grid = maxGridDimX;
}
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

GatherCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src, p_index, p_output, index_size, slice_size);
Expand Down Expand Up @@ -155,11 +153,8 @@ void GPUGatherNd(const phi::GPUContext& ctx,

int block = 512;
int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
if (grid > maxGridDimX) {
grid = maxGridDimX;
}
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

GatherNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(p_input,
g_input_dims,
Expand Down
16 changes: 7 additions & 9 deletions paddle/phi/kernels/funcs/scatter.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <unordered_set>
#include <vector>
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
Expand Down Expand Up @@ -155,9 +156,8 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
// set block and grid num
int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
grid = grid > maxGridDimX ? maxGridDimX : grid;
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

// if not overwrite mode, init data
if (!overwrite) {
Expand Down Expand Up @@ -188,9 +188,8 @@ void GPUScatterGradForX(const phi::GPUContext& ctx,
int64_t block = 512;
int64_t n = slice_size * index_size;
int64_t height = (n + block - 1) / block;

int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid = height < max_grid_dimx ? height : max_grid_dimx;
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_index, p_output, index_size, slice_size);
Expand Down Expand Up @@ -230,9 +229,8 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx,

int block = 512;
int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
grid = grid > maxGridDimX ? maxGridDimX : grid;
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

ScatterNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_update,
Expand Down
16 changes: 4 additions & 12 deletions paddle/phi/kernels/gpu/index_select_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"

DECLARE_bool(cudnn_deterministic);

Expand All @@ -35,7 +36,7 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
int64_t stride,
int64_t size,
int64_t delta) {
CUDA_KERNEL_LOOP(idx, N) {
CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
Expand All @@ -45,15 +46,6 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
}
}

template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}

template <typename T, typename Context>
void IndexSelectGradKernel(const Context& ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -97,8 +89,8 @@ void IndexSelectGradKernel(const Context& ctx,
dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
paddle::platform::LimitGridDim(ctx, &grid_dim);

index_select_grad_init<T><<<grid_dim, block_dim, 0, stream>>>(in_grad_data,
numel);
phi::funcs::SetConstant<phi::GPUContext, T> index_select_grad_init;
index_select_grad_init(ctx, x_grad, static_cast<T>(0));

if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of index_select with single thread.";
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/index_select_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ __global__ void index_select_cuda_kernel(const T* input,
int64_t stride,
int64_t size,
int64_t delta) {
CUDA_KERNEL_LOOP(idx, N) {
CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
Expand Down