forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSegmentReduce.cu
407 lines (369 loc) · 13.8 KB
/
SegmentReduce.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
#include <ATen/native/SegmentReduce.h>
#include <ATen/ATen.h>
#include <ATen/NumericUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/cub.cuh>
namespace at {
namespace native {
namespace {
struct CustomMax {
template <typename OutputT>
__host__ __device__ __forceinline__ OutputT
operator()(const OutputT& a, const OutputT& b) const {
if (at::_isnan(a)) {
return a;
} else if (at::_isnan(b)) {
return b;
}
return std::max<OutputT>(a, b);
}
};
struct CustomSum {
template <typename OutputT>
__host__ __device__ __forceinline__ OutputT
operator()(const OutputT& a, const OutputT& b) const {
return a + b;
}
};
struct CustomMin {
template <typename OutputT>
__host__ __device__ __forceinline__ OutputT
operator()(const OutputT& a, const OutputT& b) const {
if (at::_isnan(a)) {
return a;
} else if (at::_isnan(b)) {
return b;
}
return std::min<OutputT>(a, b);
}
};
Tensor _get_complete_sum(const Tensor& lengths) {
int64_t segment_count = lengths.numel();
TORCH_CHECK(segment_count < INT_MAX);
auto offsets = at::empty({segment_count + 1}, lengths.options());
offsets[0].zero_();
AT_DISPATCH_INDEX_TYPES(
lengths.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
at::cuda::cub::inclusive_sum(
lengths_data_ptr,
offsets_data_ptr + 1,
segment_count);
}));
return offsets;
}
template <typename scalar_t, typename index_t>
__global__ static void post_sum_div_kernel(
scalar_t* output_data,
const index_t* lengths_data,
const int64_t segment_count,
bool is_initial_set,
scalar_t initial) {
CUDA_KERNEL_LOOP(index, segment_count) {
CUDA_KERNEL_ASSERT(lengths_data[index] >= 0);
if (lengths_data[index] == 0) {
if (is_initial_set) {
output_data[index] = initial;
} else {
output_data[index] = NAN;
}
} else if (!at::_isnan(output_data[index])) {
output_data[index] = output_data[index] / lengths_data[index];
}
}
}
template <typename scalar_t, typename index_t>
__global__ void segment_reduce_forward_kernel(
SegmentReductionType reduction,
scalar_t* output_data,
scalar_t* values_data,
const index_t* lengths_data,
const index_t* lengths_cumsum_data,
const int64_t segment_count,
const int64_t stride_count,
bool is_initial_set,
scalar_t initial_value) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t row_id = idx / stride_count;
int64_t lane_id = idx % stride_count;
if (idx >= (segment_count * stride_count)) {
return;
}
int64_t offset_start = lengths_cumsum_data[row_id];
int64_t offset_end = lengths_cumsum_data[row_id + 1];
// ===== step2: apply reduction
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id;
const auto data = values_data[starting_index];
// TODO: There is no need to branch with every element
if (reduction == SegmentReductionType::MAX) {
initial_value =
at::_isnan(data) ? data : std::max<scalar_t>(initial_value, data);
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = initial_value + data;
} else if (reduction == SegmentReductionType::MIN) {
initial_value =
at::_isnan(data) ? data : std::min<scalar_t>(initial_value, data);
}
}
// ===== step3: finalize reduction
CUDA_KERNEL_ASSERT(lengths_data[row_id] >= 0);
if (lengths_data[row_id] == 0 && !is_initial_set &&
reduction == SegmentReductionType::MEAN) {
initial_value = static_cast<scalar_t>(NAN);
} else if (
reduction == SegmentReductionType::MEAN && lengths_data[row_id] > 0 &&
!at::_isnan(initial_value)) {
initial_value = initial_value / lengths_data[row_id];
}
int64_t output_index = (row_id * stride_count) + lane_id;
output_data[output_index] = initial_value;
}
template <typename scalar_t, typename index_t>
__global__ void segment_reduce_backward_kernel(
SegmentReductionType reduction,
scalar_t* grad_input_data,
scalar_t* grad_data,
scalar_t* output_data,
const scalar_t* values_data,
const index_t* lengths_data,
const index_t* lengths_cumsum_data,
const int64_t segment_count,
const int64_t stride_count) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t row_id = idx / stride_count;
int64_t lane_id = idx % stride_count;
if (idx >= (segment_count * stride_count)) {
return;
}
if (lengths_data[row_id] == 0) {
return;
}
int64_t offset_start = lengths_cumsum_data[row_id];
int64_t offset_end = lengths_cumsum_data[row_id + 1];
int64_t output_index = (row_id * stride_count) + lane_id;
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
int64_t counter = 0;
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id;
if (at::_isnan(values_data[starting_index]) ||
values_data[starting_index] == output_data[output_index]) {
grad_input_data[starting_index] = grad_data[output_index];
counter++;
}
}
// Average gradient based on number of maximum elements in the
// segment
if (counter < 2) {
return;
}
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id;
if (grad_input_data[starting_index] > 0) {
grad_input_data[starting_index] =
grad_input_data[starting_index] / counter;
}
}
} else if (reduction == SegmentReductionType::MEAN) {
auto grad_val = grad_data[output_index] / lengths_data[row_id];
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id;
grad_input_data[starting_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id;
grad_input_data[starting_index] = grad_val;
}
}
}
} // namespace
Tensor _segment_reduce_cuda_backward_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
const Tensor& lengths_contig,
int64_t axis) {
int64_t segment_count = lengths_contig.numel();
auto output_shape = data_contig.sizes().vec();
output_shape[axis] = segment_count;
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
int64_t stride_count = data_contig.numel() / data_contig.size(axis);
auto offsets = _get_complete_sum(lengths_contig);
constexpr int threads_per_block = 256;
int64_t num_blocks =
((segment_count * stride_count) + threads_per_block - 1) /
threads_per_block;
num_blocks = std::max(num_blocks, (int64_t)1);
AT_DISPATCH_INDEX_TYPES(
lengths_contig.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
auto* offsets_data = offsets.data_ptr<index_t>();
// TODO: Swtich to TensorIterator for better maintainablility and
// readability
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
data_contig.scalar_type(),
"_segment_reduce_cpu",
([&]() {
auto* output_data = output_contig.data_ptr<scalar_t>();
auto* grad_data = grad_contig.data_ptr<scalar_t>();
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
const auto* values_data = data_contig.data_ptr<scalar_t>();
segment_reduce_backward_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
reduction,
grad_input_data,
grad_data,
output_data,
values_data,
lengths_data,
offsets_data,
segment_count,
stride_count);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
}));
return grad_input;
}
Tensor _segment_reduce_cuda_kernel(
SegmentReductionType reduction,
const Tensor& data,
const Tensor& lengths,
int64_t axis,
const c10::optional<Scalar>& initial) {
int64_t segment_count = lengths.numel();
auto output_shape = data.sizes().vec();
output_shape[axis] = segment_count;
auto output = at::empty(output_shape, data.options());
int64_t stride_count = data.numel() / data.size(axis);
auto offsets = _get_complete_sum(lengths);
constexpr int threads_per_block = 256;
int64_t num_blocks =
((segment_count * stride_count) + threads_per_block - 1) /
threads_per_block;
num_blocks = std::max(num_blocks, (int64_t)1);
AT_DISPATCH_INDEX_TYPES(
lengths.type(), "_segment_reduce_cuda_kernel1", ([&] {
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
data.scalar_type(),
"segment_reduce_cuda",
[&]() {
auto* data_data_ptr = data.data_ptr<scalar_t>();
auto* output_data_ptr = output.data_ptr<scalar_t>();
// initialize starting value
scalar_t initial_value;
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
}
if (output_shape.size() > 1) {
segment_reduce_forward_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
reduction,
output_data_ptr,
data_data_ptr,
lengths_data_ptr,
offsets_data_ptr,
segment_count,
stride_count,
initial.has_value(),
initial_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
if (reduction == SegmentReductionType::MAX) {
CustomMax max_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
max_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::MEAN) {
CustomSum sum_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
sum_op,
initial_value,
at::cuda::getCurrentCUDAStream());
post_sum_div_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
output_data_ptr,
lengths_data_ptr,
segment_count,
initial.has_value(),
initial_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (reduction == SegmentReductionType::MIN) {
CustomMin min_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
min_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::SUM) {
CustomSum sum_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
sum_op,
initial_value,
at::cuda::getCurrentCUDAStream());
}
}
});
}));
return output;
}
REGISTER_DISPATCH(_segment_reduce_stub, &_segment_reduce_cuda_kernel);
REGISTER_DISPATCH(
_segment_reduce_backward_stub,
&_segment_reduce_cuda_backward_kernel);
} // namespace native
} // namespace at