-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tf custom op conversion example #1878
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/shape_inference.h" | ||
|
||
using namespace tensorflow; | ||
|
||
|
||
// opregister | ||
REGISTER_OP("AddOne") | ||
.Input("add_one: int32") | ||
.Output("result: int32") | ||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { | ||
c->set_output(0, c->input(0)); | ||
return Status::OK(); | ||
}); | ||
|
||
|
||
// keneldefinition | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
|
||
class AddOneOp : public OpKernel { | ||
public: | ||
explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {} | ||
|
||
void Compute(OpKernelContext* context) override { | ||
// Tensor in input | ||
const Tensor& input_tensor = context->input(0); | ||
auto input = input_tensor.flat<int32>(); | ||
|
||
// Tensor in output | ||
Tensor* output_tensor = NULL; | ||
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); | ||
auto output = output_tensor->flat<int32>(); | ||
|
||
const int N = input.size(); | ||
for (int i = 0; i < N; i++) { | ||
output(i) += 1; | ||
} | ||
} | ||
}; | ||
|
||
|
||
REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_CPU), AddOneOp); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import tf2onnx | ||
import onnx | ||
import os | ||
from tf2onnx import utils | ||
from tf2onnx.handler import tf_op | ||
from tf2onnx.tf_loader import tf_placeholder | ||
|
||
|
||
DIR_PATH = os.path.realpath(os.path.dirname(__file__)) | ||
saved_model_path = os.path.join(DIR_PATH, "model.onnx") | ||
tf_library_path = os.path.join(DIR_PATH, "add_one.so") | ||
|
||
|
||
@tf_op("AddOne", onnx_op="Add") | ||
class AddOne: | ||
@classmethod | ||
def version_1(cls, ctx, node, **kwargs): | ||
node_shape = ctx.get_shape(node.input[0]) | ||
const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype = np.int32)).output[0] | ||
node.input.append(const_one) | ||
|
||
|
||
with tf.compat.v1.Session() as sess: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use TF1 !! |
||
x = tf_placeholder(tf.int32, [2, 3], name="input") | ||
AddOne = tf.load_op_library(tf_library_path) | ||
x_ = AddOne.add_one(x) | ||
_ = tf.identity(x_, name="output") | ||
|
||
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use the OLD api !!! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will change a new example and use tf2 api |
||
input_names=["input:0"], | ||
output_names=["output:0"]) | ||
model_proto = onnx_graph.make_model("test") | ||
with open(saved_model_path, "wb") as f: | ||
f.write(model_proto.SerializeToString()) | ||
|
||
onnx_model = onnx.load(saved_model_path) | ||
onnx.checker.check_model(onnx_model) | ||
|
||
|
||
|
||
## Run the model in ONNXRuntime to verify the result. | ||
import onnxruntime as ort | ||
input = np.arange(6).reshape(2,3).astype(np.int32) | ||
ort_session = ort.InferenceSession(saved_model_path) | ||
ort_inputs = {ort_session.get_inputs()[0].name: input} | ||
|
||
ort_outs = ort_session.run(None, ort_inputs) | ||
print("input:", input, "\nort_outs:", ort_outs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the purpose of this example.
A tf2onnx user will already have the tensorflow custom op, aka this code here he already has it.
Should the example not be AddOne in onnxruntime ?