forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathblock_reduce.cuh
86 lines (77 loc) · 2.49 KB
/
block_reduce.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#pragma once
#include <thrust/tuple.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/cuda/DeviceUtils.cuh>
namespace at {
namespace native {
namespace cuda_utils {
constexpr int kCUDABlockReduceNumThreads = 512;
// Algorithmic limitation: BlockReduce does two WarpReduce calls, each
// of which reduces C10_WARP_SIZE elements. So, at most
// C10_WARP_SIZE**2 elements can be reduced at a time.
// NOTE: This is >= the max block size on current hardware anyway (1024).
constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
// Sums `val` accross all threads in a warp.
//
// Assumptions:
// - The size of each block should be a multiple of `C10_WARP_SIZE`
template <typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
// Sums `val` across all threads in a block.
//
// Assumptions:
// - Thread blocks are an 1D set of threads (indexed with `threadIdx.x` only)
// - The size of each block should be a multiple of `C10_WARP_SIZE`
// - `shared` should be a pointer to shared memory with size of, at least,
// `sizeof(T) * number_of_warps`
template <typename T>
__inline__ __device__ T BlockReduceSum(T val, T* shared) {
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
val = WarpReduceSum(val);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : T(0);
if (wid == 0) {
val = WarpReduceSum(val);
}
return val;
}
template <typename T, class ReduceOp>
__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
#pragma unroll
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
val = op.combine(val, op.warp_shfl_down(val, offset));
}
return val;
}
template <typename T, class ReduceOp>
__inline__ __device__ T
BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) {
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
val = WarpReduce(val, op);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid]
: identity_element;
if (wid == 0) {
val = WarpReduce(val, op);
}
return val;
}
} // namespace cuda_utils
} // namespace native
} // namespace at