forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUnaryFractionKernels.cu
199 lines (172 loc) · 6.73 KB
/
UnaryFractionKernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#define TORCH_ASSERT_NO_OPERATORS
#include <limits>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Math.cuh>
namespace at { namespace native {
// We manually overload ceil because std::ceil does not work with std::complex types.
template <typename scalar_t>
__host__ __device__ static inline scalar_t ceil_wrapper(scalar_t a) {
return std::ceil(a);
}
template<typename T>
__host__ __device__ static inline std::complex<T> ceil_wrapper(std::complex<T> v) {
return std::complex<T>(std::ceil(v.real()), std::ceil(v.imag()));
}
void ceil_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "ceil_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ceil_wrapper(a);
});
});
}
void frac_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "frac_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return a - ::trunc(a);
});
});
}
// We manually overload floor because std::floor does not work with std::complex types.
template <typename scalar_t>
__host__ __device__ static inline scalar_t floor_wrapper(scalar_t a) {
return std::floor(a);
}
template<typename T>
__host__ __device__ static inline std::complex<T> floor_wrapper(std::complex<T> v) {
return std::complex<T>(std::floor(v.real()), std::floor(v.imag()));
}
void floor_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "floor_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return floor_wrapper(a);
});
});
}
template <typename scalar_t>
__host__ __device__ static inline scalar_t reciprocal_wrapper(scalar_t a) {
return static_cast<scalar_t>(1)/a;
}
template<typename T>
__host__ __device__ static inline c10::complex<T> reciprocal_wrapper(c10::complex<T> v) {
// Handle extreme cases for numpy compatibility
auto both_inf = [](T real, T imag) {
return (::isinf(real) && ::isinf(imag));
};
auto either_inf = [](T real, T imag) {
return ::isinf(real) || ::isinf(imag);
};
auto either_nan = [](T real, T imag) {
return ::isnan(real) || ::isnan(imag);
};
if (either_nan(v.real(), v.imag()) || both_inf(v.real(), v.imag())) {
// If either is Nan or both are infinite, return {nan, nan}
return {std::numeric_limits<T>::quiet_NaN(), std::numeric_limits<T>::quiet_NaN()};
} else if (either_inf(v.real(), v.imag())) {
// If either is Inf, return {0, 0}
return {0, 0};
}
const c10::complex<T> one = c10::complex<T>(1.0, 0);
return one/v;
}
void reciprocal_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.common_dtype(), "reciprocal_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return reciprocal_wrapper(a);
});
});
}
// We manually overload nearbyint because std::nearbyint does not work with std::complex types and ROCm.
template <typename scalar_t>
__host__ __device__ static inline scalar_t nearbyint_wrapper(scalar_t a) {
return static_cast<scalar_t>(::nearbyintf(static_cast<float>(a)));
}
__host__ __device__ static inline double nearbyint_wrapper(double a) {
return ::nearbyint(a);
}
__host__ __device__ static inline c10::complex<float> nearbyint_wrapper(c10::complex<float> a) {
return c10::complex<float>(::nearbyintf(static_cast<float>(a.real())), ::nearbyintf(static_cast<float>(a.imag())));
}
#pragma push
#pragma diag_suppress 177 // Function was declared but never referenced
__host__ __device__ static inline c10::complex<double> nearbyint_wrapper(c10::complex<double> a) {
return c10::complex<double>(::nearbyint(static_cast<double>(a.real())), ::nearbyint(static_cast<double>(a.imag())));
}
#pragma pop
void round_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "round_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
// We do not use std::round because we would like to round midway numbers to the nearest even integer.
return nearbyint_wrapper(a);
});
});
}
void round_decimals_kernel_cuda(TensorIteratorBase& iter, int64_t decimals) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "round_cuda",
[&]() {
bool neg_flag = false;
scalar_t ten_pow_decimals;
if (decimals < 0) {
decimals = -decimals;
neg_flag = true;
}
ten_pow_decimals = static_cast<scalar_t>(std::pow(10, decimals));
gpu_kernel(iter, [ten_pow_decimals, neg_flag]GPU_LAMBDA(scalar_t a) -> scalar_t {
return neg_flag ? std::nearbyint(a / ten_pow_decimals) * ten_pow_decimals
: std::nearbyint(a * ten_pow_decimals) / ten_pow_decimals;
});
});
}
// We manually overload trunc because std::trunc does not work with std::complex types and ROCm.
template <typename scalar_t>
__host__ __device__ static inline scalar_t trunc_wrapper(scalar_t a) {
return static_cast<scalar_t>(::truncf(static_cast<float>(a)));
}
__host__ __device__ static inline double trunc_wrapper(double a) {
return ::trunc(a);
}
__host__ __device__ static inline c10::complex<float> trunc_wrapper(c10::complex<float> a) {
return c10::complex<float>(::truncf(static_cast<float>(a.real())), ::truncf(static_cast<float>(a.imag())));
}
__host__ __device__ static inline c10::complex<double> trunc_wrapper(c10::complex<double> a) {
return c10::complex<double>(::trunc(static_cast<double>(a.real())), ::trunc(static_cast<double>(a.imag())));
}
void trunc_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "trunc_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return trunc_wrapper(a);
});
});
}
REGISTER_DISPATCH(ceil_stub, &ceil_kernel_cuda);
REGISTER_DISPATCH(frac_stub, &frac_kernel_cuda);
REGISTER_DISPATCH(floor_stub, &floor_kernel_cuda);
REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel_cuda);
REGISTER_DISPATCH(round_stub, &round_kernel_cuda);
REGISTER_DISPATCH(round_decimals_stub, &round_decimals_kernel_cuda);
REGISTER_DISPATCH(trunc_stub, &trunc_kernel_cuda);
}} // namespace at::native