forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUpSampleNearest2d.cu
477 lines (403 loc) · 16.8 KB
/
UpSampleNearest2d.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/ceil_div.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/cuda/LaunchUtils.h>
#include <ATen/native/cuda/UpSample.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
namespace at {
namespace native {
namespace {
#define MAX_THREADS 512
// Define a typedef to dispatch to nearest_neighbor_compute_source_index or
// nearest_neighbor_exact_compute_source_index
typedef int (*nn_compute_source_index_fn_t)(const float, int, int);
// Define a typedef to dispatch to nearest_neighbor_bw_compute_source_index or
// nearest_neighbor_exact_bw_compute_source_index
typedef int (*nn_bw_compute_source_index_fn_t)(const float, int, int);
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_nearest2d_out_frame(
const scalar_t* idata,
scalar_t* odata,
const size_t nc,
const size_t height1,
const size_t width1,
const size_t height2,
const size_t width2,
float height_scale,
float width_scale) {
size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z;
int w2 = threadIdx.x + blockIdx.x * blockDim.x;
int h2 = threadIdx.y + blockIdx.y * blockDim.y;
if (w2 >= width2 || h2 >= height2) {
return;
}
int nc_stride = blockDim.z * gridDim.z;
const size_t h1 = height1 == height2
? h2
: nn_compute_source_index_fn(height_scale, h2, height1);
const size_t w1 = width1 == width2
? w2
: nn_compute_source_index_fn(width_scale, w2, width1);
size_t src_index = (nc_iter * height1 + h1) * width1 + w1;
size_t src_index_stride = nc_stride * width1 * height1;
size_t dst_index = (nc_iter * height2 + h2) * width2 + w2;
size_t dst_index_stride = nc_stride * width2 * height2;
// iterating over
while (nc_iter < nc) {
odata[dst_index] = idata[src_index];
dst_index += dst_index_stride;
src_index += src_index_stride;
nc_iter += nc_stride;
}
}
template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_nearest2d_nhwc_out_frame(
const scalar_t* idata,
scalar_t* odata,
const size_t channels,
const size_t height1,
const size_t width1,
const size_t height2,
const size_t width2,
float height_scale,
float width_scale,
const size_t out_numel) {
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < out_numel) {
const int c = index % channels;
const int w2 = (index / channels) % width2;
const int h2 = (index / channels / width2) % height2;
const int n = index / channels / width2 / height2;
const size_t h1 = height1 == height2 ? h2 : nn_compute_source_index_fn(height_scale, h2, height1);
const size_t w1 = width1 == width2 ? w2 : nn_compute_source_index_fn(width_scale, w2, width1);
odata[index] = idata[idx_cl(n, h1, w1, c, height1, width1, channels)];
}
}
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_nearest2d_backward_out_frame(
const scalar_t* grad_o,
size_t dim_b,
size_t dim_c,
size_t src_dim_h,
size_t src_dim_w,
size_t dst_dim_h,
size_t dst_dim_w,
scalar_t* grad_i,
float height_scale,
float width_scale) {
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (dst_idx >= dim_c * dst_dim_h * dst_dim_w)
return;
int dst_c_stride = dst_dim_h * dst_dim_w;
int src_c_stride = src_dim_h * src_dim_w;
int c = (dst_idx / (dst_c_stride)) % dim_c;
int dst_y = (dst_idx / dst_dim_w) % dst_dim_h;
// note that we do not want to clamp src_y to src_dim_y, since we might
// intentionally want to skip in case of scale_factor < 1.0
int src_y =
nn_bw_compute_source_index_fn(height_scale, dst_y, src_dim_h);
int src_y_up = nn_bw_compute_source_index_fn(
height_scale, dst_y + 1, src_dim_h);
int dst_x = dst_idx % dst_dim_w;
// note that we do not want to clamp src_x to src_dim_w, since we might
// intentionally want to skip in case of scale_factor < 1.0
int src_x =
nn_bw_compute_source_index_fn(width_scale, dst_x, src_dim_w);
int src_x_up = nn_bw_compute_source_index_fn(
width_scale, dst_x + 1, src_dim_w);
for (int b = 0; b < dim_b; b++) {
accscalar_t grad = 0;
for (int y = src_y; y < src_y_up; y++) {
for (int x = src_x; x < src_x_up; x++) {
int src_idx =
b * dim_c * src_c_stride + c * src_c_stride + y * src_dim_w + x;
grad += grad_o[src_idx];
}
}
grad_i[dst_idx] = grad;
dst_idx += dim_c * dst_c_stride;
}
}
template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_nearest2d_backward_nhwc_out_frame(
const scalar_t* go,
scalar_t* gi,
const size_t height1,
const size_t width1,
const size_t height2,
const size_t width2,
const size_t channels,
const float height_scale,
const float width_scale,
const size_t gi_numel) {
// 1 is for grad_output (src)
// 2 is for grad_input (dst)
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < gi_numel) {
const int c = index % channels;
const int w2 = (index / channels) % width2;
const int h2 = (index / channels / width2) % height2;
const int n = index / channels / width2 / height2;
int h1 = nn_bw_compute_source_index_fn(height_scale, h2, height1);
int h1_up = nn_bw_compute_source_index_fn(height_scale, h2 + 1, height1);
int w1 = nn_bw_compute_source_index_fn(width_scale, w2, width1);
int w1_up = nn_bw_compute_source_index_fn(width_scale, w2 + 1, width1);
accscalar_t grad = 0;
for (int ih = h1; ih < h1_up; ih++) {
for (int iw = w1; iw < w1_up; iw++) {
grad += go[idx_cl(n, ih, iw, c, height1, width1, channels)];
}
}
gi[index] = static_cast<scalar_t>(grad);
}
}
template<nn_compute_source_index_fn_t nn_compute_source_index_fn>
static void upsample_nearest2d_out_cuda_template(
const Tensor& output,
const Tensor& input_,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
checkAllSameGPU(__func__, {input_arg, output_arg});
if (input_.numel() == 0) {
return;
}
int output_height = output_size[0];
int output_width = output_size[1];
int nbatch = input_.size(0);
int channels = input_.size(1);
int input_height = input_.size(2);
int input_width = input_.size(3);
const float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
const float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
const auto memory_format = input_.suggest_memory_format();
if (input_.sizes() == output.sizes()) {
output.copy_(input_);
return;
}
// heuristic: only use channels_last path when it's faster than the contiguous path
if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 4 && \
output.is_contiguous(memory_format)) {
at::Tensor input = input_.contiguous(at::MemoryFormat::ChannelsLast);
TORCH_CHECK(input.numel() < std::numeric_limits<int>::max(),
"upsample_nearest_nhwc only supports input tensors with less than INT_MAX elements");
TORCH_CHECK(output.numel() < std::numeric_limits<int>::max(),
"upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements");
const int num_kernels = output.numel();
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] {
const scalar_t* idata = input.data_ptr<scalar_t>();
scalar_t* odata = output.data_ptr<scalar_t>();
upsample_nearest2d_nhwc_out_frame<scalar_t, nn_compute_source_index_fn>
<<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
idata,
odata,
channels,
input_height,
input_width,
output_height,
output_width,
height_scale,
width_scale,
output.numel()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
else {
// This is needed for non-contiguous tensors.
Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
Tensor input = input_.contiguous();
int nc = nbatch * channels;
const int max_threads = std::min<int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS);
int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
int* maxGridSize = at::cuda::getCurrentDeviceProperties()->maxGridSize;
// upsample_nearest2d meta call makes sure input/output tensor is not empty;
int block_x = std::min<int>(
maxThreadsDim[0], std::min<int>(lastPow2(output_width), max_threads));
int block_y = std::min<int>(
maxThreadsDim[1],
std::min<int>(lastPow2(output_height), max_threads / block_x));
int block_z = std::min<int>(
maxThreadsDim[2], std::min<int>(nc, max_threads / block_x / block_y));
const dim3 block(block_x, block_y, block_z);
int grid_x = ceil_div(output_width, block_x);
int grid_y = ceil_div(output_height, block_y);
int grid_z = std::min<int>(
maxGridSize[2], ceil_div(nc, block_z * 4));
const dim3 grid(grid_x, grid_y, grid_z);
// Error out on cases where grid_x & grid_y exceeds limit of launch config, as
// the current kernel implementation doesn't loop over the two dimensions.
// This is unlikely to happen.
// TODO: kernel implementation could stride on spatial dimension. We probably
// need to overhaul the kernel.
TORCH_CHECK(
grid_x <= maxGridSize[0] && grid_y <= maxGridSize[1],
"input tensor has spatial dimension larger than the kernel capacity");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto idata = input.data_ptr<scalar_t>();
auto odata = output_c.data_ptr<scalar_t>();
upsample_nearest2d_out_frame<scalar_t, nn_compute_source_index_fn>
<<<grid, block, 0, stream>>>(
idata,
odata,
nc,
input_height,
input_width,
output_height,
output_width,
height_scale,
width_scale);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
if (!output.is_contiguous()) {
output.copy_(output_c);
}
}
}
template<nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
static void upsample_nearest2d_backward_out_cuda_template(
const Tensor& grad_input,
const Tensor& grad_output_,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
TensorArg grad_input_arg{grad_input, "grad_input", 1},
grad_output_arg{grad_output_, "grad_output_", 2};
checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg});
if (grad_input.numel() == 0) {
return;
}
int output_height = output_size[0];
int output_width = output_size[1];
int nbatch = input_size[0];
int channels = input_size[1];
int input_height = input_size[2];
int input_width = input_size[3];
const float height_scale = compute_scales_value_backwards<float>(scales_h, output_height, input_height);
const float width_scale = compute_scales_value_backwards<float>(scales_w, output_width, input_width);
auto memory_format = grad_output_.suggest_memory_format();
if (grad_output_.sizes() == grad_input.sizes()) {
grad_input.copy_(grad_output_);
return;
}
if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 4 && \
grad_input.is_contiguous(memory_format)) {
Tensor grad_output = grad_output_.contiguous(at::MemoryFormat::ChannelsLast);
TORCH_CHECK(grad_input.numel() < std::numeric_limits<int>::max(),
"upsample_nearest_nhwc only supports grad_input tensors with less than INT_MAX elements");
TORCH_CHECK(grad_output.numel() < std::numeric_limits<int>::max(),
"upsample_nearest_nhwc only supports grad_output tensors with less than INT_MAX elements");
const int num_kernels = grad_input.numel();
const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest2d_backward_nhwc_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
const scalar_t* go = grad_output.data_ptr<scalar_t>();
scalar_t* gi = grad_input.data_ptr<scalar_t>();
upsample_nearest2d_backward_nhwc_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
<<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
go,
gi,
output_height,
output_width,
input_height,
input_width,
channels,
height_scale,
width_scale,
grad_input.numel()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} else {
// This is needed for non-contiguous tensors.
Tensor grad_input_c = grad_input.is_contiguous() ? grad_input : at::empty(grad_input.sizes(), grad_input.options());
Tensor grad_output = grad_output_.contiguous();
// upsample_nearest2d meta call makes sure `nbatch != 0`
unsigned int n = grad_input.numel() / nbatch;
dim3 bdim{std::min<unsigned int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
dim3 gdim{ceil_div(n, bdim.x)};
// safe check for int32 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest2d_backward_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto idata = grad_input_c.data_ptr<scalar_t>();
auto odata = grad_output.data_ptr<scalar_t>();
upsample_nearest2d_backward_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
<<<gdim, bdim, 0, stream>>>(
odata,
nbatch,
channels,
output_height,
output_width,
input_height,
input_width,
idata,
height_scale,
width_scale);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
if (!grad_input.is_contiguous()) {
grad_input.copy_(grad_input_c);
}
}
}
} // namespace
TORCH_IMPL_FUNC(upsample_nearest2d_out_cuda) (
const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& output) {
upsample_nearest2d_out_cuda_template<nearest_neighbor_compute_source_index>(
output, input, output_size, scales_h, scales_w);
}
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_cuda) (
const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& output) {
upsample_nearest2d_out_cuda_template<nearest_neighbor_exact_compute_source_index>(
output, input, output_size, scales_h, scales_w);
}
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cuda) (
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& grad_input) {
upsample_nearest2d_backward_out_cuda_template<nearest_neighbor_bw_compute_source_index>(
grad_input, grad_output, output_size, input_size, scales_h, scales_w);
}
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_cuda) (
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
const Tensor& grad_input) {
upsample_nearest2d_backward_out_cuda_template<nearest_neighbor_exact_bw_compute_source_index>(
grad_input, grad_output, output_size, input_size, scales_h, scales_w);
}
} // namespace native
} // namespace at