forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjit_utils.h
87 lines (74 loc) · 2.34 KB
/
jit_utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#pragma once
#include <string>
#include <sstream>
#include <unordered_map>
#include <vector>
#include <c10/util/irange.h>
#include <ATen/jit_macros.h>
#include <ATen/cuda/detail/LazyNVRTC.h>
namespace at { namespace cuda { namespace jit {
enum class BinaryFuncVariant {NoScalar, RhsScalar, LhsScalar};
struct NvrtcFunction {
CUmodule module = CUmodule();
CUfunction function = nullptr;
};
std::string generate_code(
int nTensors,
const std::string& func,
const std::string& name,
const std::string& f_input_type,
const std::string& compute_type,
const std::string& result_type,
bool contiguous,
bool dynamic_casting,
BinaryFuncVariant scalar_pos,
c10::SmallVector<std::string>& extra_args_typenames,
bool vectorized=false,
int vec_size=0);
NvrtcFunction jit_pwise_function(
const std::string& code,
const std::string& kernel_name);
void launch_jitted_pwise_function(
NvrtcFunction function,
void* args[],
const int nBlocks,
const int kBlockSize);
template <typename T>
struct delayed_false : std::false_type {
};
// Defines type names
// NOTE: General case is instantiated only for invalid types.
// All the valid types have specialization using the TYPE_NAME_FN
// macro below.
template <typename T>
inline std::string typeName() {
// we can't use static_assert(false) directly as the
// program will be not compiled even if the template is not
// instantiated, so we use `delayed_false`
// to make sure compiler doesn't eagerly raise
// fail this assertion.
static_assert(delayed_false<T>::value, "invalid type for jiterator");
return "void";
}
#define TYPE_NAME_FN(ctype, name) \
template <> inline std::string typeName<ctype>(){ \
return std::string(#ctype); \
}
AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN)
#undef TYPE_NAME_FN
// JIT uses std::complex directly, because nvRTC compile programs
// with -default-device, so there is no such issue like:
// "std::sin(complex) is __host__ only"
template <> inline std::string typeName<c10::complex<float>>(){
return "std::complex<float>";
}
template <> inline std::string typeName<c10::complex<double>>(){
return "std::complex<double>";
}
template <> inline std::string typeName<at::Half>(){
return "at::Half";
}
template <> inline std::string typeName<at::BFloat16>(){
return "at::BFloat16";
}
}}} // namespace at::cuda::jit