forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUnarySpecialOpsKernel.cu
355 lines (334 loc) · 12.5 KB
/
UnarySpecialOpsKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/UnaryOps.h>
#include <limits>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Math.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/NumericUtils.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/complex.h>
namespace at {
namespace native {
const char exp2_name[] = "exp2_kernel";
void exp2_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "exp2_cuda", [&]() {
jitted_gpu_kernel</*name=*/exp2_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, exp2_string);
});
#else
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.common_dtype(), "exp2_cuda",
[&]() {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::exp2(a);
});
});
#endif
}
const char i0_name[] = "i0";
void i0_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() {
jitted_gpu_kernel</*name=*/i0_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, i0_string);
});
#else
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
using opmath_t = at::opmath_type<scalar_t>;
// implicit conversion of a to opmath_t will happen here,
// but as far as TI is concerned, it's still a no-dynamic-cast kernel because lambda input is scalar_t
return calc_i0<opmath_t>(a);
});
});
#endif
}
// See note [Jiterator]
const char i0e_name[] = "i0e";
void i0e_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0e_cuda", [&]() {
jitted_gpu_kernel</*name=*/i0e_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, i0e_string);
});
#else
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0e_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
using opmath_t = at::opmath_type<scalar_t>;
return calc_i0e<opmath_t>(a);
});
});
#endif
}
// See note [Jiterator]
const char i1_name[] = "i1";
void i1_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() {
jitted_gpu_kernel</*name=*/i1_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, i1_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_i1(a);
});
});
#endif // AT_USE_JITERATOR()
}
const char i1e_name[] = "i1e";
void i1e_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cuda", [&]() {
jitted_gpu_kernel</*name=*/i1e_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, i1e_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_i1e(a);
});
});
#endif
}
const char sigmoid_name[] = "sigmoid";
void sigmoid_kernel_cuda(TensorIteratorBase& iter) {
auto common_dtype = iter.common_dtype();
if (at::isComplexType(common_dtype)) {
// only jiterate for complex-dtype
#if AT_USE_JITERATOR()
static const auto sigmoid_string = jiterator_stringify(
template <typename T>
T sigmoid(T x) {
return T{1} / (T{1} + std::exp(-x));
}
); // sigmoid_string
AT_DISPATCH_COMPLEX_TYPES(common_dtype, "sigmoid_cuda", [&]() {
jitted_gpu_kernel<
/*name=*/sigmoid_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/1>(iter, sigmoid_string);
});
#else
AT_DISPATCH_COMPLEX_TYPES(common_dtype, "sigmoid_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return scalar_t{1} / (scalar_t{1} + std::exp(-a));
});
});
#endif
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, common_dtype, "sigmoid_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return scalar_t{1} / (scalar_t{1} + std::exp(-a));
});
});
}
}
const char sinc_name[] = "sinc";
void sinc_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.common_dtype(), "sinc_cuda",
[&]() {
jitted_gpu_kernel</*name=*/sinc_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, sinc_string);
});
#else
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.common_dtype(), "sinc_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
if (a == scalar_t(0)) {
return scalar_t(1);
} else {
// NVCC says constexpr var is not accessible from device
scalar_t product = c10::detail::pi<scalar_t>() * a;
return std::sin(product) / product;
}
});
});
#endif
}
void logit_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scalar) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.common_dtype(),
"logit_cuda",
[&]() {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC eps = eps_scalar.to<T_ACC>();
if (eps < T_ACC(0)) {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
const T_ACC x_acc = static_cast<T_ACC>(x);
return c10::cuda::compat::log(x_acc / (T_ACC(1) - x_acc));
});
} else {
const T_ACC lo = eps;
const T_ACC hi = T_ACC(1) - eps;
gpu_kernel(
iter, [lo, hi] GPU_LAMBDA(scalar_t x) -> scalar_t {
const T_ACC x_acc = static_cast<T_ACC>(x);
T_ACC z = x_acc < lo ? lo : (x_acc > hi ? hi : x_acc);
return c10::cuda::compat::log(z / (T_ACC(1) - z));
});
}
});
}
const char ndtri_name[] = "ndtri";
void ndtri_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cuda", [&]() {
jitted_gpu_kernel</*name=*/ndtri_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, ndtri_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cuda", [&]() {
gpu_kernel(
iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_ndtri(a); });
});
#endif
}
void erf_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erf(a);
});
});
}
const char erfc_name[] = "erfc_kernel";
void erfc_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfc_cuda", [&]() {
jitted_gpu_kernel</*name=*/erfc_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, erfc_string);
});
#else
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
iter.common_dtype(), "erfc_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfc(a);
});
});
#endif
}
const char erfinv_name[] = "erfinv_kernel";
void erfinv_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
jitted_gpu_kernel</*name=*/erfinv_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, erfinv_string);
});
#else
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfinv(a);
});
});
#endif
}
const char erfcx_name[] = "erfcx";
void erfcx_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_cuda", [&]() {
jitted_gpu_kernel</*name=*/erfcx_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, erfcx_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_cuda", [&]() {
gpu_kernel(
iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_erfcx(a); });
});
#endif
}
void kaiser_window_kernel_cuda(TensorIteratorBase& iter, int64_t window_length, double beta_){
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t inv_alpha = static_cast<opmath_t>(2.0 / (window_length - 1));
const opmath_t beta = static_cast<opmath_t>(beta_);
const opmath_t inv_i0_beta = 1.0 / calc_i0(beta);
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t {
opmath_t x = static_cast<opmath_t>(a) * inv_alpha - 1;
opmath_t y = std::max<opmath_t>(0, 1 - x * x);
return calc_i0(beta * ::sqrt(y)) * inv_i0_beta;
});
});
}
const char entr_name[] = "entr";
void entr_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "entr_cuda", [&]() {
jitted_gpu_kernel</*name=*/entr_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, entr_string);
});
#else
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
iter.common_dtype(),
"entr_cuda",
[&]() {
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t {
if (at::_isnan(x)) {
return x;
} else if (x > 0) {
return -x * std::log(x);
} else if (x == 0) {
return 0;
}
return static_cast<scalar_t>(-INFINITY);
});
});
#endif
}
REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);
REGISTER_DISPATCH(i0_stub, &i0_kernel_cuda);
REGISTER_DISPATCH(special_i0e_stub, &i0e_kernel_cuda);
REGISTER_DISPATCH(special_i1_stub, &i1_kernel_cuda);
REGISTER_DISPATCH(special_i1e_stub, &i1e_kernel_cuda);
REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel_cuda);
REGISTER_DISPATCH(sinc_stub, &sinc_kernel_cuda);
REGISTER_DISPATCH(logit_stub, &logit_kernel_cuda);
REGISTER_DISPATCH(erf_stub, &erf_kernel_cuda);
REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda);
REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
REGISTER_DISPATCH(special_entr_stub, &entr_kernel_cuda);
REGISTER_DISPATCH(special_ndtri_stub, &ndtri_kernel_cuda);
REGISTER_DISPATCH(special_erfcx_stub, &erfcx_kernel_cuda);
} // namespace native
} // namespace at