Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for registring custom op libraries #68

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
*
!/Cargo.*
!/onnxruntime/Cargo.toml
!/onnxruntime/src
!/onnxruntime/tests
!/onnxruntime-sys/Cargo.toml
!/onnxruntime-sys/build.rs
!/onnxruntime-sys/src
!/test-models/tensorflow/*.onnx
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Add `String` datatype ([#58](https://github.com/nbigaouette/onnxruntime-rs/pull/58))
- Support custom operator libraries

## [0.0.11] - 2021-02-22

Expand Down
118 changes: 118 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# onnxruntime requires execinfo.h to build, which only works on glibc-based systems, so alpine is out...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dockerfile sets up onnxruntime and ort-customops so that the "regex" model can run. Not sure how best to integrate this into CI.

FROM debian:bullseye-slim as base

RUN apt-get update && apt-get -y dist-upgrade

FROM base AS onnxruntime

RUN apt-get install -y \
git \
bash \
python3 \
cmake \
git \
build-essential \
llvm \
locales

# onnxruntime built in tests need en_US.UTF-8 available
# Uncomment en_US.UTF-8, then generate
RUN sed -i 's/^# *\(en_US.UTF-8\)/\1/' /etc/locale.gen && locale-gen

# build onnxruntime
RUN mkdir -p /opt/onnxruntime/tmp
# onnxruntime build relies on being in a git repo, so can't just get a tarball
# it's a big repo, so fetch shallowly
RUN cd /opt/onnxruntime/tmp && \
git clone --recursive --depth 1 --shallow-submodules https://github.com/Microsoft/onnxruntime

# use version that onnxruntime-sys expects
RUN cd /opt/onnxruntime/tmp/onnxruntime && \
git fetch --depth 1 origin tag v1.6.0 && \
git checkout v1.6.0

RUN /opt/onnxruntime/tmp/onnxruntime/build.sh --config RelWithDebInfo --build_shared_lib --parallel

# Build ort-customops, linked against the onnxruntime built above.
# No tags / releases yet - that commit is from 2021-02-16
RUN mkdir -p /opt/ort-customops/tmp && \
cd /opt/ort-customops/tmp && \
git clone --recursive https://github.com/microsoft/ort-customops.git && \
cd ort-customops && \
git checkout 92f6b51106c9e9143c452e537cb5e41d2dcaa266

RUN cd /opt/ort-customops/tmp/ort-customops && \
./build.sh -D ONNXRUNTIME_LIB_DIR=/opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo


# install rust toolchain
FROM base AS rust-toolchain

ARG RUST_VERSION=1.50.0

RUN apt-get install -y \
curl

# install rust toolchain
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain $RUST_VERSION

ENV PATH $PATH:/root/.cargo/bin


# build onnxruntime-rs
FROM rust-toolchain as onnxruntime-rs
# clang & llvm needed by onnxruntime-sys
RUN apt-get install -y \
build-essential \
llvm-dev \
libclang-dev \
clang

RUN mkdir -p \
/onnxruntime-rs/build/onnxruntime-sys/src/ \
/onnxruntime-rs/build/onnxruntime/src/ \
/onnxruntime-rs/build/onnxruntime/tests/ \
/opt/onnxruntime/lib \
/opt/ort-customops/lib

COPY --from=onnxruntime /opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo/libonnxruntime.so /opt/onnxruntime/lib/
COPY --from=onnxruntime /opt/ort-customops/tmp/ort-customops/out/Linux/libortcustomops.so /opt/ort-customops/lib/

WORKDIR /onnxruntime-rs/build

ENV ORT_STRATEGY=system
# this has /lib/ appended to it and is used as a lib search path in onnxruntime-sys's build.rs
ENV ORT_LIB_LOCATION=/opt/onnxruntime/

ENV ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB=/opt/ort-customops/lib/libortcustomops.so

# create enough of an empty project that dependencies can build
COPY /Cargo.lock /Cargo.toml /onnxruntime-rs/build/
COPY /onnxruntime/Cargo.toml /onnxruntime-rs/build/onnxruntime/
COPY /onnxruntime-sys/Cargo.toml /onnxruntime-sys/build.rs /onnxruntime-rs/build/onnxruntime-sys/

CMD cargo test

# build dependencies and clean the bogus contents of our two packages
RUN touch \
onnxruntime/src/lib.rs \
onnxruntime/tests/integration_tests.rs \
onnxruntime-sys/src/lib.rs \
&& cargo build --tests \
&& cargo clean --package onnxruntime-sys \
&& cargo clean --package onnxruntime \
&& rm -rf \
onnxruntime/src/ \
onnxruntime/tests/ \
onnxruntime-sys/src/

# now build the actual source
COPY /test-models test-models
COPY /onnxruntime-sys/src onnxruntime-sys/src
COPY /onnxruntime/src onnxruntime/src
COPY /onnxruntime/tests onnxruntime/tests

RUN ln -s /opt/onnxruntime/lib/libonnxruntime.so /opt/onnxruntime/lib/libonnxruntime.so.1.6.0
ENV LD_LIBRARY_PATH=/opt/onnxruntime/lib

RUN cargo build --tests
6 changes: 6 additions & 0 deletions onnxruntime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ ndarray = "0.13"
thiserror = "1.0"
tracing = "0.1"

[target.'cfg(unix)'.dependencies]
libc = "0.2.88"

[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3.9", features = ["std"] }

# Enabled with 'model-fetching' feature
ureq = {version = "1.5.1", optional = true}

Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/examples/issue22.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ fn main() {
let input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap();
let attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap();

let outputs: Vec<OrtOwnedTensor<f32, _>> =
session.run(vec![input_ids, attention_mask]).unwrap();
let outputs: Vec<OrtOwnedTensor<f32, _>> = session
.run(vec![input_ids, attention_mask])
.unwrap()
.into_iter()
.map(|dyn_tensor| dyn_tensor.try_extract())
.collect::<Result<_, _>>()
.unwrap();
print!("outputs: {:#?}", outputs);
}
13 changes: 8 additions & 5 deletions onnxruntime/examples/sample.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#![forbid(unsafe_code)]

use onnxruntime::{
environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel,
LoggingLevel,
environment::Environment,
ndarray::Array,
tensor::{DynOrtTensor, OrtOwnedTensor},
GraphOptimizationLevel, LoggingLevel,
};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;
Expand Down Expand Up @@ -61,11 +63,12 @@ fn run() -> Result<(), Error> {
.unwrap();
let input_tensor_values = vec![array];

let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor_values)?;
let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?;

assert_eq!(outputs[0].shape(), output0_shape.as_slice());
let output: OrtOwnedTensor<f32, _> = outputs[0].try_extract().unwrap();
assert_eq!(output.view().shape(), output0_shape.as_slice());
for i in 0..5 {
println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]);
println!("Score for class [{}] = {}", i, output.view()[[0, i, 0, 0]]);
}

Ok(())
Expand Down
14 changes: 13 additions & 1 deletion onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing error definitions.

use std::{io, path::PathBuf};
use std::{io, path::PathBuf, string};

use thiserror::Error;

Expand Down Expand Up @@ -53,6 +53,12 @@ pub enum OrtError {
/// Error occurred when getting ONNX dimensions
#[error("Failed to get dimensions: {0}")]
GetDimensions(OrtApiError),
/// Error occurred when getting string length
#[error("Failed to get string tensor length: {0}")]
GetStringTensorDataLength(OrtApiError),
/// Error occurred when getting tensor element count
#[error("Failed to get tensor element count: {0}")]
GetTensorShapeElementCount(OrtApiError),
/// Error occurred when creating CPU memory information
#[error("Failed to get dimensions: {0}")]
CreateCpuMemoryInfo(OrtApiError),
Expand All @@ -77,6 +83,12 @@ pub enum OrtError {
/// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView`
#[error("Failed to get tensor data: {0}")]
GetTensorMutableData(OrtApiError),
/// Error occurred when extracting string data from an ONNX tensor
#[error("Failed to get tensor string data: {0}")]
GetStringTensorContent(OrtApiError),
/// Error occurred when converting data to a String
#[error("Data was not UTF-8: {0}")]
StringFromUtf8Error(#[from] string::FromUtf8Error),

/// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models)
#[error("Failed to download ONNX model: {0}")]
Expand Down
Loading