-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensorflow_v1_print.cc
124 lines (109 loc) · 4.34 KB
/
tensorflow_v1_print.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Example usage:
xyz_scatter = tfprint(xyz_scatter, 'xyz_scatter', 'first_xyzscatter', True)
type_embedding = tfprint(type_embedding, 'type_embedding', 'type_embedding_before_embedding_net', False)
// Print a tensor easily.
// Usage: op_module.print_identity(<tensor to print>, msg="", filename="", hash=<bool>)
// op_module: the module that contains the print_identity op. i.e. the return value of tf.load_op_library(str(module_file)).
// msg: tensors will be printed with a prefix `Tensor(msg): `.
// filename: tensors are printed to files which name are filename.tensorout. If filename is empty, tensors are printed to stdout.
// hash: whether to print the hash of the tensor. It is useful when you want to compare two extremely large tensors.
// shapes of tensors may be lost after this op, so it is recommanded to reshape the tensor after printing
// _shape = t.shape
// t = op_module.print_identity(t, msg="t", filename="t", hash=True)
// t = tf.reshape(t, _shape)
def tfprint(tensor, msg, filename, hash):
_shape = tf.shape(tensor)
tensor = op_module.print_identity(tensor, msg, filename, hash)
return tf.reshape(tensor, _shape)
//////////////////////////////////////////////////
// python file:
@ops.RegisterGradient("PrintIdentity")
def _print_identity_cc(op, dy):
return dy
// c++ op file:
#include <fstream>
#include <iomanip>
REGISTER_OP("PrintIdentity")
.Attr("T: {float, double, int32}")
.Input("data: T")
.Attr("msg: string")
.Attr("filename: string")
.Attr("hash: bool")
.Output("out: T");
// hash value of a float array, float value is truncated to 6 decimal places
// it it recommanded to use `fmtlib` to convert float to string
template<typename T>
std::size_t hash_of(const T* const arr, std::size_t size) {
std::size_t seed = 0;
std::stringstream ss;
ss.precision(6);
for (std::size_t i = 0; i < size; i++) {
ss << std::fixed << arr[i];
//include <fmt/core.h>
//std::string s = fmt::format("{:.6f}", arr[i]);
std::string s = ss.str();
seed ^= std::hash<std::string>()(s) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
ss.str("");
}
return seed;
}
template <typename Device, typename FPTYPE>
class PrintIdentityOp : public OpKernel {
public:
explicit PrintIdentityOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context,
context->GetAttr("msg", &msg));
OP_REQUIRES_OK(context,
context->GetAttr("filename", &filename));
OP_REQUIRES_OK(context,
context->GetAttr("hash", &hash));
}
void Compute(OpKernelContext* context) override {
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<FPTYPE>();
int precision = 10;
if (!filename.empty()) {
std::ofstream fs(filename + ".tensorout", std::ios::app);
fs << std::setprecision(precision);
fs << "Tensor(" << msg << "): ";
if (hash) {
fs << hash_of(input.data(), input.size());
} else {
for (int i = 0; i < input.size(); ++i) {
fs << input(i) << " ";
}
}
fs << "\n";
fs.close();
} else {
std::cout << std::setprecision(precision);
std::cout << "Tensor(" << msg << "): ";
if (hash) {
std::cout << hash_of(input.data(), input.size());
} else {
for (int i = 0; i < input.size(); ++i) {
std::cout << input(i) << " ";
}
}
std::cout << "\n";
}
// 将输入的Tensor作为输出返回
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
auto output = output_tensor->flat<FPTYPE>();
memcpy(output.data(), input.data(), sizeof(FPTYPE) * input.size());
}
private:
std::string msg;
std::string filename;
bool hash;
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("PrintIdentity").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
PrintIdentityOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(double);
REGISTER_KERNEL_BUILDER( \
Name("PrintIdentity").Device(DEVICE_CPU).TypeConstraint<int>("T"), \
PrintIdentityOp<CPUDevice, int>);