forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSorting.cu
281 lines (247 loc) · 9.16 KB
/
Sorting.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/cuda/Sorting.h>
#include <ATen/core/TensorBase.h>
#include <ATen/ceil_div.h>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <c10/macros/Macros.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/SortingRadixSelect.cuh>
#include <c10/cuda/CUDAStream.h>
#include <cassert>
#include <cstdlib>
namespace at {
namespace native {
namespace {
// Finds the rank k element, and its index, of the values along dimension dim
template <typename scalar_t, typename index_t, int Dim>
__global__ void gatherKthValue(
cuda::detail::TensorInfo<scalar_t, index_t> input,
index_t inputSliceSize,
index_t k,
index_t numInputSlices,
index_t inputWithinSliceStride,
cuda::detail::TensorInfo<scalar_t, index_t> kthValue,
cuda::detail::TensorInfo<int64_t, index_t> indices) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of index_t
__shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit
index_t slice = getLinearBlockId<index_t>();
if (slice >= numInputSlices) {
return;
}
// Find the start offset for our slice
index_t sliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, input);
index_t kthValueSliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, kthValue);
index_t indicesSliceStartIndex =
cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, indices);
scalar_t* inputSliceStart = &input.data[sliceStartIndex];
scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex];
int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
// Find the k-th highest element in our input
scalar_t kValue = static_cast<scalar_t>(0);
radixSelect<
scalar_t,
typename TopKTypeConfig<scalar_t>::RadixType,
index_t>(
inputSliceStart,
k,
false,
inputSliceSize,
inputWithinSliceStride,
smem,
&kValue);
// Find the index of the k-th highest element
index_t kValueIndex = 0;
bool foundKValue = false;
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
: static_cast<scalar_t>(0);
bool isKValue = inRange &&
((v == kValue) || (at::_isnan(v) && at::_isnan(kValue)));
if (isKValue) {
kValueIndex = i;
foundKValue = true;
break;
}
}
if (foundKValue) {
kthValueSliceStart[0] = kValue;
indicesSliceStart[0] = kValueIndex;
}
}
// CUDA kernel to find the median, and its index, of the values along dimension dim
template <typename scalar_t, typename index_t, int Dim>
__global__ void gatherMedian(
cuda::detail::TensorInfo<scalar_t, index_t> values,
cuda::detail::TensorInfo<int64_t, index_t> indices,
cuda::detail::TensorInfo<scalar_t, index_t> input,
index_t inputSliceSize,
index_t numInputSlices,
index_t inputWithinSliceStride,
bool ignore_nan) {
// Shared memory for the subroutine RadixSelect. Note that RadixSelect converts the
// floating point type to int with the same relative ordering.
__shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit
index_t slice = getLinearBlockId<index_t>();
if (slice >= numInputSlices) {
return;
}
// Finds the start offset for our slice
index_t valuesSliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, values);
index_t indicesSliceStartIndex =
cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, indices);
index_t inputSliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, input);
scalar_t* valuesSliceStart = &values.data[valuesSliceStartIndex];
int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
scalar_t* inputSliceStart = &input.data[inputSliceStartIndex];
index_t nan_count = 0;
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
nan_count += at::_isnan(val) ? 1 : 0;
}
// Counts number of nan values
// This code performs a parallel sum reduction (not the most efficient code)
__shared__ int64_t num_nan;
if (threadIdx.x == 0) {
num_nan = 0;
}
__syncthreads();
if (nan_count > 0) {
gpuAtomicAddNoReturn(&num_nan, nan_count);
}
__syncthreads();
// For torch.median, if we found nan set k to last index so the computed value
// is nan, otherwise set k to the middle element of the non-nan values
index_t k = (!ignore_nan && num_nan > 0) ? inputSliceSize - 1
: (inputSliceSize - num_nan - 1) / 2;
// Find the median
scalar_t median = static_cast<scalar_t>(0);
radixSelect<
scalar_t,
typename TopKTypeConfig<scalar_t>::RadixType,
index_t>(
inputSliceStart,
k + 1,
false,
inputSliceSize,
inputWithinSliceStride,
smem,
&median);
valuesSliceStart[0] = median;
// Find the index of the median value in the slice
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
scalar_t val = doLdg(&inputSliceStart[i * inputWithinSliceStride]);
if (val == median || (at::_isnan(val) && at::_isnan(median))) {
indicesSliceStart[0] = i;
break;
}
}
}
struct KthValueLauncher {
int64_t k;
KthValueLauncher(int64_t k) : k(k) {}
template <typename scalar_t, typename index_t, int all_dims>
inline void launch(
cuda::detail::TensorInfo<scalar_t, index_t> values_info,
int collapse_values_dim,
cuda::detail::TensorInfo<int64_t, index_t> indices_info,
int collapse_indices_dim,
cuda::detail::TensorInfo<scalar_t, index_t> self_info,
int collapse_self_dim,
int64_t num_slices,
int64_t slice_size) {
(void)collapse_indices_dim; // Suppress unused variable warning
dim3 grid;
if (!getGridFromTiles(num_slices, grid)) {
AT_ERROR("slices are too many");
}
dim3 block(std::min(
round_up(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024));
auto stream = at::cuda::getCurrentCUDAStream();
gatherKthValue<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
self_info,
slice_size,
k,
num_slices,
/* The actual dimension that the k-selection is running in */
/* may have changed from collapseDims() */
self_info.strides[collapse_self_dim],
values_info,
indices_info);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
};
struct MedianLauncher {
bool ignore_nan;
MedianLauncher(bool ignore_nan) : ignore_nan(ignore_nan) {}
template <typename scalar_t, typename index_t, int all_dims>
inline void launch(
cuda::detail::TensorInfo<scalar_t, index_t> values_info,
int collapse_values_dim,
cuda::detail::TensorInfo<int64_t, index_t> indices_info,
int collapse_indices_dim,
cuda::detail::TensorInfo<scalar_t, index_t> self_info,
int collapse_self_dim,
int64_t num_slices,
int64_t slice_size) {
(void)collapse_values_dim; // Suppress unused variable warning
(void)collapse_indices_dim; // Suppress unused variable warning
dim3 grid;
if (!getGridFromTiles(num_slices, grid)) {
AT_ERROR("slices are too many");
}
dim3 block(std::min(
round_up(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024));
auto stream = at::cuda::getCurrentCUDAStream();
gatherMedian<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
values_info,
indices_info,
self_info,
slice_size,
num_slices,
self_info.strides[collapse_self_dim],
ignore_nan);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
};
} // namespace (anonymous)
void launch_kthvalue_kernel(
const TensorBase &values, const TensorBase &indices,
const TensorBase &self, int64_t dim, int64_t k) {
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] {
AT_DISPATCH_INDEX_TYPES(
cuda::detail::canUse32BitIndexMath(self) &&
cuda::detail::canUse32BitIndexMath(values) &&
cuda::detail::canUse32BitIndexMath(indices) ? ScalarType::Int : ScalarType::Long,
"kth_value_launcher", [&] {
run_launcher<scalar_t, index_t>(
values, indices, self, dim, KthValueLauncher(k));
});
});
}
void launch_median_kernel(
const TensorBase &vals, const TensorBase &inds,
const TensorBase &self, int64_t dim, bool ignore_nan) {
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, self.scalar_type(), "median_out_impl", [&] {
if (cuda::detail::canUse32BitIndexMath(vals) &&
cuda::detail::canUse32BitIndexMath(inds) &&
cuda::detail::canUse32BitIndexMath(self)) {
run_launcher<scalar_t, uint32_t>(
vals, inds, self, dim, MedianLauncher(ignore_nan));
} else {
run_launcher<scalar_t, uint64_t>(
vals, inds, self, dim, MedianLauncher(ignore_nan));
}
});
}
} // namespace native
} // namespace at