forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUnaryComplexKernels.cu
98 lines (83 loc) · 2.89 KB
/
UnaryComplexKernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#define TORCH_ASSERT_NO_OPERATORS
#include <limits>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
namespace at { namespace native {
// We manually overload angle because std::arg does not work with types other than c10::complex.
template<typename scalar_t>
__host__ __device__ static inline scalar_t angle_wrapper(scalar_t v) {
if (at::_isnan(v)){
return v;
}
return v < 0 ? M_PI : 0;
}
template<typename T>
__host__ __device__ static inline c10::complex<T> angle_wrapper(c10::complex<T> v) {
return std::arg(v);
}
void angle_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "angle_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return angle_wrapper(a);
});
});
}
// We manually overload real because std::real does not work types other than c10::complex.
template<typename scalar_t>
__host__ __device__ static inline scalar_t real_wrapper(scalar_t v) {
return v;
}
template<typename T>
__host__ __device__ static inline c10::complex<T> real_wrapper(c10::complex<T> v) {
return v.real();
}
void real_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "real_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return real_wrapper(a);
});
});
}
// We manually overload imag because std::imag does not work types other than c10::complex.
template<typename scalar_t>
__host__ __device__ static inline scalar_t imag_wrapper(scalar_t v) {
return 0;
}
template<typename T>
__host__ __device__ static inline c10::complex<T> imag_wrapper(c10::complex<T> v) {
return v.imag();
}
void imag_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "imag_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return imag_wrapper(a);
});
});
}
// We manually overload conj because std::conj does not work types other than c10::complex.
template<typename scalar_t>
__host__ __device__ static inline scalar_t conj_wrapper(scalar_t v) {
return v;
}
template<typename T>
__host__ __device__ static inline c10::complex<T> conj_wrapper(c10::complex<T> v) {
return std::conj(v);
}
// NB: Ignores the negative bit on tensors
void conj_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
return conj_wrapper(a);
});
});
}
REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);
REGISTER_DISPATCH(real_stub, &real_kernel_cuda);
REGISTER_DISPATCH(imag_stub, &imag_kernel_cuda);
REGISTER_DISPATCH(conj_physical_stub, &conj_kernel_cuda);
}} // namespace at::native