forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReflectionPad.cu
667 lines (554 loc) · 21.7 KB
/
ReflectionPad.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
#include <ATen/ATen.h>
#include <ATen/ceil_div.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <thrust/pair.h>
namespace at {
namespace native {
namespace {
using at::cuda::detail::canUse32BitIndexMath;
__device__
inline thrust::pair<int64_t, int64_t> get_index_mapping1d(
int64_t input_w, int64_t output_w,
int64_t output_x,
int64_t pad_l) {
// 3D grid of 1D blocks
auto input_offset =
(blockIdx.y + blockIdx.z * gridDim.y) * input_w;
auto output_offset =
(blockIdx.y + blockIdx.z * gridDim.y) * output_w;
auto i_start_x = ::max(int64_t(0), -pad_l);
auto o_start_x = ::max(int64_t(0), pad_l);
int64_t input_x = ::abs(output_x - pad_l)
- ::abs(output_x - (input_w + pad_l - 1))
- output_x
+ 2 * pad_l + input_w - 1
- o_start_x + i_start_x;
return thrust::make_pair<int64_t, int64_t>(
input_offset + input_x, output_offset + output_x);
}
__device__
inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
int64_t input_dim_x, int64_t input_dim_y,
int64_t output_dim_x, int64_t output_dim_y,
int64_t pad_l, int64_t pad_t,
int64_t output_xy, int y_shift, int z_shift, int nplane) {
// 3D grid of 1D blocks
auto input_offset =
((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) * input_dim_x * input_dim_y;
auto output_offset =
((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) * output_dim_x * output_dim_y;
auto output_x = output_xy % output_dim_x;
auto output_y = output_xy / output_dim_x;
auto i_start_x = ::max(int64_t(0), -pad_l);
auto i_start_y = ::max(int64_t(0), -pad_t);
auto o_start_x = ::max(int64_t(0), pad_l);
auto o_start_y = ::max(int64_t(0), pad_t);
auto input_x = ::abs(output_x - pad_l)
- ::abs(output_x - (input_dim_x + pad_l - 1))
- output_x
+ 2 * pad_l + input_dim_x - 1
- o_start_x + i_start_x;
auto input_y = ::abs(output_y - pad_t)
- ::abs(output_y - (input_dim_y + pad_t - 1))
- output_y
+ 2 * pad_t + input_dim_y - 1
- o_start_y + i_start_y;
return thrust::make_pair<int64_t, int64_t>(
input_offset + input_y * input_dim_x + input_x,
output_offset + output_y * output_dim_x + output_x);
}
template<typename scalar_t>
__global__ void reflection_pad1d_out_kernel(
scalar_t * input, scalar_t * output,
int64_t input_w,
int64_t pad_l, int64_t pad_r) {
auto output_x = threadIdx.x + blockIdx.x * blockDim.x;
auto output_w = input_w + pad_l + pad_r;
if (output_x < output_w) {
auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
output[index_pair.second] = input[index_pair.first];
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_backward_out_kernel(
scalar_t * grad_input, scalar_t * grad_output,
int64_t input_w,
int64_t pad_l, int64_t pad_r) {
auto output_x = threadIdx.x + blockIdx.x * blockDim.x;
auto output_w = input_w + pad_l + pad_r;
if (output_x < output_w) {
auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
gpuAtomicAddNoReturn(
&grad_input[index_pair.first], grad_output[index_pair.second]);
}
}
template<typename scalar_t>
__global__ void reflection_pad2d_out_kernel(
scalar_t * input, scalar_t * output,
int64_t input_dim_x, int64_t input_dim_y,
int pad_t, int pad_b, int pad_l, int pad_r, int y_shift, int z_shift, int nplane) {
auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
auto output_dim_x = input_dim_x + pad_l + pad_r;
auto output_dim_y = input_dim_y + pad_t + pad_b;
if (output_xy < output_dim_x * output_dim_y) {
auto index_pair = get_index_mapping2d(
input_dim_x, input_dim_y,
output_dim_x, output_dim_y,
pad_l, pad_t,
output_xy, y_shift, z_shift, nplane);
output[index_pair.second] = input[index_pair.first];
}
}
template <typename scalar_t>
__global__ void reflection_pad2d_backward_out_kernel(
scalar_t * grad_input, scalar_t * grad_output,
int64_t input_dim_x, int64_t input_dim_y,
int pad_t, int pad_b, int pad_l, int pad_r, int y_shift, int z_shift, int nplane) {
auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
auto output_dim_x = input_dim_x + pad_l + pad_r;
auto output_dim_y = input_dim_y + pad_t + pad_b;
if (output_xy < output_dim_x * output_dim_y) {
auto index_pair = get_index_mapping2d(
input_dim_x, input_dim_y,
output_dim_x, output_dim_y,
pad_l, pad_t,
output_xy, y_shift, z_shift, nplane);
gpuAtomicAddNoReturn(&grad_input[index_pair.first], grad_output[index_pair.second]);
}
}
template <typename scalar_t, typename F>
__device__ inline void parallel_reflection_pad3d(
PackedTensorAccessor64<scalar_t, 5> input,
PackedTensorAccessor64<scalar_t, 5> output,
int64_t pad_left,
int64_t pad_top,
int64_t pad_front,
int64_t y_shift,
int64_t z_shift,
const F& f) {
int64_t output_id = threadIdx.x + blockIdx.x * blockDim.x;
if (output_id >= (output.size(2) * output.size(3) * output.size(4))) {
return;
}
int64_t output_x = output_id % output.size(4);
int64_t output_y = (output_id / output.size(4)) % output.size(3);
int64_t output_z = output_id / (output.size(3) * output.size(4));
int64_t i_start_x = ::max(int64_t(0), -pad_left);
int64_t o_start_x = ::max(int64_t(0), pad_left);
int64_t i_start_y = ::max(int64_t(0), -pad_top);
int64_t o_start_y = ::max(int64_t(0), pad_top);
int64_t i_start_z = ::max(int64_t(0), -pad_front);
int64_t o_start_z = ::max(int64_t(0), pad_front);
int64_t input_x = ::abs(output_x - pad_left)
- ::abs(output_x - (input.size(4) + pad_left - 1))
- output_x
+ 2 * pad_left + input.size(4) - 1
- o_start_x + i_start_x;
int64_t input_y = ::abs(output_y - pad_top)
- ::abs(output_y - (input.size(3) + pad_top - 1))
- output_y
+ 2 * pad_top + input.size(3) - 1
- o_start_y + i_start_y;
int64_t input_z = ::abs(output_z - pad_front)
- ::abs(output_z - (input.size(2) + pad_front - 1))
- output_z
+ 2 * pad_front + input.size(2) - 1
- o_start_z + i_start_z;
int64_t plane = blockIdx.y + y_shift;
int64_t batch = blockIdx.z + z_shift;
f(plane, batch, output_z, output_y, output_x, input_z, input_y, input_x);
}
template<typename scalar_t>
__global__ void reflection_pad3d_out_kernel(
PackedTensorAccessor64<scalar_t, 5> input,
PackedTensorAccessor64<scalar_t, 5> output,
int64_t pad_left, int64_t pad_top, int64_t pad_front,
int64_t y_shift, int64_t z_shift
){
parallel_reflection_pad3d(
input,
output,
pad_left,
pad_top,
pad_front,
y_shift,
z_shift,
[&] __device__(
int64_t plane,
int64_t batch,
int64_t output_z,
int64_t output_y,
int64_t output_x,
int64_t input_z,
int64_t input_y,
int64_t input_x) {
auto value_to_copy = input[batch][plane][input_z][input_y][input_x];
output[batch][plane][output_z][output_y][output_x] = value_to_copy;
});
}
template <typename scalar_t>
__global__ void reflection_pad3d_backward_out_kernel(
PackedTensorAccessor64<scalar_t, 5> grad_input,
PackedTensorAccessor64<scalar_t, 5> grad_output,
int64_t pad_left, int64_t pad_top, int64_t pad_front,
int64_t y_shift, int64_t z_shift
) {
parallel_reflection_pad3d(
grad_input,
grad_output,
pad_left,
pad_top,
pad_front,
y_shift,
z_shift,
[&] __device__(
int64_t plane,
int64_t batch,
int64_t output_z,
int64_t output_y,
int64_t output_x,
int64_t input_z,
int64_t input_y,
int64_t input_x) {
auto value_to_add = grad_output[batch][plane][output_z][output_y][output_x];
auto target = &grad_input[batch][plane][input_z][input_y][input_x];
gpuAtomicAddNoReturn(target, value_to_add);
});
}
void reflection_pad2d_out_template(
Tensor &output, const Tensor &input_, IntArrayRef padding) {
TORCH_CHECK(canUse32BitIndexMath(input_),
"input tensor must fit into 32-bit index math");
int plane_dim = 0;
int dim_h = 1;
int dim_w = 2;
int nbatch = 1;
bool valid_dims = input_.size(1) != 0 && input_.size(2) != 0;
TORCH_CHECK(
(input_.ndimension() == 3 && valid_dims) ||
(input_.ndimension() == 4 && valid_dims && input_.size(3) != 0),
"3D or 4D (batch mode) tensor expected for input, but got: ", input_);
if (input_.ndimension() == 4) {
nbatch = input_.size(0);
plane_dim++;
dim_h++;
dim_w++;
}
int64_t pad_l = padding[0];
int64_t pad_r = padding[1];
int64_t pad_t = padding[2];
int64_t pad_b = padding[3];
int nplane = input_.size(plane_dim);
int input_h = input_.size(dim_h);
int input_w = input_.size(dim_w);
TORCH_CHECK(pad_l < input_w && pad_r < input_w,
"Padding size should be less than the corresponding input dimension, but "
"got: padding (", pad_l, ", ", pad_r, ") at dimension ", dim_w,
" of input ", input_.sizes());
TORCH_CHECK(pad_t < input_h && pad_b < input_h,
"Padding size should be less than the corresponding input dimension, but "
"got: padding (", pad_t, ", ", pad_b, ") at dimension ", dim_h,
" of input ", input_.sizes());
int output_h = input_h + pad_t + pad_b;
int output_w = input_w + pad_l + pad_r;
TORCH_CHECK(output_w >= 1 || output_h >= 1,
"input (H: ", input_h, ", W: ", input_w, ")is too small. Calculated "
"output H: ", output_h, " W: ", output_w);
if (input_.ndimension() == 3) {
output.resize_({nplane, output_h, output_w});
} else {
output.resize_({nbatch, nplane, output_h, output_w});
}
if (output.numel() == 0) {
return;
}
Tensor input = input_.contiguous();
int64_t output_plane_size = output_h * output_w;
dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
int64_t size_y = nplane;
int64_t size_z = nbatch;
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
input.scalar_type(), "reflection_pad2d_out_template", [&] {
for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), block_y_size, block_z_size);
reflection_pad2d_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
input_w, input_h,
pad_t, pad_b, pad_l, pad_r, block_y, block_z, nplane);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
);
}
void reflection_pad2d_backward_out_template(
Tensor &grad_input, const Tensor &grad_output_,
const Tensor &input, IntArrayRef padding) {
if (grad_input.numel() == 0) {
return;
}
TORCH_CHECK(canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
TORCH_CHECK(canUse32BitIndexMath(grad_output_),
"output gradient tensor must fit into 32-bit index math");
int plane_dim = 0;
int dim_h = 1;
int dim_w = 2;
int nbatch = 1;
if (input.ndimension() == 4) {
nbatch = input.size(0);
plane_dim++;
dim_h++;
dim_w++;
}
int64_t pad_l = padding[0];
int64_t pad_r = padding[1];
int64_t pad_t = padding[2];
int64_t pad_b = padding[3];
int nplane = input.size(plane_dim);
int input_h = input.size(dim_h);
int input_w = input.size(dim_w);
int output_h = input_h + pad_t + pad_b;
int output_w = input_w + pad_l + pad_r;
TORCH_CHECK(output_w == grad_output_.size(dim_w), "grad_output width "
"unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w));
TORCH_CHECK(output_h == grad_output_.size(dim_h), "grad_output height "
"unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h));
Tensor grad_output = grad_output_.contiguous();
int64_t output_plane_size = output_h * output_w;
dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
int64_t size_y = nplane;
int64_t size_z = nbatch;
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
input.scalar_type(), "reflection_pad2d_backward_out_template", [&] {
for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), block_y_size, block_z_size);
reflection_pad2d_backward_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(),
input_w, input_h,
pad_t, pad_b, pad_l, pad_r, block_y, block_z, nplane);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
);
}
} // namespace
TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
(const Tensor& input_, IntArrayRef padding, const Tensor& output) {
TORCH_CHECK(
canUse32BitIndexMath(input_),
"input tensor must fit into 32-bit index math");
if (output.numel() == 0) {
return;
}
int64_t dim_plane = 0;
int64_t dim_w = 1;
int64_t nbatch = 1;
if (input_.ndimension() == 3) {
nbatch = input_.size(0);
dim_plane++;
dim_w++;
}
int64_t pad_l = padding[0];
int64_t pad_r = padding[1];
int64_t nplane = input_.size(dim_plane);
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
Tensor input = input_.contiguous();
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
kHalf, input.scalar_type(), "reflection_pad1d_out_template", [&] {
reflection_pad1d_out_kernel<<<
grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
input_w,
pad_l,
pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
const Tensor& input,
IntArrayRef padding,
const Tensor& grad_input) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("reflection_pad1d_backward_out_cuda");
grad_input.zero_();
if (grad_input.numel() == 0) {
return;
}
TORCH_CHECK(canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
TORCH_CHECK(canUse32BitIndexMath(grad_output_),
"input tensor must fit into 32-bit index math");
int64_t dim_plane = 0;
int64_t dim_w = 1;
int64_t nbatch = 1;
if (input.ndimension() == 3) {
nbatch = input.size(0);
dim_plane++;
dim_w++;
}
int64_t pad_l = padding[0];
int64_t pad_r = padding[1];
int64_t nplane = input.size(dim_plane);
int64_t input_w = input.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
Tensor grad_output = grad_output_.contiguous();
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int) ::ceil(output_w / 256.0), nplane, nbatch);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
grad_input.scalar_type(), "reflection_pad1d_backward_out_cuda", [&] {
reflection_pad1d_backward_out_kernel<<<
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(),
input_w, pad_l, pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
);
}
Tensor& reflection_pad2d_out_cuda(const Tensor& input, IntArrayRef padding,
Tensor& output) {
reflection_pad2d_out_template(output, input, padding);
return output;
}
Tensor reflection_pad2d_cuda(const Tensor& input, IntArrayRef padding) {
auto output = at::empty({0}, input.options());
reflection_pad2d_out_template(output, input, padding);
return output;
}
Tensor& reflection_pad2d_backward_out_cuda(const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding,
Tensor& grad_input) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("reflection_pad2d_backward_out_cuda");
grad_input.resize_as_(input);
grad_input.zero_();
reflection_pad2d_backward_out_template(
grad_input, grad_output, input, padding);
return grad_input;
}
Tensor reflection_pad2d_backward_cuda(
const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("reflection_pad2d_backward_cuda");
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
reflection_pad2d_backward_out_template(
grad_input, grad_output, input, padding);
return grad_input;
}
TORCH_IMPL_FUNC(reflection_pad3d_out_cuda) (
const Tensor& input_, IntArrayRef padding, const Tensor& output
) {
TORCH_CHECK(
canUse32BitIndexMath(input_),
"input tensor must fit into 32-bit index math");
if (output.numel() == 0) {
return;
}
int64_t pad_left = padding[0];
int64_t pad_top = padding[2];
int64_t pad_front = padding[4];
auto input = input_.contiguous();
bool batch_mode = (input.dim() == 5);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
input.scalar_type(), "reflection_pad3d_out_cuda", [&] {
auto input_inner = input;
auto output_inner = output;
if (!batch_mode) {
// non-batch mode
input_inner = input.unsqueeze(0);
output_inner = output.unsqueeze(0);
}
auto input_packed = input_inner.packed_accessor64<scalar_t, 5>();
auto output_packed = output_inner.packed_accessor64<scalar_t, 5>();
int64_t output_plane_size = output_packed.size(2) * output_packed.size(3) * output_packed.size(4);
int64_t size_y = input_packed.size(1);
int64_t size_z = input_packed.size(0);
dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), \
block_y_size, block_z_size);
reflection_pad3d_out_kernel<<<
grid_size, block_size,0, at::cuda::getCurrentCUDAStream()>>>(
input_packed, output_packed, pad_left, pad_top, pad_front,
block_y, block_z);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
});
}
TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cuda) (
const Tensor& grad_output, const Tensor& input, IntArrayRef padding,
const Tensor& grad_input) {
globalContext().alertNotDeterministic("reflection_pad3d_backward_out_cuda");
TORCH_CHECK(canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math");
TORCH_CHECK(canUse32BitIndexMath(grad_output), "input tensor must fit into 32-bit index math");
if (grad_input.numel() == 0) {
return;
}
grad_input.zero_();
int64_t pad_left = padding[0];
int64_t pad_top = padding[2];
int64_t pad_front = padding[4];
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
input.scalar_type(), "reflection_pad3d_backward_out_cuda", [&] {
auto grad_input_ = grad_input;
auto grad_output_ = grad_output;
if (input.dim() == 4) {
// non-batch mode
grad_input_ = grad_input.unsqueeze(0);
grad_output_ = grad_output.unsqueeze(0);
}
auto grad_input_packed = grad_input_.packed_accessor64<scalar_t, 5>();
auto grad_output_packed = grad_output_.packed_accessor64<scalar_t, 5>();
int64_t output_plane_size = grad_output_packed.size(2) *
grad_output_packed.size(3) * grad_output_packed.size(4);
int64_t size_y = grad_input_packed.size(1);
int64_t size_z = grad_input_packed.size(0);
dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), \
block_y_size, block_z_size);
reflection_pad3d_backward_out_kernel<<<
grid_size, block_size,0, at::cuda::getCurrentCUDAStream()>>>(
grad_input_packed, grad_output_packed, pad_left, pad_top, pad_front,
block_y, block_z);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
});
}
} // namespace native
} // namespace at