Skip to content

Commit

Permalink
Attempt at progressively feeding the Session to bypass type checking in
Browse files Browse the repository at this point in the history
a sane way.
  • Loading branch information
Narsil committed Mar 10, 2021
1 parent d2f1ebe commit f11527a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 63 deletions.
10 changes: 5 additions & 5 deletions onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,16 @@ pub enum OrtError {
#[derive(Error, Debug)]
pub enum NonMatchingDimensionsError {
/// Number of inputs from model does not match number of inputs from inference call
#[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
#[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?}")]
InputsCount {
/// Number of input dimensions used by inference call
inference_input_count: usize,
/// Number of input dimensions defined in model
model_input_count: usize,
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>,
// Input dimensions used by inference call
// inference_input: Vec<Vec<usize>>,
// Input dimensions defined in model
// model_input: Vec<Vec<Option<u32>>>,
},
}

Expand Down
146 changes: 88 additions & 58 deletions onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,14 @@ impl<'a> SessionBuilder<'a> {
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<Result<Vec<Output>>>()?;
let input_ort_values = Vec::with_capacity(num_output_nodes as usize);

Ok(Session {
env: self.env,
session_ptr,
allocator_ptr,
memory_info,
input_ort_values,
inputs,
outputs,
})
Expand Down Expand Up @@ -271,12 +273,14 @@ impl<'a> SessionBuilder<'a> {
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<Result<Vec<Output>>>()?;
let input_ort_values = Vec::with_capacity(num_output_nodes as usize);

Ok(Session {
env: self.env,
session_ptr,
allocator_ptr,
memory_info,
input_ort_values,
inputs,
outputs,
})
Expand All @@ -290,6 +294,7 @@ pub struct Session<'a> {
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
memory_info: MemoryInfo,
input_ort_values: Vec<*const sys::OrtValue>,
/// Information about the ONNX's inputs as stored in loaded file
pub inputs: Vec<Input>,
/// Information about the ONNX's outputs as stored in loaded file
Expand Down Expand Up @@ -357,6 +362,26 @@ impl<'a> Drop for Session<'a> {
}

impl<'a> Session<'a> {
/// Somedoc
pub fn feed<'s, 't, 'm, TIn, D>(&'s mut self, input_array: Array<TIn, D>) -> Result<()>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
'm: 't, // 'm outlives 't (memory info outlives tensor)
's: 'm, // 's outlives 'm (session outlives memory info)
{
self.validate_input_shapes(&input_array)?;
// The C API expects pointers for the arrays (pointers to C-arrays)
let input_ort_tensor: OrtTensor<TIn, D> =
OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)?;

let input_ort_value: *const sys::OrtValue = input_ort_tensor.c_ptr as *const sys::OrtValue;
std::mem::forget(input_ort_tensor);
self.input_ort_values.push(input_ort_value);

Ok(())
}

/// Run the input data through the ONNX graph, performing inference.
///
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
Expand All @@ -371,9 +396,46 @@ impl<'a> Session<'a> {
'm: 't, // 'm outlives 't (memory info outlives tensor)
's: 'm, // 's outlives 'm (session outlives memory info)
{
self.validate_input_shapes(&input_arrays)?;

input_arrays
.into_iter()
.for_each(|input_array| self.feed(input_array).unwrap());
self.inner_run()
}
/// Run the input data through the ONNX graph, performing inference.
///
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
/// used for the input data here.
pub fn inner_run<'s, 't, 'm>(
&'s mut self,
// input_arrays: Vec<Array<TIn, D>>,
) -> Result<Vec<DynOrtTensor<'m, ndarray::IxDyn>>>
where
'm: 't, // 'm outlives 't (memory info outlives tensor)
's: 'm, // 's outlives 'm (session outlives memory info)
{
// Build arguments to Run()
if self.input_ort_values.len() != self.inputs.len() {
error!(
"Non-matching number of inputs: {} (inference) vs {} (model)",
self.input_ort_values.len(),
self.inputs.len()
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsCount {
inference_input_count: 0,
model_input_count: 0,
// inference_input: input_arrays
// .iter()
// .map(|input_array| input_array.shape().to_vec())
// .collect(),
// model_input: self
// .inputs
// .iter()
// .map(|input| input.dimensions.clone())
// .collect(),
},
));
}

let input_names: Vec<String> = self.inputs.iter().map(|input| input.name.clone()).collect();
let input_names_cstring: Vec<CString> = input_names
Expand Down Expand Up @@ -403,33 +465,22 @@ impl<'a> Session<'a> {
let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> =
vec![std::ptr::null_mut(); self.outputs.len()];

// The C API expects pointers for the arrays (pointers to C-arrays)
let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays
.into_iter()
.map(|input_array| {
OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)
})
.collect::<Result<Vec<OrtTensor<TIn, D>>>>()?;
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors
.iter()
.map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue)
.collect();

let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();

let status = unsafe {
g_ort().Run.unwrap()(
self.session_ptr,
run_options_ptr,
input_names_ptr.as_ptr(),
input_ort_values.as_ptr(),
input_ort_values.len() as u64, // C API expects a u64, not isize
self.input_ort_values.as_ptr(),
self.input_ort_values.len() as u64, // C API expects a u64, not isize
output_names_ptr.as_ptr(),
output_names_ptr.len() as u64, // C API expects a u64, not isize
output_tensor_ptrs.as_mut_ptr(),
)
};
status_to_result(status).map_err(OrtError::Run)?;
self.input_ort_values.iter().for_each(std::mem::drop);

let memory_info_ref = &self.memory_info;
let outputs: Result<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> =
Expand Down Expand Up @@ -494,7 +545,7 @@ impl<'a> Session<'a> {
// Tensor::from_array(self, array)
// }

fn validate_input_shapes<TIn, D>(&mut self, input_arrays: &[Array<TIn, D>]) -> Result<()>
fn validate_input_shapes<TIn, D>(&mut self, input_array: &Array<TIn, D>) -> Result<()>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
Expand All @@ -504,62 +555,41 @@ impl<'a> Session<'a> {
// Make sure all dimensions match (except dynamic ones)

// Verify length of inputs
if input_arrays.len() != self.inputs.len() {
error!(
"Non-matching number of inputs: {} (inference) vs {} (model)",
input_arrays.len(),
self.inputs.len()
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsCount {
inference_input_count: 0,
model_input_count: 0,
inference_input: input_arrays
.iter()
.map(|input_array| input_array.shape().to_vec())
.collect(),
model_input: self
.inputs
.iter()
.map(|input| input.dimensions.clone())
.collect(),
},
));
}

// Verify length of each individual inputs
let inputs_different_length = input_arrays
.iter()
.zip(self.inputs.iter())
.any(|(l, r)| l.shape().len() != r.dimensions.len());
if inputs_different_length {
let current_input = self.input_ort_values.len();
if current_input > self.inputs.len() {
error!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
"Attempting to feed too many inputs, expecting {:?} inputs",
self.inputs.len()
);
panic!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
"Attempting to feed too many inputs, expecting {:?} inputs",
self.inputs.len()
);
}
let input = &self.inputs[current_input];
if input_array.shape().len() != input.dimensions().count() {
error!("Different input lengths: {:?} vs {:?}", input, input_array);
panic!("Different input lengths: {:?} vs {:?}", input, input_array);
}

// Verify shape of each individual inputs
let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| {
let l_shape = l.shape();
let r_shape = r.dimensions.as_slice();
l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
Some(r3) => *r3 as usize != *l2,
None => false, // None means dynamic size; in that case shape always match
})
let l = input_array;
let r = input;
let l_shape = l.shape();
let r_shape = r.dimensions.as_slice();
let inputs_different_shape = l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
Some(r3) => *r3 as usize != *l2,
None => false, // None means dynamic size; in that case shape always match
});
if inputs_different_shape {
error!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
self.inputs, input_array
);
panic!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
self.inputs, input_array
);
}

Expand Down

0 comments on commit f11527a

Please sign in to comment.