forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTensorModeKernel.cuh
432 lines (383 loc) · 14 KB
/
TensorModeKernel.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
#pragma once
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/block_reduce.cuh>
namespace at {
namespace native {
// Used for a segmented reduction
struct ModeUnsignedBoolPair {
unsigned int val;
bool flag;
};
// In the kernel below, we have a common pattern of reducing (unsigned int,
// unsigned int) pairs of data
struct ModeUnsignedPair {
unsigned int val;
unsigned int index;
};
// Inclusive Scan via an upsweep/downsweep mechanism. Assumes:
//
// 1. Power2ScanSize is a power of 2. This code still works for collections that
// do not exactly contain a power of 2 number of elements, simply round up to
// the nearest power of 2 and then call.
//
// 2. That there are two-elements per thread, i.e. the size of the smem storage
// is 2 * blockDim.x * sizeof(T).
//
// Consider a (+)-Scan on the following elements:
//
// Upsweep:
//
// 0 1 2 3 4 5 6 7
// 1 5 9 13
// 6 22
// 28
//
// Downsweep:
// 15
// 3 10 21
template <int Power2ScanSize, typename T, class BinaryOp>
__device__ void inclusivePrefixScan(T* smem, BinaryOp binop) {
// Reduce step ("upsweep")
#pragma unroll
for (int stride = 1; stride < Power2ScanSize; stride <<= 1) {
int index = (threadIdx.x + 1) * stride * 2 - 1;
if (index < Power2ScanSize) {
smem[index] = binop(smem[index], smem[index - stride]);
}
__syncthreads();
}
// Post-reduce step ("downsweep")
#pragma unroll
for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) {
int index = (threadIdx.x + 1) * stride * 2 - 1;
if ((index + stride) < Power2ScanSize) {
smem[index + stride] = binop(smem[index + stride], smem[index]);
}
__syncthreads();
}
}
// Block-wide reduction where each thread locally reduces N
// values before letting a single warp take over - assumes
// threadVals is in registers, not shared memory
//
// If smem is not used again, there is no need to __syncthreads before this
// call. However, if smem will be used, e.g., this function is called in a loop,
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template <int N, typename T, typename ReduceOp>
__device__ T reduceBlockWithNThreadLocalReductions(
T* smem,
T threadVals[N],
const unsigned int numVals,
ReduceOp reduceOp,
T init) {
int offset = threadIdx.x * N;
T local = offset < numVals ? threadVals[0] : init;
#pragma unroll
for (int i = 1; i < N; ++i) {
++offset;
T next = offset < numVals ? threadVals[i] : init;
local = reduceOp.combine(local, next);
}
return cuda_utils::BlockReduce(local, reduceOp, init, smem);
}
template <typename T>
__device__ inline void swapVars(T& t1, T& t2) {
T tmp = t1;
t1 = t2;
t2 = tmp;
}
template <typename Comparator, typename K, typename V>
__device__ inline void bitonicSwap(
K& kA,
V& vA,
bool& validA,
K& kB,
V& vB,
bool& validB,
bool dir,
const Comparator& comp) {
// Invalid entries always sort to the end
bool swap = (comp(kA, kB) && validA) || !validB;
if (swap == dir) {
swapVars(kA, kB);
swapVars(vA, vB);
swapVars(validA, validB);
}
};
template <typename Comparator, typename K>
__device__ inline void bitonicSwapKeys(
K& kA,
bool& validA,
K& kB,
bool& validB,
bool dir,
const Comparator& comp) {
bool swap = (comp(kA, kB) && validA) || !validB;
if (swap == dir) {
swapVars(kA, kB);
swapVars(validA, validB);
}
}
template <
typename K,
typename IndexType,
int Power2SortSize,
typename Comparator>
__device__ inline void bitonicSortKeys(
K keys[Power2SortSize],
bool valid[Power2SortSize],
const Comparator& comp) {
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
bool flag = ((threadIdx.x & (size / 2)) != 0);
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
__syncthreads();
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
bitonicSwapKeys<Comparator, K>(
keys[pos],
valid[pos],
keys[pos + stride],
valid[pos + stride],
flag,
comp);
}
}
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
__syncthreads();
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
bitonicSwapKeys<Comparator, K>(
keys[pos],
valid[pos],
keys[pos + stride],
valid[pos + stride],
false,
comp);
}
__syncthreads();
}
// The mode kernel has the following characteristics: It uses internal shared
// memory buffers of Power2Size, which must be greater than the number of
// elements. Additionally, there is one block for every slice to calculate the
// mode for, and in each block there is one thread for every two elements.
//
// Both sorted and positions are assumed to be contiguous Tensors with the mode
// dimension as the innermost dim, such that we can get the particular slice for
// a Tensor via its linear block dimension * the slice size.
template <typename T, unsigned int Power2Size>
__global__ void compute_mode(
T* input,
at::cuda::detail::TensorInfo<T, unsigned int> values,
at::cuda::detail::TensorInfo<int64_t, unsigned int> indices,
int64_t sliceSize,
int64_t slices) {
int tidx = threadIdx.x;
int stidx = blockDim.x + threadIdx.x; // Second index this thread responsible for
// First, we need to calculate the offset into the sorted Tensor that
// represents the start of the slice for this block to calculate the mode for.
// This offset is a combination of the gridIndices, and the number of elements
// in the slice.
unsigned int blockId = getLinearBlockId<unsigned int>();
unsigned int linearOffset = blockId * sliceSize;
if (blockId >= slices) {
return;
}
// shmem is a dynamically sized buffer we will use throughout the kernel to
// handle computation efficiently. The size of this shmem must be
// sizeof(T) * Power2Size + (2 * sizeof(unsigned int) * Power2Size)
//
// Initially, the buffer will be organized as follows:
//
// [smem (slice elements) | bmem (valid indices) | <scratch space>]
extern __shared__ char shmem[];
// smem represents a proportion of the shared memory buffer that is used to
// store the elements from the slice:
T* smem = reinterpret_cast<T*>(shmem);
// Each thread loads up to two elements from the Tensor into shared memory
if (tidx < sliceSize) {
smem[tidx] = input[linearOffset + tidx];
}
if (stidx < sliceSize) {
smem[stidx] = input[linearOffset + stidx];
}
// Next, we initialize a boolean region of the buffer, offset by the loaded
// element smem region
bool* bmem = reinterpret_cast<bool*>(&smem[Power2Size]);
// The first use of this region stores bmem[i] = i < sliceSize to mark the
// valid components in the smem buffer
bmem[tidx] = tidx < sliceSize;
bmem[stidx] = stidx < sliceSize;
__syncthreads(); // barrier for smem, bmem initialization
// First, sort the input slice in ascending order. smem contains the input
// elements, and bmem marks the valid indices
bitonicSortKeys<T, unsigned int, Power2Size>(
smem, bmem, [&] GPU_LAMBDA(const auto& a, const auto& b) {
return a < b;
});
__syncthreads(); // make no assumptions that the sort syncs at end
// The next step of our algorithm is performing a block-wide comparison of
// neighboring elements. In particular, given an sorted input slice A, we
// produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise
// 0.
//
// Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8]
// B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
//
// In particular, we can think of B[i] true indicating the start of a sequence
// of equal values in the sorted list. Similarly, we will also store the
// negation of B, which we'll call C. In particular, we can think of C[i] =
// true iff A[i-1] == A[i] in our original sorted slice.
//
// C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
// We overwrite bmem, and treat the rest of shared memory as a buffer of
// (index, flag) pairs where the index represents values from C, and the flag
// represents values from B.
//
// [smem (sorted slice) | ubpmem (index, flag pairs)]
struct ModeUnsignedBoolPair* ubpmem =
reinterpret_cast<struct ModeUnsignedBoolPair*>(&smem[Power2Size]);
if (tidx == 0) {
ubpmem[0].flag = true;
ubpmem[0].val = 0;
}
// Compares elements (0, 1), (2, 3), ... and sets 1, 3, ...
ubpmem[tidx * 2 + 1].flag =
smem[tidx * 2] != smem[tidx * 2 + 1]; // (0, 1), (1, 2), etc.
ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag;
// Compares elements (1, 2), (3, 4), ... and sets 2, 4, ...
if (((tidx + 1) * 2) < Power2Size) {
ubpmem[(tidx + 1) * 2].flag =
smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2];
ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag;
}
__syncthreads(); // barrier for ubpmem initialization
// Next, we perform a segmented prefix sum on the neighboring elements, where
// the presence of a one indicates the start of a segment. In this case B acts
// as the segment start flags, and C is the buffer to be summed:
//
// Input (C) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
// Flag (B) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
// Output (C) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0]
//
// Afterwards, the (index) components of the ubpmem buffer contain the lengths
// of the segments (minus 1), i.e. the counts of each element in the original
// input.
inclusivePrefixScan<Power2Size>(
ubpmem, [=] GPU_LAMBDA(const auto& a, const auto& b) {
ModeUnsignedBoolPair c;
c.val = a.flag ? a.val : a.val + b.val;
c.flag = a.flag | b.flag;
return c;
});
// assumes scan syncs at the end
// Next, we reinterpret the ubpmem buffer as pairs of unsigned integers (i.e.
// we treat the boolean flag regions as integers). We initialize these to
// represent indices, and we'll call this buffer I
struct ModeUnsignedPair* uupmem =
reinterpret_cast<struct ModeUnsignedPair*>(ubpmem);
// At this point, we need to find the maximum element in lengths buffer C.
// This element will represent the count (-1) of the mode. Because of the
// way we have set up the problem, the index where this mode occurs will
// also be the location of the mode value in the sorted array, e.g.
//
// smem = [0, 0, 1, 1, 1, 2]
// C = [0, 1, 0, 1, 2, 0]
// I = [0, 1, 2, 3, 4, 5]
// ^
// maximum value, also aligned with mode = 1
//
// We perform a block wide max-reduction of the C buffer, but we also need the
// indices to come along with it, so we utilize the uupmem construction.
//
// At the end we need to return the ModeUnsignedPair containing index = 4, val
// = 2, which represents the max
// In practice, we will make each thread locally reduce 2 values in its
// registers prior to the global block-wide reduction. Note that instead of
// tidx/stidx, we utilize tidx * 2, tidx * 2 + 1, so each thread deals with
// adjacent elements. This is because the reduce code below relies on thread
// elements to be adjacent.
struct ModeUnsignedPair uup[2];
uup[0].index = tidx * 2;
uup[0].val = ubpmem[tidx * 2].val;
uup[1].index = tidx * 2 + 1;
uup[1].val = ubpmem[tidx * 2 + 1].val;
__syncthreads();
struct ModeUnsignedPair max = {0, 0};
struct MaxOp {
inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const {
return b.val > a.val ? b : a;
}
inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const {
ModeUnsignedPair ret;
ret.index = WARP_SHFL_DOWN(acc.index, offset);
ret.val = WARP_SHFL_DOWN(acc.val, offset);
return ret;
}
} max_op;
max = reduceBlockWithNThreadLocalReductions<2>(
uupmem,
uup,
sliceSize,
max_op,
max);
// Store the mode in shared memory for use in finding the mode in the input
// slice
__shared__ T mode;
// Given the above constraints, the mode is the value at the reduced index in
// the original sorted element buffer
if (tidx == 0) {
mode = smem[max.index];
}
__syncthreads(); // broadcast mode
// Finally, we need to find "an" index of the mode in the input
// Tensor. The API does not constrain which index we pick, but here
// we always pick the largest index. We store the index if the value
// is the mode, or 0 otherwise. Then find the maximum value.
//
// Again we reduce 2 elements in the thread's registers prior to the
// block-wide reduction
unsigned mode_index[2] = {0u, 0u};
if (tidx * 2 < sliceSize) {
const unsigned idx = tidx * 2;
mode_index[0] = input[linearOffset + idx] == mode ? idx : 0u;
}
if (tidx * 2 + 1 < sliceSize) {
const unsigned idx = tidx * 2 + 1;
mode_index[1] = input[linearOffset + idx] == mode ? idx : 0u;
}
struct MaxIndexOp {
inline __device__ unsigned combine(unsigned a, unsigned b) const {
return b > a ? b : a;
}
inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
} max_index_op;
int64_t index = reduceBlockWithNThreadLocalReductions<2>(
reinterpret_cast<unsigned*>(&shmem[0]),
mode_index,
sliceSize,
max_index_op,
0u);
// Finally, we have the mode, and an index where it occurs. We use a single
// thread to place this in the appropriate output position
if (tidx == 0) {
unsigned int outputOffset =
at::cuda::detail::IndexToOffset<T, unsigned int, -1>::get(
blockId, values);
values.data[outputOffset] = mode;
indices.data[outputOffset] = index;
}
}
} // namespace native
} // namespace at