forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathScatterGatherKernel.cu
444 lines (370 loc) · 14.2 KB
/
ScatterGatherKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/ScatterGatherChecks.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
namespace at { namespace native {
// Implement as functors since lambdas don't get optimized.
class ReduceMultiply {
public:
template <typename scalar_t>
constexpr C10_DEVICE void operator() (scalar_t * self_data, const scalar_t * src_data) const {
gpuAtomicMul(self_data, *src_data);
}
};
static ReduceMultiply reduce_multiply;
class ReduceAdd {
public:
template <typename scalar_t>
constexpr C10_DEVICE void operator() (scalar_t * self_data, const scalar_t * src_data) const {
gpuAtomicAddNoReturn(self_data, *src_data);
}
};
static ReduceAdd reduce_add;
class TensorAssign {
public:
template <typename scalar_t>
constexpr C10_DEVICE void operator() (scalar_t * self_data, const scalar_t * src_data) const {
*self_data = *src_data;
}
};
static TensorAssign tensor_assign;
// The kernels are implemented on an opaque,
// self-aligned type of the correct size,
// to avoid redundant kernels for different types
// of the same size.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
// essentialy rewritten related to legacy::launch_kernel parts
template <int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, vt)
__global__ void _scatter_gather_elementwise_kernel(int N, func_t f) {
constexpr int nv = nt * vt;
int idx = nv * blockIdx.x + threadIdx.x;
#pragma unroll
for (int i = 0; i < vt; ++i) {
if (idx < N) {
f(idx);
idx += nt;
}
}
}
template <int nt, int vt, typename func_t>
static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
return;
}
const dim3 block(nt);
const dim3 grid((N + block.x * vt - 1) / (block.x * vt));
const auto stream = at::cuda::getCurrentCUDAStream();
_scatter_gather_elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <bool is_scatter_like, typename scalar_t>
struct _cuda_scatter_gather_internal_kernel {
template <typename func_t>
void operator() (
TensorIterator& iter,
int64_t index_size,
int64_t index_stride,
const func_t& f
) {
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
_cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t>()(
sub_iter, index_size, index_stride, f
);
}
return;
}
char* self_ptr = (char*)iter.data_ptr(0);
char* src_ptr = (char*)iter.data_ptr(1);
char* index_ptr = (char*)iter.data_ptr(2);
auto offset_calc = make_offset_calculator<3>(iter);
auto loop = [=]C10_DEVICE(int i) {
auto offsets = offset_calc.get(i);
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]);
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
&& "index out of bounds");
char* self_data = self_ptr + offsets[0];
char* src_data = src_ptr + offsets[1];
f(
(scalar_t*)self_data + (is_scatter_like ? idx_dim * index_stride : 0),
(scalar_t*)src_data + (is_scatter_like ? 0 : idx_dim * index_stride)
);
};
_launch_scatter_gather_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
}
}; // struct _cuda_scatter_fill_internal_kernel
template <bool is_scatter_like = true, bool cast_to_opaque = true>
struct cuda_scatter_gather_base_kernel {
template <typename func_t>
void operator()(
const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src,
const std::string& method_name,
const func_t& f
) {
at::assert_no_internal_overlap(self);
auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
auto self_strides = ensure_nonempty_vec(self.strides().vec());
auto src_strides = ensure_nonempty_vec(src.strides().vec());
// restride self and src such that
// self.shape = src.shape = index.shape
//
// restride stride[dim] such that
// if (is_scatter_like) self.stride[dim] = 0
// else src.stride[dim] = 0
auto self_restrided = is_scatter_like ?
restride_dim(self, dim, index_sizes)
: self.as_strided(index_sizes, self_strides);
auto src_restrided = is_scatter_like ?
src.as_strided(index_sizes, src_strides)
: restride_dim(src, dim, index_sizes);
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(self_restrided)
.add_input(src_restrided)
.add_input(index)
.build();
auto self_dim_stride = ensure_nonempty_stride(self, dim);
auto self_dim_size = ensure_nonempty_size(self, dim);
auto src_dim_stride = ensure_nonempty_stride(src, dim);
auto src_dim_size = ensure_nonempty_size(src, dim);
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
"cuda_scatter_gather_base_kernel_func", [&] {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
iter, index_size, index_stride, f
);
}
);
}
void operator()(
const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src,
const std::string& method_name,
const ReduceMultiply& f
) {
at::assert_no_internal_overlap(self);
auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
auto self_strides = ensure_nonempty_vec(self.strides().vec());
auto src_strides = ensure_nonempty_vec(src.strides().vec());
// restride self and src such that
// self.shape = src.shape = index.shape
//
// restride stride[dim] such that
// if (is_scatter_like) self.stride[dim] = 0
// else src.stride[dim] = 0
auto self_restrided = is_scatter_like ?
restride_dim(self, dim, index_sizes)
: self.as_strided(index_sizes, self_strides);
auto src_restrided = is_scatter_like ?
src.as_strided(index_sizes, src_strides)
: restride_dim(src, dim, index_sizes);
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(self_restrided)
.add_input(src_restrided)
.add_input(index)
.build();
auto self_dim_stride = ensure_nonempty_stride(self, dim);
auto self_dim_size = ensure_nonempty_size(self, dim);
auto src_dim_stride = ensure_nonempty_stride(src, dim);
auto src_dim_size = ensure_nonempty_size(src, dim);
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(),
"cuda_scatter_gather_base_kernel_reduce_multiply", [&] {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
iter, index_size, index_stride, f
);
}
);
}
}; // struct cuda_scatter_gather_base_kernel
template <typename scalar_t>
struct _cuda_scatter_fill_internal_kernel {
template <typename func_t>
void operator()(
TensorIterator& iter,
scalar_t src_val,
int64_t index_size,
int64_t index_stride,
const func_t& f
) {
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
_cuda_scatter_fill_internal_kernel<scalar_t>()(
sub_iter, src_val, index_size, index_stride, f
);
}
return;
}
char* self_ptr = (char*)iter.data_ptr(0);
char* index_ptr = (char*)iter.data_ptr(1);
auto offset_calc = make_offset_calculator<2>(iter);
auto loop = [=]C10_DEVICE(int i) {
auto offsets = offset_calc.get(i);
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[1]);
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
&& "index out of bounds"
);
char* self_data = self_ptr + offsets[0];
f(
(scalar_t*)self_data + idx_dim * index_stride,
(scalar_t*)&src_val
);
};
_launch_scatter_gather_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
}
}; // struct _cuda_scatter_fill_internal_kernel
template <bool cast_to_opaque = true>
struct cuda_scatter_fill_base_kernel {
template <typename func_t>
void operator()(
const Tensor& self, int64_t dim,
const Tensor& index, Scalar src,
const std::string& method_name,
const func_t& f
) {
at::assert_no_internal_overlap(self);
auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
// restride self such that
// self.shape = index.shape and
// self.stride[dim] = 0
auto self_restrided = restride_dim(self, dim, index_sizes);
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(self_restrided)
.add_input(index)
.build();
auto index_size = ensure_nonempty_size(self, dim);
auto index_stride = ensure_nonempty_stride(self, dim);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
"cuda_scatter_fill_base_kernel_func", [&] {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
auto src_scalar_val = src.to<scalar_t>();
auto src_val = *(dtype*)&src_scalar_val;
_cuda_scatter_fill_internal_kernel<dtype>()(
iter, src_val, index_size, index_stride, f
);
}
);
}
void operator()(
const Tensor& self, int64_t dim,
const Tensor& index, Scalar src,
const std::string& method_name,
const ReduceMultiply& f
) {
at::assert_no_internal_overlap(self);
auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
// restride self such that
// self.shape = index.shape and
// self.stride[dim] = 0
auto self_restrided = restride_dim(self, dim, index_sizes);
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(self_restrided)
.add_input(index)
.build();
auto index_size = ensure_nonempty_size(self, dim);
auto index_stride = ensure_nonempty_stride(self, dim);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(),
"cuda_scatter_fill_base_kernel_reduce_multiply", [&] {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
auto src_scalar_val = src.to<scalar_t>();
auto src_val = *(dtype*)&src_scalar_val;
_cuda_scatter_fill_internal_kernel<dtype>()(
iter, src_val, index_size, index_stride, f
);
}
);
}
}; // struct cuda_scatter_fill_base_kernel
void gather_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) {
cuda_scatter_gather_base_kernel</*is_scatter_like=*/false>()(
result, dim, index, self,
"gather_out_cuda", tensor_assign);
}
void scatter_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
cuda_scatter_gather_base_kernel<>()(
self, dim, index, src,
"scatter_cuda_", tensor_assign);
}
void scatter_fill_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src) {
cuda_scatter_fill_base_kernel<>()(
self, dim, index, src,
"scatter_fill_cuda_", tensor_assign);
}
void scatter_add_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("scatter_add_cuda_kernel");
cuda_scatter_gather_base_kernel</*is_scatter_like=*/true, /*cast_to_opaque=*/false>()(
self, dim, index, src,
"scatter_add_cuda_", reduce_add);
}
void scatter_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_multiply_", reduce_multiply);
break;
}
}
void scatter_scalar_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
const Scalar& value, const SCATTER_GATHER_OP& reduce) {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
"scatter_fill_cuda_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
"scatter_fill_cuda_multiply_", reduce_multiply);
break;
}
}
REGISTER_DISPATCH(gather_stub, &gather_cuda_kernel);
REGISTER_DISPATCH(scatter_stub, &scatter_cuda_kernel);
REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cuda_kernel);
REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cuda_kernel);
REGISTER_DISPATCH(scatter_reduce_stub, &scatter_reduce_cuda_kernel);
REGISTER_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_cuda_kernel);
}} // namespace at::native