forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSortUtils.cuh
161 lines (136 loc) · 5.1 KB
/
SortUtils.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Optional.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/Sort.h>
namespace at { namespace native {
template <typename T>
__device__ inline void swapVars(T& t1, T& t2) {
T tmp = t1;
t1 = t2;
t2 = tmp;
}
template <typename Comparator, typename K, typename V>
__device__ inline void bitonicSwap(K& kA, V& vA, bool& validA,
K& kB, V& vB, bool& validB,
bool dir,
const Comparator& comp) {
// Invalid entries always sort to the end
bool swap = (comp(kA, kB) && validA) || !validB;
if (swap == dir) {
swapVars(kA, kB);
swapVars(vA, vB);
swapVars(validA, validB);
}
};
template <typename Comparator, typename K, typename V,
typename IndexType, int Power2SortSize>
__device__ inline void bitonicSort(K keys[Power2SortSize],
V values[Power2SortSize],
bool valid[Power2SortSize],
const Comparator& comp) {
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
bool flag = ((threadIdx.x & (size / 2)) != 0);
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
__syncthreads();
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos], valid[pos],
keys[pos + stride], values[pos + stride], valid[pos + stride],
flag, comp);
}
}
#if !defined(USE_ROCM)
#pragma unroll
#endif
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
__syncthreads();
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos], valid[pos],
keys[pos + stride], values[pos + stride], valid[pos + stride],
false, comp);
}
__syncthreads();
}
// at::cuda::detail::TensorInfo version
// Sorts (key, value) pairs (in different tensors) in-place; i.e.,
// modifies the input `keys` and `values`
template <typename K, typename V,
int KeyDims, int ValueDims,
typename Comparator, typename IndexType, int Power2SortSize>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void
bitonicSortKVInPlace(at::cuda::detail::TensorInfo<K, IndexType> keys,
IndexType keySlices,
IndexType keySliceSize,
IndexType keySliceStride,
at::cuda::detail::TensorInfo<V, IndexType> values,
IndexType valueSliceStride,
Comparator comp) {
// Find the slice of the tensor that we are sorting
const IndexType linearIndex = getLinearBlockId<IndexType>();
// Tiling the slices could have us be out of bounds, if there are a
// lot of slices to sort
if (linearIndex >= keySlices) {
return;
}
__shared__ K sharedKeys[Power2SortSize];
__shared__ V sharedValues[Power2SortSize];
__shared__ bool sharedValid[Power2SortSize];
const IndexType keyStartOffset =
at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
const IndexType valueStartOffset =
at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);
// If the sort size is 1, the data is already sorted
if (Power2SortSize == 1) {
return;
} else {
// Otherwise, each thread is responsible for loading and storing 2
// elements. The sort size is guaranteed to be >= 2
const int elem1 = threadIdx.x;
const int elem2 = threadIdx.x + (Power2SortSize / 2);
bool valid1 = (elem1 < keySliceSize);
K k1 = valid1 ?
keys.data[keyStartOffset + elem1 * keySliceStride] : static_cast<K>(0);
V v1 = valid1 ?
values.data[valueStartOffset + elem1 * valueSliceStride] : static_cast<V>(0);
sharedKeys[elem1] = k1;
sharedValues[elem1] = v1;
sharedValid[elem1] = valid1;
bool valid2 = (elem2 < keySliceSize);
K k2 = valid2 ?
keys.data[keyStartOffset + elem2 * keySliceStride] : static_cast<K>(0);
V v2 = valid2 ?
values.data[valueStartOffset + elem2 * valueSliceStride] : static_cast<V>(0);
sharedKeys[elem2] = k2;
sharedValues[elem2] = v2;
sharedValid[elem2] = valid2;
// Sort!
bitonicSort<Comparator, K, V, IndexType, Power2SortSize>(
sharedKeys, sharedValues, sharedValid, comp);
// elem1 and elem2 values might be out-of-range, if the data size we are
// sorting is smaller than half the power2 size
if (valid1) {
keys.data[keyStartOffset + elem1 * keySliceStride] =
sharedKeys[elem1];
values.data[valueStartOffset + elem1 * valueSliceStride] =
sharedValues[elem1];
}
if (valid2) {
keys.data[keyStartOffset + elem2 * keySliceStride] =
sharedKeys[elem2];
values.data[valueStartOffset + elem2 * valueSliceStride] =
sharedValues[elem2];
}
}
}
}} // at::native