forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUnaryLogKernels.cu
115 lines (108 loc) · 3.84 KB
/
UnaryLogKernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#define TORCH_ASSERT_NO_OPERATORS
#include <limits>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Math.cuh>
namespace at { namespace native {
const char log_name[] = "log_kernel";
void log_kernel_cuda(TensorIteratorBase& iter) {
auto common_dtype = iter.common_dtype();
if (at::isComplexType(common_dtype)) {
#if AT_USE_JITERATOR()
static const auto log_string = jiterator_stringify(
template <typename T> T log_kernel(T x) { return std::log(x); });
AT_DISPATCH_COMPLEX_TYPES(common_dtype, "log_cuda", [&]() {
jitted_gpu_kernel<
/*name=*/log_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/1>(iter, log_string);
});
#else
AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "log_cuda", [&]() {
gpu_kernel(
iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log(a); });
});
#endif
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::log(a);
});
});
}
}
const char log10_name[] = "log10_kernel";
void log10_kernel_cuda(TensorIteratorBase& iter) {
auto common_dtype = iter.common_dtype();
if (at::isComplexType(common_dtype)) {
#if AT_USE_JITERATOR()
static const auto log10_string = jiterator_stringify(
template <typename T> T log10_kernel(T x) { return std::log10(x); });
AT_DISPATCH_COMPLEX_TYPES(common_dtype, "log10_cuda", [&]() {
jitted_gpu_kernel<
/*name=*/log10_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/1>(iter, log10_string);
});
#else
AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "log10_cuda", [&]() {
gpu_kernel(
iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log10(a); });
});
#endif
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log10_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::log10(a);
});
});
}
}
void log1p_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log1p_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::log1p(a);
});
});
}
const char log2_name[] = "log2_kernel";
void log2_kernel_cuda(TensorIteratorBase& iter) {
auto common_dtype = iter.common_dtype();
if (at::isComplexType(common_dtype)) {
#if AT_USE_JITERATOR()
static const auto log2_string = jiterator_stringify(
template <typename T> T log2_kernel(T x) { return std::log2(x); });
AT_DISPATCH_COMPLEX_TYPES(common_dtype, "log2_cuda", [&]() {
jitted_gpu_kernel<
/*name=*/log2_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/1>(iter, log2_string);
});
#else
AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "log2_cuda", [&]() {
gpu_kernel(
iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log2(a); });
});
#endif
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log2_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::log2(a);
});
});
}
}
REGISTER_DISPATCH(log_stub, &log_kernel_cuda);
REGISTER_DISPATCH(log10_stub, &log10_kernel_cuda);
REGISTER_DISPATCH(log2_stub, &log2_kernel_cuda);
REGISTER_DISPATCH(log1p_stub, &log1p_kernel_cuda);
}} // namespace at::native