Skip to content

Commit

Permalink
Add tf custom op conversion example (#1878)
Browse files Browse the repository at this point in the history
* add tf custom op example

Signed-off-by: hwangdeyu <[email protected]>

* fix name and remove unused code comment

Signed-off-by: hwangdeyu <[email protected]>
Co-authored-by: fatcat-z <[email protected]>
  • Loading branch information
hwangdeyu and fatcat-z authored Mar 14, 2022
1 parent 8f2e84b commit d536f04
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 2 deletions.
46 changes: 46 additions & 0 deletions examples/tf_custom_op/add_one.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;


// opregister
REGISTER_OP("AddOne")
.Input("add_one: int32")
.Output("result: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
c->set_output(0, c->input(0));
return Status::OK();
});


// keneldefinition
#include "tensorflow/core/framework/op_kernel.h"

class AddOneOp : public OpKernel {
public:
explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* context) override {
// Tensor in input
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();

// Tensor in output
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
auto output = output_tensor->flat<int32>();

const int N = input.size();
for (int i = 0; i < N; i++) {
output(i) += 1;
}
}
};


REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_CPU), AddOneOp);
Binary file added examples/tf_custom_op/add_one.so
Binary file not shown.
53 changes: 53 additions & 0 deletions examples/tf_custom_op/addone_custom_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0


import numpy as np
import tensorflow as tf
import tf2onnx
import onnx
import os
from tf2onnx import utils
from tf2onnx.handler import tf_op
from tf2onnx.tf_loader import tf_placeholder


DIR_PATH = os.path.realpath(os.path.dirname(__file__))
saved_model_path = os.path.join(DIR_PATH, "model.onnx")
tf_library_path = os.path.join(DIR_PATH, "add_one.so")


@tf_op("AddOne", onnx_op="Add")
class AddOne:
@classmethod
def version_1(cls, ctx, node, **kwargs):
node_shape = ctx.get_shape(node.input[0])
const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype = np.int32)).output[0]
node.input.append(const_one)


with tf.compat.v1.Session() as sess:
x = tf_placeholder(tf.int32, [2, 3], name="input")
AddOne = tf.load_op_library(tf_library_path)
x_ = AddOne.add_one(x)
_ = tf.identity(x_, name="output")

onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
input_names=["input:0"],
output_names=["output:0"])
model_proto = onnx_graph.make_model("test")
with open(saved_model_path, "wb") as f:
f.write(model_proto.SerializeToString())

onnx_model = onnx.load(saved_model_path)
onnx.checker.check_model(onnx_model)



## Run the model in ONNXRuntime to verify the result.
import onnxruntime as ort
input = np.arange(6).reshape(2,3).astype(np.int32)
ort_session = ort.InferenceSession(saved_model_path)
ort_inputs = {ort_session.get_inputs()[0].name: input}

ort_outs = ort_session.run(None, ort_inputs)
print("input:", input, "\nort_outs:", ort_outs)
2 changes: 1 addition & 1 deletion tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
return node

def append_node(self, node):
"Add a node to the graph."
"""Add a node to the graph."""
output_shapes = node.output_shapes
output_dtypes = node.output_dtypes
node.graph = self
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_erro
# or override existing ops with a custom op.
if custom_op_handlers is not None:
# below is a bit tricky since there are a few api's:
# 1. the future way we want custom ops to be registered with the @tf_op decorator. THose handlers will be
# 1. the future way we want custom ops to be registered with the @tf_op decorator. Those handlers will be
# registered via the decorator on load of the module ... nothing is required here.
# 2. the old custom op api: a dictionary of {name: (func, args[])
# We deal with this by using a compat_handler that wraps to old handler with a new style handler.
Expand Down

0 comments on commit d536f04

Please sign in to comment.