From 668a0d3a6164d45383428fb60fa7919d78ab4b34 Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Thu, 18 Feb 2021 14:03:47 -0700 Subject: [PATCH] Add support for registring custom op libraries --- .dockerignore | 9 ++ CHANGELOG.md | 1 + Dockerfile | 118 ++++++++++++++++++++++ onnxruntime/Cargo.toml | 6 ++ onnxruntime/src/session.rs | 55 +++++++++- onnxruntime/tests/custom_ops.rs | 51 ++++++++++ test-models/tensorflow/README.md | 9 ++ test-models/tensorflow/regex_model.onnx | 19 ++++ test-models/tensorflow/src/regex_model.py | 19 ++++ 9 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 .dockerignore create mode 100644 Dockerfile create mode 100644 onnxruntime/tests/custom_ops.rs create mode 100644 test-models/tensorflow/regex_model.onnx create mode 100644 test-models/tensorflow/src/regex_model.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..1b6f6f4f --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +* +!/Cargo.* +!/onnxruntime/Cargo.toml +!/onnxruntime/src +!/onnxruntime/tests +!/onnxruntime-sys/Cargo.toml +!/onnxruntime-sys/build.rs +!/onnxruntime-sys/src +!/test-models/tensorflow/*.onnx diff --git a/CHANGELOG.md b/CHANGELOG.md index 24aef3bb..872cdc56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add `String` datatype ([#58](https://github.com/nbigaouette/onnxruntime-rs/pull/58)) +- Support custom operator libraries ## [0.0.11] - 2021-02-22 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..4b8eace7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,118 @@ +# onnxruntime requires execinfo.h to build, which only works on glibc-based systems, so alpine is out... +FROM debian:bullseye-slim as base + +RUN apt-get update && apt-get -y dist-upgrade + +FROM base AS onnxruntime + +RUN apt-get install -y \ + git \ + bash \ + python3 \ + cmake \ + git \ + build-essential \ + llvm \ + locales + +# onnxruntime built in tests need en_US.UTF-8 available +# Uncomment en_US.UTF-8, then generate +RUN sed -i 's/^# *\(en_US.UTF-8\)/\1/' /etc/locale.gen && locale-gen + +# build onnxruntime +RUN mkdir -p /opt/onnxruntime/tmp +# onnxruntime build relies on being in a git repo, so can't just get a tarball +# it's a big repo, so fetch shallowly +RUN cd /opt/onnxruntime/tmp && \ + git clone --recursive --depth 1 --shallow-submodules https://github.com/Microsoft/onnxruntime + +# use version that onnxruntime-sys expects +RUN cd /opt/onnxruntime/tmp/onnxruntime && \ + git fetch --depth 1 origin tag v1.6.0 && \ + git checkout v1.6.0 + +RUN /opt/onnxruntime/tmp/onnxruntime/build.sh --config RelWithDebInfo --build_shared_lib --parallel + +# Build ort-customops, linked against the onnxruntime built above. +# No tags / releases yet - that commit is from 2021-02-16 +RUN mkdir -p /opt/ort-customops/tmp && \ + cd /opt/ort-customops/tmp && \ + git clone --recursive https://github.com/microsoft/ort-customops.git && \ + cd ort-customops && \ + git checkout 92f6b51106c9e9143c452e537cb5e41d2dcaa266 + +RUN cd /opt/ort-customops/tmp/ort-customops && \ + ./build.sh -D ONNXRUNTIME_LIB_DIR=/opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo + + +# install rust toolchain +FROM base AS rust-toolchain + +ARG RUST_VERSION=1.50.0 + +RUN apt-get install -y \ + curl + +# install rust toolchain +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain $RUST_VERSION + +ENV PATH $PATH:/root/.cargo/bin + + +# build onnxruntime-rs +FROM rust-toolchain as onnxruntime-rs +# clang & llvm needed by onnxruntime-sys +RUN apt-get install -y \ + build-essential \ + llvm-dev \ + libclang-dev \ + clang + +RUN mkdir -p \ + /onnxruntime-rs/build/onnxruntime-sys/src/ \ + /onnxruntime-rs/build/onnxruntime/src/ \ + /onnxruntime-rs/build/onnxruntime/tests/ \ + /opt/onnxruntime/lib \ + /opt/ort-customops/lib + +COPY --from=onnxruntime /opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo/libonnxruntime.so /opt/onnxruntime/lib/ +COPY --from=onnxruntime /opt/ort-customops/tmp/ort-customops/out/Linux/libortcustomops.so /opt/ort-customops/lib/ + +WORKDIR /onnxruntime-rs/build + +ENV ORT_STRATEGY=system +# this has /lib/ appended to it and is used as a lib search path in onnxruntime-sys's build.rs +ENV ORT_LIB_LOCATION=/opt/onnxruntime/ + +ENV ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB=/opt/ort-customops/lib/libortcustomops.so + +# create enough of an empty project that dependencies can build +COPY /Cargo.lock /Cargo.toml /onnxruntime-rs/build/ +COPY /onnxruntime/Cargo.toml /onnxruntime-rs/build/onnxruntime/ +COPY /onnxruntime-sys/Cargo.toml /onnxruntime-sys/build.rs /onnxruntime-rs/build/onnxruntime-sys/ + +CMD cargo test + +# build dependencies and clean the bogus contents of our two packages +RUN touch \ + onnxruntime/src/lib.rs \ + onnxruntime/tests/integration_tests.rs \ + onnxruntime-sys/src/lib.rs \ + && cargo build --tests \ + && cargo clean --package onnxruntime-sys \ + && cargo clean --package onnxruntime \ + && rm -rf \ + onnxruntime/src/ \ + onnxruntime/tests/ \ + onnxruntime-sys/src/ + +# now build the actual source +COPY /test-models test-models +COPY /onnxruntime-sys/src onnxruntime-sys/src +COPY /onnxruntime/src onnxruntime/src +COPY /onnxruntime/tests onnxruntime/tests + +RUN ln -s /opt/onnxruntime/lib/libonnxruntime.so /opt/onnxruntime/lib/libonnxruntime.so.1.6.0 +ENV LD_LIBRARY_PATH=/opt/onnxruntime/lib + +RUN cargo build --tests diff --git a/onnxruntime/Cargo.toml b/onnxruntime/Cargo.toml index 9ceec820..88a0114c 100644 --- a/onnxruntime/Cargo.toml +++ b/onnxruntime/Cargo.toml @@ -26,6 +26,12 @@ ndarray = "0.13" thiserror = "1.0" tracing = "0.1" +[target.'cfg(unix)'.dependencies] +libc = "0.2.88" + +[target.'cfg(windows)'.dependencies] +winapi = { version = "0.3.9", features = ["std"] } + # Enabled with 'model-fetching' feature ureq = {version = "1.5.1", optional = true} diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 98099fd6..0e59ef6a 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -1,6 +1,6 @@ //! Module containing session types -use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path}; +use std::{convert::TryInto as _, ffi, ffi::CString, fmt::Debug, path::Path}; #[cfg(not(target_family = "windows"))] use std::os::unix::ffi::OsStrExt; @@ -64,11 +64,16 @@ pub struct SessionBuilder<'a> { allocator: AllocatorType, memory_type: MemType, + custom_runtime_handles: Vec<*mut ::std::os::raw::c_void>, } impl<'a> Drop for SessionBuilder<'a> { #[tracing::instrument] fn drop(&mut self) { + for &handle in self.custom_runtime_handles.iter() { + close_lib_handle(handle); + } + debug!("Dropping the session options."); assert_ne!(self.session_options_ptr, std::ptr::null_mut()); unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) }; @@ -89,6 +94,7 @@ impl<'a> SessionBuilder<'a> { session_options_ptr, allocator: AllocatorType::Arena, memory_type: MemType::Default, + custom_runtime_handles: Vec::new(), }) } @@ -136,6 +142,39 @@ impl<'a> SessionBuilder<'a> { Ok(self) } + /// Registers a custom ops library with the given library path in the session. + pub fn with_custom_op_lib(mut self, lib_path: &str) -> Result> { + let path_cstr = ffi::CString::new(lib_path)?; + + let mut handle: *mut ::std::os::raw::c_void = std::ptr::null_mut(); + + let status = unsafe { + g_ort().RegisterCustomOpsLibrary.unwrap()( + self.session_options_ptr, + path_cstr.as_ptr(), + &mut handle, + ) + }; + + // per RegisterCustomOpsLibrary docs, release handle if there was an error and the handle + // is non-null + match status_to_result(status).map_err(OrtError::SessionOptions) { + Ok(_) => {} + Err(e) => { + if handle != std::ptr::null_mut() { + // handle was written to, should release it + close_lib_handle(handle); + } + + return Err(e); + } + } + + self.custom_runtime_handles.push(handle); + + Ok(self) + } + /// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session #[cfg(feature = "model-fetching")] pub fn with_model_downloaded(self, model: M) -> Result> @@ -619,6 +658,20 @@ where res } +#[cfg(unix)] +fn close_lib_handle(handle: *mut ::std::os::raw::c_void) { + unsafe { + libc::dlclose(handle); + } +} + +#[cfg(windows)] +fn close_lib_handle(handle: *mut ::std::os::raw::c_void) { + unsafe { + winapi::um::libloaderapi::FreeLibrary(handle as winapi::shared::minwindef::HINSTANCE) + }; +} + /// This module contains dangerous functions working on raw pointers. /// Those functions are only to be used from inside the /// `SessionBuilder::with_model_from_file()` method. diff --git a/onnxruntime/tests/custom_ops.rs b/onnxruntime/tests/custom_ops.rs new file mode 100644 index 00000000..c4d62c46 --- /dev/null +++ b/onnxruntime/tests/custom_ops.rs @@ -0,0 +1,51 @@ +use std::error::Error; + +use ndarray; +use onnxruntime::tensor::{DynOrtTensor, OrtOwnedTensor}; +use onnxruntime::{environment::Environment, LoggingLevel}; + +#[test] +fn run_model_with_ort_customops() -> Result<(), Box> { + let lib_path = match std::env::var("ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB") { + Ok(s) => s, + Err(_e) => { + println!("Skipping ort_customops test -- no lib specified"); + return Ok(()); + } + }; + + let environment = Environment::builder() + .with_name("test") + .with_log_level(LoggingLevel::Verbose) + .build()?; + + let mut session = environment + .new_session_builder()? + .with_custom_op_lib(&lib_path)? + .with_model_from_file("../test-models/tensorflow/regex_model.onnx")?; + + //Inputs: + // 0: + // name = input_1:0 + // type = String + // dimensions = [None] + // Outputs: + // 0: + // name = Identity:0 + // type = String + // dimensions = [None] + + let array = ndarray::Array::from(vec![String::from("Hello world!")]); + let input_tensor_values = vec![array]; + + let outputs: Vec> = session.run(input_tensor_values)?; + let strings: OrtOwnedTensor = outputs[0].try_extract()?; + + // ' ' replaced with '_' + assert_eq!( + &[String::from("Hello_world!")], + strings.view().as_slice().unwrap() + ); + + Ok(()) +} diff --git a/test-models/tensorflow/README.md b/test-models/tensorflow/README.md index 4f2e68f2..6421fb6e 100644 --- a/test-models/tensorflow/README.md +++ b/test-models/tensorflow/README.md @@ -16,3 +16,12 @@ This supports strings, and doesn't require custom operators. pipenv run python src/unique_model.py pipenv run python -m tf2onnx.convert --saved-model models/unique_model --output unique_model.onnx --opset 11 ``` + +# Model: Regex (uses `ort_customops`) + +A TensorFlow model that applies a regex, which requires the onnxruntime custom ops in `ort-customops`. + +``` +pipenv run python src/regex_model.py +pipenv run python -m tf2onnx.convert --saved-model models/regex_model --output regex_model.onnx --extra_opset ai.onnx.contrib:1 +``` diff --git a/test-models/tensorflow/regex_model.onnx b/test-models/tensorflow/regex_model.onnx new file mode 100644 index 00000000..3b4390df --- /dev/null +++ b/test-models/tensorflow/regex_model.onnx @@ -0,0 +1,19 @@ +tf2onnx1.9.0:� + + input_1:0 + +pattern__7 + +rewrite__8 +Identity:0)PartitionedCall/model1/StaticRegexReplace"StringRegexReplace:ai.onnx.contribtf2onnx*2_B +rewrite__8*2 B +pattern__7R!converted from models/regex_modelZ + input_1:0 + + +unk__9b + +Identity:0 + + unk__10B B +ai.onnx.contrib \ No newline at end of file diff --git a/test-models/tensorflow/src/regex_model.py b/test-models/tensorflow/src/regex_model.py new file mode 100644 index 00000000..5958a631 --- /dev/null +++ b/test-models/tensorflow/src/regex_model.py @@ -0,0 +1,19 @@ +import tensorflow as tf +import numpy as np +import tf2onnx + + +class RegexModel(tf.keras.Model): + + def __init__(self, name='model1', **kwargs): + super(RegexModel, self).__init__(name=name, **kwargs) + + def call(self, inputs): + return tf.strings.regex_replace(inputs, " ", "_", replace_global=True) + + +model1 = RegexModel() + +print(model1(tf.constant(["Hello world!"]))) + +model1.save("models/regex_model")