From 0d34d23d0b50aa664fc5b4c9e1337ea73e496541 Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Wed, 24 Feb 2021 09:48:34 -0700 Subject: [PATCH 1/6] Introduce a new trait to represent types that can be used as output from a tensor This is some prep work for string output types and tensor types that vary across the model outputs. For now, the supported types are just the basic numeric types. Since strings have to be copied out of a tensor, it only makes sense to have `String` be an output type, not `str`, hence the new type so that we can have more input types supported than output types. --- onnxruntime/src/lib.rs | 148 -------------- onnxruntime/src/session.rs | 112 ++++++---- onnxruntime/src/tensor.rs | 226 +++++++++++++++++++++ onnxruntime/src/tensor/ort_owned_tensor.rs | 83 ++------ onnxruntime/src/tensor/ort_tensor.rs | 8 +- 5 files changed, 322 insertions(+), 255 deletions(-) diff --git a/onnxruntime/src/lib.rs b/onnxruntime/src/lib.rs index 0d575b5e..66aa7da6 100644 --- a/onnxruntime/src/lib.rs +++ b/onnxruntime/src/lib.rs @@ -322,154 +322,6 @@ impl Into for GraphOptimizationLevel { } } -// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum -// FIXME: Add tests to cover the commented out types -/// Enum mapping ONNX Runtime's supported tensor types -#[derive(Debug)] -#[cfg_attr(not(windows), repr(u32))] -#[cfg_attr(windows, repr(i32))] -pub enum TensorElementDataType { - /// 32-bit floating point, equivalent to Rust's `f32` - Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt, - /// Unsigned 8-bit int, equivalent to Rust's `u8` - Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt, - /// Signed 8-bit int, equivalent to Rust's `i8` - Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt, - /// Unsigned 16-bit int, equivalent to Rust's `u16` - Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt, - /// Signed 16-bit int, equivalent to Rust's `i16` - Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt, - /// Signed 32-bit int, equivalent to Rust's `i32` - Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt, - /// Signed 64-bit int, equivalent to Rust's `i64` - Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt, - /// String, equivalent to Rust's `String` - String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt, - // /// Boolean, equivalent to Rust's `bool` - // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt, - // /// 16-bit floating point, equivalent to Rust's `f16` - // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt, - /// 64-bit floating point, equivalent to Rust's `f64` - Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt, - /// Unsigned 32-bit int, equivalent to Rust's `u32` - Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt, - /// Unsigned 64-bit int, equivalent to Rust's `u64` - Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt, - // /// Complex 64-bit floating point, equivalent to Rust's `???` - // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt, - // /// Complex 128-bit floating point, equivalent to Rust's `???` - // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt, - // /// Brain 16-bit floating point - // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt, -} - -impl Into for TensorElementDataType { - fn into(self) -> sys::ONNXTensorElementDataType { - use TensorElementDataType::*; - match self { - Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, - Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, - Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, - Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, - Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, - Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, - String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, - // Bool => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - // } - // Float16 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 - // } - Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, - Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, - Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, - // Complex64 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 - // } - // Complex128 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 - // } - // Bfloat16 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 - // } - } - } -} - -/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`) -pub trait TypeToTensorElementDataType { - /// Return the ONNX type for a Rust type - fn tensor_element_data_type() -> TensorElementDataType; - - /// If the type is `String`, returns `Some` with utf8 contents, else `None`. - fn try_utf8_bytes(&self) -> Option<&[u8]>; -} - -macro_rules! impl_type_trait { - ($type_:ty, $variant:ident) => { - impl TypeToTensorElementDataType for $type_ { - fn tensor_element_data_type() -> TensorElementDataType { - // unsafe { std::mem::transmute(TensorElementDataType::$variant) } - TensorElementDataType::$variant - } - - fn try_utf8_bytes(&self) -> Option<&[u8]> { - None - } - } - }; -} - -impl_type_trait!(f32, Float); -impl_type_trait!(u8, Uint8); -impl_type_trait!(i8, Int8); -impl_type_trait!(u16, Uint16); -impl_type_trait!(i16, Int16); -impl_type_trait!(i32, Int32); -impl_type_trait!(i64, Int64); -// impl_type_trait!(bool, Bool); -// impl_type_trait!(f16, Float16); -impl_type_trait!(f64, Double); -impl_type_trait!(u32, Uint32); -impl_type_trait!(u64, Uint64); -// impl_type_trait!(, Complex64); -// impl_type_trait!(, Complex128); -// impl_type_trait!(, Bfloat16); - -/// Adapter for common Rust string types to Onnx strings. -/// -/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but -/// we can't define an automatic implementation for anything that implements `AsRef` as it -/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric -/// types (which might implement `AsRef` at some point in the future). -pub trait Utf8Data { - /// Returns the utf8 contents. - fn utf8_bytes(&self) -> &[u8]; -} - -impl Utf8Data for String { - fn utf8_bytes(&self) -> &[u8] { - self.as_bytes() - } -} - -impl<'a> Utf8Data for &'a str { - fn utf8_bytes(&self) -> &[u8] { - self.as_bytes() - } -} - -impl TypeToTensorElementDataType for T { - fn tensor_element_data_type() -> TensorElementDataType { - TensorElementDataType::String - } - - fn try_utf8_bytes(&self) -> Option<&[u8]> { - Some(self.utf8_bytes()) - } -} - /// Allocator type #[derive(Debug, Clone)] #[repr(i32)] diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 04f9cf1c..d212111e 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -18,15 +18,14 @@ use onnxruntime_sys as sys; use crate::{ char_p_to_string, environment::Environment, - error::{status_to_result, NonMatchingDimensionsError, OrtError, Result}, + error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result}, g_ort, memory::MemoryInfo, tensor::{ - ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor}, - OrtTensor, + ort_owned_tensor::OrtOwnedTensor, OrtTensor, TensorDataToType, TensorElementDataType, + TypeToTensorElementDataType, }, - AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType, - TypeToTensorElementDataType, + AllocatorType, GraphOptimizationLevel, MemType, }; #[cfg(feature = "model-fetching")] @@ -371,7 +370,7 @@ impl<'a> Session<'a> { ) -> Result>> where TIn: TypeToTensorElementDataType + Debug + Clone, - TOut: TypeToTensorElementDataType + Debug + Clone, + TOut: TensorDataToType, D: ndarray::Dimension, 'm: 't, // 'm outlives 't (memory info outlives tensor) 's: 'm, // 's outlives 'm (session outlives memory info) @@ -440,21 +439,30 @@ impl<'a> Session<'a> { let outputs: Result>>> = output_tensor_extractors_ptrs .into_iter() - .map(|ptr| { - let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = - std::ptr::null_mut(); - let status = unsafe { - g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _) - }; - status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?; - let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) }; - unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) }; - let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect(); - - let mut output_tensor_extractor = - OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims)); - output_tensor_extractor.tensor_ptr = ptr; - output_tensor_extractor.extract::() + .map(|tensor_ptr| { + let dims = unsafe { + call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { + get_tensor_dimensions(tensor_info_ptr) + .map(|dims| dims.iter().map(|&n| n as usize).collect::>()) + }) + }?; + + // Note: Both tensor and array will point to the same data, nothing is copied. + // As such, there is no need to free the pointer used to create the ArrayView. + assert_ne!(tensor_ptr, std::ptr::null_mut()); + + let mut is_tensor = 0; + unsafe { call_ort(|ort| ort.IsTensor.unwrap()(tensor_ptr, &mut is_tensor)) } + .map_err(OrtError::IsTensor)?; + assert_eq!(is_tensor, 1); + + let array_view = TOut::extract_array(ndarray::IxDyn(&dims), tensor_ptr)?; + + Ok(OrtOwnedTensor::new( + tensor_ptr, + array_view, + &memory_info_ref, + )) }) .collect(); @@ -554,25 +562,60 @@ unsafe fn get_tensor_dimensions( tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, ) -> Result> { let mut num_dims = 0; - let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims); - status_to_result(status).map_err(OrtError::GetDimensionsCount)?; + call_ort(|ort| ort.GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims)) + .map_err(OrtError::GetDimensionsCount)?; assert_ne!(num_dims, 0); let mut node_dims: Vec = vec![0; num_dims as usize]; - let status = g_ort().GetDimensions.unwrap()( - tensor_info_ptr, - node_dims.as_mut_ptr(), // FIXME: UB? - num_dims, - ); - status_to_result(status).map_err(OrtError::GetDimensions)?; + call_ort(|ort| { + ort.GetDimensions.unwrap()( + tensor_info_ptr, + node_dims.as_mut_ptr(), // FIXME: UB? + num_dims, + ) + }) + .map_err(OrtError::GetDimensions)?; Ok(node_dims) } +unsafe fn extract_data_type( + tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, +) -> Result { + let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + call_ort(|ort| ort.GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys)) + .map_err(OrtError::TensorElementType)?; + assert_ne!( + type_sys, + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED + ); + // This transmute should be safe since its value is read from GetTensorElementType which we must trust. + Ok(std::mem::transmute(type_sys)) +} + +/// Calls the provided closure with the result of `GetTensorTypeAndShape`, deallocating the +/// resulting `*OrtTensorTypeAndShapeInfo` before returning. +unsafe fn call_with_tensor_info(tensor_ptr: *const sys::OrtValue, mut f: F) -> Result +where + F: FnMut(*const sys::OrtTensorTypeAndShapeInfo) -> Result, +{ + let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + call_ort(|ort| ort.GetTensorTypeAndShape.unwrap()(tensor_ptr, &mut tensor_info_ptr as _)) + .map_err(OrtError::GetTensorTypeAndShape)?; + + let res = f(tensor_info_ptr); + + // no return code, so no errors to check for + g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr); + + res +} + /// This module contains dangerous functions working on raw pointers. /// Those functions are only to be used from inside the /// `SessionBuilder::with_model_from_file()` method. mod dangerous { use super::*; + use crate::tensor::TensorElementDataType; pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> Result { let f = g_ort().SessionGetInputCount.unwrap(); @@ -689,16 +732,7 @@ mod dangerous { status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?; assert_ne!(tensor_info_ptr, std::ptr::null_mut()); - let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - let status = - unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) }; - status_to_result(status).map_err(OrtError::TensorElementType)?; - assert_ne!( - type_sys, - sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED - ); - // This transmute should be safe since its value is read from GetTensorElementType which we must trust. - let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) }; + let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? }; // info!("{} : type={}", i, type_); diff --git a/onnxruntime/src/tensor.rs b/onnxruntime/src/tensor.rs index 92404842..a5178c91 100644 --- a/onnxruntime/src/tensor.rs +++ b/onnxruntime/src/tensor.rs @@ -29,3 +29,229 @@ pub mod ort_tensor; pub use ort_owned_tensor::OrtOwnedTensor; pub use ort_tensor::OrtTensor; + +use crate::{OrtError, Result}; +use onnxruntime_sys::{self as sys, OnnxEnumInt}; +use std::{fmt, ptr}; + +// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum +// FIXME: Add tests to cover the commented out types +/// Enum mapping ONNX Runtime's supported tensor types +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(not(windows), repr(u32))] +#[cfg_attr(windows, repr(i32))] +pub enum TensorElementDataType { + /// 32-bit floating point, equivalent to Rust's `f32` + Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt, + /// Unsigned 8-bit int, equivalent to Rust's `u8` + Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt, + /// Signed 8-bit int, equivalent to Rust's `i8` + Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt, + /// Unsigned 16-bit int, equivalent to Rust's `u16` + Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt, + /// Signed 16-bit int, equivalent to Rust's `i16` + Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt, + /// Signed 32-bit int, equivalent to Rust's `i32` + Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt, + /// Signed 64-bit int, equivalent to Rust's `i64` + Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt, + /// String, equivalent to Rust's `String` + String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt, + // /// Boolean, equivalent to Rust's `bool` + // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt, + // /// 16-bit floating point, equivalent to Rust's `f16` + // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt, + /// 64-bit floating point, equivalent to Rust's `f64` + Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt, + /// Unsigned 32-bit int, equivalent to Rust's `u32` + Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt, + /// Unsigned 64-bit int, equivalent to Rust's `u64` + Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt, + // /// Complex 64-bit floating point, equivalent to Rust's `???` + // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt, + // /// Complex 128-bit floating point, equivalent to Rust's `???` + // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt, + // /// Brain 16-bit floating point + // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt, +} + +impl Into for TensorElementDataType { + fn into(self) -> sys::ONNXTensorElementDataType { + use TensorElementDataType::*; + match self { + Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, + Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, + Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, + Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, + Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, + Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, + // Bool => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL + // } + // Float16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 + // } + Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, + Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, + Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, + // Complex64 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 + // } + // Complex128 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 + // } + // Bfloat16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 + // } + } + } +} + +/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`) +pub trait TypeToTensorElementDataType { + /// Return the ONNX type for a Rust type + fn tensor_element_data_type() -> TensorElementDataType; + + /// If the type is `String`, returns `Some` with utf8 contents, else `None`. + fn try_utf8_bytes(&self) -> Option<&[u8]>; +} + +macro_rules! impl_prim_type_to_ort_trait { + ($type_:ty, $variant:ident) => { + impl TypeToTensorElementDataType for $type_ { + fn tensor_element_data_type() -> TensorElementDataType { + // unsafe { std::mem::transmute(TensorElementDataType::$variant) } + TensorElementDataType::$variant + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + None + } + } + }; +} + +impl_prim_type_to_ort_trait!(f32, Float); +impl_prim_type_to_ort_trait!(u8, Uint8); +impl_prim_type_to_ort_trait!(i8, Int8); +impl_prim_type_to_ort_trait!(u16, Uint16); +impl_prim_type_to_ort_trait!(i16, Int16); +impl_prim_type_to_ort_trait!(i32, Int32); +impl_prim_type_to_ort_trait!(i64, Int64); +// impl_type_trait!(bool, Bool); +// impl_type_trait!(f16, Float16); +impl_prim_type_to_ort_trait!(f64, Double); +impl_prim_type_to_ort_trait!(u32, Uint32); +impl_prim_type_to_ort_trait!(u64, Uint64); +// impl_type_trait!(, Complex64); +// impl_type_trait!(, Complex128); +// impl_type_trait!(, Bfloat16); + +/// Adapter for common Rust string types to Onnx strings. +/// +/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but +/// we can't define an automatic implementation for anything that implements `AsRef` as it +/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric +/// types (which might implement `AsRef` at some point in the future). +pub trait Utf8Data { + /// Returns the utf8 contents. + fn utf8_bytes(&self) -> &[u8]; +} + +impl Utf8Data for String { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl<'a> Utf8Data for &'a str { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl TypeToTensorElementDataType for T { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::String + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + Some(self.utf8_bytes()) + } +} + +/// Trait used to map onnxruntime types to Rust types +pub trait TensorDataToType: Sized + fmt::Debug { + /// The tensor element type that this type can extract from + fn tensor_element_data_type() -> TensorElementDataType; + + /// Extract an `ArrayView` from the ort-owned tensor. + fn extract_array<'t, D>( + shape: D, + tensor: *mut sys::OrtValue, + ) -> Result> + where + D: ndarray::Dimension; +} + +/// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData` +macro_rules! impl_prim_type_from_ort_trait { + ($type_:ty, $variant:ident) => { + impl TensorDataToType for $type_ { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::$variant + } + + fn extract_array<'t, D>( + shape: D, + tensor: *mut sys::OrtValue, + ) -> Result> + where + D: ndarray::Dimension, + { + extract_primitive_array(shape, tensor) + } + } + }; +} + +/// Construct an [ndarray::ArrayView] over an Ort tensor. +/// +/// Only to be used on types whose Rust in-memory representation matches Ort's (e.g. primitive +/// numeric types like u32). +fn extract_primitive_array<'t, D, T: TensorDataToType>( + shape: D, + tensor: *mut sys::OrtValue, +) -> Result> +where + D: ndarray::Dimension, +{ + // Get pointer to output tensor float values + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = + output_array_ptr_ptr as *mut *mut std::ffi::c_void; + unsafe { + crate::error::call_ort(|ort| { + ort.GetTensorMutableData.unwrap()(tensor, output_array_ptr_ptr_void) + }) + } + .map_err(OrtError::GetTensorMutableData)?; + assert_ne!(output_array_ptr, ptr::null_mut()); + + let array_view = unsafe { ndarray::ArrayView::from_shape_ptr(shape, output_array_ptr) }; + Ok(array_view) +} + +impl_prim_type_from_ort_trait!(f32, Float); +impl_prim_type_from_ort_trait!(u8, Uint8); +impl_prim_type_from_ort_trait!(i8, Int8); +impl_prim_type_from_ort_trait!(u16, Uint16); +impl_prim_type_from_ort_trait!(i16, Int16); +impl_prim_type_from_ort_trait!(i32, Int32); +impl_prim_type_from_ort_trait!(i64, Int64); +impl_prim_type_from_ort_trait!(f64, Double); +impl_prim_type_from_ort_trait!(u32, Uint32); +impl_prim_type_from_ort_trait!(u64, Uint64); diff --git a/onnxruntime/src/tensor/ort_owned_tensor.rs b/onnxruntime/src/tensor/ort_owned_tensor.rs index 161fe105..bfff3dea 100644 --- a/onnxruntime/src/tensor/ort_owned_tensor.rs +++ b/onnxruntime/src/tensor/ort_owned_tensor.rs @@ -7,10 +7,8 @@ use tracing::debug; use onnxruntime_sys as sys; -use crate::{ - error::status_to_result, g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor, - OrtError, Result, TypeToTensorElementDataType, -}; +use crate::tensor::TensorDataToType; +use crate::{g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor}; /// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference. /// @@ -25,7 +23,7 @@ use crate::{ #[derive(Debug)] pub struct OrtOwnedTensor<'t, 'm, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, 'm: 't, // 'm outlives 't { @@ -36,7 +34,7 @@ where impl<'t, 'm, T, D> Deref for OrtOwnedTensor<'t, 'm, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, { type Target = ArrayView<'t, T, D>; @@ -48,9 +46,21 @@ where impl<'t, 'm, T, D> OrtOwnedTensor<'t, 'm, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, { + pub(crate) fn new( + tensor_ptr: *mut sys::OrtValue, + array_view: ArrayView<'t, T, D>, + memory_info: &'m MemoryInfo, + ) -> OrtOwnedTensor<'t, 'm, T, D> { + OrtOwnedTensor { + tensor_ptr, + array_view, + memory_info, + } + } + /// Apply a softmax on the specified axis pub fn softmax(&self, axis: ndarray::Axis) -> Array where @@ -61,66 +71,9 @@ where } } -#[derive(Debug)] -pub(crate) struct OrtOwnedTensorExtractor<'m, D> -where - D: ndarray::Dimension, -{ - pub(crate) tensor_ptr: *mut sys::OrtValue, - memory_info: &'m MemoryInfo, - shape: D, -} - -impl<'m, D> OrtOwnedTensorExtractor<'m, D> -where - D: ndarray::Dimension, -{ - pub(crate) fn new(memory_info: &'m MemoryInfo, shape: D) -> OrtOwnedTensorExtractor<'m, D> { - OrtOwnedTensorExtractor { - tensor_ptr: std::ptr::null_mut(), - memory_info, - shape, - } - } - - pub(crate) fn extract<'t, T>(self) -> Result> - where - T: TypeToTensorElementDataType + Debug + Clone, - { - // Note: Both tensor and array will point to the same data, nothing is copied. - // As such, there is no need too free the pointer used to create the ArrayView. - - assert_ne!(self.tensor_ptr, std::ptr::null_mut()); - - let mut is_tensor = 0; - let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) }; - status_to_result(status).map_err(OrtError::IsTensor)?; - assert_eq!(is_tensor, 1); - - // Get pointer to output tensor float values - let mut output_array_ptr: *mut T = std::ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = - output_array_ptr_ptr as *mut *mut std::ffi::c_void; - let status = unsafe { - g_ort().GetTensorMutableData.unwrap()(self.tensor_ptr, output_array_ptr_ptr_void) - }; - status_to_result(status).map_err(OrtError::IsTensor)?; - assert_ne!(output_array_ptr, std::ptr::null_mut()); - - let array_view = unsafe { ArrayView::from_shape_ptr(self.shape, output_array_ptr) }; - - Ok(OrtOwnedTensor { - tensor_ptr: self.tensor_ptr, - array_view, - memory_info: self.memory_info, - }) - } -} - impl<'t, 'm, T, D> Drop for OrtOwnedTensor<'t, 'm, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, 'm: 't, // 'm outlives 't { diff --git a/onnxruntime/src/tensor/ort_tensor.rs b/onnxruntime/src/tensor/ort_tensor.rs index 437e2e86..0937afe1 100644 --- a/onnxruntime/src/tensor/ort_tensor.rs +++ b/onnxruntime/src/tensor/ort_tensor.rs @@ -8,9 +8,11 @@ use tracing::{debug, error}; use onnxruntime_sys as sys; use crate::{ - error::call_ort, error::status_to_result, g_ort, memory::MemoryInfo, - tensor::ndarray_tensor::NdArrayTensor, OrtError, Result, TensorElementDataType, - TypeToTensorElementDataType, + error::{call_ort, status_to_result}, + g_ort, + memory::MemoryInfo, + tensor::{ndarray_tensor::NdArrayTensor, TensorElementDataType, TypeToTensorElementDataType}, + OrtError, Result, }; /// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) From 555bec708e30009f2d5cc4a27e34a3830b15dd2a Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Fri, 26 Feb 2021 09:37:52 -0700 Subject: [PATCH 2/6] Use `DynOrtTensor` for model output tensors Outputs aren't all the same type for a single model, so this allows extracting different types per tensor. --- onnxruntime/examples/issue22.rs | 9 +- onnxruntime/examples/sample.rs | 13 +- onnxruntime/src/lib.rs | 30 ++-- onnxruntime/src/session.rs | 42 +++--- onnxruntime/src/tensor.rs | 2 +- onnxruntime/src/tensor/ort_owned_tensor.rs | 153 ++++++++++++++++++--- onnxruntime/tests/integration_tests.rs | 25 ++-- onnxruntime/tests/string_type.rs | 48 +++++++ test-models/tensorflow/.gitignore | 2 + test-models/tensorflow/Pipfile | 13 ++ test-models/tensorflow/README.md | 18 +++ test-models/tensorflow/src/unique_model.py | 19 +++ test-models/tensorflow/unique_model.onnx | Bin 0 -> 424 bytes 13 files changed, 295 insertions(+), 79 deletions(-) create mode 100644 onnxruntime/tests/string_type.rs create mode 100644 test-models/tensorflow/.gitignore create mode 100644 test-models/tensorflow/Pipfile create mode 100644 test-models/tensorflow/README.md create mode 100644 test-models/tensorflow/src/unique_model.py create mode 100644 test-models/tensorflow/unique_model.onnx diff --git a/onnxruntime/examples/issue22.rs b/onnxruntime/examples/issue22.rs index b2879b91..9dbd5d5b 100644 --- a/onnxruntime/examples/issue22.rs +++ b/onnxruntime/examples/issue22.rs @@ -34,7 +34,12 @@ fn main() { let input_ids = Array2::::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap(); let attention_mask = Array2::::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap(); - let outputs: Vec> = - session.run(vec![input_ids, attention_mask]).unwrap(); + let outputs: Vec> = session + .run(vec![input_ids, attention_mask]) + .unwrap() + .into_iter() + .map(|dyn_tensor| dyn_tensor.try_extract()) + .collect::>() + .unwrap(); print!("outputs: {:#?}", outputs); } diff --git a/onnxruntime/examples/sample.rs b/onnxruntime/examples/sample.rs index d16d08da..3fbc2670 100644 --- a/onnxruntime/examples/sample.rs +++ b/onnxruntime/examples/sample.rs @@ -1,8 +1,10 @@ #![forbid(unsafe_code)] use onnxruntime::{ - environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel, - LoggingLevel, + environment::Environment, + ndarray::Array, + tensor::{DynOrtTensor, OrtOwnedTensor}, + GraphOptimizationLevel, LoggingLevel, }; use tracing::Level; use tracing_subscriber::FmtSubscriber; @@ -61,11 +63,12 @@ fn run() -> Result<(), Error> { .unwrap(); let input_tensor_values = vec![array]; - let outputs: Vec> = session.run(input_tensor_values)?; + let outputs: Vec> = session.run(input_tensor_values)?; - assert_eq!(outputs[0].shape(), output0_shape.as_slice()); + let output: OrtOwnedTensor = outputs[0].try_extract().unwrap(); + assert_eq!(output.shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); } Ok(()) diff --git a/onnxruntime/src/lib.rs b/onnxruntime/src/lib.rs index 66aa7da6..6ae7c333 100644 --- a/onnxruntime/src/lib.rs +++ b/onnxruntime/src/lib.rs @@ -104,7 +104,10 @@ to download. //! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100); //! // Multiple inputs and outputs are possible //! let input_tensor = vec![array]; -//! let outputs: Vec> = session.run(input_tensor)?; +//! let outputs: Vec> = session.run(input_tensor)? +//! .into_iter() +//! .map(|dyn_tensor| dyn_tensor.try_extract()) +//! .collect::>()?; //! # Ok(()) //! # } //! ``` @@ -115,7 +118,10 @@ to download. //! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs) //! example for more details. -use std::sync::{atomic::AtomicPtr, Arc, Mutex}; +use std::{ + ffi, ptr, + sync::{atomic::AtomicPtr, Arc, Mutex}, +}; use lazy_static::lazy_static; @@ -142,7 +148,7 @@ lazy_static! { // } as *mut sys::OrtApi))); static ref G_ORT_API: Arc>> = { let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() }; - assert_ne!(base, std::ptr::null()); + assert_ne!(base, ptr::null()); let get_api: unsafe extern "C" fn(u32) -> *const onnxruntime_sys::OrtApi = unsafe { (*base).GetApi.unwrap() }; let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) }; @@ -157,13 +163,13 @@ fn g_ort() -> sys::OrtApi { let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut(); let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut; - assert_ne!(api_ptr_mut, std::ptr::null_mut()); + assert_ne!(api_ptr_mut, ptr::null_mut()); unsafe { *api_ptr_mut } } fn char_p_to_string(raw: *const i8) -> Result { - let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() }; + let c_string = unsafe { ffi::CStr::from_ptr(raw as *mut i8).to_owned() }; match c_string.into_string() { Ok(string) => Ok(string), @@ -176,7 +182,7 @@ mod onnxruntime { //! Module containing a custom logger, used to catch the runtime's own logging and send it //! to Rust's tracing logging instead. - use std::ffi::CStr; + use std::{ffi, ffi::CStr, ptr}; use tracing::{debug, error, info, span, trace, warn, Level}; use onnxruntime_sys as sys; @@ -212,7 +218,7 @@ mod onnxruntime { /// Callback from C that will handle the logging, forwarding the runtime's logs to the tracing crate. pub(crate) extern "C" fn custom_logger( - _params: *mut std::ffi::c_void, + _params: *mut ffi::c_void, severity: sys::OrtLoggingLevel, category: *const i8, logid: *const i8, @@ -227,16 +233,16 @@ mod onnxruntime { sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR, }; - assert_ne!(category, std::ptr::null()); + assert_ne!(category, ptr::null()); let category = unsafe { CStr::from_ptr(category) }; - assert_ne!(code_location, std::ptr::null()); + assert_ne!(code_location, ptr::null()); let code_location = unsafe { CStr::from_ptr(code_location) } .to_str() .unwrap_or("unknown"); - assert_ne!(message, std::ptr::null()); + assert_ne!(message, ptr::null()); let message = unsafe { CStr::from_ptr(message) }; - assert_ne!(logid, std::ptr::null()); + assert_ne!(logid, ptr::null()); let logid = unsafe { CStr::from_ptr(logid) }; // Parse the code location @@ -376,7 +382,7 @@ mod test { #[test] fn test_char_p_to_string() { - let s = std::ffi::CString::new("foo").unwrap(); + let s = ffi::CString::new("foo").unwrap(); let ptr = s.as_c_str().as_ptr(); assert_eq!("foo", char_p_to_string(ptr).unwrap()); } diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index d212111e..232d188d 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -21,10 +21,7 @@ use crate::{ error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result}, g_ort, memory::MemoryInfo, - tensor::{ - ort_owned_tensor::OrtOwnedTensor, OrtTensor, TensorDataToType, TensorElementDataType, - TypeToTensorElementDataType, - }, + tensor::{DynOrtTensor, OrtTensor, TensorElementDataType, TypeToTensorElementDataType}, AllocatorType, GraphOptimizationLevel, MemType, }; @@ -364,13 +361,12 @@ impl<'a> Session<'a> { /// /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus /// used for the input data here. - pub fn run<'s, 't, 'm, TIn, TOut, D>( + pub fn run<'s, 't, 'm, TIn, D>( &'s mut self, input_arrays: Vec>, - ) -> Result>> + ) -> Result>> where TIn: TypeToTensorElementDataType + Debug + Clone, - TOut: TensorDataToType, D: ndarray::Dimension, 'm: 't, // 'm outlives 't (memory info outlives tensor) 's: 'm, // 's outlives 'm (session outlives memory info) @@ -404,7 +400,7 @@ impl<'a> Session<'a> { .map(|n| n.as_ptr() as *const i8) .collect(); - let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = + let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()]; // The C API expects pointers for the arrays (pointers to C-arrays) @@ -430,38 +426,32 @@ impl<'a> Session<'a> { input_ort_values.len() as u64, // C API expects a u64, not isize output_names_ptr.as_ptr(), output_names_ptr.len() as u64, // C API expects a u64, not isize - output_tensor_extractors_ptrs.as_mut_ptr(), + output_tensor_ptrs.as_mut_ptr(), ) }; status_to_result(status).map_err(OrtError::Run)?; let memory_info_ref = &self.memory_info; - let outputs: Result>>> = - output_tensor_extractors_ptrs + let outputs: Result>>> = + output_tensor_ptrs .into_iter() .map(|tensor_ptr| { - let dims = unsafe { + let (dims, data_type) = unsafe { call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { get_tensor_dimensions(tensor_info_ptr) .map(|dims| dims.iter().map(|&n| n as usize).collect::>()) + .and_then(|dims| { + extract_data_type(tensor_info_ptr) + .map(|data_type| (dims, data_type)) + }) }) }?; - // Note: Both tensor and array will point to the same data, nothing is copied. - // As such, there is no need to free the pointer used to create the ArrayView. - assert_ne!(tensor_ptr, std::ptr::null_mut()); - - let mut is_tensor = 0; - unsafe { call_ort(|ort| ort.IsTensor.unwrap()(tensor_ptr, &mut is_tensor)) } - .map_err(OrtError::IsTensor)?; - assert_eq!(is_tensor, 1); - - let array_view = TOut::extract_array(ndarray::IxDyn(&dims), tensor_ptr)?; - - Ok(OrtOwnedTensor::new( + Ok(DynOrtTensor::new( tensor_ptr, - array_view, - &memory_info_ref, + memory_info_ref, + ndarray::IxDyn(&dims), + data_type, )) }) .collect(); diff --git a/onnxruntime/src/tensor.rs b/onnxruntime/src/tensor.rs index a5178c91..df85e1ed 100644 --- a/onnxruntime/src/tensor.rs +++ b/onnxruntime/src/tensor.rs @@ -27,7 +27,7 @@ pub mod ndarray_tensor; pub mod ort_owned_tensor; pub mod ort_tensor; -pub use ort_owned_tensor::OrtOwnedTensor; +pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor}; pub use ort_tensor::OrtTensor; use crate::{OrtError, Result}; diff --git a/onnxruntime/src/tensor/ort_owned_tensor.rs b/onnxruntime/src/tensor/ort_owned_tensor.rs index bfff3dea..48f48308 100644 --- a/onnxruntime/src/tensor/ort_owned_tensor.rs +++ b/onnxruntime/src/tensor/ort_owned_tensor.rs @@ -1,14 +1,122 @@ //! Module containing tensor with memory owned by the ONNX Runtime -use std::{fmt::Debug, ops::Deref}; +use std::{fmt::Debug, ops::Deref, ptr, rc, result}; use ndarray::{Array, ArrayView}; +use thiserror::Error; use tracing::debug; use onnxruntime_sys as sys; -use crate::tensor::TensorDataToType; -use crate::{g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor}; +use crate::{ + error::call_ort, + g_ort, + memory::MemoryInfo, + tensor::{ndarray_tensor::NdArrayTensor, TensorDataToType, TensorElementDataType}, + OrtError, +}; + +/// Errors that can occur while extracting a tensor from ort output. +#[derive(Error, Debug)] +pub enum TensorExtractError { + /// The user tried to extract the wrong type of tensor from the underlying data + #[error( + "Data type mismatch: was {:?}, tried to convert to {:?}", + actual, + requested + )] + DataTypeMismatch { + /// The actual type of the ort output + actual: TensorElementDataType, + /// The type corresponding to the attempted conversion into a Rust type, not equal to `actual` + requested: TensorElementDataType, + }, + /// An onnxruntime error occurred + #[error("Onnxruntime error: {:?}", 0)] + OrtError(#[from] OrtError), +} + +/// A wrapper around a tensor produced by onnxruntime inference. +/// +/// Since different outputs for the same model can have different types, this type is used to allow +/// the user to dynamically query each output's type and extract the appropriate tensor type with +/// [try_extract]. +#[derive(Debug)] +pub struct DynOrtTensor<'m, D> +where + D: ndarray::Dimension, +{ + tensor_ptr_holder: rc::Rc, + memory_info: &'m MemoryInfo, + shape: D, + data_type: TensorElementDataType, +} + +impl<'m, D> DynOrtTensor<'m, D> +where + D: ndarray::Dimension, +{ + pub(crate) fn new( + tensor_ptr: *mut sys::OrtValue, + memory_info: &'m MemoryInfo, + shape: D, + data_type: TensorElementDataType, + ) -> DynOrtTensor<'m, D> { + DynOrtTensor { + tensor_ptr_holder: rc::Rc::from(TensorPointerDropper { tensor_ptr }), + memory_info, + shape, + data_type, + } + } + + /// The ONNX data type this tensor contains. + pub fn data_type(&self) -> TensorElementDataType { + self.data_type + } + + /// Extract a tensor containing `T`. + /// + /// Where the type permits it, the tensor will be a view into existing memory. + /// + /// # Errors + /// + /// An error will be returned if `T`'s ONNX type doesn't match this tensor's type, or if an + /// onnxruntime error occurs. + pub fn try_extract<'t, T>(&self) -> result::Result, TensorExtractError> + where + T: TensorDataToType + Clone + Debug, + 'm: 't, // mem info outlives tensor + { + if self.data_type != T::tensor_element_data_type() { + Err(TensorExtractError::DataTypeMismatch { + actual: self.data_type, + requested: T::tensor_element_data_type(), + }) + } else { + // Note: Both tensor and array will point to the same data, nothing is copied. + // As such, there is no need to free the pointer used to create the ArrayView. + assert_ne!(self.tensor_ptr_holder.tensor_ptr, ptr::null_mut()); + + let mut is_tensor = 0; + unsafe { + call_ort(|ort| { + ort.IsTensor.unwrap()(self.tensor_ptr_holder.tensor_ptr, &mut is_tensor) + }) + } + .map_err(OrtError::IsTensor)?; + assert_eq!(is_tensor, 1); + + let array_view = + T::extract_array(self.shape.clone(), self.tensor_ptr_holder.tensor_ptr)?; + + Ok(OrtOwnedTensor::new( + self.tensor_ptr_holder.clone(), + array_view, + )) + } + } +} /// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference. /// @@ -21,18 +129,17 @@ use crate::{g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor}; /// `OrtOwnedTensor` implements the [`std::deref::Deref`](#impl-Deref) trait for ergonomic access to /// the underlying [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). #[derive(Debug)] -pub struct OrtOwnedTensor<'t, 'm, T, D> +pub struct OrtOwnedTensor<'t, T, D> where T: TensorDataToType, D: ndarray::Dimension, - 'm: 't, // 'm outlives 't { - pub(crate) tensor_ptr: *mut sys::OrtValue, + /// Keep the pointer alive + tensor_ptr_holder: rc::Rc, array_view: ArrayView<'t, T, D>, - memory_info: &'m MemoryInfo, } -impl<'t, 'm, T, D> Deref for OrtOwnedTensor<'t, 'm, T, D> +impl<'t, T, D> Deref for OrtOwnedTensor<'t, T, D> where T: TensorDataToType, D: ndarray::Dimension, @@ -44,20 +151,18 @@ where } } -impl<'t, 'm, T, D> OrtOwnedTensor<'t, 'm, T, D> +impl<'t, T, D> OrtOwnedTensor<'t, T, D> where T: TensorDataToType, D: ndarray::Dimension, { pub(crate) fn new( - tensor_ptr: *mut sys::OrtValue, + tensor_ptr_holder: rc::Rc, array_view: ArrayView<'t, T, D>, - memory_info: &'m MemoryInfo, - ) -> OrtOwnedTensor<'t, 'm, T, D> { + ) -> OrtOwnedTensor<'t, T, D> { OrtOwnedTensor { - tensor_ptr, + tensor_ptr_holder, array_view, - memory_info, } } @@ -71,17 +176,23 @@ where } } -impl<'t, 'm, T, D> Drop for OrtOwnedTensor<'t, 'm, T, D> -where - T: TensorDataToType, - D: ndarray::Dimension, - 'm: 't, // 'm outlives 't -{ +/// Holds on to a tensor pointer until dropped. +/// +/// This allows creating an [OrtOwnedTensor] from a [DynOrtTensor] without consuming `self`, which +/// would prevent retrying extraction and also make interacting with outputs `Vec` awkward. +/// It also avoids needing `OrtOwnedTensor` to keep a reference to `DynOrtTensor`, which would be +/// inconvenient. +#[derive(Debug)] +pub(crate) struct TensorPointerDropper { + tensor_ptr: *mut sys::OrtValue, +} + +impl Drop for TensorPointerDropper { #[tracing::instrument] fn drop(&mut self) { debug!("Dropping OrtOwnedTensor."); unsafe { g_ort().ReleaseValue.unwrap()(self.tensor_ptr) } - self.tensor_ptr = std::ptr::null_mut(); + self.tensor_ptr = ptr::null_mut(); } } diff --git a/onnxruntime/tests/integration_tests.rs b/onnxruntime/tests/integration_tests.rs index ee531feb..2a2ea164 100644 --- a/onnxruntime/tests/integration_tests.rs +++ b/onnxruntime/tests/integration_tests.rs @@ -15,6 +15,7 @@ mod download { use onnxruntime::{ download::vision::{DomainBasedImageClassification, ImageClassification}, environment::Environment, + tensor::{DynOrtTensor, OrtOwnedTensor}, GraphOptimizationLevel, LoggingLevel, }; @@ -93,13 +94,13 @@ mod download { let input_tensor_values = vec![array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor>, - > = session.run(input_tensor_values).unwrap(); + let outputs: Vec>> = + session.run(input_tensor_values).unwrap(); // Downloaded model does not have a softmax as final layer; call softmax on second axis // and iterate on resulting probabilities, creating an index to later access labels. - let mut probabilities: Vec<(usize, f32)> = outputs[0] + let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -184,11 +185,11 @@ mod download { let input_tensor_values = vec![array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor>, - > = session.run(input_tensor_values).unwrap(); + let outputs: Vec>> = + session.run(input_tensor_values).unwrap(); - let mut probabilities: Vec<(usize, f32)> = outputs[0] + let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -282,12 +283,12 @@ mod download { let input_tensor_values = vec![array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor>, - > = session.run(input_tensor_values).unwrap(); + let outputs: Vec>> = + session.run(input_tensor_values).unwrap(); assert_eq!(outputs.len(), 1); - let output = &outputs[0]; + let output: OrtOwnedTensor<'_, f32, ndarray::Dim> = + outputs[0].try_extract().unwrap(); // The image should have doubled in size assert_eq!(output.shape(), [1, 448, 448, 3]); diff --git a/onnxruntime/tests/string_type.rs b/onnxruntime/tests/string_type.rs new file mode 100644 index 00000000..fe4c0da9 --- /dev/null +++ b/onnxruntime/tests/string_type.rs @@ -0,0 +1,48 @@ +use std::error::Error; + +use ndarray; +use onnxruntime::tensor::{OrtOwnedTensor, TensorElementDataType}; +use onnxruntime::{environment::Environment, tensor::DynOrtTensor, LoggingLevel}; + +#[test] +fn run_model_with_string_input_output() -> Result<(), Box> { + let environment = Environment::builder() + .with_name("test") + .with_log_level(LoggingLevel::Verbose) + .build()?; + + let mut session = environment + .new_session_builder()? + .with_model_from_file("../test-models/tensorflow/unique_model.onnx")?; + + // Inputs: + // 0: + // name = input_1:0 + // type = String + // dimensions = [None] + // Outputs: + // 0: + // name = Identity:0 + // type = Int32 + // dimensions = [None] + // 1: + // name = Identity_1:0 + // type = String + // dimensions = [None] + + let array = ndarray::Array::from(vec!["foo", "bar", "foo", "foo"]); + let input_tensor_values = vec![array]; + + let outputs: Vec> = session.run(input_tensor_values)?; + + assert_eq!(TensorElementDataType::Int32, outputs[0].data_type()); + assert_eq!(TensorElementDataType::String, outputs[1].data_type()); + + let int_output: OrtOwnedTensor = outputs[0].try_extract()?; + + assert_eq!(&[0, 1, 0, 0], int_output.as_slice().unwrap()); + + // TODO get the string output once string extraction is implemented + + Ok(()) +} diff --git a/test-models/tensorflow/.gitignore b/test-models/tensorflow/.gitignore new file mode 100644 index 00000000..aea6a084 --- /dev/null +++ b/test-models/tensorflow/.gitignore @@ -0,0 +1,2 @@ +/Pipfile.lock +/models diff --git a/test-models/tensorflow/Pipfile b/test-models/tensorflow/Pipfile new file mode 100644 index 00000000..a7b370ab --- /dev/null +++ b/test-models/tensorflow/Pipfile @@ -0,0 +1,13 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] +tensorflow = "==2.4.1" +tf2onnx = "==1.8.3" + +[dev-packages] + +[requires] +python_version = "3.8" diff --git a/test-models/tensorflow/README.md b/test-models/tensorflow/README.md new file mode 100644 index 00000000..4f2e68f2 --- /dev/null +++ b/test-models/tensorflow/README.md @@ -0,0 +1,18 @@ +# Setup + +Have Pipenv make the virtualenv for you: + +``` +pipenv install +``` + +# Model: Unique + +A TensorFlow model that removes duplicate tensor elements. + +This supports strings, and doesn't require custom operators. + +``` +pipenv run python src/unique_model.py +pipenv run python -m tf2onnx.convert --saved-model models/unique_model --output unique_model.onnx --opset 11 +``` diff --git a/test-models/tensorflow/src/unique_model.py b/test-models/tensorflow/src/unique_model.py new file mode 100644 index 00000000..fb79dc8b --- /dev/null +++ b/test-models/tensorflow/src/unique_model.py @@ -0,0 +1,19 @@ +import tensorflow as tf +import numpy as np +import tf2onnx + + +class UniqueModel(tf.keras.Model): + + def __init__(self, name='model1', **kwargs): + super(UniqueModel, self).__init__(name=name, **kwargs) + + def call(self, inputs): + return tf.unique(inputs) + + +model1 = UniqueModel() + +print(model1(tf.constant(["foo", "bar", "foo", "baz"]))) + +model1.save("models/unique_model") diff --git a/test-models/tensorflow/unique_model.onnx b/test-models/tensorflow/unique_model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..320f620042f475636b519f6647861c249970b126 GIT binary patch literal 424 zcmd;J6Jjq(Gs@4)tB_(f)U(txu#n#{<>nVDBmS`u$)Wgx`knUb1Ul37v-=E(;n z76Cb#`FW`+&WSlW`nmZjsX2!Fp?R5wrKwhiLIe~TNy*|hL5U4wk`^x)TXB9-NotA& z!vaPoD+aDKE<#2NalxEwWgw-7*P8fvL&Nyw#NrYq7H1%-#mU7~k}tsqvxAZeOjzv3ifF{m=`6798!W@{2UxY++3VO?4^0xKr4)s Date: Mon, 8 Mar 2021 11:54:50 -0700 Subject: [PATCH 3/6] Support the onnx string type in output tensors This approach allocates owned Strings for each element, which works, but stresses the allocator, and incurs unnecessary copying. Part of the complication stems from the limitation that in Rust, a field can't be a reference to another field in the same struct. This means that having a Vec of copied data, referred to by a Vec<&str>, which is then referred to by an ArrayView, requires a sequence of 3 structs to express. Building a Vec gets rid of the references, but also loses the efficiency of 1 allocation with strs pointing into it. --- onnxruntime/examples/sample.rs | 4 +- onnxruntime/src/error.rs | 14 ++- onnxruntime/src/session.rs | 23 +++- onnxruntime/src/tensor.rs | 119 +++++++++++++++++++-- onnxruntime/src/tensor/ort_owned_tensor.rs | 105 +++++++++++------- onnxruntime/tests/integration_tests.rs | 11 +- onnxruntime/tests/string_type.rs | 21 ++-- 7 files changed, 236 insertions(+), 61 deletions(-) diff --git a/onnxruntime/examples/sample.rs b/onnxruntime/examples/sample.rs index 3fbc2670..5563ab4b 100644 --- a/onnxruntime/examples/sample.rs +++ b/onnxruntime/examples/sample.rs @@ -66,9 +66,9 @@ fn run() -> Result<(), Error> { let outputs: Vec> = session.run(input_tensor_values)?; let output: OrtOwnedTensor = outputs[0].try_extract().unwrap(); - assert_eq!(output.shape(), output0_shape.as_slice()); + assert_eq!(output.view().shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, output.view()[[0, i, 0, 0]]); } Ok(()) diff --git a/onnxruntime/src/error.rs b/onnxruntime/src/error.rs index f49613fe..86280f74 100644 --- a/onnxruntime/src/error.rs +++ b/onnxruntime/src/error.rs @@ -1,6 +1,6 @@ //! Module containing error definitions. -use std::{io, path::PathBuf}; +use std::{io, path::PathBuf, string}; use thiserror::Error; @@ -53,6 +53,12 @@ pub enum OrtError { /// Error occurred when getting ONNX dimensions #[error("Failed to get dimensions: {0}")] GetDimensions(OrtApiError), + /// Error occurred when getting string length + #[error("Failed to get string tensor length: {0}")] + GetStringTensorDataLength(OrtApiError), + /// Error occurred when getting tensor element count + #[error("Failed to get tensor element count: {0}")] + GetTensorShapeElementCount(OrtApiError), /// Error occurred when creating CPU memory information #[error("Failed to get dimensions: {0}")] CreateCpuMemoryInfo(OrtApiError), @@ -77,6 +83,12 @@ pub enum OrtError { /// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView` #[error("Failed to get tensor data: {0}")] GetTensorMutableData(OrtApiError), + /// Error occurred when extracting string data from an ONNX tensor + #[error("Failed to get tensor string data: {0}")] + GetStringTensorContent(OrtApiError), + /// Error occurred when converting data to a String + #[error("Data was not UTF-8: {0}")] + StringFromUtf8Error(#[from] string::FromUtf8Error), /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models) #[error("Failed to download ONNX model: {0}")] diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 232d188d..98099fd6 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -1,6 +1,6 @@ //! Module containing session types -use std::{ffi::CString, fmt::Debug, path::Path}; +use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path}; #[cfg(not(target_family = "windows"))] use std::os::unix::ffi::OsStrExt; @@ -436,7 +436,7 @@ impl<'a> Session<'a> { output_tensor_ptrs .into_iter() .map(|tensor_ptr| { - let (dims, data_type) = unsafe { + let (dims, data_type, len) = unsafe { call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { get_tensor_dimensions(tensor_info_ptr) .map(|dims| dims.iter().map(|&n| n as usize).collect::>()) @@ -444,6 +444,24 @@ impl<'a> Session<'a> { extract_data_type(tensor_info_ptr) .map(|data_type| (dims, data_type)) }) + .and_then(|(dims, data_type)| { + let mut len = 0_u64; + + call_ort(|ort| { + ort.GetTensorShapeElementCount.unwrap()( + tensor_info_ptr, + &mut len, + ) + }) + .map_err(OrtError::GetTensorShapeElementCount)?; + + Ok(( + dims, + data_type, + len.try_into() + .expect("u64 length could not fit into usize"), + )) + }) }) }?; @@ -451,6 +469,7 @@ impl<'a> Session<'a> { tensor_ptr, memory_info_ref, ndarray::IxDyn(&dims), + len, data_type, )) }) diff --git a/onnxruntime/src/tensor.rs b/onnxruntime/src/tensor.rs index df85e1ed..74e8329c 100644 --- a/onnxruntime/src/tensor.rs +++ b/onnxruntime/src/tensor.rs @@ -30,9 +30,10 @@ pub mod ort_tensor; pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor}; pub use ort_tensor::OrtTensor; -use crate::{OrtError, Result}; +use crate::tensor::ort_owned_tensor::TensorPointerHolder; +use crate::{error::call_ort, OrtError, Result}; use onnxruntime_sys::{self as sys, OnnxEnumInt}; -use std::{fmt, ptr}; +use std::{convert::TryInto as _, ffi, fmt, ptr, rc, result, string}; // FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum // FIXME: Add tests to cover the commented out types @@ -188,14 +189,41 @@ pub trait TensorDataToType: Sized + fmt::Debug { fn tensor_element_data_type() -> TensorElementDataType; /// Extract an `ArrayView` from the ort-owned tensor. - fn extract_array<'t, D>( + fn extract_data<'t, D>( shape: D, - tensor: *mut sys::OrtValue, - ) -> Result> + tensor_element_len: usize, + tensor_ptr: rc::Rc, + ) -> Result> where D: ndarray::Dimension; } +/// Represents the possible ways tensor data can be accessed. +/// +/// This should only be used internally. +#[derive(Debug)] +pub enum TensorData<'t, T, D> +where + D: ndarray::Dimension, +{ + /// Data resides in ort's tensor, in which case the 't lifetime is what makes this valid. + /// This is used for data types whose in-memory form from ort is compatible with Rust's, like + /// primitive numeric types. + TensorPtr { + /// The pointer ort produced. Kept alive so that `array_view` is valid. + ptr: rc::Rc, + /// A view into `ptr` + array_view: ndarray::ArrayView<'t, T, D>, + }, + /// String data is output differently by ort, and of course is also variable size, so it cannot + /// use the same simple pointer representation. + // Since 't outlives this struct, the 't lifetime is more than we need, but no harm done. + Strings { + /// Owned Strings copied out of ort's output + strings: ndarray::Array, + }, +} + /// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData` macro_rules! impl_prim_type_from_ort_trait { ($type_:ty, $variant:ident) => { @@ -204,14 +232,20 @@ macro_rules! impl_prim_type_from_ort_trait { TensorElementDataType::$variant } - fn extract_array<'t, D>( + fn extract_data<'t, D>( shape: D, - tensor: *mut sys::OrtValue, - ) -> Result> + _tensor_element_len: usize, + tensor_ptr: rc::Rc, + ) -> Result> where D: ndarray::Dimension, { - extract_primitive_array(shape, tensor) + extract_primitive_array(shape, tensor_ptr.tensor_ptr).map(|v| { + TensorData::TensorPtr { + ptr: tensor_ptr, + array_view: v, + } + }) } } }; @@ -255,3 +289,70 @@ impl_prim_type_from_ort_trait!(i64, Int64); impl_prim_type_from_ort_trait!(f64, Double); impl_prim_type_from_ort_trait!(u32, Uint32); impl_prim_type_from_ort_trait!(u64, Uint64); + +impl TensorDataToType for String { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::String + } + + fn extract_data<'t, D: ndarray::Dimension>( + shape: D, + tensor_element_len: usize, + tensor_ptr: rc::Rc, + ) -> Result> { + // Total length of string data, not including \0 suffix + let mut total_length = 0_u64; + unsafe { + call_ort(|ort| { + ort.GetStringTensorDataLength.unwrap()(tensor_ptr.tensor_ptr, &mut total_length) + }) + .map_err(OrtError::GetStringTensorDataLength)? + } + + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0_u8; total_length as usize]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0_u64; tensor_element_len as usize + 1]; + + unsafe { + call_ort(|ort| { + ort.GetStringTensorContent.unwrap()( + tensor_ptr.tensor_ptr, + string_contents.as_mut_ptr() as *mut ffi::c_void, + total_length, + offsets.as_mut_ptr(), + tensor_element_len as u64, + ) + }) + .map_err(OrtError::GetStringTensorContent)? + } + + // final offset = overall length so that per-string length calculations work for the last + // string + debug_assert_eq!(0, offsets[tensor_element_len]); + offsets[tensor_element_len] = total_length; + + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let start: usize = w[0].try_into().expect("Offset didn't fit into usize"); + let next_start: usize = w[1].try_into().expect("Offset didn't fit into usize"); + + let slice = &string_contents[start..next_start]; + String::from_utf8(slice.into()) + }) + .collect::, string::FromUtf8Error>>() + .map_err(OrtError::StringFromUtf8Error)?; + + let array = ndarray::Array::from_shape_vec(shape, strings) + .expect("Shape extracted from tensor didn't match tensor contents"); + + Ok(TensorData::Strings { strings: array }) + } +} diff --git a/onnxruntime/src/tensor/ort_owned_tensor.rs b/onnxruntime/src/tensor/ort_owned_tensor.rs index 48f48308..f782df1b 100644 --- a/onnxruntime/src/tensor/ort_owned_tensor.rs +++ b/onnxruntime/src/tensor/ort_owned_tensor.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, ops::Deref, ptr, rc, result}; -use ndarray::{Array, ArrayView}; +use ndarray::ArrayView; use thiserror::Error; use tracing::debug; @@ -12,7 +12,7 @@ use crate::{ error::call_ort, g_ort, memory::MemoryInfo, - tensor::{ndarray_tensor::NdArrayTensor, TensorDataToType, TensorElementDataType}, + tensor::{TensorData, TensorDataToType, TensorElementDataType}, OrtError, }; @@ -46,9 +46,12 @@ pub struct DynOrtTensor<'m, D> where D: ndarray::Dimension, { - tensor_ptr_holder: rc::Rc, + // TODO could this also hold a Vec for strings so that the extracted tensor could then + // hold a Vec<&str>? + tensor_ptr_holder: rc::Rc, memory_info: &'m MemoryInfo, shape: D, + tensor_element_len: usize, data_type: TensorElementDataType, } @@ -60,12 +63,14 @@ where tensor_ptr: *mut sys::OrtValue, memory_info: &'m MemoryInfo, shape: D, + tensor_element_len: usize, data_type: TensorElementDataType, ) -> DynOrtTensor<'m, D> { DynOrtTensor { - tensor_ptr_holder: rc::Rc::from(TensorPointerDropper { tensor_ptr }), + tensor_ptr_holder: rc::Rc::from(TensorPointerHolder { tensor_ptr }), memory_info, shape, + tensor_element_len, data_type, } } @@ -87,6 +92,8 @@ where where T: TensorDataToType + Clone + Debug, 'm: 't, // mem info outlives tensor + D: 't, // not clear why this is needed since we clone the shape, but it doesn't make + // a difference in practice since the shape is extracted from the tensor { if self.data_type != T::tensor_element_data_type() { Err(TensorExtractError::DataTypeMismatch { @@ -107,13 +114,13 @@ where .map_err(OrtError::IsTensor)?; assert_eq!(is_tensor, 1); - let array_view = - T::extract_array(self.shape.clone(), self.tensor_ptr_holder.tensor_ptr)?; + let data = T::extract_data( + self.shape.clone(), + self.tensor_element_len, + rc::Rc::clone(&self.tensor_ptr_holder), + )?; - Ok(OrtOwnedTensor::new( - self.tensor_ptr_holder.clone(), - array_view, - )) + Ok(OrtOwnedTensor { data }) } } } @@ -134,45 +141,69 @@ where T: TensorDataToType, D: ndarray::Dimension, { - /// Keep the pointer alive - tensor_ptr_holder: rc::Rc, - array_view: ArrayView<'t, T, D>, + data: TensorData<'t, T, D>, } -impl<'t, T, D> Deref for OrtOwnedTensor<'t, T, D> +impl<'t, T, D> OrtOwnedTensor<'t, T, D> where T: TensorDataToType, - D: ndarray::Dimension, + D: ndarray::Dimension + 't, { - type Target = ArrayView<'t, T, D>; - - fn deref(&self) -> &Self::Target { - &self.array_view + /// Produce a [ViewHolder] for the underlying data, which + pub fn view<'s>(&'s self) -> ViewHolder<'s, T, D> + where + 't: 's, // tensor ptr can outlive the TensorData + { + ViewHolder::new(&self.data) } } -impl<'t, T, D> OrtOwnedTensor<'t, T, D> +/// An intermediate step on the way to an ArrayView. +// Since Deref has to produce a reference, and the referent can't be a local in deref(), it must +// be a field in a struct. This struct exists only to hold that field. +// Its lifetime 's is bound to the TensorData its view was created around, not the underlying tensor +// pointer, since in the case of strings the data is the Array in the TensorData, not the pointer. +pub struct ViewHolder<'s, T, D> where T: TensorDataToType, D: ndarray::Dimension, { - pub(crate) fn new( - tensor_ptr_holder: rc::Rc, - array_view: ArrayView<'t, T, D>, - ) -> OrtOwnedTensor<'t, T, D> { - OrtOwnedTensor { - tensor_ptr_holder, - array_view, - } - } + array_view: ndarray::ArrayView<'s, T, D>, +} - /// Apply a softmax on the specified axis - pub fn softmax(&self, axis: ndarray::Axis) -> Array +impl<'s, T, D> ViewHolder<'s, T, D> +where + T: TensorDataToType, + D: ndarray::Dimension, +{ + fn new<'t>(data: &'s TensorData<'t, T, D>) -> ViewHolder<'s, T, D> where - D: ndarray::RemoveAxis, - T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign, + 't: 's, // underlying tensor ptr lives at least as long as TensorData { - self.array_view.softmax(axis) + match data { + TensorData::TensorPtr { array_view, .. } => ViewHolder { + // we already have a view, but creating a view from a view is cheap + array_view: array_view.view(), + }, + TensorData::Strings { strings } => ViewHolder { + // This view creation has to happen here, not at new()'s callsite, because + // a field can't be a reference to another field in the same struct. Thus, we have + // this separate struct to hold the view that refers to the `Array`. + array_view: strings.view(), + }, + } + } +} + +impl<'t, T, D> Deref for ViewHolder<'t, T, D> +where + T: TensorDataToType, + D: ndarray::Dimension, +{ + type Target = ArrayView<'t, T, D>; + + fn deref(&self) -> &Self::Target { + &self.array_view } } @@ -183,11 +214,11 @@ where /// It also avoids needing `OrtOwnedTensor` to keep a reference to `DynOrtTensor`, which would be /// inconvenient. #[derive(Debug)] -pub(crate) struct TensorPointerDropper { - tensor_ptr: *mut sys::OrtValue, +pub struct TensorPointerHolder { + pub(crate) tensor_ptr: *mut sys::OrtValue, } -impl Drop for TensorPointerDropper { +impl Drop for TensorPointerHolder { #[tracing::instrument] fn drop(&mut self) { debug!("Dropping OrtOwnedTensor."); diff --git a/onnxruntime/tests/integration_tests.rs b/onnxruntime/tests/integration_tests.rs index 2a2ea164..c332e7ce 100644 --- a/onnxruntime/tests/integration_tests.rs +++ b/onnxruntime/tests/integration_tests.rs @@ -12,6 +12,7 @@ mod download { use ndarray::s; use test_env_log::test; + use onnxruntime::tensor::ndarray_tensor::NdArrayTensor; use onnxruntime::{ download::vision::{DomainBasedImageClassification, ImageClassification}, environment::Environment, @@ -63,7 +64,7 @@ mod download { input0_shape[3] as u32, FilterType::Nearest, ) - .to_rgb(); + .to_rgb8(); // Python: // # image[y, x, RGB] @@ -101,6 +102,7 @@ mod download { // and iterate on resulting probabilities, creating an index to later access labels. let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -171,7 +173,7 @@ mod download { input0_shape[3] as u32, FilterType::Nearest, ) - .to_luma(); + .to_luma8(); let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { let pixel = image_buffer.get_pixel(i as u32, j as u32); @@ -190,6 +192,7 @@ mod download { let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -269,7 +272,7 @@ mod download { .join(IMAGE_TO_LOAD), ) .unwrap() - .to_rgb(); + .to_rgb8(); let array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| { let pixel = image_buffer.get_pixel(i as u32, j as u32); @@ -291,7 +294,7 @@ mod download { outputs[0].try_extract().unwrap(); // The image should have doubled in size - assert_eq!(output.shape(), [1, 448, 448, 3]); + assert_eq!(output.view().shape(), [1, 448, 448, 3]); } } diff --git a/onnxruntime/tests/string_type.rs b/onnxruntime/tests/string_type.rs index fe4c0da9..e07fe5f6 100644 --- a/onnxruntime/tests/string_type.rs +++ b/onnxruntime/tests/string_type.rs @@ -5,7 +5,7 @@ use onnxruntime::tensor::{OrtOwnedTensor, TensorElementDataType}; use onnxruntime::{environment::Environment, tensor::DynOrtTensor, LoggingLevel}; #[test] -fn run_model_with_string_input_output() -> Result<(), Box> { +fn run_model_with_string_1d_input_output() -> Result<(), Box> { let environment = Environment::builder() .with_name("test") .with_log_level(LoggingLevel::Verbose) @@ -30,7 +30,7 @@ fn run_model_with_string_input_output() -> Result<(), Box> { // type = String // dimensions = [None] - let array = ndarray::Array::from(vec!["foo", "bar", "foo", "foo"]); + let array = ndarray::Array::from(vec!["foo", "bar", "foo", "foo", "baz"]); let input_tensor_values = vec![array]; let outputs: Vec> = session.run(input_tensor_values)?; @@ -39,10 +39,19 @@ fn run_model_with_string_input_output() -> Result<(), Box> { assert_eq!(TensorElementDataType::String, outputs[1].data_type()); let int_output: OrtOwnedTensor = outputs[0].try_extract()?; - - assert_eq!(&[0, 1, 0, 0], int_output.as_slice().unwrap()); - - // TODO get the string output once string extraction is implemented + let string_output: OrtOwnedTensor = outputs[1].try_extract()?; + + assert_eq!(&[5], int_output.view().shape()); + assert_eq!(&[3], string_output.view().shape()); + + assert_eq!(&[0, 1, 0, 0, 2], int_output.view().as_slice().unwrap()); + assert_eq!( + vec!["foo", "bar", "baz"] + .into_iter() + .map(|s| s.to_owned()) + .collect::>(), + string_output.view().as_slice().unwrap() + ); Ok(()) } From 495ecb4c685af06a136206bb753babd9554f8d10 Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Mon, 8 Mar 2021 11:54:50 -0700 Subject: [PATCH 4/6] Support the onnx string type in output tensors This approach allocates owned Strings for each element, which works, but stresses the allocator, and incurs unnecessary copying. Part of the complication stems from the limitation that in Rust, a field can't be a reference to another field in the same struct. This means that having a Vec of copied data, referred to by a Vec<&str>, which is then referred to by an ArrayView, requires a sequence of 3 structs to express. Building a Vec gets rid of the references, but also loses the efficiency of 1 allocation with strs pointing into it. --- onnxruntime/tests/integration_tests.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/tests/integration_tests.rs b/onnxruntime/tests/integration_tests.rs index c332e7ce..5152532c 100644 --- a/onnxruntime/tests/integration_tests.rs +++ b/onnxruntime/tests/integration_tests.rs @@ -19,6 +19,7 @@ mod download { tensor::{DynOrtTensor, OrtOwnedTensor}, GraphOptimizationLevel, LoggingLevel, }; + use onnxruntime::tensor::ndarray_tensor::NdArrayTensor; #[test] fn squeezenet_mushroom() { From 6cf6d0e17c6a6b18c48fe05cb26499c46f65d9b2 Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Thu, 18 Feb 2021 14:03:47 -0700 Subject: [PATCH 5/6] Add support for registring custom op libraries --- .dockerignore | 9 ++ CHANGELOG.md | 1 + Dockerfile | 118 ++++++++++++++++++++++ onnxruntime/Cargo.toml | 6 ++ onnxruntime/src/session.rs | 55 +++++++++- onnxruntime/tests/custom_ops.rs | 51 ++++++++++ test-models/tensorflow/README.md | 9 ++ test-models/tensorflow/regex_model.onnx | 19 ++++ test-models/tensorflow/src/regex_model.py | 19 ++++ 9 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 .dockerignore create mode 100644 Dockerfile create mode 100644 onnxruntime/tests/custom_ops.rs create mode 100644 test-models/tensorflow/regex_model.onnx create mode 100644 test-models/tensorflow/src/regex_model.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..1b6f6f4f --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +* +!/Cargo.* +!/onnxruntime/Cargo.toml +!/onnxruntime/src +!/onnxruntime/tests +!/onnxruntime-sys/Cargo.toml +!/onnxruntime-sys/build.rs +!/onnxruntime-sys/src +!/test-models/tensorflow/*.onnx diff --git a/CHANGELOG.md b/CHANGELOG.md index 24aef3bb..872cdc56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add `String` datatype ([#58](https://github.com/nbigaouette/onnxruntime-rs/pull/58)) +- Support custom operator libraries ## [0.0.11] - 2021-02-22 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..4b8eace7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,118 @@ +# onnxruntime requires execinfo.h to build, which only works on glibc-based systems, so alpine is out... +FROM debian:bullseye-slim as base + +RUN apt-get update && apt-get -y dist-upgrade + +FROM base AS onnxruntime + +RUN apt-get install -y \ + git \ + bash \ + python3 \ + cmake \ + git \ + build-essential \ + llvm \ + locales + +# onnxruntime built in tests need en_US.UTF-8 available +# Uncomment en_US.UTF-8, then generate +RUN sed -i 's/^# *\(en_US.UTF-8\)/\1/' /etc/locale.gen && locale-gen + +# build onnxruntime +RUN mkdir -p /opt/onnxruntime/tmp +# onnxruntime build relies on being in a git repo, so can't just get a tarball +# it's a big repo, so fetch shallowly +RUN cd /opt/onnxruntime/tmp && \ + git clone --recursive --depth 1 --shallow-submodules https://github.com/Microsoft/onnxruntime + +# use version that onnxruntime-sys expects +RUN cd /opt/onnxruntime/tmp/onnxruntime && \ + git fetch --depth 1 origin tag v1.6.0 && \ + git checkout v1.6.0 + +RUN /opt/onnxruntime/tmp/onnxruntime/build.sh --config RelWithDebInfo --build_shared_lib --parallel + +# Build ort-customops, linked against the onnxruntime built above. +# No tags / releases yet - that commit is from 2021-02-16 +RUN mkdir -p /opt/ort-customops/tmp && \ + cd /opt/ort-customops/tmp && \ + git clone --recursive https://github.com/microsoft/ort-customops.git && \ + cd ort-customops && \ + git checkout 92f6b51106c9e9143c452e537cb5e41d2dcaa266 + +RUN cd /opt/ort-customops/tmp/ort-customops && \ + ./build.sh -D ONNXRUNTIME_LIB_DIR=/opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo + + +# install rust toolchain +FROM base AS rust-toolchain + +ARG RUST_VERSION=1.50.0 + +RUN apt-get install -y \ + curl + +# install rust toolchain +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain $RUST_VERSION + +ENV PATH $PATH:/root/.cargo/bin + + +# build onnxruntime-rs +FROM rust-toolchain as onnxruntime-rs +# clang & llvm needed by onnxruntime-sys +RUN apt-get install -y \ + build-essential \ + llvm-dev \ + libclang-dev \ + clang + +RUN mkdir -p \ + /onnxruntime-rs/build/onnxruntime-sys/src/ \ + /onnxruntime-rs/build/onnxruntime/src/ \ + /onnxruntime-rs/build/onnxruntime/tests/ \ + /opt/onnxruntime/lib \ + /opt/ort-customops/lib + +COPY --from=onnxruntime /opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo/libonnxruntime.so /opt/onnxruntime/lib/ +COPY --from=onnxruntime /opt/ort-customops/tmp/ort-customops/out/Linux/libortcustomops.so /opt/ort-customops/lib/ + +WORKDIR /onnxruntime-rs/build + +ENV ORT_STRATEGY=system +# this has /lib/ appended to it and is used as a lib search path in onnxruntime-sys's build.rs +ENV ORT_LIB_LOCATION=/opt/onnxruntime/ + +ENV ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB=/opt/ort-customops/lib/libortcustomops.so + +# create enough of an empty project that dependencies can build +COPY /Cargo.lock /Cargo.toml /onnxruntime-rs/build/ +COPY /onnxruntime/Cargo.toml /onnxruntime-rs/build/onnxruntime/ +COPY /onnxruntime-sys/Cargo.toml /onnxruntime-sys/build.rs /onnxruntime-rs/build/onnxruntime-sys/ + +CMD cargo test + +# build dependencies and clean the bogus contents of our two packages +RUN touch \ + onnxruntime/src/lib.rs \ + onnxruntime/tests/integration_tests.rs \ + onnxruntime-sys/src/lib.rs \ + && cargo build --tests \ + && cargo clean --package onnxruntime-sys \ + && cargo clean --package onnxruntime \ + && rm -rf \ + onnxruntime/src/ \ + onnxruntime/tests/ \ + onnxruntime-sys/src/ + +# now build the actual source +COPY /test-models test-models +COPY /onnxruntime-sys/src onnxruntime-sys/src +COPY /onnxruntime/src onnxruntime/src +COPY /onnxruntime/tests onnxruntime/tests + +RUN ln -s /opt/onnxruntime/lib/libonnxruntime.so /opt/onnxruntime/lib/libonnxruntime.so.1.6.0 +ENV LD_LIBRARY_PATH=/opt/onnxruntime/lib + +RUN cargo build --tests diff --git a/onnxruntime/Cargo.toml b/onnxruntime/Cargo.toml index 9ceec820..88a0114c 100644 --- a/onnxruntime/Cargo.toml +++ b/onnxruntime/Cargo.toml @@ -26,6 +26,12 @@ ndarray = "0.13" thiserror = "1.0" tracing = "0.1" +[target.'cfg(unix)'.dependencies] +libc = "0.2.88" + +[target.'cfg(windows)'.dependencies] +winapi = { version = "0.3.9", features = ["std"] } + # Enabled with 'model-fetching' feature ureq = {version = "1.5.1", optional = true} diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 98099fd6..0e59ef6a 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -1,6 +1,6 @@ //! Module containing session types -use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path}; +use std::{convert::TryInto as _, ffi, ffi::CString, fmt::Debug, path::Path}; #[cfg(not(target_family = "windows"))] use std::os::unix::ffi::OsStrExt; @@ -64,11 +64,16 @@ pub struct SessionBuilder<'a> { allocator: AllocatorType, memory_type: MemType, + custom_runtime_handles: Vec<*mut ::std::os::raw::c_void>, } impl<'a> Drop for SessionBuilder<'a> { #[tracing::instrument] fn drop(&mut self) { + for &handle in self.custom_runtime_handles.iter() { + close_lib_handle(handle); + } + debug!("Dropping the session options."); assert_ne!(self.session_options_ptr, std::ptr::null_mut()); unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) }; @@ -89,6 +94,7 @@ impl<'a> SessionBuilder<'a> { session_options_ptr, allocator: AllocatorType::Arena, memory_type: MemType::Default, + custom_runtime_handles: Vec::new(), }) } @@ -136,6 +142,39 @@ impl<'a> SessionBuilder<'a> { Ok(self) } + /// Registers a custom ops library with the given library path in the session. + pub fn with_custom_op_lib(mut self, lib_path: &str) -> Result> { + let path_cstr = ffi::CString::new(lib_path)?; + + let mut handle: *mut ::std::os::raw::c_void = std::ptr::null_mut(); + + let status = unsafe { + g_ort().RegisterCustomOpsLibrary.unwrap()( + self.session_options_ptr, + path_cstr.as_ptr(), + &mut handle, + ) + }; + + // per RegisterCustomOpsLibrary docs, release handle if there was an error and the handle + // is non-null + match status_to_result(status).map_err(OrtError::SessionOptions) { + Ok(_) => {} + Err(e) => { + if handle != std::ptr::null_mut() { + // handle was written to, should release it + close_lib_handle(handle); + } + + return Err(e); + } + } + + self.custom_runtime_handles.push(handle); + + Ok(self) + } + /// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session #[cfg(feature = "model-fetching")] pub fn with_model_downloaded(self, model: M) -> Result> @@ -619,6 +658,20 @@ where res } +#[cfg(unix)] +fn close_lib_handle(handle: *mut ::std::os::raw::c_void) { + unsafe { + libc::dlclose(handle); + } +} + +#[cfg(windows)] +fn close_lib_handle(handle: *mut ::std::os::raw::c_void) { + unsafe { + winapi::um::libloaderapi::FreeLibrary(handle as winapi::shared::minwindef::HINSTANCE) + }; +} + /// This module contains dangerous functions working on raw pointers. /// Those functions are only to be used from inside the /// `SessionBuilder::with_model_from_file()` method. diff --git a/onnxruntime/tests/custom_ops.rs b/onnxruntime/tests/custom_ops.rs new file mode 100644 index 00000000..c4d62c46 --- /dev/null +++ b/onnxruntime/tests/custom_ops.rs @@ -0,0 +1,51 @@ +use std::error::Error; + +use ndarray; +use onnxruntime::tensor::{DynOrtTensor, OrtOwnedTensor}; +use onnxruntime::{environment::Environment, LoggingLevel}; + +#[test] +fn run_model_with_ort_customops() -> Result<(), Box> { + let lib_path = match std::env::var("ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB") { + Ok(s) => s, + Err(_e) => { + println!("Skipping ort_customops test -- no lib specified"); + return Ok(()); + } + }; + + let environment = Environment::builder() + .with_name("test") + .with_log_level(LoggingLevel::Verbose) + .build()?; + + let mut session = environment + .new_session_builder()? + .with_custom_op_lib(&lib_path)? + .with_model_from_file("../test-models/tensorflow/regex_model.onnx")?; + + //Inputs: + // 0: + // name = input_1:0 + // type = String + // dimensions = [None] + // Outputs: + // 0: + // name = Identity:0 + // type = String + // dimensions = [None] + + let array = ndarray::Array::from(vec![String::from("Hello world!")]); + let input_tensor_values = vec![array]; + + let outputs: Vec> = session.run(input_tensor_values)?; + let strings: OrtOwnedTensor = outputs[0].try_extract()?; + + // ' ' replaced with '_' + assert_eq!( + &[String::from("Hello_world!")], + strings.view().as_slice().unwrap() + ); + + Ok(()) +} diff --git a/test-models/tensorflow/README.md b/test-models/tensorflow/README.md index 4f2e68f2..6421fb6e 100644 --- a/test-models/tensorflow/README.md +++ b/test-models/tensorflow/README.md @@ -16,3 +16,12 @@ This supports strings, and doesn't require custom operators. pipenv run python src/unique_model.py pipenv run python -m tf2onnx.convert --saved-model models/unique_model --output unique_model.onnx --opset 11 ``` + +# Model: Regex (uses `ort_customops`) + +A TensorFlow model that applies a regex, which requires the onnxruntime custom ops in `ort-customops`. + +``` +pipenv run python src/regex_model.py +pipenv run python -m tf2onnx.convert --saved-model models/regex_model --output regex_model.onnx --extra_opset ai.onnx.contrib:1 +``` diff --git a/test-models/tensorflow/regex_model.onnx b/test-models/tensorflow/regex_model.onnx new file mode 100644 index 00000000..3b4390df --- /dev/null +++ b/test-models/tensorflow/regex_model.onnx @@ -0,0 +1,19 @@ +tf2onnx1.9.0:� + + input_1:0 + +pattern__7 + +rewrite__8 +Identity:0)PartitionedCall/model1/StaticRegexReplace"StringRegexReplace:ai.onnx.contribtf2onnx*2_B +rewrite__8*2 B +pattern__7R!converted from models/regex_modelZ + input_1:0 + + +unk__9b + +Identity:0 + + unk__10B B +ai.onnx.contrib \ No newline at end of file diff --git a/test-models/tensorflow/src/regex_model.py b/test-models/tensorflow/src/regex_model.py new file mode 100644 index 00000000..5958a631 --- /dev/null +++ b/test-models/tensorflow/src/regex_model.py @@ -0,0 +1,19 @@ +import tensorflow as tf +import numpy as np +import tf2onnx + + +class RegexModel(tf.keras.Model): + + def __init__(self, name='model1', **kwargs): + super(RegexModel, self).__init__(name=name, **kwargs) + + def call(self, inputs): + return tf.strings.regex_replace(inputs, " ", "_", replace_global=True) + + +model1 = RegexModel() + +print(model1(tf.constant(["Hello world!"]))) + +model1.save("models/regex_model") From 457012b788471beb8f1921e938656c66af713aac Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Wed, 10 Mar 2021 09:59:43 -0700 Subject: [PATCH 6/6] Remove stray --- onnxruntime/tests/integration_tests.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/tests/integration_tests.rs b/onnxruntime/tests/integration_tests.rs index 5152532c..c332e7ce 100644 --- a/onnxruntime/tests/integration_tests.rs +++ b/onnxruntime/tests/integration_tests.rs @@ -19,7 +19,6 @@ mod download { tensor::{DynOrtTensor, OrtOwnedTensor}, GraphOptimizationLevel, LoggingLevel, }; - use onnxruntime::tensor::ndarray_tensor::NdArrayTensor; #[test] fn squeezenet_mushroom() {