-
Notifications
You must be signed in to change notification settings - Fork 435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tf custom op conversion example #1878
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/shape_inference.h" | ||
|
||
using namespace tensorflow; | ||
|
||
|
||
// opregister | ||
REGISTER_OP("AddOne") | ||
.Input("add_one: int32") | ||
.Output("oneed: int32") | ||
hwangdeyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { | ||
c->set_output(0, c->input(0)); | ||
return Status::OK(); | ||
}); | ||
|
||
|
||
// keneldefinition | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
|
||
void AddOneKernelLauncher(const Tensor* t_in, const int n, Tensor* t_out); | ||
|
||
class AddOneOp : public OpKernel { | ||
public: | ||
explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {} | ||
|
||
void Compute(OpKernelContext* context) override { | ||
// Tensore in input | ||
hwangdeyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const Tensor& input_tensor = context->input(0); | ||
auto input = input_tensor.flat<int32>(); | ||
|
||
// Tensore in output | ||
hwangdeyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Tensor* output_tensor = NULL; | ||
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); | ||
auto output = output_tensor->flat<int32>(); | ||
|
||
#if GOOGLE_CUDA | ||
AddOneKernelLauncher(input, input.size(), output); | ||
#else | ||
const int N = input.size(); | ||
for (int i = 0; i < N; i++) { | ||
output(i) += 1; | ||
} | ||
#endif | ||
if (N > 0) { | ||
output(0) = input(0); | ||
} | ||
} | ||
}; | ||
|
||
|
||
REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_CPU), AddOneOp); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import tensorflow as tf | ||
import tf2onnx | ||
import onnx | ||
import os | ||
|
||
DIR_PATH = os.path.realpath(os.path.dirname(__file__)) | ||
saved_model_path = os.path.join(DIR_PATH, "model.onnx") | ||
tf_library_path = os.path.join(DIR_PATH, "add_one.so") | ||
|
||
|
||
@tf_op("AddOne", onnx_op="Add") | ||
class AddOne: | ||
@classmethod | ||
def version_1(cls, ctx, node, **kwargs): | ||
node_shape = ctx.get_shape(node.input[0]) | ||
const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype = np.int32)).output[0] | ||
node.input.append(const_one) | ||
|
||
|
||
with tf.compat.v1.Session() as sess: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use TF1 !! |
||
x = tf_placeholder(tf.int32, [2, 3], name="input") | ||
# load tf library, using "--load_op_libraries" parameter for command conversion. | ||
AddOne = tf.load_op_library(tf_library_path) | ||
x_ = AddOne.add_one(x) | ||
_ = tf.identity(x_, name="output") | ||
|
||
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use the OLD api !!! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will change a new example and use tf2 api |
||
input_names=["input:0"], | ||
output_names=["output:0"]) | ||
model_proto = onnx_graph.make_model("test") | ||
with open(saved_model_path, "wb") as f: | ||
f.write(model_proto.SerializeToString()) | ||
|
||
onnx_model = onnx.load(saved_model_path) | ||
onnx.checker.check_model(onnx_model) | ||
|
||
|
||
import onnxruntime as ort | ||
input = np.arange(6).reshape(2,3).astype(np.int32) | ||
ort_session = ort.InferenceSession(saved_model_path) | ||
ort_inputs = {ort_session.get_inputs()[0].name: input} | ||
|
||
ort_outs = ort_session.run(None, ort_inputs) | ||
print("ort_outs:", ort_outs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the purpose of this example.
A tf2onnx user will already have the tensorflow custom op, aka this code here he already has it.
Should the example not be AddOne in onnxruntime ?