Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm64 linux and OSX #122

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions onnxruntime-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::{
/// WARNING: If version is changed, bindings for all platforms will have to be re-generated.
/// To do so, run this:
/// cargo build --package onnxruntime-sys --features generate-bindings
const ORT_VERSION: &str = "1.8.1";
const ORT_VERSION: &str = "1.14.1";

/// Base Url from which to download pre-built releases/
const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download";
Expand Down Expand Up @@ -310,12 +310,12 @@ struct Triplet {
impl OnnxPrebuiltArchive for Triplet {
fn as_onnx_str(&self) -> Cow<str> {
match (&self.os, &self.arch, &self.accelerator) {
// onnxruntime-win-x86-1.8.1.zip
// onnxruntime-win-x64-1.8.1.zip
// onnxruntime-win-arm-1.8.1.zip
// onnxruntime-win-arm64-1.8.1.zip
// onnxruntime-linux-x64-1.8.1.tgz
// onnxruntime-osx-x64-1.8.1.tgz
// onnxruntime-win-x86-1.14.1.zip
// onnxruntime-win-x64-1.14.1.zip
// onnxruntime-win-arm-1.14.1.zip
// onnxruntime-win-arm64-1.14.1.zip
// onnxruntime-linux-x64-1.14.1.tgz
// onnxruntime-osx-x64-1.14.1.tgz
(Os::Windows, Architecture::X86, Accelerator::None)
| (Os::Windows, Architecture::X86_64, Accelerator::None)
| (Os::Windows, Architecture::Arm, Accelerator::None)
Expand All @@ -326,15 +326,25 @@ impl OnnxPrebuiltArchive for Triplet {
self.os.as_onnx_str(),
self.arch.as_onnx_str()
)),
// onnxruntime-win-gpu-x64-1.8.1.zip
(Os::Linux, Architecture::Arm64, Accelerator::None) => Cow::from(format!(
"{}-aarch64",
self.os.as_onnx_str(),
)),
// onnxruntime-osx-arm64-1.14.1.tgz
(Os::MacOs, Architecture::Arm64, Accelerator::None) => Cow::from(format!(
"{}-{}",
self.os.as_onnx_str(),
self.arch.as_onnx_str(),
)),
// onnxruntime-win-gpu-x64-1.14.1.zip
// Note how this one is inverted from the linux one next
(Os::Windows, Architecture::X86_64, Accelerator::Gpu) => Cow::from(format!(
"{}-{}-{}",
self.os.as_onnx_str(),
self.accelerator.as_onnx_str(),
self.arch.as_onnx_str(),
)),
// onnxruntime-linux-x64-gpu-1.8.1.tgz
// onnxruntime-linux-x64-gpu-1.14.1.tgz
// Note how this one is inverted from the windows one above
(Os::Linux, Architecture::X86_64, Accelerator::Gpu) => Cow::from(format!(
"{}-{}-{}",
Expand Down
20 changes: 10 additions & 10 deletions onnxruntime-sys/examples/c_api_sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fn main() {
// iterate over all input nodes
for i in 0..num_input_nodes {
// print input node names
let mut input_name: *mut i8 = std::ptr::null_mut();
let mut input_name: *mut std::os::raw::c_char = std::ptr::null_mut();
let status = unsafe {
g_ort.as_ref().unwrap().SessionGetInputName.unwrap()(
session_ptr,
Expand Down Expand Up @@ -282,24 +282,24 @@ fn main() {
.into_iter()
.map(|n| std::ffi::CString::new(n).unwrap())
.collect();
let input_node_names_ptr: Vec<*const i8> = input_node_names_cstring
let input_node_names_ptr: Vec<*const std::os::raw::c_char> = input_node_names_cstring
.into_iter()
.map(|n| n.into_raw() as *const i8)
.map(|n| n.into_raw() as *const std::os::raw::c_char)
.collect();
let input_node_names_ptr_ptr: *const *const i8 = input_node_names_ptr.as_ptr();
let input_node_names_ptr_ptr: *const *const std::os::raw::c_char = input_node_names_ptr.as_ptr();

let output_node_names_cstring: Vec<std::ffi::CString> = output_node_names
.into_iter()
.map(|n| std::ffi::CString::new(n.clone()).unwrap())
.collect();
let output_node_names_ptr: Vec<*const i8> = output_node_names_cstring
let output_node_names_ptr: Vec<*const std::os::raw::c_char> = output_node_names_cstring
.iter()
.map(|n| n.as_ptr() as *const i8)
.map(|n| n.as_ptr() as *const std::os::raw::c_char)
.collect();
let output_node_names_ptr_ptr: *const *const i8 = output_node_names_ptr.as_ptr();
let output_node_names_ptr_ptr: *const *const std::os::raw::c_char = output_node_names_ptr.as_ptr();

let _input_node_names_cstring =
unsafe { std::ffi::CString::from_raw(input_node_names_ptr[0] as *mut i8) };
unsafe { std::ffi::CString::from_raw(input_node_names_ptr[0] as *mut std::os::raw::c_char) };
let run_options_ptr: *const OrtRunOptions = std::ptr::null();
let mut output_tensor_ptr: *mut OrtValue = std::ptr::null_mut();
let output_tensor_ptr_ptr: *mut *mut OrtValue = &mut output_tensor_ptr;
Expand Down Expand Up @@ -371,7 +371,7 @@ fn CheckStatus(g_ort: *const OrtApi, status: *const OrtStatus) -> Result<(), Str
}
}

fn char_p_to_str<'a>(raw: *const i8) -> Result<&'a str, std::str::Utf8Error> {
let c_str = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8) };
fn char_p_to_str<'a>(raw: *const std::os::raw::c_char) -> Result<&'a str, std::str::Utf8Error> {
let c_str = unsafe { std::ffi::CStr::from_ptr(raw as *mut std::os::raw::c_char) };
c_str.to_str()
}
6 changes: 6 additions & 0 deletions onnxruntime-sys/src/generated/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ include!(concat!(
"/src/generated/linux/x86_64/bindings.rs"
));

#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/generated/linux/aarch64/bindings.rs"
));

#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
include!(concat!(
env!("CARGO_MANIFEST_DIR"),
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime-sys/src/generated/linux/aarch64/bindings.rs

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions onnxruntime-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#![allow(clippy::all)]
#![allow(improper_ctypes)]

#[allow(clippy::all)]

include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/generated/bindings.rs"
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
if status.0.is_null() {
Ok(())
} else {
let raw: *const i8 = unsafe { g_ort().GetErrorMessage.unwrap()(status.0) };
let raw: *const std::os::raw::c_char = unsafe { g_ort().GetErrorMessage.unwrap()(status.0) };
match char_p_to_string(raw) {
Ok(msg) => Err(OrtApiError::Msg(msg)),
Err(err) => match err {
Expand Down
21 changes: 10 additions & 11 deletions onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ fn g_ort() -> sys::OrtApi {
unsafe { *api_ptr_mut }
}

fn char_p_to_string(raw: *const i8) -> Result<String> {
let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
fn char_p_to_string(raw: *const std::os::raw::c_char) -> Result<String> {
let c_string = unsafe { std::ffi::CStr::from_ptr(raw).to_owned() };

match c_string.into_string() {
Ok(string) => Ok(string),
Expand All @@ -196,7 +196,6 @@ mod onnxruntime {
//! Module containing a custom logger, used to catch the runtime's own logging and send it
//! to Rust's tracing logging instead.

use std::ffi::CStr;
use tracing::{debug, error, info, span, trace, warn, Level};

use onnxruntime_sys as sys;
Expand Down Expand Up @@ -235,10 +234,10 @@ mod onnxruntime {
pub(crate) fn custom_logger(
_params: *mut std::ffi::c_void,
severity: sys::OrtLoggingLevel,
category: *const i8,
logid: *const i8,
code_location: *const i8,
message: *const i8,
category: *const std::os::raw::c_char,
logid: *const std::os::raw::c_char,
code_location: *const std::os::raw::c_char,
message: *const std::os::raw::c_char,
) {
let log_level = match severity {
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
Expand All @@ -249,16 +248,16 @@ mod onnxruntime {
};

assert_ne!(category, std::ptr::null());
let category = unsafe { CStr::from_ptr(category) };
let category = unsafe { std::ffi::CStr::from_ptr(category) };
assert_ne!(code_location, std::ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }
let code_location = unsafe { std::ffi::CStr::from_ptr(code_location) }
.to_str()
.unwrap_or("unknown");
assert_ne!(message, std::ptr::null());
let message = unsafe { CStr::from_ptr(message) };
let message = unsafe { std::ffi::CStr::from_ptr(message) };

assert_ne!(logid, std::ptr::null());
let logid = unsafe { CStr::from_ptr(logid) };
let logid = unsafe { std::ffi::CStr::from_ptr(logid) };

// Parse the code location
let code_location: CodeLocation = code_location.into();
Expand Down
15 changes: 8 additions & 7 deletions onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ impl<'a> SessionBuilder<'a> {
/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html)
#[derive(Debug)]
pub struct Session<'a> {
#[allow(dead_code)]
env: &'a Environment,
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
Expand Down Expand Up @@ -390,12 +391,12 @@ impl<'a> Session<'a> {

// Build arguments to Run()

let input_names_ptr: Vec<*const i8> = self
let input_names_ptr: Vec<*const std::os::raw::c_char> = self
.inputs
.iter()
.map(|input| input.name.clone())
.map(|n| CString::new(n).unwrap())
.map(|n| n.into_raw() as *const i8)
.map(|n| n.into_raw() as *const std::os::raw::c_char)
.collect();

let output_names_cstring: Vec<CString> = self
Expand All @@ -404,9 +405,9 @@ impl<'a> Session<'a> {
.map(|output| output.name.clone())
.map(|n| CString::new(n).unwrap())
.collect();
let output_names_ptr: Vec<*const i8> = output_names_cstring
let output_names_ptr: Vec<*const std::os::raw::c_char> = output_names_cstring
.iter()
.map(|n| n.as_ptr() as *const i8)
.map(|n| n.as_ptr() as *const std::os::raw::c_char)
.collect();

let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> =
Expand Down Expand Up @@ -467,7 +468,7 @@ impl<'a> Session<'a> {
.into_iter()
.map(|p| {
assert_not_null_pointer(p, "i8 for CString")?;
unsafe { Ok(CString::from_raw(p as *mut i8)) }
unsafe { Ok(CString::from_raw(p as *mut std::os::raw::c_char)) }
})
.collect();
cstrings?;
Expand Down Expand Up @@ -646,13 +647,13 @@ mod dangerous {
*const sys::OrtSession,
usize,
*mut sys::OrtAllocator,
*mut *mut i8,
*mut *mut std::os::raw::c_char,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize,
) -> Result<String> {
let mut name_bytes: *mut i8 = std::ptr::null_mut();
let mut name_bytes: *mut std::os::raw::c_char = std::ptr::null_mut();

let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::InputName)?;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/src/tensor/ort_owned_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ where
{
pub(crate) tensor_ptr: *mut sys::OrtValue,
array_view: ArrayView<'t, T, D>,
#[allow(dead_code)]
memory_info: &'m MemoryInfo,
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/src/tensor/ort_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ where
{
pub(crate) c_ptr: *mut sys::OrtValue,
array: Array<T, D>,
#[allow(dead_code)]
memory_info: &'t MemoryInfo,
}

Expand Down