From 7aa62c82c17b7ad382ff7f79b32d24ca9511ba61 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Tue, 22 Mar 2022 06:37:13 +0000 Subject: [PATCH 1/2] add maximum limit for grid of reduce, elementwise and gather --- .../platform/device/gpu/gpu_launch_config.h | 2 ++ paddle/phi/backends/gpu/gpu_launch_config.h | 2 ++ .../phi/kernels/funcs/elementwise_grad_base.h | 20 +++++++++++++++++++ paddle/phi/kernels/funcs/gather.cu.h | 4 ++++ paddle/phi/kernels/funcs/reduce_function.h | 16 +++++++++++++-- paddle/phi/kernels/funcs/scatter.cu.h | 6 +++++- 6 files changed, 47 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/platform/device/gpu/gpu_launch_config.h b/paddle/fluid/platform/device/gpu/gpu_launch_config.h index cb0173fd6d911d..a87008d2ce0272 100644 --- a/paddle/fluid/platform/device/gpu/gpu_launch_config.h +++ b/paddle/fluid/platform/device/gpu/gpu_launch_config.h @@ -128,6 +128,8 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D( // Number of threads per block shall be larger than 64. threads = std::max(64, threads); int blocks = DivUp(DivUp(numel, vec_size), threads); + int limit_blocks = context.GetCUDAMaxGridDimSize()[0]; + if (blocks > limit_blocks) blocks = limit_blocks; GpuLaunchConfig config; config.thread_per_block.x = threads; diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index e45b4651225882..f859bf0ce24ee8 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -132,6 +132,8 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, // Number of threads per block shall be larger than 64. threads = std::max(64, threads); int blocks = DivUp(DivUp(numel, vec_size), threads); + int limit_blocks = context.GetCUDAMaxGridDimSize()[0]; + if (blocks > limit_blocks) blocks = limit_blocks; GpuLaunchConfig config; config.thread_per_block.x = threads; diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index 17bf873587381c..c81f030d0dfd19 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -49,6 +49,12 @@ namespace phi { namespace funcs { using DDim = phi::DDim; +template +void LimitGridDim(const GPUContext &ctx, T *grid_dim) { + auto max_grid_dim = ctx.GetCUDAMaxGridDimSize()[0]; + if (*grid_dim > max_grid_dim) *grid_dim = max_grid_dim; +} + template void CommonGradBroadcastCPU(const DenseTensor &x, const DenseTensor &y, @@ -977,6 +983,10 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, // suppose perfoemance improves with h increased. dim3 block_size = dim3(BLOCK_X, BLOCK_Y); int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + auto gplace = phi::GPUPlace(); + auto *ctx = static_cast( + paddle::platform::DeviceContextPool::Instance().Get(gplace)); + LimitGridDim(*ctx, &grid_size); FastElemwiseGradBroadcast1CUDAKernel<<>>( x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); } @@ -998,6 +1008,11 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, T *dy) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); int gird_size = n; + int grid_size = n; + auto gplace = phi::GPUPlace(); + auto *ctx = static_cast( + paddle::platform::DeviceContextPool::Instance().Get(gplace)); + LimitGridDim(*ctx, &grid_size); ElemwiseGradBroadcast2CUDAKernel<<>>( x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy); } @@ -1200,6 +1215,7 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, } else { dim3 block_size = dim3(BLOCK_X, BLOCK_Y); int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + LimitGridDim(ctx, &grid_size); FastCommonGradBroadcastCUDAKernelHeight<<>>( x_data, @@ -1373,6 +1391,7 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, std::multiplies()); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); int grid_size = pre * post; + LimitGridDim(ctx, &grid_size); // we need to calc y offset with blockid, so do x_pre/y_pre to get left // size. if (k_pre != pre) k_pre = pre / k_pre; @@ -1403,6 +1422,7 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, std::multiplies()); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); int grid_size = pre * post; + LimitGridDim(ctx, &grid_size); if (k_pre != pre) k_pre = pre / k_pre; FastCommonGradBroadcastOneCUDAKernel<< maxGridDimX) grid = maxGridDimX; GatherCUDAKernel<<>>( p_src, p_index, p_output, index_size, slice_size); @@ -161,6 +163,8 @@ void GPUGatherNd(const phi::GPUContext& ctx, int block = 512; int64_t n = slice_size * remain_numel; int64_t grid = (n + block - 1) / block; + unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; + if (grid > maxGridDimX) grid = maxGridDimX; GatherNdCUDAKernel<<>>(p_input, g_input_dims, diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 85c371e9f9d450..dacef1516e2fa5 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -309,7 +309,7 @@ struct ReduceConfig { : reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {} // get the parameters of reduceKernel - void Run() { + void Run(const paddle::platform::Place& place) { // step1: update the reduce_dim left_dim and x_dim SetReduceDim(); @@ -321,6 +321,9 @@ struct ReduceConfig { // step4: set the block and grid for launch kernel SetBlockDim(); + + // step5: limit the grid to prevent thead overflow + LimitGridDim(place); } // when should_reduce_again is true, we need malloc temp space for temp data @@ -609,6 +612,15 @@ struct ReduceConfig { grid = grid_dim; } + void LimitGridDim(const paddle::platform::Place& place) { + auto* ctx = static_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)); + std::array max_grid_dim = ctx->GetCUDAMaxGridDimSize(); + grid.x = grid.x < max_grid_dim[0] ? grid.x : max_grid_dim[0]; + grid.y = grid.y < max_grid_dim[1] ? grid.y : max_grid_dim[1]; + grid.z = grid.z < max_grid_dim[2] ? grid.z : max_grid_dim[2]; + } + public: std::vector reduce_dims_origin; std::vector reduce_dim; @@ -1044,7 +1056,7 @@ void ReduceKernel(const KPDevice& dev_ctx, auto x_dim = phi::vectorize(x.dims()); auto config = ReduceConfig(origin_reduce_dims, x_dim); - config.Run(); + config.Run(x.place()); int numel = x.numel(); // after config.run() // SetOutputData for ReduceHigherDim when should_reduce_again is true, diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index f87e8c882c4320..4d33c28e77f6bd 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -156,6 +156,8 @@ void GPUScatterAssign(const phi::GPUContext& ctx, int block = 512; int64_t n = slice_size * index_size; int64_t grid = (n + block - 1) / block; + unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; + grid = grid > maxGridDimX ? maxGridDimX : grid; // if not overwrite mode, init data if (!overwrite) { @@ -240,6 +242,8 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx, int block = 512; int64_t n = slice_size * remain_numel; int64_t grid = (n + block - 1) / block; + unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; + grid = grid > maxGridDimX ? maxGridDimX : grid; ScatterNdCUDAKernel<<>>( p_update, @@ -252,4 +256,4 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx, } } // namespace funcs -} // namespace pten +} // namespace phi From 669da7719c7f53259cb5255077759c3415ffe661 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Thu, 24 Mar 2022 06:40:24 +0000 Subject: [PATCH 2/2] add {} after if --- paddle/fluid/platform/device/gpu/gpu_launch_config.h | 4 +++- paddle/phi/backends/gpu/gpu_launch_config.h | 4 +++- paddle/phi/kernels/funcs/elementwise_grad_base.h | 4 +++- paddle/phi/kernels/funcs/gather.cu.h | 8 ++++++-- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/platform/device/gpu/gpu_launch_config.h b/paddle/fluid/platform/device/gpu/gpu_launch_config.h index a87008d2ce0272..4e8b790fa63d1a 100644 --- a/paddle/fluid/platform/device/gpu/gpu_launch_config.h +++ b/paddle/fluid/platform/device/gpu/gpu_launch_config.h @@ -129,7 +129,9 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D( threads = std::max(64, threads); int blocks = DivUp(DivUp(numel, vec_size), threads); int limit_blocks = context.GetCUDAMaxGridDimSize()[0]; - if (blocks > limit_blocks) blocks = limit_blocks; + if (blocks > limit_blocks) { + blocks = limit_blocks; + } GpuLaunchConfig config; config.thread_per_block.x = threads; diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index f859bf0ce24ee8..41bc6bb47c160b 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -133,7 +133,9 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, threads = std::max(64, threads); int blocks = DivUp(DivUp(numel, vec_size), threads); int limit_blocks = context.GetCUDAMaxGridDimSize()[0]; - if (blocks > limit_blocks) blocks = limit_blocks; + if (blocks > limit_blocks) { + blocks = limit_blocks; + } GpuLaunchConfig config; config.thread_per_block.x = threads; diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index c81f030d0dfd19..23b8388c745899 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -52,7 +52,9 @@ using DDim = phi::DDim; template void LimitGridDim(const GPUContext &ctx, T *grid_dim) { auto max_grid_dim = ctx.GetCUDAMaxGridDimSize()[0]; - if (*grid_dim > max_grid_dim) *grid_dim = max_grid_dim; + if (*grid_dim > max_grid_dim) { + *grid_dim = max_grid_dim; + } } template diff --git a/paddle/phi/kernels/funcs/gather.cu.h b/paddle/phi/kernels/funcs/gather.cu.h index a6934d89d15c25..147f716c126ec5 100644 --- a/paddle/phi/kernels/funcs/gather.cu.h +++ b/paddle/phi/kernels/funcs/gather.cu.h @@ -113,7 +113,9 @@ void GPUGather(const phi::GPUContext& ctx, int64_t n = slice_size * index_size; int64_t grid = (n + block - 1) / block; unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - if (grid > maxGridDimX) grid = maxGridDimX; + if (grid > maxGridDimX) { + grid = maxGridDimX; + } GatherCUDAKernel<<>>( p_src, p_index, p_output, index_size, slice_size); @@ -164,7 +166,9 @@ void GPUGatherNd(const phi::GPUContext& ctx, int64_t n = slice_size * remain_numel; int64_t grid = (n + block - 1) / block; unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - if (grid > maxGridDimX) grid = maxGridDimX; + if (grid > maxGridDimX) { + grid = maxGridDimX; + } GatherNdCUDAKernel<<>>(p_input, g_input_dims,