From 0d34d23d0b50aa664fc5b4c9e1337ea73e496541 Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Wed, 24 Feb 2021 09:48:34 -0700 Subject: [PATCH 1/2] Introduce a new trait to represent types that can be used as output from a tensor This is some prep work for string output types and tensor types that vary across the model outputs. For now, the supported types are just the basic numeric types. Since strings have to be copied out of a tensor, it only makes sense to have `String` be an output type, not `str`, hence the new type so that we can have more input types supported than output types. --- onnxruntime/src/lib.rs | 148 -------------- onnxruntime/src/session.rs | 112 ++++++---- onnxruntime/src/tensor.rs | 226 +++++++++++++++++++++ onnxruntime/src/tensor/ort_owned_tensor.rs | 83 ++------ onnxruntime/src/tensor/ort_tensor.rs | 8 +- 5 files changed, 322 insertions(+), 255 deletions(-) diff --git a/onnxruntime/src/lib.rs b/onnxruntime/src/lib.rs index 0d575b5e..66aa7da6 100644 --- a/onnxruntime/src/lib.rs +++ b/onnxruntime/src/lib.rs @@ -322,154 +322,6 @@ impl Into for GraphOptimizationLevel { } } -// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum -// FIXME: Add tests to cover the commented out types -/// Enum mapping ONNX Runtime's supported tensor types -#[derive(Debug)] -#[cfg_attr(not(windows), repr(u32))] -#[cfg_attr(windows, repr(i32))] -pub enum TensorElementDataType { - /// 32-bit floating point, equivalent to Rust's `f32` - Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt, - /// Unsigned 8-bit int, equivalent to Rust's `u8` - Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt, - /// Signed 8-bit int, equivalent to Rust's `i8` - Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt, - /// Unsigned 16-bit int, equivalent to Rust's `u16` - Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt, - /// Signed 16-bit int, equivalent to Rust's `i16` - Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt, - /// Signed 32-bit int, equivalent to Rust's `i32` - Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt, - /// Signed 64-bit int, equivalent to Rust's `i64` - Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt, - /// String, equivalent to Rust's `String` - String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt, - // /// Boolean, equivalent to Rust's `bool` - // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt, - // /// 16-bit floating point, equivalent to Rust's `f16` - // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt, - /// 64-bit floating point, equivalent to Rust's `f64` - Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt, - /// Unsigned 32-bit int, equivalent to Rust's `u32` - Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt, - /// Unsigned 64-bit int, equivalent to Rust's `u64` - Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt, - // /// Complex 64-bit floating point, equivalent to Rust's `???` - // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt, - // /// Complex 128-bit floating point, equivalent to Rust's `???` - // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt, - // /// Brain 16-bit floating point - // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt, -} - -impl Into for TensorElementDataType { - fn into(self) -> sys::ONNXTensorElementDataType { - use TensorElementDataType::*; - match self { - Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, - Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, - Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, - Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, - Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, - Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, - String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, - // Bool => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - // } - // Float16 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 - // } - Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, - Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, - Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, - // Complex64 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 - // } - // Complex128 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 - // } - // Bfloat16 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 - // } - } - } -} - -/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`) -pub trait TypeToTensorElementDataType { - /// Return the ONNX type for a Rust type - fn tensor_element_data_type() -> TensorElementDataType; - - /// If the type is `String`, returns `Some` with utf8 contents, else `None`. - fn try_utf8_bytes(&self) -> Option<&[u8]>; -} - -macro_rules! impl_type_trait { - ($type_:ty, $variant:ident) => { - impl TypeToTensorElementDataType for $type_ { - fn tensor_element_data_type() -> TensorElementDataType { - // unsafe { std::mem::transmute(TensorElementDataType::$variant) } - TensorElementDataType::$variant - } - - fn try_utf8_bytes(&self) -> Option<&[u8]> { - None - } - } - }; -} - -impl_type_trait!(f32, Float); -impl_type_trait!(u8, Uint8); -impl_type_trait!(i8, Int8); -impl_type_trait!(u16, Uint16); -impl_type_trait!(i16, Int16); -impl_type_trait!(i32, Int32); -impl_type_trait!(i64, Int64); -// impl_type_trait!(bool, Bool); -// impl_type_trait!(f16, Float16); -impl_type_trait!(f64, Double); -impl_type_trait!(u32, Uint32); -impl_type_trait!(u64, Uint64); -// impl_type_trait!(, Complex64); -// impl_type_trait!(, Complex128); -// impl_type_trait!(, Bfloat16); - -/// Adapter for common Rust string types to Onnx strings. -/// -/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but -/// we can't define an automatic implementation for anything that implements `AsRef` as it -/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric -/// types (which might implement `AsRef` at some point in the future). -pub trait Utf8Data { - /// Returns the utf8 contents. - fn utf8_bytes(&self) -> &[u8]; -} - -impl Utf8Data for String { - fn utf8_bytes(&self) -> &[u8] { - self.as_bytes() - } -} - -impl<'a> Utf8Data for &'a str { - fn utf8_bytes(&self) -> &[u8] { - self.as_bytes() - } -} - -impl TypeToTensorElementDataType for T { - fn tensor_element_data_type() -> TensorElementDataType { - TensorElementDataType::String - } - - fn try_utf8_bytes(&self) -> Option<&[u8]> { - Some(self.utf8_bytes()) - } -} - /// Allocator type #[derive(Debug, Clone)] #[repr(i32)] diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 04f9cf1c..d212111e 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -18,15 +18,14 @@ use onnxruntime_sys as sys; use crate::{ char_p_to_string, environment::Environment, - error::{status_to_result, NonMatchingDimensionsError, OrtError, Result}, + error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result}, g_ort, memory::MemoryInfo, tensor::{ - ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor}, - OrtTensor, + ort_owned_tensor::OrtOwnedTensor, OrtTensor, TensorDataToType, TensorElementDataType, + TypeToTensorElementDataType, }, - AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType, - TypeToTensorElementDataType, + AllocatorType, GraphOptimizationLevel, MemType, }; #[cfg(feature = "model-fetching")] @@ -371,7 +370,7 @@ impl<'a> Session<'a> { ) -> Result>> where TIn: TypeToTensorElementDataType + Debug + Clone, - TOut: TypeToTensorElementDataType + Debug + Clone, + TOut: TensorDataToType, D: ndarray::Dimension, 'm: 't, // 'm outlives 't (memory info outlives tensor) 's: 'm, // 's outlives 'm (session outlives memory info) @@ -440,21 +439,30 @@ impl<'a> Session<'a> { let outputs: Result>>> = output_tensor_extractors_ptrs .into_iter() - .map(|ptr| { - let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = - std::ptr::null_mut(); - let status = unsafe { - g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _) - }; - status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?; - let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) }; - unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) }; - let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect(); - - let mut output_tensor_extractor = - OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims)); - output_tensor_extractor.tensor_ptr = ptr; - output_tensor_extractor.extract::() + .map(|tensor_ptr| { + let dims = unsafe { + call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { + get_tensor_dimensions(tensor_info_ptr) + .map(|dims| dims.iter().map(|&n| n as usize).collect::>()) + }) + }?; + + // Note: Both tensor and array will point to the same data, nothing is copied. + // As such, there is no need to free the pointer used to create the ArrayView. + assert_ne!(tensor_ptr, std::ptr::null_mut()); + + let mut is_tensor = 0; + unsafe { call_ort(|ort| ort.IsTensor.unwrap()(tensor_ptr, &mut is_tensor)) } + .map_err(OrtError::IsTensor)?; + assert_eq!(is_tensor, 1); + + let array_view = TOut::extract_array(ndarray::IxDyn(&dims), tensor_ptr)?; + + Ok(OrtOwnedTensor::new( + tensor_ptr, + array_view, + &memory_info_ref, + )) }) .collect(); @@ -554,25 +562,60 @@ unsafe fn get_tensor_dimensions( tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, ) -> Result> { let mut num_dims = 0; - let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims); - status_to_result(status).map_err(OrtError::GetDimensionsCount)?; + call_ort(|ort| ort.GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims)) + .map_err(OrtError::GetDimensionsCount)?; assert_ne!(num_dims, 0); let mut node_dims: Vec = vec![0; num_dims as usize]; - let status = g_ort().GetDimensions.unwrap()( - tensor_info_ptr, - node_dims.as_mut_ptr(), // FIXME: UB? - num_dims, - ); - status_to_result(status).map_err(OrtError::GetDimensions)?; + call_ort(|ort| { + ort.GetDimensions.unwrap()( + tensor_info_ptr, + node_dims.as_mut_ptr(), // FIXME: UB? + num_dims, + ) + }) + .map_err(OrtError::GetDimensions)?; Ok(node_dims) } +unsafe fn extract_data_type( + tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, +) -> Result { + let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + call_ort(|ort| ort.GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys)) + .map_err(OrtError::TensorElementType)?; + assert_ne!( + type_sys, + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED + ); + // This transmute should be safe since its value is read from GetTensorElementType which we must trust. + Ok(std::mem::transmute(type_sys)) +} + +/// Calls the provided closure with the result of `GetTensorTypeAndShape`, deallocating the +/// resulting `*OrtTensorTypeAndShapeInfo` before returning. +unsafe fn call_with_tensor_info(tensor_ptr: *const sys::OrtValue, mut f: F) -> Result +where + F: FnMut(*const sys::OrtTensorTypeAndShapeInfo) -> Result, +{ + let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + call_ort(|ort| ort.GetTensorTypeAndShape.unwrap()(tensor_ptr, &mut tensor_info_ptr as _)) + .map_err(OrtError::GetTensorTypeAndShape)?; + + let res = f(tensor_info_ptr); + + // no return code, so no errors to check for + g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr); + + res +} + /// This module contains dangerous functions working on raw pointers. /// Those functions are only to be used from inside the /// `SessionBuilder::with_model_from_file()` method. mod dangerous { use super::*; + use crate::tensor::TensorElementDataType; pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> Result { let f = g_ort().SessionGetInputCount.unwrap(); @@ -689,16 +732,7 @@ mod dangerous { status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?; assert_ne!(tensor_info_ptr, std::ptr::null_mut()); - let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - let status = - unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) }; - status_to_result(status).map_err(OrtError::TensorElementType)?; - assert_ne!( - type_sys, - sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED - ); - // This transmute should be safe since its value is read from GetTensorElementType which we must trust. - let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) }; + let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? }; // info!("{} : type={}", i, type_); diff --git a/onnxruntime/src/tensor.rs b/onnxruntime/src/tensor.rs index 92404842..a5178c91 100644 --- a/onnxruntime/src/tensor.rs +++ b/onnxruntime/src/tensor.rs @@ -29,3 +29,229 @@ pub mod ort_tensor; pub use ort_owned_tensor::OrtOwnedTensor; pub use ort_tensor::OrtTensor; + +use crate::{OrtError, Result}; +use onnxruntime_sys::{self as sys, OnnxEnumInt}; +use std::{fmt, ptr}; + +// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum +// FIXME: Add tests to cover the commented out types +/// Enum mapping ONNX Runtime's supported tensor types +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(not(windows), repr(u32))] +#[cfg_attr(windows, repr(i32))] +pub enum TensorElementDataType { + /// 32-bit floating point, equivalent to Rust's `f32` + Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt, + /// Unsigned 8-bit int, equivalent to Rust's `u8` + Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt, + /// Signed 8-bit int, equivalent to Rust's `i8` + Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt, + /// Unsigned 16-bit int, equivalent to Rust's `u16` + Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt, + /// Signed 16-bit int, equivalent to Rust's `i16` + Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt, + /// Signed 32-bit int, equivalent to Rust's `i32` + Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt, + /// Signed 64-bit int, equivalent to Rust's `i64` + Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt, + /// String, equivalent to Rust's `String` + String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt, + // /// Boolean, equivalent to Rust's `bool` + // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt, + // /// 16-bit floating point, equivalent to Rust's `f16` + // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt, + /// 64-bit floating point, equivalent to Rust's `f64` + Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt, + /// Unsigned 32-bit int, equivalent to Rust's `u32` + Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt, + /// Unsigned 64-bit int, equivalent to Rust's `u64` + Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt, + // /// Complex 64-bit floating point, equivalent to Rust's `???` + // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt, + // /// Complex 128-bit floating point, equivalent to Rust's `???` + // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt, + // /// Brain 16-bit floating point + // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt, +} + +impl Into for TensorElementDataType { + fn into(self) -> sys::ONNXTensorElementDataType { + use TensorElementDataType::*; + match self { + Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, + Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, + Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, + Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, + Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, + Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, + // Bool => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL + // } + // Float16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 + // } + Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, + Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, + Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, + // Complex64 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 + // } + // Complex128 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 + // } + // Bfloat16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 + // } + } + } +} + +/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`) +pub trait TypeToTensorElementDataType { + /// Return the ONNX type for a Rust type + fn tensor_element_data_type() -> TensorElementDataType; + + /// If the type is `String`, returns `Some` with utf8 contents, else `None`. + fn try_utf8_bytes(&self) -> Option<&[u8]>; +} + +macro_rules! impl_prim_type_to_ort_trait { + ($type_:ty, $variant:ident) => { + impl TypeToTensorElementDataType for $type_ { + fn tensor_element_data_type() -> TensorElementDataType { + // unsafe { std::mem::transmute(TensorElementDataType::$variant) } + TensorElementDataType::$variant + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + None + } + } + }; +} + +impl_prim_type_to_ort_trait!(f32, Float); +impl_prim_type_to_ort_trait!(u8, Uint8); +impl_prim_type_to_ort_trait!(i8, Int8); +impl_prim_type_to_ort_trait!(u16, Uint16); +impl_prim_type_to_ort_trait!(i16, Int16); +impl_prim_type_to_ort_trait!(i32, Int32); +impl_prim_type_to_ort_trait!(i64, Int64); +// impl_type_trait!(bool, Bool); +// impl_type_trait!(f16, Float16); +impl_prim_type_to_ort_trait!(f64, Double); +impl_prim_type_to_ort_trait!(u32, Uint32); +impl_prim_type_to_ort_trait!(u64, Uint64); +// impl_type_trait!(, Complex64); +// impl_type_trait!(, Complex128); +// impl_type_trait!(, Bfloat16); + +/// Adapter for common Rust string types to Onnx strings. +/// +/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but +/// we can't define an automatic implementation for anything that implements `AsRef` as it +/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric +/// types (which might implement `AsRef` at some point in the future). +pub trait Utf8Data { + /// Returns the utf8 contents. + fn utf8_bytes(&self) -> &[u8]; +} + +impl Utf8Data for String { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl<'a> Utf8Data for &'a str { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl TypeToTensorElementDataType for T { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::String + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + Some(self.utf8_bytes()) + } +} + +/// Trait used to map onnxruntime types to Rust types +pub trait TensorDataToType: Sized + fmt::Debug { + /// The tensor element type that this type can extract from + fn tensor_element_data_type() -> TensorElementDataType; + + /// Extract an `ArrayView` from the ort-owned tensor. + fn extract_array<'t, D>( + shape: D, + tensor: *mut sys::OrtValue, + ) -> Result> + where + D: ndarray::Dimension; +} + +/// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData` +macro_rules! impl_prim_type_from_ort_trait { + ($type_:ty, $variant:ident) => { + impl TensorDataToType for $type_ { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::$variant + } + + fn extract_array<'t, D>( + shape: D, + tensor: *mut sys::OrtValue, + ) -> Result> + where + D: ndarray::Dimension, + { + extract_primitive_array(shape, tensor) + } + } + }; +} + +/// Construct an [ndarray::ArrayView] over an Ort tensor. +/// +/// Only to be used on types whose Rust in-memory representation matches Ort's (e.g. primitive +/// numeric types like u32). +fn extract_primitive_array<'t, D, T: TensorDataToType>( + shape: D, + tensor: *mut sys::OrtValue, +) -> Result> +where + D: ndarray::Dimension, +{ + // Get pointer to output tensor float values + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = + output_array_ptr_ptr as *mut *mut std::ffi::c_void; + unsafe { + crate::error::call_ort(|ort| { + ort.GetTensorMutableData.unwrap()(tensor, output_array_ptr_ptr_void) + }) + } + .map_err(OrtError::GetTensorMutableData)?; + assert_ne!(output_array_ptr, ptr::null_mut()); + + let array_view = unsafe { ndarray::ArrayView::from_shape_ptr(shape, output_array_ptr) }; + Ok(array_view) +} + +impl_prim_type_from_ort_trait!(f32, Float); +impl_prim_type_from_ort_trait!(u8, Uint8); +impl_prim_type_from_ort_trait!(i8, Int8); +impl_prim_type_from_ort_trait!(u16, Uint16); +impl_prim_type_from_ort_trait!(i16, Int16); +impl_prim_type_from_ort_trait!(i32, Int32); +impl_prim_type_from_ort_trait!(i64, Int64); +impl_prim_type_from_ort_trait!(f64, Double); +impl_prim_type_from_ort_trait!(u32, Uint32); +impl_prim_type_from_ort_trait!(u64, Uint64); diff --git a/onnxruntime/src/tensor/ort_owned_tensor.rs b/onnxruntime/src/tensor/ort_owned_tensor.rs index 161fe105..bfff3dea 100644 --- a/onnxruntime/src/tensor/ort_owned_tensor.rs +++ b/onnxruntime/src/tensor/ort_owned_tensor.rs @@ -7,10 +7,8 @@ use tracing::debug; use onnxruntime_sys as sys; -use crate::{ - error::status_to_result, g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor, - OrtError, Result, TypeToTensorElementDataType, -}; +use crate::tensor::TensorDataToType; +use crate::{g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor}; /// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference. /// @@ -25,7 +23,7 @@ use crate::{ #[derive(Debug)] pub struct OrtOwnedTensor<'t, 'm, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, 'm: 't, // 'm outlives 't { @@ -36,7 +34,7 @@ where impl<'t, 'm, T, D> Deref for OrtOwnedTensor<'t, 'm, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, { type Target = ArrayView<'t, T, D>; @@ -48,9 +46,21 @@ where impl<'t, 'm, T, D> OrtOwnedTensor<'t, 'm, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, { + pub(crate) fn new( + tensor_ptr: *mut sys::OrtValue, + array_view: ArrayView<'t, T, D>, + memory_info: &'m MemoryInfo, + ) -> OrtOwnedTensor<'t, 'm, T, D> { + OrtOwnedTensor { + tensor_ptr, + array_view, + memory_info, + } + } + /// Apply a softmax on the specified axis pub fn softmax(&self, axis: ndarray::Axis) -> Array where @@ -61,66 +71,9 @@ where } } -#[derive(Debug)] -pub(crate) struct OrtOwnedTensorExtractor<'m, D> -where - D: ndarray::Dimension, -{ - pub(crate) tensor_ptr: *mut sys::OrtValue, - memory_info: &'m MemoryInfo, - shape: D, -} - -impl<'m, D> OrtOwnedTensorExtractor<'m, D> -where - D: ndarray::Dimension, -{ - pub(crate) fn new(memory_info: &'m MemoryInfo, shape: D) -> OrtOwnedTensorExtractor<'m, D> { - OrtOwnedTensorExtractor { - tensor_ptr: std::ptr::null_mut(), - memory_info, - shape, - } - } - - pub(crate) fn extract<'t, T>(self) -> Result> - where - T: TypeToTensorElementDataType + Debug + Clone, - { - // Note: Both tensor and array will point to the same data, nothing is copied. - // As such, there is no need too free the pointer used to create the ArrayView. - - assert_ne!(self.tensor_ptr, std::ptr::null_mut()); - - let mut is_tensor = 0; - let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) }; - status_to_result(status).map_err(OrtError::IsTensor)?; - assert_eq!(is_tensor, 1); - - // Get pointer to output tensor float values - let mut output_array_ptr: *mut T = std::ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = - output_array_ptr_ptr as *mut *mut std::ffi::c_void; - let status = unsafe { - g_ort().GetTensorMutableData.unwrap()(self.tensor_ptr, output_array_ptr_ptr_void) - }; - status_to_result(status).map_err(OrtError::IsTensor)?; - assert_ne!(output_array_ptr, std::ptr::null_mut()); - - let array_view = unsafe { ArrayView::from_shape_ptr(self.shape, output_array_ptr) }; - - Ok(OrtOwnedTensor { - tensor_ptr: self.tensor_ptr, - array_view, - memory_info: self.memory_info, - }) - } -} - impl<'t, 'm, T, D> Drop for OrtOwnedTensor<'t, 'm, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, 'm: 't, // 'm outlives 't { diff --git a/onnxruntime/src/tensor/ort_tensor.rs b/onnxruntime/src/tensor/ort_tensor.rs index 437e2e86..0937afe1 100644 --- a/onnxruntime/src/tensor/ort_tensor.rs +++ b/onnxruntime/src/tensor/ort_tensor.rs @@ -8,9 +8,11 @@ use tracing::{debug, error}; use onnxruntime_sys as sys; use crate::{ - error::call_ort, error::status_to_result, g_ort, memory::MemoryInfo, - tensor::ndarray_tensor::NdArrayTensor, OrtError, Result, TensorElementDataType, - TypeToTensorElementDataType, + error::{call_ort, status_to_result}, + g_ort, + memory::MemoryInfo, + tensor::{ndarray_tensor::NdArrayTensor, TensorElementDataType, TypeToTensorElementDataType}, + OrtError, Result, }; /// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) From d2663dd316e99e39bca26b83b5b33f26de443f0b Mon Sep 17 00:00:00 2001 From: Marshall Pierce <575695+marshallpierce@users.noreply.github.com> Date: Mon, 5 Apr 2021 13:17:38 -0600 Subject: [PATCH 2/2] Add Clone back in to make it possible for users to clone output tensors --- onnxruntime/src/tensor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/src/tensor.rs b/onnxruntime/src/tensor.rs index a5178c91..f7fb7a28 100644 --- a/onnxruntime/src/tensor.rs +++ b/onnxruntime/src/tensor.rs @@ -183,7 +183,7 @@ impl TypeToTensorElementDataType for T { } /// Trait used to map onnxruntime types to Rust types -pub trait TensorDataToType: Sized + fmt::Debug { +pub trait TensorDataToType: Sized + fmt::Debug + Clone { /// The tensor element type that this type can extract from fn tensor_element_data_type() -> TensorElementDataType;