Skip to content

Commit

Permalink
Add customized_op.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mansterteddy committed Oct 14, 2021
1 parent dc2c057 commit f8537f7
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 0 deletions.
27 changes: 27 additions & 0 deletions onnxruntime/customized_op/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch

class CustomInverse(torch.nn.Module):
def forward(self, x):
return torch.inverse(x) + x

from torch.onnx import register_custom_op_symbolic

def my_inverse(g, input):
return g.op("ai.onnx.contrib::Inverse", input)

register_custom_op_symbolic('::inverse', my_inverse, 1)

x0 = torch.Tensor([[-3.7806, 1.0857, -0.8645],
[-0.0398, 0.3996, -0.7268],
[ 0.3433, 0.6064, 0.0934]])
t_model = CustomInverse()
print(t_model(x0))

torch.onnx.export(t_model,
(x0, ),
"export.onnx",
opset_version=12,
verbose=True,
input_names= ["input"],
output_names= ["output"]
)
3 changes: 3 additions & 0 deletions onnxruntime/customized_op/lib_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from onnxruntime_extensions import get_library_path as _lib_path

print(_lib_path())
17 changes: 17 additions & 0 deletions onnxruntime/customized_op/run_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import onnx
import torch
import numpy as np
import onnxruntime as ort
from onnxruntime_extensions import onnx_op, PyOp, PyOrtFunction

@onnx_op(op_type="Inverse")
def inverse(x):
return np.linalg.inv(x)

x = [[-3.7806, 1.0857, -0.8645], [-0.0398, 0.3996, -0.7268], [ 0.3433, 0.6064, 0.0934]]
x = np.asarray(x, dtype=np.float32)

onnx_model = onnx.load("export.onnx")
onnx_fn = PyOrtFunction.from_model(onnx_model)
y = onnx_fn(x)
print(y)
14 changes: 14 additions & 0 deletions onnxruntime/customized_op/run_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch

class CustomInverse(torch.nn.Module):
def forward(self, x):
return torch.inverse(x) + x

x = torch.Tensor([[-3.7806, 1.0857, -0.8645],
[-0.0398, 0.3996, -0.7268],
[ 0.3433, 0.6064, 0.0934]])
model = CustomInverse()
y = model(x)
print(y)
print(x.dtype)
#torch.onnx.export(model, (x), "export_torch.onnx")

0 comments on commit f8537f7

Please sign in to comment.