-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for registring custom op libraries
- Loading branch information
1 parent
1871703
commit 668a0d3
Showing
9 changed files
with
286 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
* | ||
!/Cargo.* | ||
!/onnxruntime/Cargo.toml | ||
!/onnxruntime/src | ||
!/onnxruntime/tests | ||
!/onnxruntime-sys/Cargo.toml | ||
!/onnxruntime-sys/build.rs | ||
!/onnxruntime-sys/src | ||
!/test-models/tensorflow/*.onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# onnxruntime requires execinfo.h to build, which only works on glibc-based systems, so alpine is out... | ||
FROM debian:bullseye-slim as base | ||
|
||
RUN apt-get update && apt-get -y dist-upgrade | ||
|
||
FROM base AS onnxruntime | ||
|
||
RUN apt-get install -y \ | ||
git \ | ||
bash \ | ||
python3 \ | ||
cmake \ | ||
git \ | ||
build-essential \ | ||
llvm \ | ||
locales | ||
|
||
# onnxruntime built in tests need en_US.UTF-8 available | ||
# Uncomment en_US.UTF-8, then generate | ||
RUN sed -i 's/^# *\(en_US.UTF-8\)/\1/' /etc/locale.gen && locale-gen | ||
|
||
# build onnxruntime | ||
RUN mkdir -p /opt/onnxruntime/tmp | ||
# onnxruntime build relies on being in a git repo, so can't just get a tarball | ||
# it's a big repo, so fetch shallowly | ||
RUN cd /opt/onnxruntime/tmp && \ | ||
git clone --recursive --depth 1 --shallow-submodules https://github.com/Microsoft/onnxruntime | ||
|
||
# use version that onnxruntime-sys expects | ||
RUN cd /opt/onnxruntime/tmp/onnxruntime && \ | ||
git fetch --depth 1 origin tag v1.6.0 && \ | ||
git checkout v1.6.0 | ||
|
||
RUN /opt/onnxruntime/tmp/onnxruntime/build.sh --config RelWithDebInfo --build_shared_lib --parallel | ||
|
||
# Build ort-customops, linked against the onnxruntime built above. | ||
# No tags / releases yet - that commit is from 2021-02-16 | ||
RUN mkdir -p /opt/ort-customops/tmp && \ | ||
cd /opt/ort-customops/tmp && \ | ||
git clone --recursive https://github.com/microsoft/ort-customops.git && \ | ||
cd ort-customops && \ | ||
git checkout 92f6b51106c9e9143c452e537cb5e41d2dcaa266 | ||
|
||
RUN cd /opt/ort-customops/tmp/ort-customops && \ | ||
./build.sh -D ONNXRUNTIME_LIB_DIR=/opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo | ||
|
||
|
||
# install rust toolchain | ||
FROM base AS rust-toolchain | ||
|
||
ARG RUST_VERSION=1.50.0 | ||
|
||
RUN apt-get install -y \ | ||
curl | ||
|
||
# install rust toolchain | ||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain $RUST_VERSION | ||
|
||
ENV PATH $PATH:/root/.cargo/bin | ||
|
||
|
||
# build onnxruntime-rs | ||
FROM rust-toolchain as onnxruntime-rs | ||
# clang & llvm needed by onnxruntime-sys | ||
RUN apt-get install -y \ | ||
build-essential \ | ||
llvm-dev \ | ||
libclang-dev \ | ||
clang | ||
|
||
RUN mkdir -p \ | ||
/onnxruntime-rs/build/onnxruntime-sys/src/ \ | ||
/onnxruntime-rs/build/onnxruntime/src/ \ | ||
/onnxruntime-rs/build/onnxruntime/tests/ \ | ||
/opt/onnxruntime/lib \ | ||
/opt/ort-customops/lib | ||
|
||
COPY --from=onnxruntime /opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo/libonnxruntime.so /opt/onnxruntime/lib/ | ||
COPY --from=onnxruntime /opt/ort-customops/tmp/ort-customops/out/Linux/libortcustomops.so /opt/ort-customops/lib/ | ||
|
||
WORKDIR /onnxruntime-rs/build | ||
|
||
ENV ORT_STRATEGY=system | ||
# this has /lib/ appended to it and is used as a lib search path in onnxruntime-sys's build.rs | ||
ENV ORT_LIB_LOCATION=/opt/onnxruntime/ | ||
|
||
ENV ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB=/opt/ort-customops/lib/libortcustomops.so | ||
|
||
# create enough of an empty project that dependencies can build | ||
COPY /Cargo.lock /Cargo.toml /onnxruntime-rs/build/ | ||
COPY /onnxruntime/Cargo.toml /onnxruntime-rs/build/onnxruntime/ | ||
COPY /onnxruntime-sys/Cargo.toml /onnxruntime-sys/build.rs /onnxruntime-rs/build/onnxruntime-sys/ | ||
|
||
CMD cargo test | ||
|
||
# build dependencies and clean the bogus contents of our two packages | ||
RUN touch \ | ||
onnxruntime/src/lib.rs \ | ||
onnxruntime/tests/integration_tests.rs \ | ||
onnxruntime-sys/src/lib.rs \ | ||
&& cargo build --tests \ | ||
&& cargo clean --package onnxruntime-sys \ | ||
&& cargo clean --package onnxruntime \ | ||
&& rm -rf \ | ||
onnxruntime/src/ \ | ||
onnxruntime/tests/ \ | ||
onnxruntime-sys/src/ | ||
|
||
# now build the actual source | ||
COPY /test-models test-models | ||
COPY /onnxruntime-sys/src onnxruntime-sys/src | ||
COPY /onnxruntime/src onnxruntime/src | ||
COPY /onnxruntime/tests onnxruntime/tests | ||
|
||
RUN ln -s /opt/onnxruntime/lib/libonnxruntime.so /opt/onnxruntime/lib/libonnxruntime.so.1.6.0 | ||
ENV LD_LIBRARY_PATH=/opt/onnxruntime/lib | ||
|
||
RUN cargo build --tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
use std::error::Error; | ||
|
||
use ndarray; | ||
use onnxruntime::tensor::{DynOrtTensor, OrtOwnedTensor}; | ||
use onnxruntime::{environment::Environment, LoggingLevel}; | ||
|
||
#[test] | ||
fn run_model_with_ort_customops() -> Result<(), Box<dyn Error>> { | ||
let lib_path = match std::env::var("ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB") { | ||
Ok(s) => s, | ||
Err(_e) => { | ||
println!("Skipping ort_customops test -- no lib specified"); | ||
return Ok(()); | ||
} | ||
}; | ||
|
||
let environment = Environment::builder() | ||
.with_name("test") | ||
.with_log_level(LoggingLevel::Verbose) | ||
.build()?; | ||
|
||
let mut session = environment | ||
.new_session_builder()? | ||
.with_custom_op_lib(&lib_path)? | ||
.with_model_from_file("../test-models/tensorflow/regex_model.onnx")?; | ||
|
||
//Inputs: | ||
// 0: | ||
// name = input_1:0 | ||
// type = String | ||
// dimensions = [None] | ||
// Outputs: | ||
// 0: | ||
// name = Identity:0 | ||
// type = String | ||
// dimensions = [None] | ||
|
||
let array = ndarray::Array::from(vec![String::from("Hello world!")]); | ||
let input_tensor_values = vec![array]; | ||
|
||
let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?; | ||
let strings: OrtOwnedTensor<String, _> = outputs[0].try_extract()?; | ||
|
||
// ' ' replaced with '_' | ||
assert_eq!( | ||
&[String::from("Hello_world!")], | ||
strings.view().as_slice().unwrap() | ||
); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
tf2onnx1.9.0:� | ||
| ||
input_1:0 | ||
|
||
pattern__7 | ||
|
||
rewrite__8 | ||
Identity:0)PartitionedCall/model1/StaticRegexReplace"StringRegexReplace:ai.onnx.contribtf2onnx*2_B | ||
rewrite__8*2 B | ||
pattern__7R!converted from models/regex_modelZ | ||
input_1:0 | ||
|
||
unk__9b | ||
|
||
Identity:0 | ||
unk__10B B | ||
ai.onnx.contrib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
import tf2onnx | ||
|
||
|
||
class RegexModel(tf.keras.Model): | ||
|
||
def __init__(self, name='model1', **kwargs): | ||
super(RegexModel, self).__init__(name=name, **kwargs) | ||
|
||
def call(self, inputs): | ||
return tf.strings.regex_replace(inputs, " ", "_", replace_global=True) | ||
|
||
|
||
model1 = RegexModel() | ||
|
||
print(model1(tf.constant(["Hello world!"]))) | ||
|
||
model1.save("models/regex_model") |