From 8c36f889c224c0f040b7f57630ac8b2cdd14f5f8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 25 Aug 2020 11:47:26 -0700 Subject: [PATCH 01/50] WIP WIP --- rust/tvm-rt/src/lib.rs | 1 + rust/tvm-rt/src/ndarray2.rs | 440 ++++++++++++++++++++++++++++++++++++ 2 files changed, 441 insertions(+) create mode 100644 rust/tvm-rt/src/ndarray2.rs diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 84951f4c8e67..e32877a85d98 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -98,6 +98,7 @@ pub mod function; pub mod map; pub mod module; pub mod ndarray; +pub mod ndarray2; mod to_function; pub mod value; diff --git a/rust/tvm-rt/src/ndarray2.rs b/rust/tvm-rt/src/ndarray2.rs new file mode 100644 index 000000000000..d4b965b0fea8 --- /dev/null +++ b/rust/tvm-rt/src/ndarray2.rs @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements the [`NDArray`] type for working with *TVM tensors* or +//! coverting from a Rust's ndarray to TVM `NDArray`. +//! +//! One can create an empty NDArray given the shape, context and dtype using [`empty`]. +//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. +//! To copy an NDArray to different context use [`copy_to_ctx`]. +//! +//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: +//! +//! # Example +//! +//! ``` +//! # use tvm_rt::{NDArray, Context, DataType}; +//! # use ndarray::{Array, ArrayD}; +//! # use std::str::FromStr; +//! use std::convert::TryFrom; +//! +//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) +//! .unwrap() +//! .into_dyn(); // Rust's ndarray +//! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); +//! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); +//! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); +//! assert!(rnd.all_close(&a, 1e-8f32)); +//! ``` +//! +//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ +//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer +//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx + +use std::convert::TryInto; +use std::ffi::c_void; +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; + +use tvm_sys::ffi::DLTensor; +use tvm_sys::{ffi, ByteArray, Context, DataType}; +use tvm_macros::Object; + +use ndarray::{Array, ArrayD}; +use num_traits::Num; + +use crate::object::{Object, ObjectPtr} + +/// See the [`module-level documentation`](../ndarray/index.html) for more details. +#[repr(C)] +#[derive(Object)] +#[ref_name = "NDArray"] +#[type_key = "runtime.NDArray"] +pub struct NDArrayContainer { + base: Object, + dl_tensor: *mut DLTensor, + manager_ctx: *mut c_void, +} + + +impl NDArray { + pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Self { + let object: *mut NDArrayContainer = unsafe { std::mem::transmute(handle) }; + let object_ptr = ObjectPtr::from_raw(object); + NDArray(Some(object_ptr)) + } + + pub fn as_dltensor(&self) -> &DLTensor { + let ptr: *mut DLTensor = match self { + NDArray::Borrowed { ref handle } => *handle, + NDArray::Owned { ref handle } => *handle as *mut DLTensor, + }; + + unsafe { std::mem::transmute(ptr) } + } + + pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { + match self { + NDArray::Borrowed { handle } => *handle, + NDArray::Owned { handle } => *handle as *mut DLTensor, + } + } + + pub fn is_view(&self) -> bool { + if let &NDArray::Borrowed { .. } = self { + true + } else { + false + } + } + + /// Returns the shape of the NDArray. + pub fn shape(&self) -> Option<&mut [usize]> { + let arr = self.as_dltensor(); + if arr.shape.is_null() || arr.data.is_null() { + return None; + }; + let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; + Some(slc) + } + + /// Returns the total number of entries of the NDArray. + pub fn size(&self) -> Option { + self.shape().map(|v| v.iter().product()) + } + + /// Returns the context which the NDArray was defined. + pub fn ctx(&self) -> Context { + self.as_dltensor().ctx.into() + } + + /// Returns the type of the entries of the NDArray. + pub fn dtype(&self) -> DataType { + self.as_dltensor().dtype.into() + } + + /// Returns the number of dimensions of the NDArray. + pub fn ndim(&self) -> usize { + self.as_dltensor() + .ndim + .try_into() + .expect("number of dimensions must always be positive") + } + + /// Returns the strides of the underlying NDArray. + pub fn strides(&self) -> Option<&[usize]> { + unsafe { + let sz = self.ndim() * mem::size_of::(); + let strides_ptr = self.as_dltensor().strides as *const usize; + let slc = slice::from_raw_parts(strides_ptr, sz); + Some(slc) + } + } + + /// Shows whether the underlying ndarray is contiguous in memory or not. + pub fn is_contiguous(&self) -> Result { + Ok(match self.strides() { + None => true, + Some(strides) => { + // NDArrayError::MissingShape in case shape is not determined + self.shape() + .ok_or(NDArrayError::MissingShape)? + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + }) + } + + pub fn byte_offset(&self) -> isize { + self.as_dltensor().byte_offset as isize + } + + /// Flattens the NDArray to a `Vec` of the same type in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let mut shape = [4]; + /// let mut data = vec![1i32, 2, 3, 4]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); + /// assert_eq!(ndarray.to_vec::().unwrap(), data); + /// ``` + pub fn to_vec(&self) -> Result, NDArrayError> { + if !self.shape().is_some() { + return Err(NDArrayError::EmptyArray); + } + let earr = NDArray::empty( + self.shape().ok_or(NDArrayError::MissingShape)?, + Context::cpu(0), + self.dtype(), + ); + let target = self.copy_to_ndarray(earr)?; + let arr = target.as_dltensor(); + let sz = self.size().ok_or(NDArrayError::MissingShape)?; + let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); + unsafe { + v.as_mut_ptr() + .copy_from_nonoverlapping(arr.data as *const T, sz); + v.set_len(sz); + } + Ok(v) + } + + /// Converts the NDArray to [`ByteArray`]. + pub fn to_bytearray(&self) -> Result { + let v = self.to_vec::()?; + Ok(ByteArray::from(v)) + } + + /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let shape = &mut [2]; + /// let mut data = vec![1f32, 2.0]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// ``` + /// + /// *Note*: if something goes wrong during the copy, it will panic + /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. + pub fn copy_from_buffer(&mut self, data: &mut [T]) { + check_call!(ffi::TVMArrayCopyFromBytes( + self.as_raw_dltensor(), + data.as_ptr() as *mut _, + data.len() * mem::size_of::() + )); + } + + /// Copies the NDArray to another target NDArray. + pub fn copy_to_ndarray(&self, target: NDArray) -> Result { + if self.dtype() != target.dtype() { + return Err(NDArrayError::DataTypeMismatch { + expected: self.dtype(), + actual: target.dtype(), + }); + } + + check_call!(ffi::TVMArrayCopyFromTo( + self.as_raw_dltensor(), + target.as_raw_dltensor(), + ptr::null_mut() as ffi::TVMStreamHandle + )); + + Ok(target) + } + + /// Copies the NDArray to a target context. + pub fn copy_to_ctx(&self, target: &Context) -> Result { + let tmp = NDArray::empty( + self.shape().ok_or(NDArrayError::MissingShape)?, + *target, + self.dtype(), + ); + let copy = self.copy_to_ndarray(tmp)?; + Ok(copy) + } + + /// Converts a Rust's ndarray to TVM NDArray. + pub fn from_rust_ndarray( + rnd: &ArrayD, + ctx: Context, + dtype: DataType, + ) -> Result { + let shape = rnd.shape().to_vec(); + let mut nd = NDArray::empty(&shape, ctx, dtype); + let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); + nd.copy_from_buffer( + buf.as_slice_mut() + .expect("Array from iter must be contiguous."), + ); + Ok(nd) + } + + /// Allocates and creates an empty NDArray given the shape, context and dtype. + pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { + let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + let dtype: tvm_sys::ffi::DLDataType = dtype.into(); + check_call!(ffi::TVMArrayAlloc( + shape.as_ptr() as *const i64, + shape.len() as c_int, + i32::from(dtype.code) as c_int, + i32::from(dtype.bits) as c_int, + i32::from(dtype.lanes) as c_int, + ctx.device_type as c_int, + ctx.device_id as c_int, + &mut handle as *mut _, + )); + NDArray::Borrowed { handle: handle } + } +} + +macro_rules! impl_from_ndarray_rustndarray { + ($type:ty, $type_name:tt) => { + impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { + type Error = NDArrayError; + + fn try_from(nd: &NDArray) -> Result, Self::Error> { + if !nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + } + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(NDArrayError::MissingShape)?, + nd.to_vec::<$type>()?, + )?) + } + } + + impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { + type Error = NDArrayError; + + fn try_from(nd: &mut NDArray) -> Result, Self::Error> { + if !nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + }; + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(NDArrayError::MissingShape)?, + nd.to_vec::<$type>()?, + )?) + } + } + }; +} + +impl_from_ndarray_rustndarray!(i32, "int"); +impl_from_ndarray_rustndarray!(u32, "uint"); +impl_from_ndarray_rustndarray!(f32, "float"); + +impl Drop for NDArray { + fn drop(&mut self) { + if let &mut NDArray::Owned { .. } = self { + check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); + } + } +} + +mod sealed { + /// Private trait to prevent other traits from being implemeneted in downstream crates. + pub trait Sealed {} +} + +/// A trait for the supported 32-bits numerical types in frontend. +pub trait Num32: Num + sealed::Sealed { + const BITS: u8 = 32; +} + +macro_rules! impl_num32 { + ($($type:ty),+) => { + $( + impl sealed::Sealed for $type {} + impl Num32 for $type {} + )+ + }; +} + +impl_num32!(i32, u32, f32); + +#[cfg(test)] +mod tests { + // use super::*; + + // #[test] + // fn basics() { + // let shape = &mut [1, 2, 3]; + // let ctx = Context::cpu(0); + // let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + // assert_eq!(ndarray.shape().unwrap(), shape); + // assert_eq!( + // ndarray.size().unwrap(), + // shape.to_vec().into_iter().product() + // ); + // assert_eq!(ndarray.ndim(), 3); + // assert!(ndarray.strides().is_none()); + // assert_eq!(ndarray.byte_offset(), 0); + // } + + // #[test] + // fn copy() { + // let shape = &mut [4]; + // let mut data = vec![1i32, 2, 3, 4]; + // let ctx = Context::cpu(0); + // let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + // assert!(ndarray.to_vec::().is_ok()); + // ndarray.copy_from_buffer(&mut data); + // assert_eq!(ndarray.shape().unwrap(), shape); + // assert_eq!(ndarray.to_vec::().unwrap(), data); + // assert_eq!(ndarray.ndim(), 1); + // assert!(ndarray.is_contiguous().is_ok()); + // assert_eq!(ndarray.byte_offset(), 0); + // let shape = vec![4]; + // let e = NDArray::empty( + // &shape, + // Context::cpu(0), + // DataType::from_str("int32").unwrap(), + // ); + // let nd = ndarray.copy_to_ndarray(e); + // assert!(nd.is_ok()); + // assert_eq!(nd.unwrap().to_vec::().unwrap(), data); + // } + + // // #[test] + // // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + // // fn copy_wrong_dtype() { + // // let shape = vec![4]; + // // let mut data = vec![1f32, 2., 3., 4.]; + // // let ctx = Context::cpu(0); + // // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + // // nd_float.copy_from_buffer(&mut data); + // // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + // // nd_float.copy_to_ndarray(empty_int).unwrap(); + // // } + + // #[test] + // fn rust_ndarray() { + // let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) + // .unwrap() + // .into_dyn(); + // let nd = + // NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) + // .unwrap(); + // assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + // let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); + // assert!(rnd.all_close(&a, 1e-8f32)); + // } +} From b0b2b5907850a025b19709f3db5fbfdfb116b2ea Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 31 Aug 2020 13:32:01 -0700 Subject: [PATCH 02/50] Add support for loading Python packed functions --- rust/tvm/Cargo.toml | 2 ++ rust/tvm/src/lib.rs | 2 ++ rust/tvm/src/python.rs | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+) create mode 100644 rust/tvm/src/python.rs diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index ebfb5e64a4a7..c42336cd382c 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -40,6 +40,8 @@ tvm-macros = { version = "*", path = "../tvm-macros/" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" +pyo3 = { version = "0.11.1", optional = true } [features] blas = ["ndarray/blas"] +python = ["pyo3"] diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index 64252a4f9c6f..36c750328249 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -41,6 +41,8 @@ pub use tvm_rt::module; pub use tvm_rt::ndarray; pub use tvm_rt::value; pub mod ir; +#[cfg(feature = "python")] +pub mod python; pub mod runtime; pub mod transform; diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs new file mode 100644 index 000000000000..e5c5784e3f97 --- /dev/null +++ b/rust/tvm/src/python.rs @@ -0,0 +1,39 @@ +use pyo3::prelude::*; + +/// Load the Python interpreter into the address space. +/// +/// This enables the ability for Rust code to call TVM +/// functionality defined in Python. +/// +/// For example registered TVM functions can now be +/// obtained via `Function::get`. +pub fn load() -> Result { + let gil = Python::acquire_gil(); + let py = gil.python(); + load_python_tvm_(py).map_err(|e| { + // We can't display Python exceptions via std::fmt::Display, + // so print the error here manually. + e.print_and_set_sys_last_vars(py); + }) +} + +// const TVMC_CODE: &'static str = include_str!("tvmc.py"); + +fn load_python_tvm_(py: Python) -> PyResult { + let sys = py.import("tvm")?; + let version: String = sys.get("__version__")?.extract()?; + // py.run(TVMC_CODE, None, None)?; + Ok(version) +} + +#[cfg(test)] +mod tests { + use super::load_python_tvm; + use anyhow::Result; + + #[test] + fn test_run() -> Result<()> { + load_python_tvm().unwrap(); + Ok(()) + } +} From d282ef56424acd6e4be025d3dac9de1860227209 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 6 Sep 2020 23:10:01 -0700 Subject: [PATCH 03/50] Flesh out Relay AST in Rust --- include/tvm/relay/adt.h | 1 + rust/tvm-rt/src/array.rs | 21 + rust/tvm-rt/src/ndarray2.rs | 880 +++++++++++++++++------------------ rust/tvm/Cargo.toml | 2 + rust/tvm/src/ir/arith.rs | 2 +- rust/tvm/src/ir/expr.rs | 81 ++++ rust/tvm/src/ir/function.rs | 27 ++ rust/tvm/src/ir/mod.rs | 58 +-- rust/tvm/src/ir/module.rs | 49 ++ rust/tvm/src/ir/relay/mod.rs | 432 +++++++++++++---- rust/tvm/src/ir/tir.rs | 4 +- rust/tvm/src/ir/ty.rs | 3 + 12 files changed, 970 insertions(+), 590 deletions(-) create mode 100644 rust/tvm/src/ir/expr.rs create mode 100644 rust/tvm/src/ir/function.rs create mode 100644 rust/tvm/src/ir/module.rs create mode 100644 rust/tvm/src/ir/ty.rs diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 37182abb2681..b5dcab5e0bfc 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -190,6 +190,7 @@ class PatternTuple; /*! \brief PatternVar container node */ class PatternTupleNode : public PatternNode { public: + /* TODO(@jroesch): rename to field_pats */ /*! Sub-patterns to match against each value of the tuple. */ tvm::Array patterns; diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index d2c82fce0b33..213d7ee8a9c2 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -117,3 +117,24 @@ impl<'a, T: IsObjectRef> TryFrom for Array { }) } } + +#[cfg(test)] +mod tests { + use super::Array; + use crate::string::String; + use crate::function::Result; + + #[test] + fn create_array_and_get() -> Result<()> { + let vec: Vec = vec![ + "foo".into(), + "bar".into(), + "baz".into(), + ]; + let array = Array::from_vec(vec)?; + assert_eq!(array.get(0)?.to_string(), "foo"); + assert_eq!(array.get(1)?.to_string(), "bar"); + assert_eq!(array.get(1)?.to_string(), "baz"); + Ok(()) + } +} diff --git a/rust/tvm-rt/src/ndarray2.rs b/rust/tvm-rt/src/ndarray2.rs index d4b965b0fea8..c7dfb2c79491 100644 --- a/rust/tvm-rt/src/ndarray2.rs +++ b/rust/tvm-rt/src/ndarray2.rs @@ -1,440 +1,440 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module implements the [`NDArray`] type for working with *TVM tensors* or -//! coverting from a Rust's ndarray to TVM `NDArray`. -//! -//! One can create an empty NDArray given the shape, context and dtype using [`empty`]. -//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. -//! To copy an NDArray to different context use [`copy_to_ctx`]. -//! -//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: -//! -//! # Example -//! -//! ``` -//! # use tvm_rt::{NDArray, Context, DataType}; -//! # use ndarray::{Array, ArrayD}; -//! # use std::str::FromStr; -//! use std::convert::TryFrom; -//! -//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) -//! .unwrap() -//! .into_dyn(); // Rust's ndarray -//! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); -//! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); -//! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); -//! assert!(rnd.all_close(&a, 1e-8f32)); -//! ``` -//! -//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ -//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer -//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx - -use std::convert::TryInto; -use std::ffi::c_void; -use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; - -use tvm_sys::ffi::DLTensor; -use tvm_sys::{ffi, ByteArray, Context, DataType}; -use tvm_macros::Object; - -use ndarray::{Array, ArrayD}; -use num_traits::Num; - -use crate::object::{Object, ObjectPtr} - -/// See the [`module-level documentation`](../ndarray/index.html) for more details. -#[repr(C)] -#[derive(Object)] -#[ref_name = "NDArray"] -#[type_key = "runtime.NDArray"] -pub struct NDArrayContainer { - base: Object, - dl_tensor: *mut DLTensor, - manager_ctx: *mut c_void, -} - - -impl NDArray { - pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Self { - let object: *mut NDArrayContainer = unsafe { std::mem::transmute(handle) }; - let object_ptr = ObjectPtr::from_raw(object); - NDArray(Some(object_ptr)) - } - - pub fn as_dltensor(&self) -> &DLTensor { - let ptr: *mut DLTensor = match self { - NDArray::Borrowed { ref handle } => *handle, - NDArray::Owned { ref handle } => *handle as *mut DLTensor, - }; - - unsafe { std::mem::transmute(ptr) } - } - - pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { - match self { - NDArray::Borrowed { handle } => *handle, - NDArray::Owned { handle } => *handle as *mut DLTensor, - } - } - - pub fn is_view(&self) -> bool { - if let &NDArray::Borrowed { .. } = self { - true - } else { - false - } - } - - /// Returns the shape of the NDArray. - pub fn shape(&self) -> Option<&mut [usize]> { - let arr = self.as_dltensor(); - if arr.shape.is_null() || arr.data.is_null() { - return None; - }; - let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; - Some(slc) - } - - /// Returns the total number of entries of the NDArray. - pub fn size(&self) -> Option { - self.shape().map(|v| v.iter().product()) - } - - /// Returns the context which the NDArray was defined. - pub fn ctx(&self) -> Context { - self.as_dltensor().ctx.into() - } - - /// Returns the type of the entries of the NDArray. - pub fn dtype(&self) -> DataType { - self.as_dltensor().dtype.into() - } - - /// Returns the number of dimensions of the NDArray. - pub fn ndim(&self) -> usize { - self.as_dltensor() - .ndim - .try_into() - .expect("number of dimensions must always be positive") - } - - /// Returns the strides of the underlying NDArray. - pub fn strides(&self) -> Option<&[usize]> { - unsafe { - let sz = self.ndim() * mem::size_of::(); - let strides_ptr = self.as_dltensor().strides as *const usize; - let slc = slice::from_raw_parts(strides_ptr, sz); - Some(slc) - } - } - - /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> Result { - Ok(match self.strides() { - None => true, - Some(strides) => { - // NDArrayError::MissingShape in case shape is not determined - self.shape() - .ok_or(NDArrayError::MissingShape)? - .iter() - .zip(strides) - .rfold( - (true, 1), - |(is_contig, expected_stride), (shape, stride)| { - ( - is_contig && *stride == expected_stride, - expected_stride * (*shape as usize), - ) - }, - ) - .0 - } - }) - } - - pub fn byte_offset(&self) -> isize { - self.as_dltensor().byte_offset as isize - } - - /// Flattens the NDArray to a `Vec` of the same type in cpu. - /// - /// ## Example - /// - /// ``` - /// # use tvm_rt::{Context, DataType, NDArray}; - /// # use std::str::FromStr; - /// let mut shape = [4]; - /// let mut data = vec![1i32, 2, 3, 4]; - /// let ctx = Context::cpu(0); - /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); - /// ndarray.copy_from_buffer(&mut data); - /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); - /// assert_eq!(ndarray.to_vec::().unwrap(), data); - /// ``` - pub fn to_vec(&self) -> Result, NDArrayError> { - if !self.shape().is_some() { - return Err(NDArrayError::EmptyArray); - } - let earr = NDArray::empty( - self.shape().ok_or(NDArrayError::MissingShape)?, - Context::cpu(0), - self.dtype(), - ); - let target = self.copy_to_ndarray(earr)?; - let arr = target.as_dltensor(); - let sz = self.size().ok_or(NDArrayError::MissingShape)?; - let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); - unsafe { - v.as_mut_ptr() - .copy_from_nonoverlapping(arr.data as *const T, sz); - v.set_len(sz); - } - Ok(v) - } - - /// Converts the NDArray to [`ByteArray`]. - pub fn to_bytearray(&self) -> Result { - let v = self.to_vec::()?; - Ok(ByteArray::from(v)) - } - - /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. - /// - /// ## Example - /// - /// ``` - /// # use tvm_rt::{Context, DataType, NDArray}; - /// # use std::str::FromStr; - /// let shape = &mut [2]; - /// let mut data = vec![1f32, 2.0]; - /// let ctx = Context::cpu(0); - /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - /// ndarray.copy_from_buffer(&mut data); - /// ``` - /// - /// *Note*: if something goes wrong during the copy, it will panic - /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. - pub fn copy_from_buffer(&mut self, data: &mut [T]) { - check_call!(ffi::TVMArrayCopyFromBytes( - self.as_raw_dltensor(), - data.as_ptr() as *mut _, - data.len() * mem::size_of::() - )); - } - - /// Copies the NDArray to another target NDArray. - pub fn copy_to_ndarray(&self, target: NDArray) -> Result { - if self.dtype() != target.dtype() { - return Err(NDArrayError::DataTypeMismatch { - expected: self.dtype(), - actual: target.dtype(), - }); - } - - check_call!(ffi::TVMArrayCopyFromTo( - self.as_raw_dltensor(), - target.as_raw_dltensor(), - ptr::null_mut() as ffi::TVMStreamHandle - )); - - Ok(target) - } - - /// Copies the NDArray to a target context. - pub fn copy_to_ctx(&self, target: &Context) -> Result { - let tmp = NDArray::empty( - self.shape().ok_or(NDArrayError::MissingShape)?, - *target, - self.dtype(), - ); - let copy = self.copy_to_ndarray(tmp)?; - Ok(copy) - } - - /// Converts a Rust's ndarray to TVM NDArray. - pub fn from_rust_ndarray( - rnd: &ArrayD, - ctx: Context, - dtype: DataType, - ) -> Result { - let shape = rnd.shape().to_vec(); - let mut nd = NDArray::empty(&shape, ctx, dtype); - let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); - nd.copy_from_buffer( - buf.as_slice_mut() - .expect("Array from iter must be contiguous."), - ); - Ok(nd) - } - - /// Allocates and creates an empty NDArray given the shape, context and dtype. - pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { - let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; - let dtype: tvm_sys::ffi::DLDataType = dtype.into(); - check_call!(ffi::TVMArrayAlloc( - shape.as_ptr() as *const i64, - shape.len() as c_int, - i32::from(dtype.code) as c_int, - i32::from(dtype.bits) as c_int, - i32::from(dtype.lanes) as c_int, - ctx.device_type as c_int, - ctx.device_id as c_int, - &mut handle as *mut _, - )); - NDArray::Borrowed { handle: handle } - } -} - -macro_rules! impl_from_ndarray_rustndarray { - ($type:ty, $type_name:tt) => { - impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { - type Error = NDArrayError; - - fn try_from(nd: &NDArray) -> Result, Self::Error> { - if !nd.shape().is_some() { - return Err(NDArrayError::MissingShape); - } - assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec( - &*nd.shape().ok_or(NDArrayError::MissingShape)?, - nd.to_vec::<$type>()?, - )?) - } - } - - impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { - type Error = NDArrayError; - - fn try_from(nd: &mut NDArray) -> Result, Self::Error> { - if !nd.shape().is_some() { - return Err(NDArrayError::MissingShape); - }; - assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec( - &*nd.shape().ok_or(NDArrayError::MissingShape)?, - nd.to_vec::<$type>()?, - )?) - } - } - }; -} - -impl_from_ndarray_rustndarray!(i32, "int"); -impl_from_ndarray_rustndarray!(u32, "uint"); -impl_from_ndarray_rustndarray!(f32, "float"); - -impl Drop for NDArray { - fn drop(&mut self) { - if let &mut NDArray::Owned { .. } = self { - check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); - } - } -} - -mod sealed { - /// Private trait to prevent other traits from being implemeneted in downstream crates. - pub trait Sealed {} -} - -/// A trait for the supported 32-bits numerical types in frontend. -pub trait Num32: Num + sealed::Sealed { - const BITS: u8 = 32; -} - -macro_rules! impl_num32 { - ($($type:ty),+) => { - $( - impl sealed::Sealed for $type {} - impl Num32 for $type {} - )+ - }; -} - -impl_num32!(i32, u32, f32); - -#[cfg(test)] -mod tests { - // use super::*; - - // #[test] - // fn basics() { - // let shape = &mut [1, 2, 3]; - // let ctx = Context::cpu(0); - // let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - // assert_eq!(ndarray.shape().unwrap(), shape); - // assert_eq!( - // ndarray.size().unwrap(), - // shape.to_vec().into_iter().product() - // ); - // assert_eq!(ndarray.ndim(), 3); - // assert!(ndarray.strides().is_none()); - // assert_eq!(ndarray.byte_offset(), 0); - // } - - // #[test] - // fn copy() { - // let shape = &mut [4]; - // let mut data = vec![1i32, 2, 3, 4]; - // let ctx = Context::cpu(0); - // let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - // assert!(ndarray.to_vec::().is_ok()); - // ndarray.copy_from_buffer(&mut data); - // assert_eq!(ndarray.shape().unwrap(), shape); - // assert_eq!(ndarray.to_vec::().unwrap(), data); - // assert_eq!(ndarray.ndim(), 1); - // assert!(ndarray.is_contiguous().is_ok()); - // assert_eq!(ndarray.byte_offset(), 0); - // let shape = vec![4]; - // let e = NDArray::empty( - // &shape, - // Context::cpu(0), - // DataType::from_str("int32").unwrap(), - // ); - // let nd = ndarray.copy_to_ndarray(e); - // assert!(nd.is_ok()); - // assert_eq!(nd.unwrap().to_vec::().unwrap(), data); - // } - - // // #[test] - // // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] - // // fn copy_wrong_dtype() { - // // let shape = vec![4]; - // // let mut data = vec![1f32, 2., 3., 4.]; - // // let ctx = Context::cpu(0); - // // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); - // // nd_float.copy_from_buffer(&mut data); - // // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); - // // nd_float.copy_to_ndarray(empty_int).unwrap(); - // // } - - // #[test] - // fn rust_ndarray() { - // let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) - // .unwrap() - // .into_dyn(); - // let nd = - // NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) - // .unwrap(); - // assert_eq!(nd.shape().unwrap(), &mut [2, 2]); - // let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); - // assert!(rnd.all_close(&a, 1e-8f32)); - // } -} +// /* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, +// * software distributed under the License is distributed on an +// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// * KIND, either express or implied. See the License for the +// * specific language governing permissions and limitations +// * under the License. +// */ + +// //! This module implements the [`NDArray`] type for working with *TVM tensors* or +// //! coverting from a Rust's ndarray to TVM `NDArray`. +// //! +// //! One can create an empty NDArray given the shape, context and dtype using [`empty`]. +// //! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. +// //! To copy an NDArray to different context use [`copy_to_ctx`]. +// //! +// //! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: +// //! +// //! # Example +// //! +// //! ``` +// //! # use tvm_rt::{NDArray, Context, DataType}; +// //! # use ndarray::{Array, ArrayD}; +// //! # use std::str::FromStr; +// //! use std::convert::TryFrom; +// //! +// //! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) +// //! .unwrap() +// //! .into_dyn(); // Rust's ndarray +// //! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); +// //! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); +// //! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); +// //! assert!(rnd.all_close(&a, 1e-8f32)); +// //! ``` +// //! +// //! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ +// //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer +// //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx + +// use std::convert::TryInto; +// use std::ffi::c_void; +// use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; + +// use tvm_sys::ffi::DLTensor; +// use tvm_sys::{ffi, ByteArray, Context, DataType}; +// use tvm_macros::Object; + +// use ndarray::{Array, ArrayD}; +// use num_traits::Num; + +// use crate::object::{Object, ObjectPtr} + +// /// See the [`module-level documentation`](../ndarray/index.html) for more details. +// #[repr(C)] +// #[derive(Object)] +// #[ref_name = "NDArray"] +// #[type_key = "runtime.NDArray"] +// pub struct NDArrayContainer { +// base: Object, +// dl_tensor: *mut DLTensor, +// manager_ctx: *mut c_void, +// } + + +// impl NDArray { +// pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Self { +// let object: *mut NDArrayContainer = unsafe { std::mem::transmute(handle) }; +// let object_ptr = ObjectPtr::from_raw(object); +// NDArray(Some(object_ptr)) +// } + +// pub fn as_dltensor(&self) -> &DLTensor { +// let ptr: *mut DLTensor = match self { +// NDArray::Borrowed { ref handle } => *handle, +// NDArray::Owned { ref handle } => *handle as *mut DLTensor, +// }; + +// unsafe { std::mem::transmute(ptr) } +// } + +// pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { +// match self { +// NDArray::Borrowed { handle } => *handle, +// NDArray::Owned { handle } => *handle as *mut DLTensor, +// } +// } + +// pub fn is_view(&self) -> bool { +// if let &NDArray::Borrowed { .. } = self { +// true +// } else { +// false +// } +// } + +// /// Returns the shape of the NDArray. +// pub fn shape(&self) -> Option<&mut [usize]> { +// let arr = self.as_dltensor(); +// if arr.shape.is_null() || arr.data.is_null() { +// return None; +// }; +// let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; +// Some(slc) +// } + +// /// Returns the total number of entries of the NDArray. +// pub fn size(&self) -> Option { +// self.shape().map(|v| v.iter().product()) +// } + +// /// Returns the context which the NDArray was defined. +// pub fn ctx(&self) -> Context { +// self.as_dltensor().ctx.into() +// } + +// /// Returns the type of the entries of the NDArray. +// pub fn dtype(&self) -> DataType { +// self.as_dltensor().dtype.into() +// } + +// /// Returns the number of dimensions of the NDArray. +// pub fn ndim(&self) -> usize { +// self.as_dltensor() +// .ndim +// .try_into() +// .expect("number of dimensions must always be positive") +// } + +// /// Returns the strides of the underlying NDArray. +// pub fn strides(&self) -> Option<&[usize]> { +// unsafe { +// let sz = self.ndim() * mem::size_of::(); +// let strides_ptr = self.as_dltensor().strides as *const usize; +// let slc = slice::from_raw_parts(strides_ptr, sz); +// Some(slc) +// } +// } + +// /// Shows whether the underlying ndarray is contiguous in memory or not. +// pub fn is_contiguous(&self) -> Result { +// Ok(match self.strides() { +// None => true, +// Some(strides) => { +// // NDArrayError::MissingShape in case shape is not determined +// self.shape() +// .ok_or(NDArrayError::MissingShape)? +// .iter() +// .zip(strides) +// .rfold( +// (true, 1), +// |(is_contig, expected_stride), (shape, stride)| { +// ( +// is_contig && *stride == expected_stride, +// expected_stride * (*shape as usize), +// ) +// }, +// ) +// .0 +// } +// }) +// } + +// pub fn byte_offset(&self) -> isize { +// self.as_dltensor().byte_offset as isize +// } + +// /// Flattens the NDArray to a `Vec` of the same type in cpu. +// /// +// /// ## Example +// /// +// /// ``` +// /// # use tvm_rt::{Context, DataType, NDArray}; +// /// # use std::str::FromStr; +// /// let mut shape = [4]; +// /// let mut data = vec![1i32, 2, 3, 4]; +// /// let ctx = Context::cpu(0); +// /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); +// /// ndarray.copy_from_buffer(&mut data); +// /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); +// /// assert_eq!(ndarray.to_vec::().unwrap(), data); +// /// ``` +// pub fn to_vec(&self) -> Result, NDArrayError> { +// if !self.shape().is_some() { +// return Err(NDArrayError::EmptyArray); +// } +// let earr = NDArray::empty( +// self.shape().ok_or(NDArrayError::MissingShape)?, +// Context::cpu(0), +// self.dtype(), +// ); +// let target = self.copy_to_ndarray(earr)?; +// let arr = target.as_dltensor(); +// let sz = self.size().ok_or(NDArrayError::MissingShape)?; +// let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); +// unsafe { +// v.as_mut_ptr() +// .copy_from_nonoverlapping(arr.data as *const T, sz); +// v.set_len(sz); +// } +// Ok(v) +// } + +// /// Converts the NDArray to [`ByteArray`]. +// pub fn to_bytearray(&self) -> Result { +// let v = self.to_vec::()?; +// Ok(ByteArray::from(v)) +// } + +// /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. +// /// +// /// ## Example +// /// +// /// ``` +// /// # use tvm_rt::{Context, DataType, NDArray}; +// /// # use std::str::FromStr; +// /// let shape = &mut [2]; +// /// let mut data = vec![1f32, 2.0]; +// /// let ctx = Context::cpu(0); +// /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); +// /// ndarray.copy_from_buffer(&mut data); +// /// ``` +// /// +// /// *Note*: if something goes wrong during the copy, it will panic +// /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. +// pub fn copy_from_buffer(&mut self, data: &mut [T]) { +// check_call!(ffi::TVMArrayCopyFromBytes( +// self.as_raw_dltensor(), +// data.as_ptr() as *mut _, +// data.len() * mem::size_of::() +// )); +// } + +// /// Copies the NDArray to another target NDArray. +// pub fn copy_to_ndarray(&self, target: NDArray) -> Result { +// if self.dtype() != target.dtype() { +// return Err(NDArrayError::DataTypeMismatch { +// expected: self.dtype(), +// actual: target.dtype(), +// }); +// } + +// check_call!(ffi::TVMArrayCopyFromTo( +// self.as_raw_dltensor(), +// target.as_raw_dltensor(), +// ptr::null_mut() as ffi::TVMStreamHandle +// )); + +// Ok(target) +// } + +// /// Copies the NDArray to a target context. +// pub fn copy_to_ctx(&self, target: &Context) -> Result { +// let tmp = NDArray::empty( +// self.shape().ok_or(NDArrayError::MissingShape)?, +// *target, +// self.dtype(), +// ); +// let copy = self.copy_to_ndarray(tmp)?; +// Ok(copy) +// } + +// /// Converts a Rust's ndarray to TVM NDArray. +// pub fn from_rust_ndarray( +// rnd: &ArrayD, +// ctx: Context, +// dtype: DataType, +// ) -> Result { +// let shape = rnd.shape().to_vec(); +// let mut nd = NDArray::empty(&shape, ctx, dtype); +// let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); +// nd.copy_from_buffer( +// buf.as_slice_mut() +// .expect("Array from iter must be contiguous."), +// ); +// Ok(nd) +// } + +// /// Allocates and creates an empty NDArray given the shape, context and dtype. +// pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { +// let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; +// let dtype: tvm_sys::ffi::DLDataType = dtype.into(); +// check_call!(ffi::TVMArrayAlloc( +// shape.as_ptr() as *const i64, +// shape.len() as c_int, +// i32::from(dtype.code) as c_int, +// i32::from(dtype.bits) as c_int, +// i32::from(dtype.lanes) as c_int, +// ctx.device_type as c_int, +// ctx.device_id as c_int, +// &mut handle as *mut _, +// )); +// NDArray::Borrowed { handle: handle } +// } +// } + +// macro_rules! impl_from_ndarray_rustndarray { +// ($type:ty, $type_name:tt) => { +// impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { +// type Error = NDArrayError; + +// fn try_from(nd: &NDArray) -> Result, Self::Error> { +// if !nd.shape().is_some() { +// return Err(NDArrayError::MissingShape); +// } +// assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); +// Ok(Array::from_shape_vec( +// &*nd.shape().ok_or(NDArrayError::MissingShape)?, +// nd.to_vec::<$type>()?, +// )?) +// } +// } + +// impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { +// type Error = NDArrayError; + +// fn try_from(nd: &mut NDArray) -> Result, Self::Error> { +// if !nd.shape().is_some() { +// return Err(NDArrayError::MissingShape); +// }; +// assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); +// Ok(Array::from_shape_vec( +// &*nd.shape().ok_or(NDArrayError::MissingShape)?, +// nd.to_vec::<$type>()?, +// )?) +// } +// } +// }; +// } + +// impl_from_ndarray_rustndarray!(i32, "int"); +// impl_from_ndarray_rustndarray!(u32, "uint"); +// impl_from_ndarray_rustndarray!(f32, "float"); + +// impl Drop for NDArray { +// fn drop(&mut self) { +// if let &mut NDArray::Owned { .. } = self { +// check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); +// } +// } +// } + +// mod sealed { +// /// Private trait to prevent other traits from being implemeneted in downstream crates. +// pub trait Sealed {} +// } + +// /// A trait for the supported 32-bits numerical types in frontend. +// pub trait Num32: Num + sealed::Sealed { +// const BITS: u8 = 32; +// } + +// macro_rules! impl_num32 { +// ($($type:ty),+) => { +// $( +// impl sealed::Sealed for $type {} +// impl Num32 for $type {} +// )+ +// }; +// } + +// impl_num32!(i32, u32, f32); + +// #[cfg(test)] +// mod tests { +// // use super::*; + +// // #[test] +// // fn basics() { +// // let shape = &mut [1, 2, 3]; +// // let ctx = Context::cpu(0); +// // let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); +// // assert_eq!(ndarray.shape().unwrap(), shape); +// // assert_eq!( +// // ndarray.size().unwrap(), +// // shape.to_vec().into_iter().product() +// // ); +// // assert_eq!(ndarray.ndim(), 3); +// // assert!(ndarray.strides().is_none()); +// // assert_eq!(ndarray.byte_offset(), 0); +// // } + +// // #[test] +// // fn copy() { +// // let shape = &mut [4]; +// // let mut data = vec![1i32, 2, 3, 4]; +// // let ctx = Context::cpu(0); +// // let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); +// // assert!(ndarray.to_vec::().is_ok()); +// // ndarray.copy_from_buffer(&mut data); +// // assert_eq!(ndarray.shape().unwrap(), shape); +// // assert_eq!(ndarray.to_vec::().unwrap(), data); +// // assert_eq!(ndarray.ndim(), 1); +// // assert!(ndarray.is_contiguous().is_ok()); +// // assert_eq!(ndarray.byte_offset(), 0); +// // let shape = vec![4]; +// // let e = NDArray::empty( +// // &shape, +// // Context::cpu(0), +// // DataType::from_str("int32").unwrap(), +// // ); +// // let nd = ndarray.copy_to_ndarray(e); +// // assert!(nd.is_ok()); +// // assert_eq!(nd.unwrap().to_vec::().unwrap(), data); +// // } + +// // // #[test] +// // // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] +// // // fn copy_wrong_dtype() { +// // // let shape = vec![4]; +// // // let mut data = vec![1f32, 2., 3., 4.]; +// // // let ctx = Context::cpu(0); +// // // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); +// // // nd_float.copy_from_buffer(&mut data); +// // // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); +// // // nd_float.copy_to_ndarray(empty_int).unwrap(); +// // // } + +// // #[test] +// // fn rust_ndarray() { +// // let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) +// // .unwrap() +// // .into_dyn(); +// // let nd = +// // NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) +// // .unwrap(); +// // assert_eq!(nd.shape().unwrap(), &mut [2, 2]); +// // let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); +// // assert!(rnd.all_close(&a, 1e-8f32)); +// // } +// } diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index c42336cd382c..55fc1790604e 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -43,5 +43,7 @@ once_cell = "^1.3.1" pyo3 = { version = "0.11.1", optional = true } [features] +default = ["python"] + blas = ["ndarray/blas"] python = ["pyo3"] diff --git a/rust/tvm/src/ir/arith.rs b/rust/tvm/src/ir/arith.rs index c2de24a299f7..f589f2ac25c6 100644 --- a/rust/tvm/src/ir/arith.rs +++ b/rust/tvm/src/ir/arith.rs @@ -19,7 +19,7 @@ use crate::runtime::{Object, ObjectPtr}; -use super::*; +use tvm_macros::Object; macro_rules! define_node { ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs new file mode 100644 index 000000000000..20e92f9f5925 --- /dev/null +++ b/rust/tvm/src/ir/expr.rs @@ -0,0 +1,81 @@ +use crate::runtime::String as TString; +use crate::runtime::{self, external, IsObject, IsObjectRef, Object, ObjectPtr, ObjectRef}; +use crate::DataType; +use super::relay; + +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseExpr"] +#[type_key = "Expr"] +pub struct BaseExprNode { + pub base: Object, +} + +impl BaseExprNode { + pub fn base() -> BaseExprNode { + BaseExprNode { + base: Object::base_object::(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PrimExpr"] +#[type_key = "PrimExpr"] +pub struct PrimExprNode { + pub base: BaseExprNode, + pub datatype: DataType, +} + +impl PrimExprNode { + pub fn base(datatype: DataType) -> PrimExprNode { + PrimExprNode { + base: BaseExprNode::base::(), + datatype, + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalVar"] +#[type_key = "GlobalVar"] +pub struct GlobalVarNode { + pub base: relay::ExprNode, + pub name_hint: TString, +} + +impl GlobalVar { + pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { + let node = GlobalVarNode { + base: relay::ExprNode::base::(), + name_hint: name_hint.into(), + }; + GlobalVar(Some(ObjectPtr::new(node))) + } +} + +// TODO(@jroesch): update to match TVM +// Move IntImm +// Define FloatImm +// Define Bool +// Define tvm::Integer? +// Define RangeNode + +// TODO: figure out how to type the last argument runtime::TypedPackedFunc annotate) +external! { + #[name("ir.AsText")] + fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString; +} + +pub fn as_text(object: T) -> String { + let no_func = unsafe { runtime::Function::null() }; + _as_text(object.upcast(), 0, no_func) + .unwrap() + .as_str() + .unwrap() + .into() +} diff --git a/rust/tvm/src/ir/function.rs b/rust/tvm/src/ir/function.rs new file mode 100644 index 000000000000..e1294f1311de --- /dev/null +++ b/rust/tvm/src/ir/function.rs @@ -0,0 +1,27 @@ +use crate::runtime::{IsObjectRef, IsObject, ObjectRef}; +use crate::ir::relay::ExprNode; + +use tvm_macros::Object; + +// Define Calling Convention. + +// TODO(@jroesch): define DictAttrs +pub type DictAttrs = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseFunc"] +#[type_key = "BaseFunc"] +pub struct BaseFuncNode { + pub base: ExprNode, + pub attrs: DictAttrs, +} + +impl BaseFuncNode { + pub fn base() -> BaseFuncNode { + BaseFuncNode { + base: ExprNode::base::(), + attrs: ::null(), + } + } +} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index b615c1ec588e..08624af6aad1 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -17,60 +17,12 @@ * under the License. */ -use crate::runtime::String as TString; -use crate::runtime::{self, external, IsObject, IsObjectRef, Object, ObjectRef}; -use crate::DataType; -use tvm_macros::Object; - pub mod arith; +pub mod expr; +pub mod function; +pub mod module; pub mod relay; pub mod tir; +pub mod ty; -// TODO: figure out how to type the last argument runtime::TypedPackedFunc annotate) -external! { - #[name("ir.AsText")] - fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString; -} - -pub fn as_text(object: T) -> String { - let no_func = unsafe { runtime::Function::null() }; - _as_text(object.upcast(), 0, no_func) - .unwrap() - .as_str() - .unwrap() - .into() -} - -#[repr(C)] -#[derive(Object)] -#[ref_name = "BaseExpr"] -#[type_key = "Expr"] -pub struct BaseExprNode { - pub base: Object, -} - -impl BaseExprNode { - fn base() -> BaseExprNode { - BaseExprNode { - base: Object::base_object::(), - } - } -} - -#[repr(C)] -#[derive(Object)] -#[ref_name = "PrimExpr"] -#[type_key = "PrimExpr"] -pub struct PrimExprNode { - pub base: BaseExprNode, - pub datatype: DataType, -} - -impl PrimExprNode { - pub fn base(datatype: DataType) -> PrimExprNode { - PrimExprNode { - base: BaseExprNode::base::(), - datatype, - } - } -} +pub use expr::*; diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs new file mode 100644 index 000000000000..ecd374f6f420 --- /dev/null +++ b/rust/tvm/src/ir/module.rs @@ -0,0 +1,49 @@ +use crate::runtime::{external, Object, ObjectRef}; +use crate::runtime::{string::String as TVMString}; +use crate::runtime::map::Map; + +use super::expr::GlobalVar; +use super::function::BaseFunc; + +use std::io::Result; +use std::path::Path; + +use tvm_macros::Object; + +// TODO(@jroesch): define type +type TypeData = ObjectRef; +type GlobalTypeVar = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "IRModule"] +#[type_key = "IRModule"] +pub struct IRModuleNode { + pub base: Object, + pub functions: Map, + pub type_definitions: Map, +} + + +external! { + #[name("parser.ParseModule")] + fn parse_module(file_name: TVMString, source: TVMString) -> IRModule; + #[name("parser.ParseExpr")] + fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; +} + +impl IRModule { + pub fn parse(file_name: N, source: S) -> IRModule + where N: Into, S: Into { + parse_module(file_name.into(), source.into()) + .expect("failed to call parser") + } + + pub fn parse_file>(file_path: P) -> Result { + let file_path = file_path.as_ref(); + let file_path_as_str = file_path.to_str().unwrap().to_string(); + let source = std::fs::read_to_string(file_path)?; + let module = IRModule::parse(file_path_as_str, source); + Ok(module) + } +} diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 4f4497ea0fce..e1f0ed483887 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -19,63 +19,30 @@ use crate::runtime::array::Array; use crate::runtime::{object::*, String as TString}; -use crate::DataType; -use tvm_macros::Object; -#[repr(C)] -#[derive(Object)] -#[ref_name = "Id"] -#[type_key = "relay.Id"] -pub struct IdNode { - pub base: Object, - pub name_hint: TString, -} - -impl Id { - fn new(name_hint: TString) -> Id { - let node = IdNode { - base: Object::base_object::(), - name_hint: name_hint, - }; - Id(Some(ObjectPtr::new(node))) - } -} +use super::expr::{BaseExprNode}; +use super::function::BaseFuncNode; +use super::ty::Type; -#[repr(C)] -#[derive(Object)] -#[ref_name = "BaseExpr"] -#[type_key = "Expr"] -pub struct BaseExprNode { - pub base: Object, -} +use tvm_macros::Object; -#[repr(C)] -pub struct PrimExprNode { - pub base: BaseExprNode, - pub datatype: DataType, -} +pub use super::expr::{GlobalVarNode, GlobalVar}; -impl BaseExprNode { - fn base() -> BaseExprNode { - BaseExprNode { - base: Object::base_object::(), - } - } -} +pub type Attrs = ObjectRef; #[repr(C)] #[derive(Object)] #[ref_name = "Expr"] #[type_key = "relay.Expr"] -pub struct RelayExpr { +pub struct ExprNode { pub base: BaseExprNode, pub span: ObjectRef, pub checked_type: ObjectRef, } -impl RelayExpr { - fn base() -> RelayExpr { - RelayExpr { +impl ExprNode { + pub fn base() -> ExprNode { + ExprNode { base: BaseExprNode::base::(), span: ObjectRef::null(), checked_type: ObjectRef::null(), @@ -83,60 +50,81 @@ impl RelayExpr { } } + #[repr(C)] #[derive(Object)] -#[ref_name = "GlobalVar"] -#[type_key = "GlobalVar"] -pub struct GlobalVarNode { - pub base: RelayExpr, +#[ref_name = "Id"] +#[type_key = "relay.Id"] +pub struct IdNode { + pub base: Object, pub name_hint: TString, } -impl GlobalVar { - pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { - let node = GlobalVarNode { - base: RelayExpr::base::(), - name_hint: name_hint.into(), +impl Id { + fn new(name_hint: TString) -> Id { + let node = IdNode { + base: Object::base_object::(), + name_hint: name_hint, }; - GlobalVar(Some(ObjectPtr::new(node))) + Id(Some(ObjectPtr::new(node))) } } + #[repr(C)] #[derive(Object)] #[ref_name = "Constant"] #[type_key = "relay.Constant"] pub struct ConstantNode { - pub base: RelayExpr, + pub base: ExprNode, pub data: ObjectRef, // make this NDArray. } impl Constant { pub fn new(data: ObjectRef, _span: ObjectRef) -> Constant { let node = ConstantNode { - base: RelayExpr::base::(), + base: ExprNode::base::(), data: data, }; Constant(Some(ObjectPtr::new(node))) } } +#[repr(C)] +#[derive(Object)] +#[ref_name = "Tuple"] +#[type_key = "relay.Tuple"] +pub struct TupleNode { + pub base: ExprNode, + pub fields: Array, +} + +impl Tuple { + pub fn new(fields: Array, _span: ObjectRef) -> Tuple { + let node = TupleNode { + base: ExprNode::base::(), + fields, + }; + Tuple(Some(ObjectPtr::new(node))) + } +} + #[repr(C)] #[derive(Object)] #[ref_name = "Var"] #[type_key = "relay.Var"] pub struct VarNode { - pub base: RelayExpr, + pub base: ExprNode, pub vid: Id, - pub type_annotation: ObjectRef, + pub type_annotation: Type, } impl Var { - pub fn new(name_hint: String, _span: ObjectRef) -> Var { + pub fn new(name_hint: String, type_annotation: Type, _span: ObjectRef) -> Var { let node = VarNode { - base: RelayExpr::base::(), + base: ExprNode::base::(), vid: Id::new(name_hint.into()), - type_annotation: ObjectRef::null(), + type_annotation, }; Var(Some(ObjectPtr::new(node))) } @@ -150,19 +138,17 @@ impl Var { } } -pub type Type = ObjectRef; -pub type Attrs = ObjectRef; #[repr(C)] #[derive(Object)] #[ref_name = "Call"] #[type_key = "relay.Call"] pub struct CallNode { - pub base: RelayExpr, + pub base: ExprNode, pub op: Expr, pub args: Array, - pub attrs: ObjectRef, - pub type_args: Array, + pub attrs: Attrs, + pub type_args: Array, } impl Call { @@ -170,11 +156,11 @@ impl Call { op: Expr, args: Array, attrs: Attrs, - type_args: Array, + type_args: Array, _span: ObjectRef, ) -> Call { let node = CallNode { - base: RelayExpr::base::(), + base: ExprNode::base::(), op: op, args: args, attrs: attrs, @@ -186,22 +172,294 @@ impl Call { #[repr(C)] #[derive(Object)] -#[ref_name = "BaseFunc"] -#[type_key = "BaseFunc"] -pub struct BaseFuncNode { - pub base: RelayExpr, - pub attrs: ObjectRef, -} - -impl BaseFuncNode { - fn base() -> BaseFuncNode { - BaseFuncNode { - base: RelayExpr::base::(), - attrs: ObjectRef::null(), +#[ref_name = "Let"] +#[type_key = "relay.Let"] +pub struct LetNode { + pub base: ExprNode, + pub var: Var, + pub value: Expr, + pub body: Expr, +} + +impl Let { + pub fn new(var: Var, value: Expr, body: Expr, _span: ObjectRef) -> Let { + let node = LetNode { + base: ExprNode::base::(), + var, + value, + body + }; + Let(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "If"] +#[type_key = "relay.If"] +pub struct IfNode { + pub base: ExprNode, + pub cond: Expr, + pub true_branch: Expr, + pub false_branch: Expr, +} + +impl If { + pub fn new(cond: Expr, true_branch: Expr, false_branch: Expr, _span: ObjectRef) -> If { + let node = IfNode { + base: ExprNode::base::(), + cond, + true_branch, + false_branch, + }; + If(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "TupleGetItem"] +#[type_key = "relay.TupleGetItem"] +pub struct TupleGetItemNode { + pub base: ExprNode, + pub tuple: Expr, + pub index: i32, +} + +impl TupleGetItem { + pub fn new(tuple: Expr, index: i32, _span: ObjectRef) -> TupleGetItem { + let node = TupleGetItemNode { + base: ExprNode::base::(), + tuple, + index, + }; + TupleGetItem(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "RefCreate"] +#[type_key = "relay.RefCreate"] +pub struct RefCreateNode { + pub base: ExprNode, + pub value: Expr, +} + +impl RefCreate { + pub fn new(value: Expr, _span: ObjectRef) -> RefCreate { + let node = RefCreateNode { + base: ExprNode::base::(), + value, + }; + RefCreate(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "RefRead"] +#[type_key = "relay.RefRead"] +pub struct RefReadNode { + pub base: ExprNode, + pub ref_value: Expr, +} + +impl RefRead { + pub fn new(ref_value: Expr, _span: ObjectRef) -> RefRead { + let node = RefReadNode { + base: ExprNode::base::(), + ref_value + }; + RefRead(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "RefWrite"] +#[type_key = "relay.RefWrite"] +pub struct RefWriteNode { + pub base: ExprNode, + pub ref_value: Expr, + pub value: Expr, +} + +impl RefWrite { + pub fn new(ref_value: Expr, value: Expr, _span: ObjectRef) -> RefWrite { + let node = RefWriteNode { + base: ExprNode::base::(), + ref_value, + value, + }; + RefWrite(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Constructor"] +#[type_key = "relay.Constructor"] +pub struct ConstructorNode { + pub base: ExprNode, + pub name_hint: String, + pub inputs: Array, + pub tag: i32, +} + +impl Constructor { + pub fn new(name_hint: String, inputs: Array, tag: i32, _span: ObjectRef) -> Constructor { + let node = ConstructorNode { + base: ExprNode::base::(), + name_hint, + inputs, + tag, + }; + Constructor(Some(ObjectPtr::new(node))) + } +} + +// TODO(@jroesch): define the type data + + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Pattern"] +#[type_key = "relay.Pattern"] +pub struct PatternNode { + pub base: Object, + pub span: ObjectRef +} + +impl PatternNode { + pub fn base() -> PatternNode { + PatternNode { + base: Object::base_object::(), + span: ObjectRef::null(), } } } +#[repr(C)] +#[derive(Object)] +#[ref_name = "PatternWildcard"] +#[type_key = "relay.PatternWildcard"] +pub struct PatternWildcardNode { + pub base: PatternNode, +} + +impl PatternWildcard { + pub fn new(_span: ObjectRef) -> PatternWildcard { + let node = PatternWildcardNode { + base: PatternNode::base::(), + }; + PatternWildcard(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PatternVar"] +#[type_key = "relay.PatternVar"] +pub struct PatternVarNode { + pub base: PatternNode, + pub var: Var, +} + +impl PatternVar { + pub fn new(var: Var, _span: ObjectRef) -> PatternVar { + let node = PatternVarNode { + base: PatternNode::base::(), + var: var, + }; + PatternVar(Some(ObjectPtr::new(node))) + } +} + + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PatternConstructor"] +#[type_key = "relay.PatternConstructor"] +pub struct PatternConstructorNode { + pub base: PatternNode, + pub constructor: Constructor, + pub patterns: Array, +} + +impl PatternConstructor { + pub fn new(constructor: Constructor, patterns: Array, _span: ObjectRef) -> PatternConstructor { + let node = PatternConstructorNode { + base: PatternNode::base::(), + constructor, + patterns, + }; + PatternConstructor(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PatternTuple"] +#[type_key = "relay.PatternTuple"] +pub struct PatternTupleNode { + pub base: PatternNode, + pub patterns: Array, +} + +impl PatternTuple { + pub fn new(patterns: Array, _span: ObjectRef) -> PatternTuple { + let node = PatternTupleNode { + base: PatternNode::base::(), + patterns, + }; + PatternTuple(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Clause"] +#[type_key = "relay.Clause"] +pub struct ClauseNode { + pub base: Object, + pub lhs: Pattern, + pub rhs: Expr, +} + +impl Clause { + pub fn new(lhs: Pattern, rhs: Expr, _span: ObjectRef) -> Clause { + let node = ClauseNode { + base: Object::base_object::(), + lhs, rhs, + }; + Clause(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Match"] +#[type_key = "relay.Match"] +pub struct MatchNode { + pub base: ExprNode, + pub data: Expr, + pub clauses: Array, + pub complete: bool, +} + +impl Match { + pub fn new(data: Expr, clauses: Array, complete: bool, _span: ObjectRef) -> Match { + let node = MatchNode { + base: ExprNode::base::(), + data, + clauses, + complete, + }; + Match(Some(ObjectPtr::new(node))) + } +} + #[repr(C)] #[derive(Object)] #[ref_name = "Function"] @@ -263,20 +521,4 @@ mod tests { assert!(text.contains("%local")); Ok(()) } - - use super::Array; - use crate::ir::relay::Var; - use crate::runtime::object::ObjectRef; - - #[test] - fn create_array_and_get() -> Result<()> { - let vec = vec![ - Var::new("foo".into(), ObjectRef::null()), - Var::new("bar".into(), ObjectRef::null()), - ]; - let array = Array::from_vec(vec)?; - assert_eq!(array.get(0)?.name_hint().to_string(), "foo"); - assert_eq!(array.get(1)?.name_hint().to_string(), "bar"); - Ok(()) - } } diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index ee30c513e9f0..a19f0cbc7869 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -19,8 +19,9 @@ use crate::runtime::String as TVMString; use crate::DataType; +use super::{PrimExprNode, PrimExpr}; -use super::*; +use tvm_macros::Object; macro_rules! define_node { ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { @@ -43,6 +44,7 @@ macro_rules! define_node { } } +// TODO(@jroesch): should move up to expr.rs to mirror TVM. define_node!(IntImm, "IntImm", "IntImm"; IntImmNode { value: i64 }); define_node!(Var, "Var", "tir.Var"; diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs new file mode 100644 index 000000000000..c73f1d586edd --- /dev/null +++ b/rust/tvm/src/ir/ty.rs @@ -0,0 +1,3 @@ +use crate::runtime::ObjectRef; + +pub type Type = ObjectRef; From a80a9e425694941e56fb2b7b431f3cb50e2c891a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 7 Sep 2020 00:48:47 -0700 Subject: [PATCH 04/50] More tweeks for getting functions out --- rust/tvm-sys/src/packed_func.rs | 15 ++++++ rust/tvm/src/ir/mod.rs | 1 + rust/tvm/src/ir/module.rs | 87 ++++++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 2 deletions(-) diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 358853951fda..f7b289c59675 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -409,3 +409,18 @@ impl<'a> TryFrom> for bool { try_downcast!(val -> bool, |ArgValue::Int(val)| { !(val == 0) }) } } + +impl From<()> for RetValue { + fn from(_: ()) -> Self { + RetValue::Null + } +} + +impl TryFrom for () { + type Error = ValueDowncastError; + + fn try_from(val: RetValue) -> Result<(), Self::Error> { + try_downcast!(val -> bool, + |RetValue::Null| { () }) + } +} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 08624af6aad1..0620a10ab1ad 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -26,3 +26,4 @@ pub mod tir; pub mod ty; pub use expr::*; +pub use module::IRModule; diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index ecd374f6f420..eee828f7632a 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -1,11 +1,13 @@ use crate::runtime::{external, Object, ObjectRef}; use crate::runtime::{string::String as TVMString}; +use crate::runtime::function::Result; +use crate::runtime::array::Array; use crate::runtime::map::Map; use super::expr::GlobalVar; use super::function::BaseFunc; -use std::io::Result; +use std::io::{Result as IOResult}; use std::path::Path; use tvm_macros::Object; @@ -26,12 +28,72 @@ pub struct IRModuleNode { external! { + // Parser functions #[name("parser.ParseModule")] fn parse_module(file_name: TVMString, source: TVMString) -> IRModule; #[name("parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; + // Module methods + #[name("ir.Module_AddDef")] + fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> (); + #[name("ir.Module_GetGlobalVar")] + fn module_get_global_var(module: IRModule, name: TVMString) -> GlobalVar; + #[name("ir.Module_GetGlobalVars")] + fn module_get_global_vars(module: IRModule) -> Array; + #[name("ir.Module_Lookup")] + fn module_lookup(module: IRModule, var: GlobalVar) -> BaseFunc; + #[name("ir.Module_Lookup_str")] + fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc; } +// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") +// .set_body_method(&IRModuleNode::GetGlobalTypeVars); + +// TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") +// .set_body_method(&IRModuleNode::ContainGlobalVar); + +// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") +// .set_body_method(&IRModuleNode::GetGlobalTypeVar); + +// TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) { +// return mod->LookupTypeDef(var); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) { +// return mod->LookupTypeDef(var); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) { +// return mod->LookupTag(tag); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_FromExpr") +// .set_body_typed([](RelayExpr e, tvm::Map funcs, +// tvm::Map type_defs) { +// return IRModule::FromExpr(e, funcs, type_defs); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { +// mod->Update(from); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") +// .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); + +// TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { +// mod->Import(path); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) { +// mod->ImportFromStd(path); +// }); + +// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { +// auto* node = static_cast(ref.get()); +// p->stream << "IRModuleNode( " << node->functions << ")"; +// }); + impl IRModule { pub fn parse(file_name: N, source: S) -> IRModule where N: Into, S: Into { @@ -39,11 +101,32 @@ impl IRModule { .expect("failed to call parser") } - pub fn parse_file>(file_path: P) -> Result { + pub fn parse_file>(file_path: P) -> IOResult { let file_path = file_path.as_ref(); let file_path_as_str = file_path.to_str().unwrap().to_string(); let source = std::fs::read_to_string(file_path)?; let module = IRModule::parse(file_path_as_str, source); Ok(module) } + + pub fn add_def(&mut self, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> Result<()> { + module_add_def(self.clone(), type_name, type_data, update) + } + + pub fn get_global_var(&self, name: TVMString) -> Result { + module_get_global_var(self.clone(), name) + } + + pub fn get_global_vars(&self) -> Result> { + module_get_global_vars(self.clone()) + } + + pub fn lookup(&self, var: GlobalVar) -> Result { + module_lookup(self.clone(), var) + } + + pub fn lookup_str(&self, name: S) -> Result + where S: Into { + module_lookup_str(self.clone(), name.into()) + } } From 5f9b5216355066baa242134d790dcb1e9f90579d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 10 Sep 2020 00:10:08 -0700 Subject: [PATCH 05/50] Deploy Rust docs as part of build --- tests/scripts/task_python_docs.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 29166c627663..71bb92250a00 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -68,6 +68,11 @@ npm install npm run typedoc cd .. +# Rust doc +cd rust +cargo doc --workspace --no-deps +cd .. + # Prepare the doc dir rm -rf _docs mv docs/_build/html _docs @@ -75,6 +80,7 @@ rm -f _docs/.buildinfo mkdir -p _docs/api mv docs/doxygen/html _docs/api/doxygen mv jvm/core/target/site/apidocs _docs/api/javadoc +mv rust/target/doc _docs/api/rust mv web/dist/docs _docs/api/typedoc echo "Start creating the docs tarball.." From 3343f61676a8f17a1a1b01c27d903f3c19c76a94 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Mon, 21 Sep 2020 12:03:28 -0700 Subject: [PATCH 06/50] Add some more types --- rust/tvm/src/ir/ty.rs | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index c73f1d586edd..a323d71aede0 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -1,3 +1,32 @@ -use crate::runtime::ObjectRef; +use tvm_macros::Object; +use tvm_rt::{array::Array, DataType}; +use crate::runtime::{ObjectRef, Object}; -pub type Type = ObjectRef; +use super::PrimExpr; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Type"] +#[type_key = "Type"] +pub struct TypeNode { + pub base: Object, + pub span: ObjectRef, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseTensorType"] +#[type_key = "relay.BaseTensorType"] +pub struct BaseTensorTypeNode { + pub base: TypeNode, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "TensorType"] +#[type_key = "relay.TensorType"] +pub struct TensorTypeNode { + pub base: TypeNode, + pub shape: Array, + pub dtype: DataType, +} From 772375f0807ca25aa72f216efd80f948b888c10c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 11 Sep 2020 02:55:01 -0700 Subject: [PATCH 07/50] Introduce NDArray 2.0 --- rust/tvm-rt/src/lib.rs | 1 - rust/tvm-rt/src/ndarray.rs | 209 ++++++++--------- rust/tvm-rt/src/ndarray2.rs | 440 ------------------------------------ 3 files changed, 100 insertions(+), 550 deletions(-) delete mode 100644 rust/tvm-rt/src/ndarray2.rs diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index e32877a85d98..84951f4c8e67 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -98,7 +98,6 @@ pub mod function; pub mod map; pub mod module; pub mod ndarray; -pub mod ndarray2; mod to_function; pub mod value; diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 24fa5e0dfcbc..4836490dcb5c 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -51,54 +51,52 @@ use std::convert::TryInto; use std::ffi::c_void; use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; -use crate::errors::NDArrayError; - use tvm_sys::ffi::DLTensor; use tvm_sys::{ffi, ByteArray, Context, DataType}; +use tvm_macros::Object; use ndarray::{Array, ArrayD}; use num_traits::Num; +use crate::errors::NDArrayError; + +use crate::object::{Object, ObjectPtr}; + /// See the [`module-level documentation`](../ndarray/index.html) for more details. -/// -/// Wrapper around TVM array handle. -#[derive(Debug)] -pub enum NDArray { - Borrowed { handle: ffi::TVMArrayHandle }, - Owned { handle: *mut c_void }, +#[repr(C)] +#[derive(Object)] +#[ref_name = "NDArray"] +#[type_key = "runtime.NDArray"] +pub struct NDArrayContainer { + base: Object, + dl_tensor: *mut DLTensor, + manager_ctx: *mut c_void, } -impl NDArray { - pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { - NDArray::Borrowed { handle } - } - pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self { - NDArray::Owned { handle } +impl NDArray { + pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Self { + let object: *mut Object = unsafe { std::mem::transmute(handle) }; + let object_ptr = ObjectPtr::from_raw(object); + let ptr = object_ptr + .map(|ptr| + ptr.downcast::() + .expect("we know this is an NDArray container")); + NDArray(ptr) } pub fn as_dltensor(&self) -> &DLTensor { - let ptr: *mut DLTensor = match self { - NDArray::Borrowed { ref handle } => *handle, - NDArray::Owned { ref handle } => *handle as *mut DLTensor, - }; - - unsafe { std::mem::transmute(ptr) } + unsafe { + std::mem::transmute(self.0.as_ref().unwrap().dl_tensor) + } } pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { - match self { - NDArray::Borrowed { handle } => *handle, - NDArray::Owned { handle } => *handle as *mut DLTensor, - } + self.0.as_ref().unwrap().dl_tensor } pub fn is_view(&self) -> bool { - if let &NDArray::Borrowed { .. } = self { - true - } else { - false - } + false } /// Returns the shape of the NDArray. @@ -285,19 +283,20 @@ impl NDArray { /// Allocates and creates an empty NDArray given the shape, context and dtype. pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { - let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; - let dtype: tvm_sys::ffi::DLDataType = dtype.into(); - check_call!(ffi::TVMArrayAlloc( - shape.as_ptr() as *const i64, - shape.len() as c_int, - i32::from(dtype.code) as c_int, - i32::from(dtype.bits) as c_int, - i32::from(dtype.lanes) as c_int, - ctx.device_type as c_int, - ctx.device_id as c_int, - &mut handle as *mut _, - )); - NDArray::Borrowed { handle: handle } + // let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + // let dtype: tvm_sys::ffi::DLDataType = dtype.into(); + // check_call!(ffi::TVMArrayAlloc( + // shape.as_ptr() as *const i64, + // shape.len() as c_int, + // i32::from(dtype.code) as c_int, + // i32::from(dtype.bits) as c_int, + // i32::from(dtype.lanes) as c_int, + // ctx.device_type as c_int, + // ctx.device_id as c_int, + // &mut handle as *mut _, + // )); + // NDArray::Borrowed { handle: handle } + panic!() } } @@ -339,14 +338,6 @@ impl_from_ndarray_rustndarray!(i32, "int"); impl_from_ndarray_rustndarray!(u32, "uint"); impl_from_ndarray_rustndarray!(f32, "float"); -impl Drop for NDArray { - fn drop(&mut self) { - if let &mut NDArray::Owned { .. } = self { - check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); - } - } -} - mod sealed { /// Private trait to prevent other traits from being implemeneted in downstream crates. pub trait Sealed {} @@ -370,69 +361,69 @@ impl_num32!(i32, u32, f32); #[cfg(test)] mod tests { - use super::*; - - #[test] - fn basics() { - let shape = &mut [1, 2, 3]; - let ctx = Context::cpu(0); - let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - assert_eq!(ndarray.shape().unwrap(), shape); - assert_eq!( - ndarray.size().unwrap(), - shape.to_vec().into_iter().product() - ); - assert_eq!(ndarray.ndim(), 3); - assert!(ndarray.strides().is_none()); - assert_eq!(ndarray.byte_offset(), 0); - } + // use super::*; - #[test] - fn copy() { - let shape = &mut [4]; - let mut data = vec![1i32, 2, 3, 4]; - let ctx = Context::cpu(0); - let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - assert!(ndarray.to_vec::().is_ok()); - ndarray.copy_from_buffer(&mut data); - assert_eq!(ndarray.shape().unwrap(), shape); - assert_eq!(ndarray.to_vec::().unwrap(), data); - assert_eq!(ndarray.ndim(), 1); - assert!(ndarray.is_contiguous().is_ok()); - assert_eq!(ndarray.byte_offset(), 0); - let shape = vec![4]; - let e = NDArray::empty( - &shape, - Context::cpu(0), - DataType::from_str("int32").unwrap(), - ); - let nd = ndarray.copy_to_ndarray(e); - assert!(nd.is_ok()); - assert_eq!(nd.unwrap().to_vec::().unwrap(), data); - } + // #[test] + // fn basics() { + // let shape = &mut [1, 2, 3]; + // let ctx = Context::cpu(0); + // let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + // assert_eq!(ndarray.shape().unwrap(), shape); + // assert_eq!( + // ndarray.size().unwrap(), + // shape.to_vec().into_iter().product() + // ); + // assert_eq!(ndarray.ndim(), 3); + // assert!(ndarray.strides().is_none()); + // assert_eq!(ndarray.byte_offset(), 0); + // } // #[test] - // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] - // fn copy_wrong_dtype() { - // let shape = vec![4]; - // let mut data = vec![1f32, 2., 3., 4.]; + // fn copy() { + // let shape = &mut [4]; + // let mut data = vec![1i32, 2, 3, 4]; // let ctx = Context::cpu(0); - // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); - // nd_float.copy_from_buffer(&mut data); - // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); - // nd_float.copy_to_ndarray(empty_int).unwrap(); + // let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + // assert!(ndarray.to_vec::().is_ok()); + // ndarray.copy_from_buffer(&mut data); + // assert_eq!(ndarray.shape().unwrap(), shape); + // assert_eq!(ndarray.to_vec::().unwrap(), data); + // assert_eq!(ndarray.ndim(), 1); + // assert!(ndarray.is_contiguous().is_ok()); + // assert_eq!(ndarray.byte_offset(), 0); + // let shape = vec![4]; + // let e = NDArray::empty( + // &shape, + // Context::cpu(0), + // DataType::from_str("int32").unwrap(), + // ); + // let nd = ndarray.copy_to_ndarray(e); + // assert!(nd.is_ok()); + // assert_eq!(nd.unwrap().to_vec::().unwrap(), data); // } - #[test] - fn rust_ndarray() { - let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) - .unwrap() - .into_dyn(); - let nd = - NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) - .unwrap(); - assert_eq!(nd.shape().unwrap(), &mut [2, 2]); - let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); - assert!(rnd.all_close(&a, 1e-8f32)); - } + // // #[test] + // // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + // // fn copy_wrong_dtype() { + // // let shape = vec![4]; + // // let mut data = vec![1f32, 2., 3., 4.]; + // // let ctx = Context::cpu(0); + // // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + // // nd_float.copy_from_buffer(&mut data); + // // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + // // nd_float.copy_to_ndarray(empty_int).unwrap(); + // // } + + // #[test] + // fn rust_ndarray() { + // let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) + // .unwrap() + // .into_dyn(); + // let nd = + // NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) + // .unwrap(); + // assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + // let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); + // assert!(rnd.all_close(&a, 1e-8f32)); + // } } diff --git a/rust/tvm-rt/src/ndarray2.rs b/rust/tvm-rt/src/ndarray2.rs deleted file mode 100644 index c7dfb2c79491..000000000000 --- a/rust/tvm-rt/src/ndarray2.rs +++ /dev/null @@ -1,440 +0,0 @@ -// /* -// * Licensed to the Apache Software Foundation (ASF) under one -// * or more contributor license agreements. See the NOTICE file -// * distributed with this work for additional information -// * regarding copyright ownership. The ASF licenses this file -// * to you under the Apache License, Version 2.0 (the -// * "License"); you may not use this file except in compliance -// * with the License. You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, -// * software distributed under the License is distributed on an -// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// * KIND, either express or implied. See the License for the -// * specific language governing permissions and limitations -// * under the License. -// */ - -// //! This module implements the [`NDArray`] type for working with *TVM tensors* or -// //! coverting from a Rust's ndarray to TVM `NDArray`. -// //! -// //! One can create an empty NDArray given the shape, context and dtype using [`empty`]. -// //! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. -// //! To copy an NDArray to different context use [`copy_to_ctx`]. -// //! -// //! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: -// //! -// //! # Example -// //! -// //! ``` -// //! # use tvm_rt::{NDArray, Context, DataType}; -// //! # use ndarray::{Array, ArrayD}; -// //! # use std::str::FromStr; -// //! use std::convert::TryFrom; -// //! -// //! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) -// //! .unwrap() -// //! .into_dyn(); // Rust's ndarray -// //! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); -// //! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); -// //! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); -// //! assert!(rnd.all_close(&a, 1e-8f32)); -// //! ``` -// //! -// //! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ -// //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer -// //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx - -// use std::convert::TryInto; -// use std::ffi::c_void; -// use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; - -// use tvm_sys::ffi::DLTensor; -// use tvm_sys::{ffi, ByteArray, Context, DataType}; -// use tvm_macros::Object; - -// use ndarray::{Array, ArrayD}; -// use num_traits::Num; - -// use crate::object::{Object, ObjectPtr} - -// /// See the [`module-level documentation`](../ndarray/index.html) for more details. -// #[repr(C)] -// #[derive(Object)] -// #[ref_name = "NDArray"] -// #[type_key = "runtime.NDArray"] -// pub struct NDArrayContainer { -// base: Object, -// dl_tensor: *mut DLTensor, -// manager_ctx: *mut c_void, -// } - - -// impl NDArray { -// pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Self { -// let object: *mut NDArrayContainer = unsafe { std::mem::transmute(handle) }; -// let object_ptr = ObjectPtr::from_raw(object); -// NDArray(Some(object_ptr)) -// } - -// pub fn as_dltensor(&self) -> &DLTensor { -// let ptr: *mut DLTensor = match self { -// NDArray::Borrowed { ref handle } => *handle, -// NDArray::Owned { ref handle } => *handle as *mut DLTensor, -// }; - -// unsafe { std::mem::transmute(ptr) } -// } - -// pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { -// match self { -// NDArray::Borrowed { handle } => *handle, -// NDArray::Owned { handle } => *handle as *mut DLTensor, -// } -// } - -// pub fn is_view(&self) -> bool { -// if let &NDArray::Borrowed { .. } = self { -// true -// } else { -// false -// } -// } - -// /// Returns the shape of the NDArray. -// pub fn shape(&self) -> Option<&mut [usize]> { -// let arr = self.as_dltensor(); -// if arr.shape.is_null() || arr.data.is_null() { -// return None; -// }; -// let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; -// Some(slc) -// } - -// /// Returns the total number of entries of the NDArray. -// pub fn size(&self) -> Option { -// self.shape().map(|v| v.iter().product()) -// } - -// /// Returns the context which the NDArray was defined. -// pub fn ctx(&self) -> Context { -// self.as_dltensor().ctx.into() -// } - -// /// Returns the type of the entries of the NDArray. -// pub fn dtype(&self) -> DataType { -// self.as_dltensor().dtype.into() -// } - -// /// Returns the number of dimensions of the NDArray. -// pub fn ndim(&self) -> usize { -// self.as_dltensor() -// .ndim -// .try_into() -// .expect("number of dimensions must always be positive") -// } - -// /// Returns the strides of the underlying NDArray. -// pub fn strides(&self) -> Option<&[usize]> { -// unsafe { -// let sz = self.ndim() * mem::size_of::(); -// let strides_ptr = self.as_dltensor().strides as *const usize; -// let slc = slice::from_raw_parts(strides_ptr, sz); -// Some(slc) -// } -// } - -// /// Shows whether the underlying ndarray is contiguous in memory or not. -// pub fn is_contiguous(&self) -> Result { -// Ok(match self.strides() { -// None => true, -// Some(strides) => { -// // NDArrayError::MissingShape in case shape is not determined -// self.shape() -// .ok_or(NDArrayError::MissingShape)? -// .iter() -// .zip(strides) -// .rfold( -// (true, 1), -// |(is_contig, expected_stride), (shape, stride)| { -// ( -// is_contig && *stride == expected_stride, -// expected_stride * (*shape as usize), -// ) -// }, -// ) -// .0 -// } -// }) -// } - -// pub fn byte_offset(&self) -> isize { -// self.as_dltensor().byte_offset as isize -// } - -// /// Flattens the NDArray to a `Vec` of the same type in cpu. -// /// -// /// ## Example -// /// -// /// ``` -// /// # use tvm_rt::{Context, DataType, NDArray}; -// /// # use std::str::FromStr; -// /// let mut shape = [4]; -// /// let mut data = vec![1i32, 2, 3, 4]; -// /// let ctx = Context::cpu(0); -// /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); -// /// ndarray.copy_from_buffer(&mut data); -// /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); -// /// assert_eq!(ndarray.to_vec::().unwrap(), data); -// /// ``` -// pub fn to_vec(&self) -> Result, NDArrayError> { -// if !self.shape().is_some() { -// return Err(NDArrayError::EmptyArray); -// } -// let earr = NDArray::empty( -// self.shape().ok_or(NDArrayError::MissingShape)?, -// Context::cpu(0), -// self.dtype(), -// ); -// let target = self.copy_to_ndarray(earr)?; -// let arr = target.as_dltensor(); -// let sz = self.size().ok_or(NDArrayError::MissingShape)?; -// let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); -// unsafe { -// v.as_mut_ptr() -// .copy_from_nonoverlapping(arr.data as *const T, sz); -// v.set_len(sz); -// } -// Ok(v) -// } - -// /// Converts the NDArray to [`ByteArray`]. -// pub fn to_bytearray(&self) -> Result { -// let v = self.to_vec::()?; -// Ok(ByteArray::from(v)) -// } - -// /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. -// /// -// /// ## Example -// /// -// /// ``` -// /// # use tvm_rt::{Context, DataType, NDArray}; -// /// # use std::str::FromStr; -// /// let shape = &mut [2]; -// /// let mut data = vec![1f32, 2.0]; -// /// let ctx = Context::cpu(0); -// /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); -// /// ndarray.copy_from_buffer(&mut data); -// /// ``` -// /// -// /// *Note*: if something goes wrong during the copy, it will panic -// /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. -// pub fn copy_from_buffer(&mut self, data: &mut [T]) { -// check_call!(ffi::TVMArrayCopyFromBytes( -// self.as_raw_dltensor(), -// data.as_ptr() as *mut _, -// data.len() * mem::size_of::() -// )); -// } - -// /// Copies the NDArray to another target NDArray. -// pub fn copy_to_ndarray(&self, target: NDArray) -> Result { -// if self.dtype() != target.dtype() { -// return Err(NDArrayError::DataTypeMismatch { -// expected: self.dtype(), -// actual: target.dtype(), -// }); -// } - -// check_call!(ffi::TVMArrayCopyFromTo( -// self.as_raw_dltensor(), -// target.as_raw_dltensor(), -// ptr::null_mut() as ffi::TVMStreamHandle -// )); - -// Ok(target) -// } - -// /// Copies the NDArray to a target context. -// pub fn copy_to_ctx(&self, target: &Context) -> Result { -// let tmp = NDArray::empty( -// self.shape().ok_or(NDArrayError::MissingShape)?, -// *target, -// self.dtype(), -// ); -// let copy = self.copy_to_ndarray(tmp)?; -// Ok(copy) -// } - -// /// Converts a Rust's ndarray to TVM NDArray. -// pub fn from_rust_ndarray( -// rnd: &ArrayD, -// ctx: Context, -// dtype: DataType, -// ) -> Result { -// let shape = rnd.shape().to_vec(); -// let mut nd = NDArray::empty(&shape, ctx, dtype); -// let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); -// nd.copy_from_buffer( -// buf.as_slice_mut() -// .expect("Array from iter must be contiguous."), -// ); -// Ok(nd) -// } - -// /// Allocates and creates an empty NDArray given the shape, context and dtype. -// pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { -// let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; -// let dtype: tvm_sys::ffi::DLDataType = dtype.into(); -// check_call!(ffi::TVMArrayAlloc( -// shape.as_ptr() as *const i64, -// shape.len() as c_int, -// i32::from(dtype.code) as c_int, -// i32::from(dtype.bits) as c_int, -// i32::from(dtype.lanes) as c_int, -// ctx.device_type as c_int, -// ctx.device_id as c_int, -// &mut handle as *mut _, -// )); -// NDArray::Borrowed { handle: handle } -// } -// } - -// macro_rules! impl_from_ndarray_rustndarray { -// ($type:ty, $type_name:tt) => { -// impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { -// type Error = NDArrayError; - -// fn try_from(nd: &NDArray) -> Result, Self::Error> { -// if !nd.shape().is_some() { -// return Err(NDArrayError::MissingShape); -// } -// assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); -// Ok(Array::from_shape_vec( -// &*nd.shape().ok_or(NDArrayError::MissingShape)?, -// nd.to_vec::<$type>()?, -// )?) -// } -// } - -// impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { -// type Error = NDArrayError; - -// fn try_from(nd: &mut NDArray) -> Result, Self::Error> { -// if !nd.shape().is_some() { -// return Err(NDArrayError::MissingShape); -// }; -// assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); -// Ok(Array::from_shape_vec( -// &*nd.shape().ok_or(NDArrayError::MissingShape)?, -// nd.to_vec::<$type>()?, -// )?) -// } -// } -// }; -// } - -// impl_from_ndarray_rustndarray!(i32, "int"); -// impl_from_ndarray_rustndarray!(u32, "uint"); -// impl_from_ndarray_rustndarray!(f32, "float"); - -// impl Drop for NDArray { -// fn drop(&mut self) { -// if let &mut NDArray::Owned { .. } = self { -// check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); -// } -// } -// } - -// mod sealed { -// /// Private trait to prevent other traits from being implemeneted in downstream crates. -// pub trait Sealed {} -// } - -// /// A trait for the supported 32-bits numerical types in frontend. -// pub trait Num32: Num + sealed::Sealed { -// const BITS: u8 = 32; -// } - -// macro_rules! impl_num32 { -// ($($type:ty),+) => { -// $( -// impl sealed::Sealed for $type {} -// impl Num32 for $type {} -// )+ -// }; -// } - -// impl_num32!(i32, u32, f32); - -// #[cfg(test)] -// mod tests { -// // use super::*; - -// // #[test] -// // fn basics() { -// // let shape = &mut [1, 2, 3]; -// // let ctx = Context::cpu(0); -// // let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); -// // assert_eq!(ndarray.shape().unwrap(), shape); -// // assert_eq!( -// // ndarray.size().unwrap(), -// // shape.to_vec().into_iter().product() -// // ); -// // assert_eq!(ndarray.ndim(), 3); -// // assert!(ndarray.strides().is_none()); -// // assert_eq!(ndarray.byte_offset(), 0); -// // } - -// // #[test] -// // fn copy() { -// // let shape = &mut [4]; -// // let mut data = vec![1i32, 2, 3, 4]; -// // let ctx = Context::cpu(0); -// // let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); -// // assert!(ndarray.to_vec::().is_ok()); -// // ndarray.copy_from_buffer(&mut data); -// // assert_eq!(ndarray.shape().unwrap(), shape); -// // assert_eq!(ndarray.to_vec::().unwrap(), data); -// // assert_eq!(ndarray.ndim(), 1); -// // assert!(ndarray.is_contiguous().is_ok()); -// // assert_eq!(ndarray.byte_offset(), 0); -// // let shape = vec![4]; -// // let e = NDArray::empty( -// // &shape, -// // Context::cpu(0), -// // DataType::from_str("int32").unwrap(), -// // ); -// // let nd = ndarray.copy_to_ndarray(e); -// // assert!(nd.is_ok()); -// // assert_eq!(nd.unwrap().to_vec::().unwrap(), data); -// // } - -// // // #[test] -// // // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] -// // // fn copy_wrong_dtype() { -// // // let shape = vec![4]; -// // // let mut data = vec![1f32, 2., 3., 4.]; -// // // let ctx = Context::cpu(0); -// // // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); -// // // nd_float.copy_from_buffer(&mut data); -// // // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); -// // // nd_float.copy_to_ndarray(empty_int).unwrap(); -// // // } - -// // #[test] -// // fn rust_ndarray() { -// // let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) -// // .unwrap() -// // .into_dyn(); -// // let nd = -// // NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) -// // .unwrap(); -// // assert_eq!(nd.shape().unwrap(), &mut [2, 2]); -// // let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); -// // assert!(rnd.all_close(&a, 1e-8f32)); -// // } -// } From 2b982c1d88e279f2d5143c3c36a83c3a9233c1e7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 11 Sep 2020 03:21:51 -0700 Subject: [PATCH 08/50] Work on NDArray 2.0 before restoring tests --- rust/tvm-rt/src/object/object_ptr.rs | 17 +++++++-- rust/tvm-rt/src/value.rs | 56 +--------------------------- 2 files changed, 15 insertions(+), 58 deletions(-) diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 1388d3c96d02..792b14917fec 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -291,9 +291,14 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { impl<'a, T: IsObject> From> for ArgValue<'a> { fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void; + let object_ptr = object_ptr.upcast::(); + let index = object_ptr.type_index; + let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); - ArgValue::ObjectHandle(raw_ptr) + match index { + tvm_sys::ffi::TVMArgTypeCode_kTVMNDArrayHandle => ArgValue::NDArrayHandle(raw_ptr), + _ => ArgValue::ObjectHandle(raw_ptr) + } } } @@ -307,7 +312,13 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { debug_assert!(optr.count() >= 1); // println!("count: {}", optr.count()); optr.downcast() - } + }, + ArgValue::NDArrayHandle(handle) => { + let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; + debug_assert!(optr.count() >= 1); + // println!("count: {}", optr.count()); + optr.downcast() + }, _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } } diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs index 1812c0cfbe45..c49944dc7e33 100644 --- a/rust/tvm-rt/src/value.rs +++ b/rust/tvm-rt/src/value.rs @@ -24,7 +24,7 @@ use std::convert::TryFrom; // use std::ffi::c_void; -use crate::{ArgValue, Module, NDArray, RetValue}; +use crate::{ArgValue, Module, RetValue}; use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; macro_rules! impl_handle_val { @@ -72,60 +72,6 @@ macro_rules! impl_handle_val { impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); -impl<'a> From<&'a NDArray> for ArgValue<'a> { - fn from(arg: &'a NDArray) -> Self { - match arg { - &NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), - &NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), - } - } -} - -impl<'a> From<&'a mut NDArray> for ArgValue<'a> { - fn from(arg: &'a mut NDArray) -> Self { - match arg { - &mut NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), - &mut NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), - } - } -} - -impl<'a> TryFrom> for NDArray { - type Error = ValueDowncastError; - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> NDArray, - |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, - |ArgValue::ArrayHandle(val)| { NDArray::new(val) }) - } -} - -impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for NDArray { - type Error = ValueDowncastError; - fn try_from(val: &'a ArgValue<'v>) -> Result { - try_downcast!(val -> NDArray, - |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) }, - |ArgValue::ArrayHandle(val)| { NDArray::new(*val) }) - } -} - -impl From for RetValue { - fn from(val: NDArray) -> RetValue { - match val { - NDArray::Owned { handle } => RetValue::NDArrayHandle(handle), - _ => panic!("NYI"), - } - } -} - -impl TryFrom for NDArray { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> NDArray, - |RetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, - |RetValue::ArrayHandle(val)| { NDArray::new(val) }) - } -} - #[cfg(test)] mod tests { use std::{convert::TryInto, str::FromStr}; From dbe32f1f341c7b1c5f171b402f07c9a3cb0ec56b Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Fri, 18 Sep 2020 10:35:00 -0700 Subject: [PATCH 09/50] Formatting and code fixes to get it to compile --- rust/tvm-graph-rt/src/threading.rs | 4 ++-- rust/tvm-rt/src/array.rs | 10 +++------- rust/tvm/src/ir/expr.rs | 2 +- rust/tvm/src/ir/function.rs | 2 +- rust/tvm/src/ir/module.rs | 28 ++++++++++++++++++---------- rust/tvm/src/ir/relay/mod.rs | 26 +++++++++++++------------- rust/tvm/src/ir/tir.rs | 2 +- rust/tvm/src/python.rs | 6 ++++-- 8 files changed, 43 insertions(+), 37 deletions(-) diff --git a/rust/tvm-graph-rt/src/threading.rs b/rust/tvm-graph-rt/src/threading.rs index cbb3bf14c31c..03765e0a049b 100644 --- a/rust/tvm-graph-rt/src/threading.rs +++ b/rust/tvm-graph-rt/src/threading.rs @@ -215,7 +215,7 @@ pub unsafe extern "C" fn TVMBackendParallelBarrier( #[cfg(test)] mod tests { - use std::{ptr, thread, time::Duration}; + use std::{thread, time::Duration}; use super::*; @@ -228,7 +228,7 @@ mod tests { assert_eq!(max_concurrency(), 24); } - extern "C" fn flambda( + extern "C" fn _flambda( task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void, diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 213d7ee8a9c2..5e19cefd8e97 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -121,20 +121,16 @@ impl<'a, T: IsObjectRef> TryFrom for Array { #[cfg(test)] mod tests { use super::Array; - use crate::string::String; use crate::function::Result; + use crate::string::String; #[test] fn create_array_and_get() -> Result<()> { - let vec: Vec = vec![ - "foo".into(), - "bar".into(), - "baz".into(), - ]; + let vec: Vec = vec!["foo".into(), "bar".into(), "baz".into()]; let array = Array::from_vec(vec)?; assert_eq!(array.get(0)?.to_string(), "foo"); assert_eq!(array.get(1)?.to_string(), "bar"); - assert_eq!(array.get(1)?.to_string(), "baz"); + assert_eq!(array.get(2)?.to_string(), "baz"); Ok(()) } } diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index 20e92f9f5925..a8a188e39ae2 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -1,7 +1,7 @@ +use super::relay; use crate::runtime::String as TString; use crate::runtime::{self, external, IsObject, IsObjectRef, Object, ObjectPtr, ObjectRef}; use crate::DataType; -use super::relay; use tvm_macros::Object; diff --git a/rust/tvm/src/ir/function.rs b/rust/tvm/src/ir/function.rs index e1294f1311de..e6a1d3d9d620 100644 --- a/rust/tvm/src/ir/function.rs +++ b/rust/tvm/src/ir/function.rs @@ -1,5 +1,5 @@ -use crate::runtime::{IsObjectRef, IsObject, ObjectRef}; use crate::ir::relay::ExprNode; +use crate::runtime::{IsObject, IsObjectRef, ObjectRef}; use tvm_macros::Object; diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index eee828f7632a..365680160c33 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -1,13 +1,13 @@ -use crate::runtime::{external, Object, ObjectRef}; -use crate::runtime::{string::String as TVMString}; -use crate::runtime::function::Result; use crate::runtime::array::Array; +use crate::runtime::function::Result; use crate::runtime::map::Map; +use crate::runtime::string::String as TVMString; +use crate::runtime::{external, Object, ObjectRef}; use super::expr::GlobalVar; use super::function::BaseFunc; -use std::io::{Result as IOResult}; +use std::io::Result as IOResult; use std::path::Path; use tvm_macros::Object; @@ -26,7 +26,6 @@ pub struct IRModuleNode { pub type_definitions: Map, } - external! { // Parser functions #[name("parser.ParseModule")] @@ -96,9 +95,11 @@ external! { impl IRModule { pub fn parse(file_name: N, source: S) -> IRModule - where N: Into, S: Into { - parse_module(file_name.into(), source.into()) - .expect("failed to call parser") + where + N: Into, + S: Into, + { + parse_module(file_name.into(), source.into()).expect("failed to call parser") } pub fn parse_file>(file_path: P) -> IOResult { @@ -109,7 +110,12 @@ impl IRModule { Ok(module) } - pub fn add_def(&mut self, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> Result<()> { + pub fn add_def( + &mut self, + type_name: GlobalTypeVar, + type_data: TypeData, + update: bool, + ) -> Result<()> { module_add_def(self.clone(), type_name, type_data, update) } @@ -126,7 +132,9 @@ impl IRModule { } pub fn lookup_str(&self, name: S) -> Result - where S: Into { + where + S: Into, + { module_lookup_str(self.clone(), name.into()) } } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index e1f0ed483887..df3b3d500c60 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -20,13 +20,13 @@ use crate::runtime::array::Array; use crate::runtime::{object::*, String as TString}; -use super::expr::{BaseExprNode}; +use super::expr::BaseExprNode; use super::function::BaseFuncNode; use super::ty::Type; use tvm_macros::Object; -pub use super::expr::{GlobalVarNode, GlobalVar}; +pub use super::expr::{GlobalVar, GlobalVarNode}; pub type Attrs = ObjectRef; @@ -50,7 +50,6 @@ impl ExprNode { } } - #[repr(C)] #[derive(Object)] #[ref_name = "Id"] @@ -70,7 +69,6 @@ impl Id { } } - #[repr(C)] #[derive(Object)] #[ref_name = "Constant"] @@ -138,7 +136,6 @@ impl Var { } } - #[repr(C)] #[derive(Object)] #[ref_name = "Call"] @@ -187,7 +184,7 @@ impl Let { base: ExprNode::base::(), var, value, - body + body, }; Let(Some(ObjectPtr::new(node))) } @@ -269,7 +266,7 @@ impl RefRead { pub fn new(ref_value: Expr, _span: ObjectRef) -> RefRead { let node = RefReadNode { base: ExprNode::base::(), - ref_value + ref_value, }; RefRead(Some(ObjectPtr::new(node))) } @@ -321,14 +318,13 @@ impl Constructor { // TODO(@jroesch): define the type data - #[repr(C)] #[derive(Object)] #[ref_name = "Pattern"] #[type_key = "relay.Pattern"] pub struct PatternNode { pub base: Object, - pub span: ObjectRef + pub span: ObjectRef, } impl PatternNode { @@ -376,7 +372,6 @@ impl PatternVar { } } - #[repr(C)] #[derive(Object)] #[ref_name = "PatternConstructor"] @@ -388,7 +383,11 @@ pub struct PatternConstructorNode { } impl PatternConstructor { - pub fn new(constructor: Constructor, patterns: Array, _span: ObjectRef) -> PatternConstructor { + pub fn new( + constructor: Constructor, + patterns: Array, + _span: ObjectRef, + ) -> PatternConstructor { let node = PatternConstructorNode { base: PatternNode::base::(), constructor, @@ -431,7 +430,8 @@ impl Clause { pub fn new(lhs: Pattern, rhs: Expr, _span: ObjectRef) -> Clause { let node = ClauseNode { base: Object::base_object::(), - lhs, rhs, + lhs, + rhs, }; Clause(Some(ObjectPtr::new(node))) } @@ -516,7 +516,7 @@ mod tests { #[test] fn test_var() -> Result<()> { - let var = Var::new("local".to_string(), ObjectRef::null()); + let var = Var::new("local".to_string(), Type::null(), ObjectRef::null()); let text = as_text(var.clone()); assert!(text.contains("%local")); Ok(()) diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index a19f0cbc7869..22d4e02054e1 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -17,9 +17,9 @@ * under the License. */ +use super::{PrimExpr, PrimExprNode}; use crate::runtime::String as TVMString; use crate::DataType; -use super::{PrimExprNode, PrimExpr}; use tvm_macros::Object; diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs index e5c5784e3f97..87cc6cd2be79 100644 --- a/rust/tvm/src/python.rs +++ b/rust/tvm/src/python.rs @@ -28,12 +28,14 @@ fn load_python_tvm_(py: Python) -> PyResult { #[cfg(test)] mod tests { - use super::load_python_tvm; + use super::load_python_tvm_; use anyhow::Result; + use pyo3::prelude::*; + #[ignore] #[test] fn test_run() -> Result<()> { - load_python_tvm().unwrap(); + load_python_tvm_(Python::acquire_gil().python()).unwrap(); Ok(()) } } From f7c36ca327d439da3d92c83569e01a8259c331d5 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Mon, 21 Sep 2020 18:34:05 -0700 Subject: [PATCH 10/50] Add more Rust bindings - Converts Conv2d attrs to use tvm::String, so that we can add Rust binding - Uses Type for checked_type in Rust bindings - Fix type key in Rust bindings - Make data field contain NDArray in Rust bindings --- include/tvm/relay/attrs/nn.h | 7 ++--- include/tvm/tir/data_layout.h | 3 +++ rust/tvm/src/ir/attrs.rs | 10 +++++++ rust/tvm/src/ir/mod.rs | 2 ++ rust/tvm/src/ir/op.rs | 24 +++++++++++++++++ rust/tvm/src/ir/relay/attrs/mod.rs | 1 + rust/tvm/src/ir/relay/attrs/nn.rs | 27 +++++++++++++++++++ rust/tvm/src/ir/relay/mod.rs | 21 +++++++++------ src/relay/qnn/op/convolution.cc | 6 ++--- .../transforms/combine_parallel_conv2d.cc | 2 +- src/relay/transforms/pattern_util.h | 2 +- 11 files changed, 89 insertions(+), 16 deletions(-) create mode 100644 rust/tvm/src/ir/attrs.rs create mode 100644 rust/tvm/src/ir/op.rs create mode 100644 rust/tvm/src/ir/relay/attrs/mod.rs create mode 100644 rust/tvm/src/ir/relay/attrs/nn.rs diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index fbe31a305ea5..60c37aff2e90 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -28,6 +28,7 @@ #include #include +#include "tvm/runtime/container.h" namespace tvm { namespace relay { @@ -115,9 +116,9 @@ struct Conv2DAttrs : public tvm::AttrsNode { int groups; IndexExpr channels; Array kernel_size; - std::string data_layout; - std::string kernel_layout; - std::string out_layout; + tvm::String data_layout; + tvm::String kernel_layout; + tvm::String out_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") { diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index af384f9b67f9..ee93a0675470 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -124,6 +124,9 @@ class Layout : public ObjectRef { public: explicit Layout(const Array& axes); + /*! \brief construct from a string */ + Layout(const tvm::String& name) : Layout(name.operator std::string()) {} // NOLINT(*) + /*! \brief construct from a string */ Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) diff --git a/rust/tvm/src/ir/attrs.rs b/rust/tvm/src/ir/attrs.rs new file mode 100644 index 000000000000..883ee7d699e1 --- /dev/null +++ b/rust/tvm/src/ir/attrs.rs @@ -0,0 +1,10 @@ +use crate::runtime::Object; +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Attrs"] +#[type_key = "Attrs"] +pub struct BaseAttrsNode { + pub base: Object, +} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 0620a10ab1ad..2379e12df3fb 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -18,9 +18,11 @@ */ pub mod arith; +pub mod attrs; pub mod expr; pub mod function; pub mod module; +pub mod op; pub mod relay; pub mod tir; pub mod ty; diff --git a/rust/tvm/src/ir/op.rs b/rust/tvm/src/ir/op.rs new file mode 100644 index 000000000000..4ab74c4c6625 --- /dev/null +++ b/rust/tvm/src/ir/op.rs @@ -0,0 +1,24 @@ +use crate::ir::relay::ExprNode; +use crate::runtime::array::Array; +use crate::runtime::ObjectRef; +use crate::runtime::String as TString; +use tvm_macros::Object; + +type FuncType = ObjectRef; +type AttrFieldInfo = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Op"] +#[type_key = "Op"] +pub struct OpNode { + pub base: ExprNode, + pub name: TString, + pub op_type: FuncType, + pub description: TString, + pub arguments: Array, + pub attrs_type_key: TString, + pub attrs_type_index: u32, + pub num_inputs: i32, + pub support_level: i32, +} diff --git a/rust/tvm/src/ir/relay/attrs/mod.rs b/rust/tvm/src/ir/relay/attrs/mod.rs new file mode 100644 index 000000000000..cb1fa9728ae1 --- /dev/null +++ b/rust/tvm/src/ir/relay/attrs/mod.rs @@ -0,0 +1 @@ +pub mod nn; diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs new file mode 100644 index 000000000000..42260a1ec8e3 --- /dev/null +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -0,0 +1,27 @@ +use crate::ir::attrs::BaseAttrsNode; +use crate::ir::PrimExpr; +use crate::runtime::array::Array; +use crate::runtime::DataType; +use crate::runtime::String as TString; +use tvm_macros::Object; + +type IndexExpr = PrimExpr; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Conv2DAttrs"] +#[type_key = "relay.attrs.Conv2DAttrs"] +pub struct Conv2DAttrsNode { + pub base: BaseAttrsNode, + pub strides: Array, + pub padding: Array, + pub dilation: Array, + // TODO(@gussmith23) groups is "int", what should it be here? + pub groups: i32, + pub channels: IndexExpr, + pub kernel_size: Array, + pub data_layout: TString, + pub kernel_layout: TString, + pub out_layout: TString, + pub out_dtype: DataType, +} diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index df3b3d500c60..a5decfd1df7d 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -17,27 +17,29 @@ * under the License. */ +pub mod attrs; + use crate::runtime::array::Array; use crate::runtime::{object::*, String as TString}; +use super::attrs::Attrs; use super::expr::BaseExprNode; use super::function::BaseFuncNode; -use super::ty::Type; +use super::ty::{Type, TypeNode}; use tvm_macros::Object; +use tvm_rt::NDArray; pub use super::expr::{GlobalVar, GlobalVarNode}; -pub type Attrs = ObjectRef; - #[repr(C)] #[derive(Object)] #[ref_name = "Expr"] -#[type_key = "relay.Expr"] +#[type_key = "RelayExpr"] pub struct ExprNode { pub base: BaseExprNode, pub span: ObjectRef, - pub checked_type: ObjectRef, + pub checked_type: Type, } impl ExprNode { @@ -45,7 +47,10 @@ impl ExprNode { ExprNode { base: BaseExprNode::base::(), span: ObjectRef::null(), - checked_type: ObjectRef::null(), + checked_type: Type::from(TypeNode { + base: Object::base_object::(), + span: ObjectRef::null(), + }), } } } @@ -75,11 +80,11 @@ impl Id { #[type_key = "relay.Constant"] pub struct ConstantNode { pub base: ExprNode, - pub data: ObjectRef, // make this NDArray. + pub data: NDArray, } impl Constant { - pub fn new(data: ObjectRef, _span: ObjectRef) -> Constant { + pub fn new(data: NDArray, _span: ObjectRef) -> Constant { let node = ConstantNode { base: ExprNode::base::(), data: data, diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 847f81f72a04..f112a7259552 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -62,13 +62,13 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale // Kernel scale can be a vector of length output_channels or a scalar. if (param->groups == 1) { - size_t axis = param->kernel_layout.find('O'); + size_t axis = param->kernel_layout.operator std::string().find('O'); CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale } else { // Here, total number of output channels depend on depth multiplier. - size_t o_axis = param->kernel_layout.find('O'); - size_t i_axis = param->kernel_layout.find('I'); + size_t o_axis = param->kernel_layout.operator std::string().find('O'); + size_t i_axis = param->kernel_layout.operator std::string().find('I'); CHECK(o_axis != std::string::npos || i_axis != std::string::npos) << "Kernel layout attribute is not defined"; AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis], diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index a639fcd60af6..68520efe2bbd 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -196,7 +196,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { auto channels = GetConv2DSuperChannelsDim(conv2d); num_filters += channels; } - auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); + auto index = branches[0][0]->attrs.as()->kernel_layout.operator std::string().find('O'); CHECK_NE(index, std::string::npos); return std::make_tuple(MakeConcatenate(Tuple(weights), index), tir::make_const(DataType::Int(32), num_filters)); diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 39fbec584e4a..d3107ad07a20 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -193,7 +193,7 @@ inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param, inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { auto param = call->attrs.as(); auto tweight = call->args[1]->type_as(); - auto index = param->kernel_layout.find('O'); + auto index = param->kernel_layout.operator std::string().find('O'); CHECK_NE(index, std::string::npos); auto channels = tir::as_const_int(tweight->shape[index]); return *channels; From 5363ff4da77d0aafb23c65d7dea781e695639458 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 22 Sep 2020 14:32:17 -0700 Subject: [PATCH 11/50] Clean up object ptr passing. --- rust/tvm-rt/src/object/object_ptr.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 792b14917fec..5cb330c29ab4 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -277,7 +277,7 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { fn try_from(ret_value: RetValue) -> Result, Self::Error> { match ret_value { - RetValue::ObjectHandle(handle) => { + RetValue::ObjectHandle(handle) | RetValue::NDArrayHandle(handle){ let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); // println!("back to type {}", optr.count()); @@ -307,13 +307,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { match arg_value { - ArgValue::ObjectHandle(handle) => { - let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); - // println!("count: {}", optr.count()); - optr.downcast() - }, - ArgValue::NDArrayHandle(handle) => { + ArgValue::ObjectHandle(handle) | ArgValue::NDArrayHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); // println!("count: {}", optr.count()); From 930e29e0dd7fb25a03b816c82c0372fc6a82365c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 22 Sep 2020 15:19:57 -0700 Subject: [PATCH 12/50] WIP --- rust/tvm-rt/src/ndarray.rs | 62 ++++++++++++++-------------- rust/tvm-rt/src/object/object_ptr.rs | 33 +-------------- 2 files changed, 34 insertions(+), 61 deletions(-) diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 4836490dcb5c..196c1d3074de 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -283,20 +283,22 @@ impl NDArray { /// Allocates and creates an empty NDArray given the shape, context and dtype. pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { - // let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; - // let dtype: tvm_sys::ffi::DLDataType = dtype.into(); - // check_call!(ffi::TVMArrayAlloc( - // shape.as_ptr() as *const i64, - // shape.len() as c_int, - // i32::from(dtype.code) as c_int, - // i32::from(dtype.bits) as c_int, - // i32::from(dtype.lanes) as c_int, - // ctx.device_type as c_int, - // ctx.device_id as c_int, - // &mut handle as *mut _, - // )); - // NDArray::Borrowed { handle: handle } - panic!() + let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + let dtype: tvm_sys::ffi::DLDataType = dtype.into(); + check_call!(ffi::TVMArrayAlloc( + shape.as_ptr() as *const i64, + shape.len() as c_int, + i32::from(dtype.code) as c_int, + i32::from(dtype.bits) as c_int, + i32::from(dtype.lanes) as c_int, + ctx.device_type as c_int, + ctx.device_id as c_int, + &mut handle as *mut _, + )); + let ptr = + ObjectPtr::from_raw(handle as *mut Object).map(|o| + o.downcast().expect("this should never fail")); + NDArray(ptr) } } @@ -361,22 +363,22 @@ impl_num32!(i32, u32, f32); #[cfg(test)] mod tests { - // use super::*; - - // #[test] - // fn basics() { - // let shape = &mut [1, 2, 3]; - // let ctx = Context::cpu(0); - // let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - // assert_eq!(ndarray.shape().unwrap(), shape); - // assert_eq!( - // ndarray.size().unwrap(), - // shape.to_vec().into_iter().product() - // ); - // assert_eq!(ndarray.ndim(), 3); - // assert!(ndarray.strides().is_none()); - // assert_eq!(ndarray.byte_offset(), 0); - // } + use super::*; + + #[test] + fn basics() { + let shape = &mut [1, 2, 3]; + let ctx = Context::cpu(0); + let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!( + ndarray.size().unwrap(), + shape.to_vec().into_iter().product() + ); + assert_eq!(ndarray.ndim(), 3); + assert!(ndarray.strides().is_none()); + assert_eq!(ndarray.byte_offset(), 0); + } // #[test] // fn copy() { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 5cb330c29ab4..34bea1190139 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -277,7 +277,7 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { fn try_from(ret_value: RetValue) -> Result, Self::Error> { match ret_value { - RetValue::ObjectHandle(handle) | RetValue::NDArrayHandle(handle){ + RetValue::ObjectHandle(handle) | RetValue::NDArrayHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); // println!("back to type {}", optr.count()); @@ -407,37 +407,8 @@ mod tests { return o; } - // #[test] - // fn test_ref_count_boundary() { - // use super::*; - // use crate::function::{register, Function, Result}; - // // 1 - // let ptr = ObjectPtr::new(Object::base_object::()); - // assert_eq!(ptr.count(), 1); - // // 2 - // let stay = ptr.clone(); - // assert_eq!(ptr.count(), 2); - // register(test_fn, "my_func").unwrap(); - // let func = Function::get("my_func").unwrap(); - // let func = func.to_boxed_fn::) -> Result>>(); - // let same = func(ptr).unwrap(); - // drop(func); - // assert_eq!(stay.count(), 4); - // assert_eq!(same.count(), 4); - // drop(same); - // assert_eq!(stay.count(), 3); - // } - - // fn test_fn2(o: ArgValue<'static>) -> RetValue { - // // The call machinery adds at least 1 extra count while inside the call. - // match o { - // ArgValue::ObjectHandle(ptr) => RetValue::ObjectHandle(ptr), - // _ => panic!() - // } - // } - #[test] - fn test_ref_count_boundary2() { + fn test_ref_count_boundary3() { use super::*; use crate::function::{register, Function}; let ptr = ObjectPtr::new(Object::base_object::()); From d35c03cc033c6bcb0a9a228c7cc739453718a85b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 22 Sep 2020 18:25:21 -0700 Subject: [PATCH 13/50] Add debugging for NDArray and fix all test cases --- include/tvm/runtime/ndarray.h | 5 +- rust/tvm-rt/Cargo.toml | 1 + rust/tvm-rt/src/ndarray.rs | 147 ++++++++++++++++----------- rust/tvm-rt/src/object/object_ptr.rs | 41 ++++++-- src/runtime/ndarray.cc | 6 +- 5 files changed, 128 insertions(+), 72 deletions(-) diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index e5f6fac8725e..de6e78f2dde9 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -373,8 +373,11 @@ inline ObjectPtr NDArray::FFIDataFromHandle(TVMArrayHandle handle) { inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { // NOTE: it is necessary to cast to container then to base // so that the FFI handle uses the ContainerBase address. - return reinterpret_cast(static_cast( + std::cout << "Object: " << const_cast(nd.get()) << std::endl; + auto ptr = reinterpret_cast(static_cast( static_cast(const_cast(nd.get())))); + std::cout << "Ptr: " << ptr << std::endl; + return ptr; } inline void NDArray::FFIDecRef(TVMArrayHandle handle) { diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml index 465ae583ab6c..acece5aeec48 100644 --- a/rust/tvm-rt/Cargo.toml +++ b/rust/tvm-rt/Cargo.toml @@ -37,6 +37,7 @@ tvm-macros = { version = "0.1", path = "../tvm-macros" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" +memoffset = "0.5.6" [dev-dependencies] anyhow = "^1.0" diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 196c1d3074de..8c38791c26f5 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -60,7 +60,7 @@ use num_traits::Num; use crate::errors::NDArrayError; -use crate::object::{Object, ObjectPtr}; +use crate::object::{Object, ObjectPtr, IsObjectRef}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. #[repr(C)] @@ -69,30 +69,52 @@ use crate::object::{Object, ObjectPtr}; #[type_key = "runtime.NDArray"] pub struct NDArrayContainer { base: Object, - dl_tensor: *mut DLTensor, + // Container Base + dl_tensor: DLTensor, manager_ctx: *mut c_void, + // TOOD: shape? } - -impl NDArray { - pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Self { - let object: *mut Object = unsafe { std::mem::transmute(handle) }; +impl NDArrayContainer { + pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Option> { + let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; + println!("Base Object {:?}", base_offset); + let base_ptr = unsafe { (handle as *mut i8).offset(-base_offset) }; + println!("Base Ptr {:?}", base_ptr); + let object: *mut Object = unsafe { std::mem::transmute(base_ptr) }; + println!("Rust Object {:?}", object); let object_ptr = ObjectPtr::from_raw(object); - let ptr = object_ptr + println!("{:?}", crate::object::debug_print(IsObjectRef::from_ptr(object_ptr.clone()))); + object_ptr .map(|ptr| ptr.downcast::() - .expect("we know this is an NDArray container")); + .expect("we know this is an NDArray container")) + } + + pub fn leak<'a>(object_ptr: ObjectPtr) -> &'a mut NDArrayContainer + where + NDArrayContainer: 'a, + { + let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; + unsafe { &mut *std::mem::ManuallyDrop::new(object_ptr).ptr.as_ptr().offset(base_offset) } + } +} + +impl NDArray { + pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Self { + let ptr = NDArrayContainer::from_raw(handle); NDArray(ptr) } + // I think these should be marked as unsafe functions? projecting a reference is bad news. pub fn as_dltensor(&self) -> &DLTensor { - unsafe { - std::mem::transmute(self.0.as_ref().unwrap().dl_tensor) - } + &self.0.as_ref().unwrap().dl_tensor } pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { - self.0.as_ref().unwrap().dl_tensor + unsafe { + std::mem::transmute(&self.0.as_ref().unwrap().dl_tensor) + } } pub fn is_view(&self) -> bool { @@ -295,8 +317,9 @@ impl NDArray { ctx.device_id as c_int, &mut handle as *mut _, )); + println!("{:?}", handle); let ptr = - ObjectPtr::from_raw(handle as *mut Object).map(|o| + NDArrayContainer::from_raw(handle).map(|o| o.downcast().expect("this should never fail")); NDArray(ptr) } @@ -369,7 +392,9 @@ mod tests { fn basics() { let shape = &mut [1, 2, 3]; let ctx = Context::cpu(0); + println!("before empty"); let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + println!("after empty"); assert_eq!(ndarray.shape().unwrap(), shape); assert_eq!( ndarray.size().unwrap(), @@ -380,52 +405,52 @@ mod tests { assert_eq!(ndarray.byte_offset(), 0); } - // #[test] - // fn copy() { - // let shape = &mut [4]; - // let mut data = vec![1i32, 2, 3, 4]; - // let ctx = Context::cpu(0); - // let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - // assert!(ndarray.to_vec::().is_ok()); - // ndarray.copy_from_buffer(&mut data); - // assert_eq!(ndarray.shape().unwrap(), shape); - // assert_eq!(ndarray.to_vec::().unwrap(), data); - // assert_eq!(ndarray.ndim(), 1); - // assert!(ndarray.is_contiguous().is_ok()); - // assert_eq!(ndarray.byte_offset(), 0); - // let shape = vec![4]; - // let e = NDArray::empty( - // &shape, - // Context::cpu(0), - // DataType::from_str("int32").unwrap(), - // ); - // let nd = ndarray.copy_to_ndarray(e); - // assert!(nd.is_ok()); - // assert_eq!(nd.unwrap().to_vec::().unwrap(), data); - // } - - // // #[test] - // // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] - // // fn copy_wrong_dtype() { - // // let shape = vec![4]; - // // let mut data = vec![1f32, 2., 3., 4.]; - // // let ctx = Context::cpu(0); - // // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); - // // nd_float.copy_from_buffer(&mut data); - // // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); - // // nd_float.copy_to_ndarray(empty_int).unwrap(); - // // } - - // #[test] - // fn rust_ndarray() { - // let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) - // .unwrap() - // .into_dyn(); - // let nd = - // NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) - // .unwrap(); - // assert_eq!(nd.shape().unwrap(), &mut [2, 2]); - // let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); - // assert!(rnd.all_close(&a, 1e-8f32)); - // } + #[test] + fn copy() { + let shape = &mut [4]; + let mut data = vec![1i32, 2, 3, 4]; + let ctx = Context::cpu(0); + let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert!(ndarray.to_vec::().is_ok()); + ndarray.copy_from_buffer(&mut data); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!(ndarray.to_vec::().unwrap(), data); + assert_eq!(ndarray.ndim(), 1); + assert!(ndarray.is_contiguous().is_ok()); + assert_eq!(ndarray.byte_offset(), 0); + let shape = vec![4]; + let e = NDArray::empty( + &shape, + Context::cpu(0), + DataType::from_str("int32").unwrap(), + ); + let nd = ndarray.copy_to_ndarray(e); + assert!(nd.is_ok()); + assert_eq!(nd.unwrap().to_vec::().unwrap(), data); + } + + #[test] + #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + fn copy_wrong_dtype() { + let shape = vec![4]; + let mut data = vec![1f32, 2., 3., 4.]; + let ctx = Context::cpu(0); + let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + nd_float.copy_from_buffer(&mut data); + let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + nd_float.copy_to_ndarray(empty_int).unwrap(); + } + + #[test] + fn rust_ndarray() { + let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) + .unwrap() + .into_dyn(); + let nd = + NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) + .unwrap(); + assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); + assert!(rnd.all_close(&a, 1e-8f32)); + } } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 34bea1190139..f87c9280b209 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -276,13 +276,22 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { type Error = Error; fn try_from(ret_value: RetValue) -> Result, Self::Error> { + use crate::ndarray::NDArrayContainer; + use crate::ffi::DLTensor; + match ret_value { - RetValue::ObjectHandle(handle) | RetValue::NDArrayHandle(handle) => { + RetValue::ObjectHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); - // println!("back to type {}", optr.count()); optr.downcast() } + RetValue::NDArrayHandle(handle) => { + let optr = NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; + debug_assert!(optr.count() >= 1); + // (@mwillsey): can we remove this? + optr.upcast::().downcast() + } + // TODO(@mxwillsey, jared): ObjectHandle is wrong here. _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), } } @@ -293,11 +302,19 @@ impl<'a, T: IsObject> From> for ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); let object_ptr = object_ptr.upcast::(); let index = object_ptr.type_index; - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; - assert!(!raw_ptr.is_null()); match index { - tvm_sys::ffi::TVMArgTypeCode_kTVMNDArrayHandle => ArgValue::NDArrayHandle(raw_ptr), - _ => ArgValue::ObjectHandle(raw_ptr) + tvm_sys::ffi::TVMArgTypeCode_kTVMNDArrayHandle => { + use crate::ndarray::NDArrayContainer; + // TODO(this is probably not optimal) + let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) as *mut NDArrayContainer as *mut std::ffi::c_void; + assert!(!raw_ptr.is_null()); + ArgValue::NDArrayHandle(raw_ptr) + }, + _ => { + let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + assert!(!raw_ptr.is_null()); + ArgValue::ObjectHandle(raw_ptr) + } } } } @@ -306,13 +323,21 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { type Error = Error; fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { + use crate::ndarray::NDArrayContainer; + use crate::ffi::DLTensor; + match arg_value { - ArgValue::ObjectHandle(handle) | ArgValue::NDArrayHandle(handle) => { + ArgValue::ObjectHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); - // println!("count: {}", optr.count()); optr.downcast() }, + ArgValue::NDArrayHandle(handle) => { + let optr = NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; + debug_assert!(optr.count() >= 1); + // (@mwillsey): can we remove this? + optr.upcast::().downcast() + } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } } diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 800a9167dadc..b176b88985b6 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -267,8 +267,10 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ DLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; - *out = NDArray::Internal::MoveToFFIHandle( - NDArray::Empty(std::vector(shape, shape + ndim), dtype, ctx)); + auto ndarray = NDArray::Empty(std::vector(shape, shape + ndim), dtype, ctx); + + *out = NDArray::Internal::MoveToFFIHandle(ndarray); + std::flush(std::cout); API_END(); } From 5eb46c218fec808d073e1e748e5bcdb4e24a5ad4 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Wed, 23 Sep 2020 09:24:22 -0700 Subject: [PATCH 14/50] Add breaking test --- rust/tvm/src/ir/relay/mod.rs | 45 ++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index a5decfd1df7d..02b7955835f3 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -526,4 +526,49 @@ mod tests { assert!(text.contains("%local")); Ok(()) } + + #[test] + fn test_parse_constant() -> Result<()> { + let module = crate::ir::module::IRModule::parse("", r#" +#[version = "0.0.5"] +def @main() -> float32 { + 0.01639530062675476f +} +"#); + let main = module + .lookup(module.get_global_var("main".to_string().into()).unwrap()) + .unwrap(); + let func = main.downcast::().unwrap(); + let constant = func.body.clone().downcast::().unwrap(); + let tuple_type = constant + .clone() + .upcast::() + .checked_type + .clone() + .downcast::() + .unwrap(); + // Test type + assert_eq!( + tuple_type.shape.len(), + 0, + ); + assert_eq!( + tuple_type.dtype, + "float32".parse().unwrap(), + ); + // Check that actual data matches up with type + assert_eq!( + constant.data.dtype(), + "float32".parse().unwrap(), + ); + assert_eq!( + constant.data.size(), + Some(1), + ); + assert_eq!( + constant.data.shape().unwrap().len(), + 0, + ); + Ok(()) + } } From 518b230a857af0a6df739f0b34a8339e7aa356b5 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Wed, 23 Sep 2020 09:26:12 -0700 Subject: [PATCH 15/50] Dispatch some todos --- rust/tvm-rt/src/object/object_ptr.rs | 22 +++++++++++----------- rust/tvm-rt/src/string.rs | 4 +--- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index f87c9280b209..aa65b193931f 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -276,8 +276,8 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { type Error = Error; fn try_from(ret_value: RetValue) -> Result, Self::Error> { - use crate::ndarray::NDArrayContainer; use crate::ffi::DLTensor; + use crate::ndarray::NDArrayContainer; match ret_value { RetValue::ObjectHandle(handle) => { @@ -286,13 +286,12 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { optr.downcast() } RetValue::NDArrayHandle(handle) => { - let optr = NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; + let optr: ObjectPtr = + NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); - // (@mwillsey): can we remove this? optr.upcast::().downcast() } - // TODO(@mxwillsey, jared): ObjectHandle is wrong here. - _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), + _ => Err(Error::downcast(format!("{:?}", ret_value), T::TYPE_KEY)), } } } @@ -306,10 +305,11 @@ impl<'a, T: IsObject> From> for ArgValue<'a> { tvm_sys::ffi::TVMArgTypeCode_kTVMNDArrayHandle => { use crate::ndarray::NDArrayContainer; // TODO(this is probably not optimal) - let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) as *mut NDArrayContainer as *mut std::ffi::c_void; + let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) + as *mut NDArrayContainer as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) - }, + } _ => { let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); @@ -323,19 +323,19 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { type Error = Error; fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { - use crate::ndarray::NDArrayContainer; use crate::ffi::DLTensor; + use crate::ndarray::NDArrayContainer; match arg_value { ArgValue::ObjectHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); optr.downcast() - }, + } ArgValue::NDArrayHandle(handle) => { - let optr = NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; + let optr = + NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); - // (@mwillsey): can we remove this? optr.upcast::().downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index a5ee1f183389..6ff24bef3a60 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -114,9 +114,7 @@ impl Hash for String { impl std::fmt::Debug for String { fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // TODO(@mwillsey): remove this clone? - let string: String = self.clone().into(); - formatter.write_fmt(format_args!("{:?}", string)) + formatter.write_fmt(format_args!("{:?}", self.to_string_lossy())) } } From ba92c43eb833da715ca330ad044b905ad50dfc11 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Wed, 23 Sep 2020 09:26:24 -0700 Subject: [PATCH 16/50] Format --- rust/tvm-rt/src/ndarray.rs | 33 ++++++++++++++++++------------- rust/tvm/src/ir/relay/mod.rs | 38 ++++++++++++++---------------------- rust/tvm/src/ir/ty.rs | 2 +- 3 files changed, 35 insertions(+), 38 deletions(-) diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 8c38791c26f5..ea14e76197c7 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -51,16 +51,16 @@ use std::convert::TryInto; use std::ffi::c_void; use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; +use tvm_macros::Object; use tvm_sys::ffi::DLTensor; use tvm_sys::{ffi, ByteArray, Context, DataType}; -use tvm_macros::Object; use ndarray::{Array, ArrayD}; use num_traits::Num; use crate::errors::NDArrayError; -use crate::object::{Object, ObjectPtr, IsObjectRef}; +use crate::object::{IsObjectRef, Object, ObjectPtr}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. #[repr(C)] @@ -84,11 +84,14 @@ impl NDArrayContainer { let object: *mut Object = unsafe { std::mem::transmute(base_ptr) }; println!("Rust Object {:?}", object); let object_ptr = ObjectPtr::from_raw(object); - println!("{:?}", crate::object::debug_print(IsObjectRef::from_ptr(object_ptr.clone()))); - object_ptr - .map(|ptr| - ptr.downcast::() - .expect("we know this is an NDArray container")) + println!( + "{:?}", + crate::object::debug_print(IsObjectRef::from_ptr(object_ptr.clone())) + ); + object_ptr.map(|ptr| { + ptr.downcast::() + .expect("we know this is an NDArray container") + }) } pub fn leak<'a>(object_ptr: ObjectPtr) -> &'a mut NDArrayContainer @@ -96,7 +99,12 @@ impl NDArrayContainer { NDArrayContainer: 'a, { let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; - unsafe { &mut *std::mem::ManuallyDrop::new(object_ptr).ptr.as_ptr().offset(base_offset) } + unsafe { + &mut *std::mem::ManuallyDrop::new(object_ptr) + .ptr + .as_ptr() + .offset(base_offset) + } } } @@ -112,9 +120,7 @@ impl NDArray { } pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { - unsafe { - std::mem::transmute(&self.0.as_ref().unwrap().dl_tensor) - } + unsafe { std::mem::transmute(&self.0.as_ref().unwrap().dl_tensor) } } pub fn is_view(&self) -> bool { @@ -318,9 +324,8 @@ impl NDArray { &mut handle as *mut _, )); println!("{:?}", handle); - let ptr = - NDArrayContainer::from_raw(handle).map(|o| - o.downcast().expect("this should never fail")); + let ptr = NDArrayContainer::from_raw(handle) + .map(|o| o.downcast().expect("this should never fail")); NDArray(ptr) } } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 02b7955835f3..5981e0bd1b1d 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -529,17 +529,24 @@ mod tests { #[test] fn test_parse_constant() -> Result<()> { - let module = crate::ir::module::IRModule::parse("", r#" + let module = crate::ir::module::IRModule::parse( + "", + r#" #[version = "0.0.5"] def @main() -> float32 { 0.01639530062675476f } -"#); +"#, + ); let main = module .lookup(module.get_global_var("main".to_string().into()).unwrap()) .unwrap(); let func = main.downcast::().unwrap(); - let constant = func.body.clone().downcast::().unwrap(); + let constant = func + .body + .clone() + .downcast::() + .unwrap(); let tuple_type = constant .clone() .upcast::() @@ -548,27 +555,12 @@ def @main() -> float32 { .downcast::() .unwrap(); // Test type - assert_eq!( - tuple_type.shape.len(), - 0, - ); - assert_eq!( - tuple_type.dtype, - "float32".parse().unwrap(), - ); + assert_eq!(tuple_type.shape.len(), 0,); + assert_eq!(tuple_type.dtype, "float32".parse().unwrap(),); // Check that actual data matches up with type - assert_eq!( - constant.data.dtype(), - "float32".parse().unwrap(), - ); - assert_eq!( - constant.data.size(), - Some(1), - ); - assert_eq!( - constant.data.shape().unwrap().len(), - 0, - ); + assert_eq!(constant.data.dtype(), "float32".parse().unwrap(),); + assert_eq!(constant.data.size(), Some(1),); + assert_eq!(constant.data.shape().unwrap().len(), 0,); Ok(()) } } diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index a323d71aede0..80cb11c4b965 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -1,6 +1,6 @@ +use crate::runtime::{Object, ObjectRef}; use tvm_macros::Object; use tvm_rt::{array::Array, DataType}; -use crate::runtime::{ObjectRef, Object}; use super::PrimExpr; From f99154d5e87577d54fd2418c8a6d36c5dc65d785 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Wed, 23 Sep 2020 12:20:55 -0700 Subject: [PATCH 17/50] Fix ndarray size and len --- rust/tvm-rt/src/ndarray.rs | 24 +++++++++++++++++------- rust/tvm/src/ir/relay/mod.rs | 5 +++-- rust/tvm/tests/basics/src/main.rs | 2 +- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index ea14e76197c7..74a9b83a2404 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -116,11 +116,11 @@ impl NDArray { // I think these should be marked as unsafe functions? projecting a reference is bad news. pub fn as_dltensor(&self) -> &DLTensor { - &self.0.as_ref().unwrap().dl_tensor + &self.dl_tensor } pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { - unsafe { std::mem::transmute(&self.0.as_ref().unwrap().dl_tensor) } + unsafe { std::mem::transmute(self.as_dltensor()) } } pub fn is_view(&self) -> bool { @@ -137,9 +137,19 @@ impl NDArray { Some(slc) } + /// Returns true if the tensor is empty + pub fn is_empty(&self) -> bool { + self.as_dltensor().data.is_null() + } + /// Returns the total number of entries of the NDArray. - pub fn size(&self) -> Option { - self.shape().map(|v| v.iter().product()) + pub fn len(&self) -> usize { + self.shape().unwrap_or(&mut []).iter().product() + } + + /// Returns the total bytes taken up by the data. + pub fn size(&self) -> usize { + self.len() * self.dtype().itemsize() } /// Returns the context which the NDArray was defined. @@ -224,8 +234,8 @@ impl NDArray { ); let target = self.copy_to_ndarray(earr)?; let arr = target.as_dltensor(); - let sz = self.size().ok_or(NDArrayError::MissingShape)?; - let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); + let sz = self.size(); + let mut v: Vec = Vec::with_capacity(sz / mem::size_of::()); unsafe { v.as_mut_ptr() .copy_from_nonoverlapping(arr.data as *const T, sz); @@ -402,7 +412,7 @@ mod tests { println!("after empty"); assert_eq!(ndarray.shape().unwrap(), shape); assert_eq!( - ndarray.size().unwrap(), + ndarray.size(), shape.to_vec().into_iter().product() ); assert_eq!(ndarray.ndim(), 3); diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 5981e0bd1b1d..b68602b2ee17 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -559,8 +559,9 @@ def @main() -> float32 { assert_eq!(tuple_type.dtype, "float32".parse().unwrap(),); // Check that actual data matches up with type assert_eq!(constant.data.dtype(), "float32".parse().unwrap(),); - assert_eq!(constant.data.size(), Some(1),); - assert_eq!(constant.data.shape().unwrap().len(), 0,); + assert_eq!(constant.data.len(), 1); + assert_eq!(constant.data.size(), 4); + assert_eq!(constant.data.shape(), None); Ok(()) } } diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs index 04d8382d3c1f..9a21ff02b2f2 100644 --- a/rust/tvm/tests/basics/src/main.rs +++ b/rust/tvm/tests/basics/src/main.rs @@ -44,7 +44,7 @@ fn main() { fadd.entry() .expect("module must have entry point") - .invoke(vec![(&arr).into(), (&arr).into(), (&mut ret).into()]) + .invoke(vec![(&arr).into(), (&arr).into(), ret.into()]) .unwrap(); assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); From 4830621d9a5595e6a7410a60c16b1766068bb682 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Wed, 23 Sep 2020 14:44:12 -0700 Subject: [PATCH 18/50] Add BiasAddAttrs rust bindings --- rust/tvm/src/ir/relay/attrs/nn.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index 42260a1ec8e3..6f0977066f45 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -25,3 +25,12 @@ pub struct Conv2DAttrsNode { pub out_layout: TString, pub out_dtype: DataType, } + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BiasAddAttrs"] +#[type_key = "relay.attrs.BiasAddAttrs"] +pub struct BiasAddAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32, +} From 53d83771bc3bef4183d899193d321b11fb7a29ba Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Wed, 23 Sep 2020 15:45:30 -0700 Subject: [PATCH 19/50] Add DenseAttrs rust bindings --- rust/tvm/src/ir/relay/attrs/nn.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index 6f0977066f45..223d1452e8f5 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -34,3 +34,13 @@ pub struct BiasAddAttrsNode { pub base: BaseAttrsNode, pub axis: i32, } + +#[repr(C)] +#[derive(Object)] +#[ref_name = "DenseAttrs"] +#[type_key = "relay.attrs.DenseAttrs"] +pub struct DenseAttrsNode { + pub base: BaseAttrsNode, + pub units: IndexExpr, + pub out_dtype: DataType, +} From 5062ecbfb6a36a91d0e3040f868ca9e9176f53d4 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Wed, 23 Sep 2020 20:47:58 -0700 Subject: [PATCH 20/50] Change to TVM string --- include/tvm/relay/attrs/nn.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 60c37aff2e90..e1d7a18eb8be 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -682,7 +682,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; Array padding; - std::string layout; + tvm::String layout; bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") { @@ -745,7 +745,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { /*! \brief Attributes for global pool operator */ struct GlobalPool2DAttrs : public tvm::AttrsNode { - std::string layout; + tvm::String layout; TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") { TVM_ATTR_FIELD(layout).set_default("NCHW").describe( From 66fcc918e74ae4713d34f877288d4a3150430d42 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Wed, 23 Sep 2020 20:48:09 -0700 Subject: [PATCH 21/50] Add more Rust bindings Add GlobalPool2DAttrs Rust binding Add ExpandDimsAttrs Rust bindings Add MaxPool2DAttrs rust bindings --- rust/tvm/src/ir/relay/attrs/mod.rs | 1 + rust/tvm/src/ir/relay/attrs/nn.rs | 22 ++++++++++++++++++++++ rust/tvm/src/ir/relay/attrs/transform.rs | 12 ++++++++++++ 3 files changed, 35 insertions(+) create mode 100644 rust/tvm/src/ir/relay/attrs/transform.rs diff --git a/rust/tvm/src/ir/relay/attrs/mod.rs b/rust/tvm/src/ir/relay/attrs/mod.rs index cb1fa9728ae1..459cef0ed76b 100644 --- a/rust/tvm/src/ir/relay/attrs/mod.rs +++ b/rust/tvm/src/ir/relay/attrs/mod.rs @@ -1 +1,2 @@ pub mod nn; +pub mod transform; diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index 223d1452e8f5..d648baf7eb10 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -44,3 +44,25 @@ pub struct DenseAttrsNode { pub units: IndexExpr, pub out_dtype: DataType, } + +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalPool2DAttrs"] +#[type_key = "relay.attrs.GlobalPool2DAttrs"] +pub struct GlobalPool2DAttrsNode { + pub base: BaseAttrsNode, + pub layout: TString, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "MaxPool2DAttrs"] +#[type_key = "relay.attrs.MaxPool2DAttrs"] +pub struct MaxPool2DAttrsNode { + pub base: BaseAttrsNode, + pub pool_size: Array, + pub strides: Array, + pub padding: Array, + pub layout: TString, + pub ceil_mode: bool, +} diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs new file mode 100644 index 000000000000..c7c90cd30682 --- /dev/null +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -0,0 +1,12 @@ +use crate::ir::attrs::BaseAttrsNode; +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "ExpandDimsAttrs"] +#[type_key = "relay.attrs.ExpandDimsAttrs"] +pub struct ExpandDimsAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32, + pub num_newaxis: i32, +} From 4519acdfc709cecbb440969956ce1751ca8251c2 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Thu, 24 Sep 2020 12:23:19 -0700 Subject: [PATCH 22/50] Fix some test attributes --- rust/tvm-graph-rt/src/graph.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rust/tvm-graph-rt/src/graph.rs b/rust/tvm-graph-rt/src/graph.rs index 91021dd12bb7..87dd4a76d5e4 100644 --- a/rust/tvm-graph-rt/src/graph.rs +++ b/rust/tvm-graph-rt/src/graph.rs @@ -46,8 +46,10 @@ const _NDARRAY_LIST_MAGIC: u64 = 0xF7E5_8D4F_0504_9CB7; /// /// # Examples /// -/// ```norun -/// let graph_json = fs::read_to_string("graph.json").unwrap(); +/// ```no_run +/// use tvm_graph_rt::Graph; +/// use std::convert::TryFrom; +/// let graph_json = std::fs::read_to_string("graph.json").unwrap(); /// let graph = Graph::try_from(&graph_json).unwrap(); /// ``` #[derive(Serialize, Deserialize, Debug)] @@ -147,7 +149,7 @@ impl<'a> TryFrom<&'a str> for Graph { /// /// # Examples /// -/// ```norun +/// ```no_compile /// use ndarray::Array; /// /// let syslib = SystemLibModule::default(); // a provider of TVM functions From 0b922f2cce1768e4a3e0961a49ab17e1786a3528 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Thu, 24 Sep 2020 12:27:53 -0700 Subject: [PATCH 23/50] Improve the NDArray api --- rust/tvm-rt/src/errors.rs | 2 - rust/tvm-rt/src/ndarray.rs | 112 +++++++++-------------- rust/tvm/examples/resnet/src/main.rs | 2 +- rust/tvm/src/ir/relay/mod.rs | 2 +- rust/tvm/tests/basics/src/main.rs | 4 +- rust/tvm/tests/callback/src/bin/array.rs | 4 +- 6 files changed, 49 insertions(+), 77 deletions(-) diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index e194bfa9febd..c884c56fed44 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -33,8 +33,6 @@ pub struct TypeMismatchError { #[derive(Debug, Error)] pub enum NDArrayError { - #[error("Missing NDArray shape.")] - MissingShape, #[error("Cannot convert from an empty array.")] EmptyArray, #[error("Invalid datatype when attempting to convert ndarray.")] diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 74a9b83a2404..1b764270d48f 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -38,7 +38,7 @@ //! .unwrap() //! .into_dyn(); // Rust's ndarray //! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); -//! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); +//! assert_eq!(nd.shape(), &[2, 2]); //! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); //! assert!(rnd.all_close(&a, 1e-8f32)); //! ``` @@ -60,7 +60,7 @@ use num_traits::Num; use crate::errors::NDArrayError; -use crate::object::{IsObjectRef, Object, ObjectPtr}; +use crate::object::{Object, ObjectPtr}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. #[repr(C)] @@ -78,16 +78,9 @@ pub struct NDArrayContainer { impl NDArrayContainer { pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Option> { let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; - println!("Base Object {:?}", base_offset); let base_ptr = unsafe { (handle as *mut i8).offset(-base_offset) }; - println!("Base Ptr {:?}", base_ptr); let object: *mut Object = unsafe { std::mem::transmute(base_ptr) }; - println!("Rust Object {:?}", object); let object_ptr = ObjectPtr::from_raw(object); - println!( - "{:?}", - crate::object::debug_print(IsObjectRef::from_ptr(object_ptr.clone())) - ); object_ptr.map(|ptr| { ptr.downcast::() .expect("we know this is an NDArray container") @@ -128,13 +121,31 @@ impl NDArray { } /// Returns the shape of the NDArray. - pub fn shape(&self) -> Option<&mut [usize]> { + pub fn shape(&self) -> &[usize] { let arr = self.as_dltensor(); if arr.shape.is_null() || arr.data.is_null() { - return None; - }; - let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; - Some(slc) + &[] + } else { + let s = unsafe { slice::from_raw_parts(arr.shape.cast(), self.ndim()) }; + debug_assert!(s.iter().all(|&x| x as i64 >= 0), "negative shape: {:?}", s); + s + } + } + + /// Returns the strides of the underlying NDArray. + pub fn strides(&self) -> Option<&[usize]> { + let arr = self.as_dltensor(); + if arr.strides.is_null() { + None + } else { + let s = unsafe { slice::from_raw_parts(arr.strides.cast(), self.ndim()) }; + debug_assert!( + s.iter().all(|&x| x as i64 >= 0), + "negative strides: {:?}", + s + ); + Some(s) + } } /// Returns true if the tensor is empty @@ -144,10 +155,11 @@ impl NDArray { /// Returns the total number of entries of the NDArray. pub fn len(&self) -> usize { - self.shape().unwrap_or(&mut []).iter().product() + self.shape().iter().product() } /// Returns the total bytes taken up by the data. + /// This is equal to `nd.len() * nd.dtype().itemsize()` pub fn size(&self) -> usize { self.len() * self.dtype().itemsize() } @@ -170,16 +182,6 @@ impl NDArray { .expect("number of dimensions must always be positive") } - /// Returns the strides of the underlying NDArray. - pub fn strides(&self) -> Option<&[usize]> { - unsafe { - let sz = self.ndim() * mem::size_of::(); - let strides_ptr = self.as_dltensor().strides as *const usize; - let slc = slice::from_raw_parts(strides_ptr, sz); - Some(slc) - } - } - /// Shows whether the underlying ndarray is contiguous in memory or not. pub fn is_contiguous(&self) -> Result { Ok(match self.strides() { @@ -187,7 +189,6 @@ impl NDArray { Some(strides) => { // NDArrayError::MissingShape in case shape is not determined self.shape() - .ok_or(NDArrayError::MissingShape)? .iter() .zip(strides) .rfold( @@ -195,7 +196,7 @@ impl NDArray { |(is_contig, expected_stride), (shape, stride)| { ( is_contig && *stride == expected_stride, - expected_stride * (*shape as usize), + expected_stride * shape, ) }, ) @@ -220,26 +221,19 @@ impl NDArray { /// let ctx = Context::cpu(0); /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); /// ndarray.copy_from_buffer(&mut data); - /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); + /// assert_eq!(ndarray.shape(), shape); /// assert_eq!(ndarray.to_vec::().unwrap(), data); /// ``` pub fn to_vec(&self) -> Result, NDArrayError> { - if !self.shape().is_some() { - return Err(NDArrayError::EmptyArray); - } - let earr = NDArray::empty( - self.shape().ok_or(NDArrayError::MissingShape)?, - Context::cpu(0), - self.dtype(), - ); + let earr = NDArray::empty(self.shape(), Context::cpu(0), self.dtype()); let target = self.copy_to_ndarray(earr)?; let arr = target.as_dltensor(); - let sz = self.size(); - let mut v: Vec = Vec::with_capacity(sz / mem::size_of::()); + let len = self.len(); + let mut v: Vec = Vec::with_capacity(len); unsafe { v.as_mut_ptr() - .copy_from_nonoverlapping(arr.data as *const T, sz); - v.set_len(sz); + .copy_from_nonoverlapping(arr.data as *const T, len); + v.set_len(len); } Ok(v) } @@ -294,11 +288,7 @@ impl NDArray { /// Copies the NDArray to a target context. pub fn copy_to_ctx(&self, target: &Context) -> Result { - let tmp = NDArray::empty( - self.shape().ok_or(NDArrayError::MissingShape)?, - *target, - self.dtype(), - ); + let tmp = NDArray::empty(self.shape(), *target, self.dtype()); let copy = self.copy_to_ndarray(tmp)?; Ok(copy) } @@ -333,7 +323,6 @@ impl NDArray { ctx.device_id as c_int, &mut handle as *mut _, )); - println!("{:?}", handle); let ptr = NDArrayContainer::from_raw(handle) .map(|o| o.downcast().expect("this should never fail")); NDArray(ptr) @@ -346,14 +335,8 @@ macro_rules! impl_from_ndarray_rustndarray { type Error = NDArrayError; fn try_from(nd: &NDArray) -> Result, Self::Error> { - if !nd.shape().is_some() { - return Err(NDArrayError::MissingShape); - } assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec( - &*nd.shape().ok_or(NDArrayError::MissingShape)?, - nd.to_vec::<$type>()?, - )?) + Ok(Array::from_shape_vec(&*nd.shape(), nd.to_vec::<$type>()?)?) } } @@ -361,14 +344,8 @@ macro_rules! impl_from_ndarray_rustndarray { type Error = NDArrayError; fn try_from(nd: &mut NDArray) -> Result, Self::Error> { - if !nd.shape().is_some() { - return Err(NDArrayError::MissingShape); - }; assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec( - &*nd.shape().ok_or(NDArrayError::MissingShape)?, - nd.to_vec::<$type>()?, - )?) + Ok(Array::from_shape_vec(&*nd.shape(), nd.to_vec::<$type>()?)?) } } }; @@ -405,16 +382,13 @@ mod tests { #[test] fn basics() { - let shape = &mut [1, 2, 3]; + let shape = &[1, 2, 3]; let ctx = Context::cpu(0); println!("before empty"); let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); println!("after empty"); - assert_eq!(ndarray.shape().unwrap(), shape); - assert_eq!( - ndarray.size(), - shape.to_vec().into_iter().product() - ); + assert_eq!(ndarray.shape(), shape); + assert_eq!(ndarray.len(), shape.iter().product()); assert_eq!(ndarray.ndim(), 3); assert!(ndarray.strides().is_none()); assert_eq!(ndarray.byte_offset(), 0); @@ -422,13 +396,13 @@ mod tests { #[test] fn copy() { - let shape = &mut [4]; + let shape = &[4]; let mut data = vec![1i32, 2, 3, 4]; let ctx = Context::cpu(0); let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); assert!(ndarray.to_vec::().is_ok()); ndarray.copy_from_buffer(&mut data); - assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!(ndarray.shape(), shape); assert_eq!(ndarray.to_vec::().unwrap(), data); assert_eq!(ndarray.ndim(), 1); assert!(ndarray.is_contiguous().is_ok()); @@ -464,7 +438,7 @@ mod tests { let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) .unwrap(); - assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + assert_eq!(nd.shape(), &[2, 2]); let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); assert!(rnd.all_close(&a, 1e-8f32)); } diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 16ca8c7386f1..bd1554a6f63d 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -60,7 +60,7 @@ fn main() { let input = NDArray::from_rust_ndarray(&arr, Context::cpu(0), DataType::float(32, 1)).unwrap(); println!( "input size is {:?}", - input.shape().expect("cannot get the input shape") + input.shape(), ); let graph = fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index b68602b2ee17..4f542b4d8785 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -561,7 +561,7 @@ def @main() -> float32 { assert_eq!(constant.data.dtype(), "float32".parse().unwrap(),); assert_eq!(constant.data.len(), 1); assert_eq!(constant.data.size(), 4); - assert_eq!(constant.data.shape(), None); + assert_eq!(constant.data.shape(), &[]); Ok(()) } } diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs index 9a21ff02b2f2..e4249a491746 100644 --- a/rust/tvm/tests/basics/src/main.rs +++ b/rust/tvm/tests/basics/src/main.rs @@ -33,7 +33,7 @@ fn main() { let dtype = DataType::from_str("float32").unwrap(); let mut arr = NDArray::empty(shape, ctx, dtype); arr.copy_from_buffer(data.as_mut_slice()); - let mut ret = NDArray::empty(shape, ctx, dtype); + let ret = NDArray::empty(shape, ctx, dtype); let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); if !fadd.enabled(ctx_name) { return; @@ -44,7 +44,7 @@ fn main() { fadd.entry() .expect("module must have entry point") - .invoke(vec![(&arr).into(), (&arr).into(), ret.into()]) + .invoke(vec![(&arr).into(), (&arr).into(), (&ret).into()]) .unwrap(); assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs index ad41bd18ec8b..f9db91881e1e 100644 --- a/rust/tvm/tests/callback/src/bin/array.rs +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -37,8 +37,8 @@ use tvm::{ fn main() { fn sum(args: Vec>) -> Result { let mut ret = 0f32; - let shape = &mut [2]; - for arg in args.iter() { + let shape = &[2]; + for arg in args { let e = NDArray::empty(shape, Context::cpu(0), DataType::float(32, 1)); let arg: NDArray = arg.try_into()?; let arr = arg.copy_to_ndarray(e)?; From 15f88fdb1d8e49a1649cd19e0a154eabe57fa6c4 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Fri, 25 Sep 2020 10:29:06 -0700 Subject: [PATCH 24/50] Fix some more ndarray stuff --- rust/tvm-rt/src/ndarray.rs | 100 +++++++++++++++++---------- rust/tvm-rt/src/object/object_ptr.rs | 5 +- 2 files changed, 66 insertions(+), 39 deletions(-) diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 1b764270d48f..c0c8637267dc 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -47,12 +47,13 @@ //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx -use std::convert::TryInto; use std::ffi::c_void; +use std::{borrow::Cow, convert::TryInto}; use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; use tvm_macros::Object; use tvm_sys::ffi::DLTensor; +use mem::size_of; use tvm_sys::{ffi, ByteArray, Context, DataType}; use ndarray::{Array, ArrayD}; @@ -79,8 +80,7 @@ impl NDArrayContainer { pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Option> { let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; let base_ptr = unsafe { (handle as *mut i8).offset(-base_offset) }; - let object: *mut Object = unsafe { std::mem::transmute(base_ptr) }; - let object_ptr = ObjectPtr::from_raw(object); + let object_ptr = ObjectPtr::from_raw(base_ptr.cast()); object_ptr.map(|ptr| { ptr.downcast::() .expect("we know this is an NDArray container") @@ -96,11 +96,27 @@ impl NDArrayContainer { &mut *std::mem::ManuallyDrop::new(object_ptr) .ptr .as_ptr() + .cast::() .offset(base_offset) + .cast::() } } } +fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { + if std::mem::size_of::() == 64 { + debug_assert!(slice.iter().all(|&x| x >= 0)); + let shape: &[usize] = unsafe { std::mem::transmute(slice) }; + Cow::Borrowed(shape) + } else { + let shape: Vec = slice + .iter() + .map(|&x| usize::try_from(x).unwrap_or_else(|_| panic!("Cannot fit into usize: {}", x))) + .collect(); + Cow::Owned(shape) + } +} + impl NDArray { pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Self { let ptr = NDArrayContainer::from_raw(handle); @@ -121,33 +137,41 @@ impl NDArray { } /// Returns the shape of the NDArray. - pub fn shape(&self) -> &[usize] { + pub fn shape(&self) -> &[i64] { let arr = self.as_dltensor(); if arr.shape.is_null() || arr.data.is_null() { &[] } else { - let s = unsafe { slice::from_raw_parts(arr.shape.cast(), self.ndim()) }; - debug_assert!(s.iter().all(|&x| x as i64 >= 0), "negative shape: {:?}", s); - s + unsafe { slice::from_raw_parts(arr.shape, self.ndim()) } } } + /// Returns the shape of the NDArray as a &[usize] + /// + /// On 64-bit platforms, this is zero-cost and uses the shape from the DLTensor. + /// On other platforms, this copies into a buffer. + pub fn shape_usize(&self) -> Cow<[usize]> { + cow_usize(self.shape()) + } + /// Returns the strides of the underlying NDArray. - pub fn strides(&self) -> Option<&[usize]> { + pub fn strides(&self) -> Option<&[i64]> { let arr = self.as_dltensor(); if arr.strides.is_null() { None } else { - let s = unsafe { slice::from_raw_parts(arr.strides.cast(), self.ndim()) }; - debug_assert!( - s.iter().all(|&x| x as i64 >= 0), - "negative strides: {:?}", - s - ); - Some(s) + Some(unsafe { slice::from_raw_parts(arr.strides, self.ndim()) }) } } + /// Returns the strides of the NDArray as a &[usize] + /// + /// On 64-bit platforms, this is zero-cost and uses the strides from the DLTensor. + /// On other platforms, this copies into a buffer. + pub fn strides_usize(&self) -> Option> { + self.strides().map(cow_usize) + } + /// Returns true if the tensor is empty pub fn is_empty(&self) -> bool { self.as_dltensor().data.is_null() @@ -155,7 +179,8 @@ impl NDArray { /// Returns the total number of entries of the NDArray. pub fn len(&self) -> usize { - self.shape().iter().product() + let len: i64 = self.shape().iter().product(); + usize::try_from(len).unwrap_or_else(|_| panic!("bad len: {}", len)) } /// Returns the total bytes taken up by the data. @@ -225,17 +250,14 @@ impl NDArray { /// assert_eq!(ndarray.to_vec::().unwrap(), data); /// ``` pub fn to_vec(&self) -> Result, NDArrayError> { - let earr = NDArray::empty(self.shape(), Context::cpu(0), self.dtype()); - let target = self.copy_to_ndarray(earr)?; - let arr = target.as_dltensor(); - let len = self.len(); - let mut v: Vec = Vec::with_capacity(len); + let n = self.size() / size_of::(); + let mut vec: Vec = Vec::with_capacity(n); + let ptr = vec.as_mut_ptr(); unsafe { - v.as_mut_ptr() - .copy_from_nonoverlapping(arr.data as *const T, len); - v.set_len(len); - } - Ok(v) + ptr.copy_from_nonoverlapping(self.as_dltensor().data.cast(), n); + vec.set_len(n) + }; + Ok(vec) } /// Converts the NDArray to [`ByteArray`]. @@ -260,7 +282,7 @@ impl NDArray { /// /// *Note*: if something goes wrong during the copy, it will panic /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. - pub fn copy_from_buffer(&mut self, data: &mut [T]) { + pub fn copy_from_buffer(&mut self, data: &[T]) { check_call!(ffi::TVMArrayCopyFromBytes( self.as_raw_dltensor(), data.as_ptr() as *mut _, @@ -299,7 +321,7 @@ impl NDArray { ctx: Context, dtype: DataType, ) -> Result { - let shape = rnd.shape().to_vec(); + let shape: Vec = rnd.shape().iter().map(|&x| x as i64).collect(); let mut nd = NDArray::empty(&shape, ctx, dtype); let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); nd.copy_from_buffer( @@ -310,11 +332,11 @@ impl NDArray { } /// Allocates and creates an empty NDArray given the shape, context and dtype. - pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { + pub fn empty(shape: &[i64], ctx: Context, dtype: DataType) -> NDArray { let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; let dtype: tvm_sys::ffi::DLDataType = dtype.into(); check_call!(ffi::TVMArrayAlloc( - shape.as_ptr() as *const i64, + shape.as_ptr(), shape.len() as c_int, i32::from(dtype.code) as c_int, i32::from(dtype.bits) as c_int, @@ -336,7 +358,10 @@ macro_rules! impl_from_ndarray_rustndarray { fn try_from(nd: &NDArray) -> Result, Self::Error> { assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec(&*nd.shape(), nd.to_vec::<$type>()?)?) + Ok(Array::from_shape_vec( + &*nd.shape_usize(), + nd.to_vec::<$type>()?, + )?) } } @@ -345,7 +370,10 @@ macro_rules! impl_from_ndarray_rustndarray { fn try_from(nd: &mut NDArray) -> Result, Self::Error> { assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec(&*nd.shape(), nd.to_vec::<$type>()?)?) + Ok(Array::from_shape_vec( + &*nd.shape_usize(), + nd.to_vec::<$type>()?, + )?) } } }; @@ -388,7 +416,7 @@ mod tests { let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); println!("after empty"); assert_eq!(ndarray.shape(), shape); - assert_eq!(ndarray.len(), shape.iter().product()); + assert_eq!(ndarray.len(), shape.iter().product::() as usize); assert_eq!(ndarray.ndim(), 3); assert!(ndarray.strides().is_none()); assert_eq!(ndarray.byte_offset(), 0); @@ -397,11 +425,11 @@ mod tests { #[test] fn copy() { let shape = &[4]; - let mut data = vec![1i32, 2, 3, 4]; + let data = vec![1i32, 2, 3, 4]; let ctx = Context::cpu(0); let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - assert!(ndarray.to_vec::().is_ok()); - ndarray.copy_from_buffer(&mut data); + assert_eq!(ndarray.to_vec::().unwrap(), vec![0, 0, 0, 0]); + ndarray.copy_from_buffer(&data); assert_eq!(ndarray.shape(), shape); assert_eq!(ndarray.to_vec::().unwrap(), data); assert_eq!(ndarray.ndim(), 1); diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index aa65b193931f..77254d2fbca2 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -300,9 +300,8 @@ impl<'a, T: IsObject> From> for ArgValue<'a> { fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); let object_ptr = object_ptr.upcast::(); - let index = object_ptr.type_index; - match index { - tvm_sys::ffi::TVMArgTypeCode_kTVMNDArrayHandle => { + match T::TYPE_KEY { + "runtime.NDArray" => { use crate::ndarray::NDArrayContainer; // TODO(this is probably not optimal) let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) From a435cb82b86792d3837c33fd12d99e102b8a1ada Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Fri, 25 Sep 2020 10:31:29 -0700 Subject: [PATCH 25/50] Get the resnet demo kinda working --- rust/tvm/examples/resnet/build.rs | 1 + rust/tvm/examples/resnet/src/main.rs | 42 +++++++++++++++------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index b259a626eb5e..d89be896dfcc 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -24,6 +24,7 @@ fn main() -> Result<()> { let output = Command::new("python3") .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .arg("--pretrained") .output() .with_context(|| anyhow::anyhow!("failed to run python3"))?; diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index bd1554a6f63d..c8c8aa0faff8 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -32,6 +32,7 @@ use tvm::*; fn main() { let ctx = Context::cpu(0); + println!("{}", concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")); let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap(); println!("original image dimensions: {:?}", img.dimensions()); // for bigger size images, one needs to first resize to 256x256 @@ -59,8 +60,10 @@ fn main() { // create input tensor from rust's ndarray let input = NDArray::from_rust_ndarray(&arr, Context::cpu(0), DataType::float(32, 1)).unwrap(); println!( - "input size is {:?}", + "input shape is {:?}, len: {}, size: {}", input.shape(), + input.len(), + input.size(), ); let graph = fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); @@ -76,7 +79,8 @@ fn main() { graph.into(), (&lib).into(), (&ctx.device_type).into(), - (&ctx.device_id).into(), + // NOTE you must pass the device id in as i32 because that's what TVM expects + (ctx.device_id as i32).into(), ]); // get graph runtime module @@ -89,6 +93,7 @@ fn main() { // parse parameters and convert to TVMByteArray let params: Vec = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap(); + println!("param bytes: {}", params.len()); let barr = ByteArray::from(¶ms); // load the parameters load_param_fn.invoke(vec![(&barr).into()]).unwrap(); @@ -98,7 +103,7 @@ fn main() { .unwrap(); set_input_fn - .invoke(vec!["data".into(), (&input).into()]) + .invoke(vec!["data".into(), input.into()]) .unwrap(); // get `run` function from runtime module @@ -106,7 +111,7 @@ fn main() { // execute the run function. Note that it has no argument run_fn.invoke(vec![]).unwrap(); // prepare to get the output - let output_shape = &mut [1, 1000]; + let output_shape = &[1, 1000]; let output = NDArray::empty(output_shape, Context::cpu(0), DataType::float(32, 1)); // get the `get_output` function from runtime module let ref get_output_fn = graph_runtime_module @@ -114,21 +119,20 @@ fn main() { .unwrap(); // execute the get output function get_output_fn - .invoke(vec![(&0).into(), (&output).into()]) + .invoke(vec![0.into(), (&output).into()]) .unwrap(); // flatten the output as Vec let output = output.to_vec::().unwrap(); // find the maximum entry in the output and its index - let mut argmax = -1; - let mut max_prob = 0.; - for i in 0..output.len() { - if output[i] > max_prob { - max_prob = output[i]; - argmax = i as i32; - } - } + let (argmax, max_prob) = output + .iter() + .copied() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .unwrap(); + // create a hash map of (class id, class name) - let mut synset: HashMap = HashMap::new(); + let mut synset: HashMap = HashMap::new(); let file = File::open("synset.csv").unwrap(); let mut rdr = csv::ReaderBuilder::new() .has_headers(true) @@ -136,16 +140,16 @@ fn main() { for result in rdr.records() { let record = result.unwrap(); - let id: i32 = record[0].parse().unwrap(); + let id: usize = record[0].parse().unwrap(); let cls = record[1].to_string(); synset.insert(id, cls); } + let label = synset + .get(&argmax) + .expect("cannot find the class id for argmax"); println!( "input image belongs to the class `{}` with probability {}", - synset - .get(&argmax) - .expect("cannot find the class id for argmax"), - max_prob + label, max_prob ); } From d119570dabe22428f9bfe279f2af5c6a7e69c804 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Thu, 24 Sep 2020 15:11:31 -0700 Subject: [PATCH 26/50] Add SoftmaxAttrs Rust bindings --- rust/tvm/src/ir/relay/attrs/nn.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index d648baf7eb10..1e2a9bffc20b 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -66,3 +66,12 @@ pub struct MaxPool2DAttrsNode { pub layout: TString, pub ceil_mode: bool, } + +#[repr(C)] +#[derive(Object)] +#[ref_name = "SoftmaxAttrs"] +#[type_key = "relay.attrs.SoftmaxAttrs"] +pub struct SoftmaxAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32, +} From f48282ae62ee38eb6f36d6124fe8dffbf0c88311 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Fri, 25 Sep 2020 15:34:11 -0700 Subject: [PATCH 27/50] Implement Hash and Eq for Relay Exprs --- rust/tvm/src/ir/relay/mod.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 4f542b4d8785..e539221d1db6 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -19,6 +19,8 @@ pub mod attrs; +use std::hash::Hash; + use crate::runtime::array::Array; use crate::runtime::{object::*, String as TString}; @@ -55,6 +57,20 @@ impl ExprNode { } } +impl Hash for Expr { + fn hash(&self, state: &mut H) { + self.as_ptr().unwrap().ptr.hash(state) + } +} + +impl PartialEq for Expr { + fn eq(&self, other: &Self) -> bool { + self.as_ptr().unwrap().ptr.eq(&other.as_ptr().unwrap().ptr) + } +} + +impl Eq for Expr {} + #[repr(C)] #[derive(Object)] #[ref_name = "Id"] From 497904126b219ca1b5ab52cdea5410250712a9cd Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Fri, 25 Sep 2020 15:41:01 -0700 Subject: [PATCH 28/50] Add underscore to unused function --- rust/tvm-rt/src/ndarray.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index c0c8637267dc..6845689067af 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -118,7 +118,7 @@ fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { } impl NDArray { - pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Self { + pub(crate) fn _from_raw(handle: ffi::TVMArrayHandle) -> Self { let ptr = NDArrayContainer::from_raw(handle); NDArray(ptr) } From 0ef7e3e6d47f078595b848ab36e3f284728d3ee2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 25 Sep 2020 16:44:05 -0700 Subject: [PATCH 29/50] Fix broken ass resnet script --- include/tvm/runtime/ndarray.h | 2 - rust/tvm/examples/resnet/build.rs | 1 - rust/tvm/examples/resnet/src/build_resnet.py | 142 +++++++++---------- src/runtime/ndarray.cc | 1 - 4 files changed, 66 insertions(+), 80 deletions(-) diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index de6e78f2dde9..f6f74f1aa1d2 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -373,10 +373,8 @@ inline ObjectPtr NDArray::FFIDataFromHandle(TVMArrayHandle handle) { inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { // NOTE: it is necessary to cast to container then to base // so that the FFI handle uses the ContainerBase address. - std::cout << "Object: " << const_cast(nd.get()) << std::endl; auto ptr = reinterpret_cast(static_cast( static_cast(const_cast(nd.get())))); - std::cout << "Ptr: " << ptr << std::endl; return ptr; } diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index d89be896dfcc..b259a626eb5e 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -24,7 +24,6 @@ fn main() -> Result<()> { let output = Command::new("python3") .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) - .arg("--pretrained") .output() .with_context(|| anyhow::anyhow!("failed to run python3"))?; diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py index 904f244e0a9a..5a9f9d829c15 100644 --- a/rust/tvm/examples/resnet/src/build_resnet.py +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -21,6 +21,7 @@ import logging from os import path as osp import sys +import shutil import numpy as np @@ -29,121 +30,110 @@ from tvm import relay from tvm.relay import testing from tvm.contrib import graph_runtime, cc +from PIL import Image +from tvm.contrib.download import download_testdata +from mxnet.gluon.model_zoo.vision import get_model -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) -parser = argparse.ArgumentParser(description="Resnet build example") +parser = argparse.ArgumentParser(description='Resnet build example') aa = parser.add_argument -aa("--build-dir", type=str, required=True, help="directory to put the build artifacts") -aa("--pretrained", action="store_true", help="use a pretrained resnet") -aa("--batch-size", type=int, default=1, help="input image batch size") -aa( - "--opt-level", - type=int, - default=3, - help="level of optimization. 0 is unoptimized and 3 is the highest level", -) -aa("--target", type=str, default="llvm", help="target context for compilation") -aa("--image-shape", type=str, default="3,224,224", help="input image dimensions") -aa("--image-name", type=str, default="cat.png", help="name of input image to download") +aa('--build-dir', type=str, required=True, help='directory to put the build artifacts') +aa('--batch-size', type=int, default=1, help='input image batch size') +aa('--opt-level', type=int, default=3, + help='level of optimization. 0 is unoptimized and 3 is the highest level') +aa('--target', type=str, default='llvm', help='target context for compilation') +aa('--image-shape', type=str, default='3,224,224', help='input image dimensions') +aa('--image-name', type=str, default='cat.png', help='name of input image to download') args = parser.parse_args() build_dir = args.build_dir batch_size = args.batch_size opt_level = args.opt_level -target = tvm.target.Target(args.target) +target = tvm.target.create(args.target) image_shape = tuple(map(int, args.image_shape.split(","))) data_shape = (batch_size,) + image_shape - def build(target_dir): """ Compiles resnet18 with TVM""" - deploy_lib = osp.join(target_dir, "deploy_lib.o") - if osp.exists(deploy_lib): - return - - if args.pretrained: - # needs mxnet installed - from mxnet.gluon.model_zoo.vision import get_model - - # if `--pretrained` is enabled, it downloads a pretrained - # resnet18 trained on imagenet1k dataset for image classification task - block = get_model("resnet18_v1", pretrained=True) - net, params = relay.frontend.from_mxnet(block, {"data": data_shape}) - # we want a probability so add a softmax operator - func = net["main"] - net = relay.Function( - func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs - ) - else: - # use random weights from relay.testing - net, params = relay.testing.resnet.get_workload( - num_layers=18, batch_size=batch_size, image_shape=image_shape - ) - - # compile the model - with tvm.transform.PassContext(opt_level=opt_level): - graph, lib, params = relay.build_module.build(net, target, params=params) + # Download the pretrained model in MxNet's format. + block = get_model("resnet18_v1", pretrained=True) + + shape_dict = {"data": (1, 3, 224, 224) } + mod, params = relay.frontend.from_mxnet(block, shape_dict) + # Add softmax to do classification in last layer. + func = mod["main"] + func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) + + target = "llvm" + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build(func, target, params=params) # save the model artifacts + deploy_lib = osp.join(target_dir, 'deploy_lib.o') lib.save(deploy_lib) - cc.create_shared(osp.join(target_dir, "deploy_lib.so"), [osp.join(target_dir, "deploy_lib.o")]) + cc.create_shared(osp.join(target_dir, "deploy_lib.so"), + [osp.join(target_dir, "deploy_lib.o")]) with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: fo.write(graph) - with open(osp.join(target_dir, "deploy_param.params"), "wb") as fo: + with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: fo.write(relay.save_param_dict(params)) - def download_img_labels(): """ Download an image and imagenet1k class labels for test""" from mxnet.gluon.utils import download - img_name = "cat.png" - synset_url = "".join( - [ - "https://gist.githubusercontent.com/zhreshold/", - "4d0b62f3d01426887599d4f7ede23ee5/raw/", - "596b27d23537e5a1b5751d2b0481ef172f58b539/", - "imagenet1000_clsid_to_human.txt", - ] - ) - synset_name = "synset.txt" - download("https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true", img_name) - download(synset_url, synset_name) - - with open(synset_name) as fin: + synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) + synset_name = 'synset.txt' + synset_path = download_testdata(synset_url, synset_name, module="data") + + with open(synset_path) as fin: synset = eval(fin.read()) - with open("synset.csv", "w") as fout: - w = csv.writer(fout) - w.writerows(synset.items()) + with open(synset_name, 'w') as f: + for key in synset: + f.write(synset[key]) + f.write("\n") + + return synset +def transform_image(image): + image = np.array(image) - np.array([123.0, 117.0, 104.0]) + image /= np.array([58.395, 57.12, 57.375]) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :] + return image + +def get_cat_image(): + img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true" + img_path = download_testdata(img_url, "cat.png", module="data") + shutil.copyfile(img_path, "cat.png") + img = Image.open(img_path).resize((224, 224)) + return transform_image(img) def test_build(build_dir): - """ Sanity check with random input""" + """ Sanity check with the cat image we download.""" graph = open(osp.join(build_dir, "deploy_graph.json")).read() lib = tvm.runtime.load_module(osp.join(build_dir, "deploy_lib.so")) - params = bytearray(open(osp.join(build_dir, "deploy_param.params"), "rb").read()) - input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) + params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read()) + input_data = get_cat_image() ctx = tvm.cpu() module = graph_runtime.create(graph, lib, ctx) module.load_params(params) module.run(data=input_data) out = module.get_output(0).asnumpy() + top1 = np.argmax(out[0]) + synset = download_img_labels() + print("TVM prediction top-1:", top1, synset[top1]) - -if __name__ == "__main__": - logger.info("building the model") +if __name__ == '__main__': + logger.info("Compiling the model to graph runtime.") build(build_dir) - logger.info("build was successful") - logger.info("test the build artifacts") + logger.info("Testing the model's predication on test data.") test_build(build_dir) - logger.info("test was successful") - if args.pretrained: - download_img_labels() - logger.info("image and synset downloads are successful") diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index b176b88985b6..c08d36ed5e79 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -270,7 +270,6 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ auto ndarray = NDArray::Empty(std::vector(shape, shape + ndim), dtype, ctx); *out = NDArray::Internal::MoveToFFIHandle(ndarray); - std::flush(std::cout); API_END(); } From f01dcfc3341c04303401b89ff576bddf428c7c55 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Fri, 25 Sep 2020 16:19:45 -0700 Subject: [PATCH 30/50] Improve some ndarray conversions --- rust/tvm-rt/src/ndarray.rs | 50 ++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 6845689067af..9d4954889085 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -51,9 +51,9 @@ use std::ffi::c_void; use std::{borrow::Cow, convert::TryInto}; use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; +use mem::size_of; use tvm_macros::Object; use tvm_sys::ffi::DLTensor; -use mem::size_of; use tvm_sys::{ffi, ByteArray, Context, DataType}; use ndarray::{Array, ArrayD}; @@ -208,8 +208,8 @@ impl NDArray { } /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> Result { - Ok(match self.strides() { + pub fn is_contiguous(&self) -> bool { + match self.strides() { None => true, Some(strides) => { // NDArrayError::MissingShape in case shape is not determined @@ -227,7 +227,7 @@ impl NDArray { ) .0 } - }) + } } pub fn byte_offset(&self) -> isize { @@ -252,11 +252,12 @@ impl NDArray { pub fn to_vec(&self) -> Result, NDArrayError> { let n = self.size() / size_of::(); let mut vec: Vec = Vec::with_capacity(n); + let ptr = vec.as_mut_ptr(); - unsafe { - ptr.copy_from_nonoverlapping(self.as_dltensor().data.cast(), n); - vec.set_len(n) - }; + let slice = unsafe { slice::from_raw_parts_mut(ptr, n) }; + self.copy_to_buffer(slice); + + unsafe { vec.set_len(n) }; Ok(vec) } @@ -290,6 +291,29 @@ impl NDArray { )); } + pub fn copy_to_buffer(&self, data: &mut [T]) { + assert_eq!(self.size(), data.len() * size_of::()); + check_call!(ffi::TVMArrayCopyToBytes( + self.as_raw_dltensor(), + data.as_ptr() as *mut _, + self.size(), + )); + } + + pub fn fill_from_iter(&mut self, iter: I) + where + T: Num32, + I: ExactSizeIterator, + { + assert!(self.is_contiguous()); + assert_eq!(self.size(), size_of::() * iter.len()); + let mut ptr: *mut T = self.as_dltensor().data.cast(); + iter.for_each(|x| unsafe { + ptr.write(x); + ptr = ptr.add(1); + }) + } + /// Copies the NDArray to another target NDArray. pub fn copy_to_ndarray(&self, target: NDArray) -> Result { if self.dtype() != target.dtype() { @@ -317,17 +341,13 @@ impl NDArray { /// Converts a Rust's ndarray to TVM NDArray. pub fn from_rust_ndarray( - rnd: &ArrayD, + input_nd: &ArrayD, ctx: Context, dtype: DataType, ) -> Result { - let shape: Vec = rnd.shape().iter().map(|&x| x as i64).collect(); + let shape: Vec = input_nd.shape().iter().map(|&x| x as i64).collect(); let mut nd = NDArray::empty(&shape, ctx, dtype); - let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); - nd.copy_from_buffer( - buf.as_slice_mut() - .expect("Array from iter must be contiguous."), - ); + nd.fill_from_iter(input_nd.iter().copied()); Ok(nd) } From 84f864e4895d96813c8f6b002f8c56c43306ef83 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Fri, 25 Sep 2020 16:58:49 -0700 Subject: [PATCH 31/50] Make sure the build script runs correctly --- rust/tvm/examples/resnet/build.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index b259a626eb5e..9306b1e4928b 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -18,7 +18,7 @@ */ use anyhow::{Context, Result}; -use std::{path::Path, process::Command}; +use std::{io::Write, path::Path, process::Command}; fn main() -> Result<()> { let output = Command::new("python3") @@ -26,7 +26,12 @@ fn main() -> Result<()> { .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) .output() .with_context(|| anyhow::anyhow!("failed to run python3"))?; - + if !output.status.success() { + std::io::stdout() + .write_all(&output.stderr) + .expect("Failed to write error"); + panic!("Failed to execute build script"); + } assert!( Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(), "Could not prepare demo: {}", From c702cf4802700d9a6b298f3f84e0168c89595731 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 25 Sep 2020 18:11:21 -0700 Subject: [PATCH 32/50] Clean up ResNet example tremedously Expose C++ graph runtime via cleaner Rust API rewrite example. --- rust/tvm/examples/resnet/build.rs | 8 +-- rust/tvm/examples/resnet/src/main.rs | 103 +++++++++------------------ rust/tvm/src/runtime/graph_rt.rs | 78 ++++++++++++++++++++ rust/tvm/src/runtime/mod.rs | 2 + 4 files changed, 117 insertions(+), 74 deletions(-) create mode 100644 rust/tvm/src/runtime/graph_rt.rs diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index 9306b1e4928b..d1f710567e04 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -17,15 +17,14 @@ * under the License. */ -use anyhow::{Context, Result}; use std::{io::Write, path::Path, process::Command}; -fn main() -> Result<()> { +fn main() { let output = Command::new("python3") .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) .output() - .with_context(|| anyhow::anyhow!("failed to run python3"))?; + .expect("Failed to execute command"); if !output.status.success() { std::io::stdout() .write_all(&output.stderr) @@ -42,11 +41,8 @@ fn main() -> Result<()> { .last() .unwrap_or("") ); - println!( "cargo:rustc-link-search=native={}", env!("CARGO_MANIFEST_DIR") ); - - Ok(()) } diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index c8c8aa0faff8..039d5a1844f8 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -18,22 +18,24 @@ */ use std::{ - collections::HashMap, - convert::TryInto, fs::{self, File}, + io::{BufRead, BufReader}, path::Path, }; use ::ndarray::{Array, ArrayD, Axis}; use image::{FilterType, GenericImageView}; -use tvm::runtime::ByteArray; +use anyhow::Context as _; +use tvm::runtime::graph_rt::GraphRt; use tvm::*; -fn main() { +fn main() -> anyhow::Result<()> { let ctx = Context::cpu(0); println!("{}", concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")); - let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap(); + let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")) + .context("Failed to open cat.png")?; + println!("original image dimensions: {:?}", img.dimensions()); // for bigger size images, one needs to first resize to 256x256 // with `img.resize_exact` method and then `image.crop` to 224x224 @@ -53,76 +55,47 @@ fn main() { } } - let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap(); + let arr = Array::from_shape_vec((224, 224, 3), pixels)?; let arr: ArrayD = arr.permuted_axes([2, 0, 1]).into_dyn(); // make arr shape as [1, 3, 224, 224] acceptable to resnet let arr = arr.insert_axis(Axis(0)); // create input tensor from rust's ndarray - let input = NDArray::from_rust_ndarray(&arr, Context::cpu(0), DataType::float(32, 1)).unwrap(); + let input = NDArray::from_rust_ndarray(&arr, Context::cpu(0), DataType::float(32, 1))?; println!( "input shape is {:?}, len: {}, size: {}", input.shape(), input.len(), input.size(), ); - let graph = - fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); + + let graph = fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")) + .context("Failed to open graph")?; + // load the built module let lib = Module::load(&Path::new(concat!( env!("CARGO_MANIFEST_DIR"), "/deploy_lib.so" - ))) - .unwrap(); - // get the global TVM graph runtime function - let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); - let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ - graph.into(), - (&lib).into(), - (&ctx.device_type).into(), - // NOTE you must pass the device id in as i32 because that's what TVM expects - (ctx.device_id as i32).into(), - ]); - - // get graph runtime module - let graph_runtime_module: Module = runtime_create_fn_ret.unwrap().try_into().unwrap(); - - // get the registered `load_params` from runtime module - let ref load_param_fn = graph_runtime_module - .get_function("load_params", false) - .unwrap(); + )))?; + + let mut graph_rt = GraphRt::create_from_parts(&graph, lib, ctx)?; + // parse parameters and convert to TVMByteArray - let params: Vec = - fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap(); + let params: Vec = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?; + println!("param bytes: {}", params.len()); - let barr = ByteArray::from(¶ms); - // load the parameters - load_param_fn.invoke(vec![(&barr).into()]).unwrap(); - // get the set_input function - let ref set_input_fn = graph_runtime_module - .get_function("set_input", false) - .unwrap(); - set_input_fn - .invoke(vec!["data".into(), input.into()]) - .unwrap(); + graph_rt.load_params(¶ms)?; + graph_rt.set_input("data", input)?; + graph_rt.run()?; - // get `run` function from runtime module - let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); - // execute the run function. Note that it has no argument - run_fn.invoke(vec![]).unwrap(); // prepare to get the output let output_shape = &[1, 1000]; let output = NDArray::empty(output_shape, Context::cpu(0), DataType::float(32, 1)); - // get the `get_output` function from runtime module - let ref get_output_fn = graph_runtime_module - .get_function("get_output", false) - .unwrap(); - // execute the get output function - get_output_fn - .invoke(vec![0.into(), (&output).into()]) - .unwrap(); + graph_rt.get_output_into(0, output.clone())?; + // flatten the output as Vec - let output = output.to_vec::().unwrap(); + let output = output.to_vec::()?; + // find the maximum entry in the output and its index let (argmax, max_prob) = output .iter() @@ -132,24 +105,18 @@ fn main() { .unwrap(); // create a hash map of (class id, class name) - let mut synset: HashMap = HashMap::new(); - let file = File::open("synset.csv").unwrap(); - let mut rdr = csv::ReaderBuilder::new() - .has_headers(true) - .from_reader(file); - - for result in rdr.records() { - let record = result.unwrap(); - let id: usize = record[0].parse().unwrap(); - let cls = record[1].to_string(); - synset.insert(id, cls); - } + let file = File::open("synset.txt").context("failed to open synset")?; + let synset: Vec = BufReader::new(file) + .lines() + .into_iter() + .map(|x| x.expect("readline failed")) + .collect(); - let label = synset - .get(&argmax) - .expect("cannot find the class id for argmax"); + let label = &synset[argmax]; println!( "input image belongs to the class `{}` with probability {}", label, max_prob ); + + Ok(()) } diff --git a/rust/tvm/src/runtime/graph_rt.rs b/rust/tvm/src/runtime/graph_rt.rs new file mode 100644 index 000000000000..835ac66b8aaf --- /dev/null +++ b/rust/tvm/src/runtime/graph_rt.rs @@ -0,0 +1,78 @@ +use std::convert::TryInto; + +use crate::runtime::Function; +use crate::{runtime::function::Result, runtime::ByteArray, Context, Module, NDArray}; + +/// An instance of the C++ graph runtime. +/// +/// An efficient and light weight runtime for static deep learning models. +pub struct GraphRt { + /// The backing graph runtime module which exposes a set of packed functions + /// which can be invoked by a client. + /// + /// In the graph runtime module, it exposes create, load_params, set_input, get_output, and run. + module: Module, +} + +impl GraphRt { + /// Create a graph runtime from the deprecated graph, lib, ctx triple. + pub fn create_from_parts(graph: &str, lib: Module, ctx: Context) -> Result { + let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); + + let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ + graph.into(), + (&lib).into(), + (&ctx.device_type).into(), + // NOTE you must pass the device id in as i32 because that's what TVM expects + (ctx.device_id as i32).into(), + ]); + let graph_runtime_module: Module = runtime_create_fn_ret?.try_into()?; + Ok(Self { + module: graph_runtime_module, + }) + } + + /// Load the parameters of the model into the runtime. + pub fn load_params

(&mut self, params: P) -> Result<()> + where + P: Into, + { + let load_param_fn = self.module.get_function("load_params", false)?; + + let params: ByteArray = params.into(); + + load_param_fn.invoke(vec![(¶ms).into()])?; + + Ok(()) + } + + /// Set the input with name `name` with the value of `input`. + pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> { + let ref set_input_fn = self.module.get_function("set_input", false)?; + + set_input_fn.invoke(vec![name.into(), input.into()])?; + Ok(()) + } + + /// Run the graph module, once setting parameters and inputs. + pub fn run(&mut self) -> Result<()> { + let ref run_fn = self.module.get_function("run", false)?; + + // execute the run function. Note that it has no argument + run_fn.invoke(vec![])?; + Ok(()) + } + + /// Extract the ith output from the graph runtime and returns it. + pub fn get_output(&mut self, i: i64) -> Result { + let get_output_fn = self.module.get_function("get_output", false)?; + get_output_fn.invoke(vec![i.into()])?.try_into() + } + + /// Extract the ith output from the graph runtime and write the results into output. + pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> { + let get_output_fn = self.module.get_function("get_output", false)?; + get_output_fn.invoke(vec![i.into(), output.into()])?; + Ok(()) + } +} diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs index 69fbb371824a..84da186557f7 100644 --- a/rust/tvm/src/runtime/mod.rs +++ b/rust/tvm/src/runtime/mod.rs @@ -18,3 +18,5 @@ */ pub use tvm_rt::*; + +pub mod graph_rt; From 3b6edf9ec0b6b3ab6a91174e7e2aa321cd8ec9b2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 25 Sep 2020 18:18:50 -0700 Subject: [PATCH 33/50] Add ASF header --- rust/tvm/src/ir/attrs.rs | 19 +++++++++++++++++++ rust/tvm/src/ir/expr.rs | 19 +++++++++++++++++++ rust/tvm/src/ir/function.rs | 19 +++++++++++++++++++ rust/tvm/src/ir/module.rs | 19 +++++++++++++++++++ rust/tvm/src/ir/op.rs | 19 +++++++++++++++++++ rust/tvm/src/ir/relay/attrs/mod.rs | 19 +++++++++++++++++++ rust/tvm/src/ir/relay/attrs/nn.rs | 19 +++++++++++++++++++ rust/tvm/src/ir/relay/attrs/transform.rs | 19 +++++++++++++++++++ rust/tvm/src/ir/ty.rs | 19 +++++++++++++++++++ rust/tvm/src/python.rs | 19 +++++++++++++++++++ rust/tvm/src/runtime/graph_rt.rs | 19 +++++++++++++++++++ 11 files changed, 209 insertions(+) diff --git a/rust/tvm/src/ir/attrs.rs b/rust/tvm/src/ir/attrs.rs index 883ee7d699e1..5bd027ab4b4c 100644 --- a/rust/tvm/src/ir/attrs.rs +++ b/rust/tvm/src/ir/attrs.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::runtime::Object; use tvm_macros::Object; diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index a8a188e39ae2..91c42f0edbcf 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use super::relay; use crate::runtime::String as TString; use crate::runtime::{self, external, IsObject, IsObjectRef, Object, ObjectPtr, ObjectRef}; diff --git a/rust/tvm/src/ir/function.rs b/rust/tvm/src/ir/function.rs index e6a1d3d9d620..3043bf9e7cff 100644 --- a/rust/tvm/src/ir/function.rs +++ b/rust/tvm/src/ir/function.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::ir::relay::ExprNode; use crate::runtime::{IsObject, IsObjectRef, ObjectRef}; diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 365680160c33..e0444b3101da 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::runtime::array::Array; use crate::runtime::function::Result; use crate::runtime::map::Map; diff --git a/rust/tvm/src/ir/op.rs b/rust/tvm/src/ir/op.rs index 4ab74c4c6625..d81d6a69c1eb 100644 --- a/rust/tvm/src/ir/op.rs +++ b/rust/tvm/src/ir/op.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::ir::relay::ExprNode; use crate::runtime::array::Array; use crate::runtime::ObjectRef; diff --git a/rust/tvm/src/ir/relay/attrs/mod.rs b/rust/tvm/src/ir/relay/attrs/mod.rs index 459cef0ed76b..d1bcc0009657 100644 --- a/rust/tvm/src/ir/relay/attrs/mod.rs +++ b/rust/tvm/src/ir/relay/attrs/mod.rs @@ -1,2 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + pub mod nn; pub mod transform; diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index 1e2a9bffc20b..f743534e5f61 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::ir::attrs::BaseAttrsNode; use crate::ir::PrimExpr; use crate::runtime::array::Array; diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs index c7c90cd30682..863f07617778 100644 --- a/rust/tvm/src/ir/relay/attrs/transform.rs +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::ir::attrs::BaseAttrsNode; use tvm_macros::Object; diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index 80cb11c4b965..71bafc998b81 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::runtime::{Object, ObjectRef}; use tvm_macros::Object; use tvm_rt::{array::Array, DataType}; diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs index 87cc6cd2be79..89558af733b3 100644 --- a/rust/tvm/src/python.rs +++ b/rust/tvm/src/python.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use pyo3::prelude::*; /// Load the Python interpreter into the address space. diff --git a/rust/tvm/src/runtime/graph_rt.rs b/rust/tvm/src/runtime/graph_rt.rs index 835ac66b8aaf..8b26ebb4ca22 100644 --- a/rust/tvm/src/runtime/graph_rt.rs +++ b/rust/tvm/src/runtime/graph_rt.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use std::convert::TryInto; use crate::runtime::Function; From f0af06ed6e7c1750fe5a8f7d5fd51aab378a17ed Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Tue, 29 Sep 2020 10:56:29 -0700 Subject: [PATCH 34/50] Format --- src/relay/transforms/combine_parallel_conv2d.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index 68520efe2bbd..54aec99f46fb 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -196,7 +196,8 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { auto channels = GetConv2DSuperChannelsDim(conv2d); num_filters += channels; } - auto index = branches[0][0]->attrs.as()->kernel_layout.operator std::string().find('O'); + auto index = + branches[0][0]->attrs.as()->kernel_layout.operator std::string().find('O'); CHECK_NE(index, std::string::npos); return std::make_tuple(MakeConcatenate(Tuple(weights), index), tir::make_const(DataType::Int(32), num_filters)); From 70e8a3e22875fd372aa48a488bcd192c508c0521 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Tue, 29 Sep 2020 11:08:56 -0700 Subject: [PATCH 35/50] Format --- include/tvm/relay/attrs/nn.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index e1d7a18eb8be..b2555de6d35e 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -28,6 +28,7 @@ #include #include + #include "tvm/runtime/container.h" namespace tvm { From 3e9648438a630086344842da7cab1b8b5cd3ebc8 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Tue, 29 Sep 2020 11:22:20 -0700 Subject: [PATCH 36/50] Format resnet rust python script --- rust/tvm/examples/resnet/src/build_resnet.py | 63 +++++++++++++------- 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py index 5a9f9d829c15..324bb52e08a9 100644 --- a/rust/tvm/examples/resnet/src/build_resnet.py +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -34,18 +34,24 @@ from tvm.contrib.download import download_testdata from mxnet.gluon.model_zoo.vision import get_model -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) -parser = argparse.ArgumentParser(description='Resnet build example') +parser = argparse.ArgumentParser(description="Resnet build example") aa = parser.add_argument -aa('--build-dir', type=str, required=True, help='directory to put the build artifacts') -aa('--batch-size', type=int, default=1, help='input image batch size') -aa('--opt-level', type=int, default=3, - help='level of optimization. 0 is unoptimized and 3 is the highest level') -aa('--target', type=str, default='llvm', help='target context for compilation') -aa('--image-shape', type=str, default='3,224,224', help='input image dimensions') -aa('--image-name', type=str, default='cat.png', help='name of input image to download') +aa("--build-dir", type=str, required=True, help="directory to put the build artifacts") +aa("--batch-size", type=int, default=1, help="input image batch size") +aa( + "--opt-level", + type=int, + default=3, + help="level of optimization. 0 is unoptimized and 3 is the highest level", +) +aa("--target", type=str, default="llvm", help="target context for compilation") +aa("--image-shape", type=str, default="3,224,224", help="input image dimensions") +aa("--image-name", type=str, default="cat.png", help="name of input image to download") args = parser.parse_args() build_dir = args.build_dir @@ -55,54 +61,62 @@ image_shape = tuple(map(int, args.image_shape.split(","))) data_shape = (batch_size,) + image_shape + def build(target_dir): """ Compiles resnet18 with TVM""" # Download the pretrained model in MxNet's format. block = get_model("resnet18_v1", pretrained=True) - shape_dict = {"data": (1, 3, 224, 224) } + shape_dict = {"data": (1, 3, 224, 224)} mod, params = relay.frontend.from_mxnet(block, shape_dict) # Add softmax to do classification in last layer. func = mod["main"] - func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) + func = relay.Function( + func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs + ) target = "llvm" with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(func, target, params=params) # save the model artifacts - deploy_lib = osp.join(target_dir, 'deploy_lib.o') + deploy_lib = osp.join(target_dir, "deploy_lib.o") lib.save(deploy_lib) - cc.create_shared(osp.join(target_dir, "deploy_lib.so"), - [osp.join(target_dir, "deploy_lib.o")]) + cc.create_shared(osp.join(target_dir, "deploy_lib.so"), [osp.join(target_dir, "deploy_lib.o")]) with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: fo.write(graph) - with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: + with open(osp.join(target_dir, "deploy_param.params"), "wb") as fo: fo.write(relay.save_param_dict(params)) + def download_img_labels(): """ Download an image and imagenet1k class labels for test""" from mxnet.gluon.utils import download - synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', - '4d0b62f3d01426887599d4f7ede23ee5/raw/', - '596b27d23537e5a1b5751d2b0481ef172f58b539/', - 'imagenet1000_clsid_to_human.txt']) - synset_name = 'synset.txt' + synset_url = "".join( + [ + "https://gist.githubusercontent.com/zhreshold/", + "4d0b62f3d01426887599d4f7ede23ee5/raw/", + "596b27d23537e5a1b5751d2b0481ef172f58b539/", + "imagenet1000_clsid_to_human.txt", + ] + ) + synset_name = "synset.txt" synset_path = download_testdata(synset_url, synset_name, module="data") with open(synset_path) as fin: synset = eval(fin.read()) - with open(synset_name, 'w') as f: + with open(synset_name, "w") as f: for key in synset: f.write(synset[key]) f.write("\n") return synset + def transform_image(image): image = np.array(image) - np.array([123.0, 117.0, 104.0]) image /= np.array([58.395, 57.12, 57.375]) @@ -110,6 +124,7 @@ def transform_image(image): image = image[np.newaxis, :] return image + def get_cat_image(): img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true" img_path = download_testdata(img_url, "cat.png", module="data") @@ -117,11 +132,12 @@ def get_cat_image(): img = Image.open(img_path).resize((224, 224)) return transform_image(img) + def test_build(build_dir): """ Sanity check with the cat image we download.""" graph = open(osp.join(build_dir, "deploy_graph.json")).read() lib = tvm.runtime.load_module(osp.join(build_dir, "deploy_lib.so")) - params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read()) + params = bytearray(open(osp.join(build_dir, "deploy_param.params"), "rb").read()) input_data = get_cat_image() ctx = tvm.cpu() module = graph_runtime.create(graph, lib, ctx) @@ -132,7 +148,8 @@ def test_build(build_dir): synset = download_img_labels() print("TVM prediction top-1:", top1, synset[top1]) -if __name__ == '__main__': + +if __name__ == "__main__": logger.info("Compiling the model to graph runtime.") build(build_dir) logger.info("Testing the model's predication on test data.") From e893b5748f4136ff6d9a9a11320e437d07cc5384 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 29 Sep 2020 15:21:03 -0700 Subject: [PATCH 37/50] Add type files and refactor span --- rust/tvm-macros/Cargo.toml | 1 + rust/tvm/src/ir/mod.rs | 1 + rust/tvm/src/ir/span.rs | 3 + rust/tvm/src/ir/ty.rs | 192 ++++++++++++++++++++++++++++++++++++- 4 files changed, 195 insertions(+), 2 deletions(-) create mode 100644 rust/tvm/src/ir/span.rs diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml index a9ac09e6fa68..63b84727c525 100644 --- a/rust/tvm-macros/Cargo.toml +++ b/rust/tvm-macros/Cargo.toml @@ -34,3 +34,4 @@ goblin = "^0.2" proc-macro2 = "^1.0" quote = "^1.0" syn = { version = "1.0.17", features = ["full", "extra-traits"] } +proc-macro-error = "^1.0" diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 2379e12df3fb..126d0faccabb 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -24,6 +24,7 @@ pub mod function; pub mod module; pub mod op; pub mod relay; +pub mod span; pub mod tir; pub mod ty; diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs new file mode 100644 index 000000000000..6e2c35b6c1de --- /dev/null +++ b/rust/tvm/src/ir/span.rs @@ -0,0 +1,3 @@ +use crate::runtime::ObjectRef; + +pub type Span = ObjectRef; diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index 71bafc998b81..fb6358910d37 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -17,7 +17,8 @@ * under the License. */ -use crate::runtime::{Object, ObjectRef}; +use crate::runtime::{Object, ObjectPtr, IsObject}; +use super::span::Span; use tvm_macros::Object; use tvm_rt::{array::Array, DataType}; @@ -29,7 +30,174 @@ use super::PrimExpr; #[type_key = "Type"] pub struct TypeNode { pub base: Object, - pub span: ObjectRef, + pub span: Span, +} + +impl TypeNode { + fn base(span: Span) -> Self { + TypeNode { base: Object::base_object::(), span } + } +} + +/* + * \brief Primitive data types used in the low-level IR. + * + * PrimType represents POD-values and handles that are + * not automatically managed by the runtime. + * + * \sa PrimType + */ + #[repr(C)] + #[derive(Object)] + #[ref_name = "PrimType"] + #[type_key = "PrimType"] + pub struct PrimTypeNode { + pub base: TypeNode, + /// The corresponding dtype field. + pub dtype: DataType, + } + + +/* + *! + * \brief Low-level raw pointer type. + * + * PointerType represents type hints in the TIR to be + * passed to the final code generator. + * + * PointerType should not occur in the high-level analysis. + * + * \sa PointerType + */ + + #[repr(C)] + #[derive(Object)] + #[ref_name = "PointerType"] + #[type_key = "PointerType"] +pub struct PointerTypeNode { + pub base: TypeNode, + /// The type of the element which the pointer points to. + pub element_type: Type, +} +/// Possible kinds of type variables. +pub enum TypeKind { + Type = 0, + /// Template variable in shape expression. + ShapeVar = 1, + kConstraint = 4, + kAdtHandle = 5, + kTypeData = 6 +} + +/* + * \brief Type parameter in functions. + * + * A type variable can be viewed as template parameter in c++ template function. + * + * For example, in the following pesudo code, + * the TypeVar of f is TypeVar("n", kind=kShapeVar). + * This function can take in a Tensor with shape=(3, 3) and + * returns a Tensor with shape=(9,) + * + * \code + * + * template + * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] + * + * \endcode + * \sa TypeVar, TypeKind + */ +#[repr(C)] +#[derive(Object)] +#[ref_name = "TypeVar"] +#[type_key = "TypeVar"] +pub struct TypeVarNode { + pub base: TypeNode, + pub name_hint: String, + pub kind: TypeKind, +} + +/// A global type variable that is used for defining new types or type aliases. +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalTypeVar"] +#[type_key = "GlobalTypeVar"] +pub struct GlobalTypeVarNode { + pub base: TypeNode, + pub name_hint: String, + pub kind: TypeKind, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "TupleType"] +#[type_key = "TupleType"] +pub struct TupleTypeNode { + pub base: TypeNode, + pub fields: Array, +} + +impl TupleType { + fn empty() -> TupleType { + todo!() + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "TypeConstraint"] +#[type_key = "TypeConstraint"] +pub struct TypeConstraintNode { + pub base: TypeNode, +} + +/// The representation of a polymoprhic function type. +#[repr(C)] +#[derive(Object)] +#[ref_name = "FuncType"] +#[type_key = "FuncType"] +pub struct FuncTypeNode { + pub base: TypeNode, + /// The type of arguments. + pub arg_types: Array, + /// The return type of the function. + pub ret_type: Type, + /// ... + pub type_params: Array, + /// Type constraints that must hold when + /// calling this function. + pub type_constraints: Array, +} + +/* + * \brief Intermediate values that is used to indicate incomplete type + * during type inference. + * + * If we view the type relations as "computational graph of types", + * then IncompleteType represents intermediate values of the graph, + * TypeVar represents the input to the graph. + */ +#[repr(C)] +#[derive(Object)] +#[ref_name = "IncompleteType"] +#[type_key = "IncompleteType"] +pub struct IncompleteTypeNode { + pub base: TypeNode, + pub kind: TypeKind, +} + +/* + * \brief Reference Type High-level Relay IR. + * + * \sa RelayRefType. + */ +#[repr(C)] +#[derive(Object)] +#[ref_name = "RefType"] +#[type_key = "relay.RefType"] +pub struct RelayRefTypeNode { + pub base: TypeNode, + pub value: Type, } #[repr(C)] @@ -49,3 +217,23 @@ pub struct TensorTypeNode { pub shape: Array, pub dtype: DataType, } + +impl TensorType { + pub fn new(shape: Array, dtype: DataType, span: Span) -> TensorType { + let node = TensorTypeNode { + base: TypeNode::base::(span), + shape, + dtype + }; + ObjectPtr::new(node).into() + } +} +// TODO(@jroesch): implement these in future. +// +// using TypeCall = tvm::TypeCall; +// using TypeCallNode = tvm::TypeCallNode; +// using TypeRelation = tvm::TypeRelation; +// using TypeRelationNode = tvm::TypeRelationNode; +// using TypeRelationFn = tvm::TypeRelationFn; +// using TypeReporter = tvm::TypeReporter; +// using TypeReporterNode = tvm::TypeReporterNode; From b6f3962a4cabcfc725d6b913d2c004f5998f8c99 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 29 Sep 2020 15:50:22 -0700 Subject: [PATCH 38/50] Format --- rust/tvm/src/ir/ty.rs | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index fb6358910d37..be4520bfd571 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -17,8 +17,8 @@ * under the License. */ -use crate::runtime::{Object, ObjectPtr, IsObject}; use super::span::Span; +use crate::runtime::{IsObject, Object, ObjectPtr}; use tvm_macros::Object; use tvm_rt::{array::Array, DataType}; @@ -35,7 +35,10 @@ pub struct TypeNode { impl TypeNode { fn base(span: Span) -> Self { - TypeNode { base: Object::base_object::(), span } + TypeNode { + base: Object::base_object::(), + span, + } } } @@ -47,19 +50,18 @@ impl TypeNode { * * \sa PrimType */ - #[repr(C)] - #[derive(Object)] - #[ref_name = "PrimType"] - #[type_key = "PrimType"] - pub struct PrimTypeNode { +#[repr(C)] +#[derive(Object)] +#[ref_name = "PrimType"] +#[type_key = "PrimType"] +pub struct PrimTypeNode { pub base: TypeNode, /// The corresponding dtype field. pub dtype: DataType, - } - +} /* - *! + *! * \brief Low-level raw pointer type. * * PointerType represents type hints in the TIR to be @@ -70,10 +72,10 @@ impl TypeNode { * \sa PointerType */ - #[repr(C)] - #[derive(Object)] - #[ref_name = "PointerType"] - #[type_key = "PointerType"] +#[repr(C)] +#[derive(Object)] +#[ref_name = "PointerType"] +#[type_key = "PointerType"] pub struct PointerTypeNode { pub base: TypeNode, /// The type of the element which the pointer points to. @@ -86,7 +88,7 @@ pub enum TypeKind { ShapeVar = 1, kConstraint = 4, kAdtHandle = 5, - kTypeData = 6 + kTypeData = 6, } /* @@ -223,7 +225,7 @@ impl TensorType { let node = TensorTypeNode { base: TypeNode::base::(span), shape, - dtype + dtype, }; ObjectPtr::new(node).into() } From ed326e81039756fee119831d103f9ce90f8d927c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 29 Sep 2020 15:52:44 -0700 Subject: [PATCH 39/50] Format --- rust/tvm/src/ir/ty.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index be4520bfd571..b6a47f553da4 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -81,14 +81,15 @@ pub struct PointerTypeNode { /// The type of the element which the pointer points to. pub element_type: Type, } + /// Possible kinds of type variables. pub enum TypeKind { Type = 0, /// Template variable in shape expression. ShapeVar = 1, - kConstraint = 4, - kAdtHandle = 5, - kTypeData = 6, + Constraint = 4, + AdtHandle = 5, + TypeData = 6, } /* @@ -140,7 +141,7 @@ pub struct TupleTypeNode { } impl TupleType { - fn empty() -> TupleType { + pub fn empty() -> TupleType { todo!() } } @@ -153,7 +154,7 @@ pub struct TypeConstraintNode { pub base: TypeNode, } -/// The representation of a polymoprhic function type. +/// The representation of a polymorphic function type. #[repr(C)] #[derive(Object)] #[ref_name = "FuncType"] From 49a42baabe3764cb8aa2874de93a19ef528fe5b0 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Tue, 29 Sep 2020 16:20:59 -0700 Subject: [PATCH 40/50] Change types from std::string to tvm::String in packed function --- src/relay/op/nn/convolution.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 438500f45e5e..2b9103b9709a 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -431,8 +431,8 @@ weight transformation in advance. TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transform") .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, - Array kernel_size, std::string data_layout, - std::string kernel_layout, std::string out_layout, DataType out_dtype) { + Array kernel_size, tvm::String data_layout, + tvm::String kernel_layout, tvm::String out_layout, DataType out_dtype) { return MakeConvGemm( data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform"); From 54ed9b17fb362fb4356ce622f0919a7f30e47c38 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Tue, 29 Sep 2020 16:29:04 -0700 Subject: [PATCH 41/50] Add ASF header --- rust/tvm/src/ir/span.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index 6e2c35b6c1de..d2e19a25a950 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::runtime::ObjectRef; pub type Span = ObjectRef; From 644b74631a2132ce46c07d6005c16e77c9651b95 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Wed, 30 Sep 2020 07:42:56 -0700 Subject: [PATCH 42/50] Fix test w/ ndarray's API change --- rust/tvm-rt/src/ndarray.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 9d4954889085..585aaa5af96b 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -453,7 +453,7 @@ mod tests { assert_eq!(ndarray.shape(), shape); assert_eq!(ndarray.to_vec::().unwrap(), data); assert_eq!(ndarray.ndim(), 1); - assert!(ndarray.is_contiguous().is_ok()); + assert!(ndarray.is_contiguous()); assert_eq!(ndarray.byte_offset(), 0); let shape = vec![4]; let e = NDArray::empty( From 5be20630a7d70062cbf447199bb788df518280d1 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Wed, 30 Sep 2020 11:26:10 -0700 Subject: [PATCH 43/50] Fix array test --- rust/tvm/tests/callback/src/bin/array.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs index f9db91881e1e..2f1848ec6471 100644 --- a/rust/tvm/tests/callback/src/bin/array.rs +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -36,31 +36,28 @@ use tvm::{ fn main() { fn sum(args: Vec>) -> Result { - let mut ret = 0f32; - let shape = &[2]; + let mut ret = 0.0; for arg in args { - let e = NDArray::empty(shape, Context::cpu(0), DataType::float(32, 1)); let arg: NDArray = arg.try_into()?; - let arr = arg.copy_to_ndarray(e)?; - let rnd: ArrayD = ArrayD::try_from(&arr)?; + let rnd: ArrayD = ArrayD::try_from(&arg)?; ret += rnd.scalar_sum(); } Ok(RetValue::from(ret)) } - let shape = &mut [2]; - let mut data = vec![3f32, 4.0]; + let shape = &[2]; + let data = vec![3.0, 4.0]; let mut arr = NDArray::empty(shape, Context::cpu(0), DataType::float(32, 1)); - arr.copy_from_buffer(data.as_mut_slice()); + arr.copy_from_buffer(data.as_slice()); register_untyped(sum, "sum", true).unwrap(); let func = Function::get("sum").expect("function registered"); let ret: f32 = func - .invoke(vec![(&arr).into(), (&arr).into()]) + .invoke(vec![(&arr).into()]) .unwrap() .try_into() .expect("call should succeed"); - assert_eq!(ret, 7f32); + assert_eq!(ret, 7.0); } From 83eb87f93117791e0b1485a965561d273e473083 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Wed, 30 Sep 2020 11:26:20 -0700 Subject: [PATCH 44/50] Fix anyhow import --- rust/tvm/examples/resnet/Cargo.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/rust/tvm/examples/resnet/Cargo.toml b/rust/tvm/examples/resnet/Cargo.toml index fd10569869d5..350f412b1b35 100644 --- a/rust/tvm/examples/resnet/Cargo.toml +++ b/rust/tvm/examples/resnet/Cargo.toml @@ -28,6 +28,4 @@ ndarray = "0.12" tvm = { path = "../../" } image = "0.20" csv = "1.1" - -[build-dependencies] anyhow = "^1.0" From 72bce31d9df896b3b1fcdcb7a559ba702c9c211d Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Wed, 30 Sep 2020 12:03:51 -0700 Subject: [PATCH 45/50] Put back some anyhow stuff --- rust/tvm/examples/resnet/Cargo.toml | 3 +++ rust/tvm/examples/resnet/build.rs | 9 ++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/rust/tvm/examples/resnet/Cargo.toml b/rust/tvm/examples/resnet/Cargo.toml index 350f412b1b35..646385a6373e 100644 --- a/rust/tvm/examples/resnet/Cargo.toml +++ b/rust/tvm/examples/resnet/Cargo.toml @@ -29,3 +29,6 @@ tvm = { path = "../../" } image = "0.20" csv = "1.1" anyhow = "^1.0" + +[build-dependencies] +anyhow = "1.0" diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index d1f710567e04..1e5d8a98736d 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -17,18 +17,19 @@ * under the License. */ +use anyhow::{Context, Result}; use std::{io::Write, path::Path, process::Command}; -fn main() { +fn main() -> Result<()> { let output = Command::new("python3") .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) .output() - .expect("Failed to execute command"); + .with_context(|| anyhow::anyhow!("failed to run python3"))?; if !output.status.success() { std::io::stdout() .write_all(&output.stderr) - .expect("Failed to write error"); + .context("Failed to write error")?; panic!("Failed to execute build script"); } assert!( @@ -45,4 +46,6 @@ fn main() { "cargo:rustc-link-search=native={}", env!("CARGO_MANIFEST_DIR") ); + + Ok(()) } From 9e580aab45b98ccf993acb4c00f93469b8fd50bc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 1 Oct 2020 13:31:03 -0700 Subject: [PATCH 46/50] Clean up --- rust/tvm-rt/src/ndarray.rs | 13 ++++++++++++- rust/tvm/examples/resnet/src/main.rs | 1 + 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 585aaa5af96b..ed280ccc2d80 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -369,6 +369,16 @@ impl NDArray { .map(|o| o.downcast().expect("this should never fail")); NDArray(ptr) } + + pub fn zeroed(self) -> NDArray { + unsafe { + let dltensor = self.as_raw_dltensor(); + let bytes_ptr: *mut u8 = std::mem::transmute((*dltensor).data); + println!("size {}", self.size()); + std::ptr::write_bytes(bytes_ptr, 0, self.size()); + self + } + } } macro_rules! impl_from_ndarray_rustndarray { @@ -447,7 +457,7 @@ mod tests { let shape = &[4]; let data = vec![1i32, 2, 3, 4]; let ctx = Context::cpu(0); - let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + let mut ndarray = NDArray::empty(shape, ctx, DataType::int(32, 1)).zeroed(); assert_eq!(ndarray.to_vec::().unwrap(), vec![0, 0, 0, 0]); ndarray.copy_from_buffer(&data); assert_eq!(ndarray.shape(), shape); @@ -466,6 +476,7 @@ mod tests { assert_eq!(nd.unwrap().to_vec::().unwrap(), data); } + /// This occasionally panics on macOS: https://github.com/rust-lang/rust/issues/71397 #[test] #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] fn copy_wrong_dtype() { diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 039d5a1844f8..f24c358ab52a 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -33,6 +33,7 @@ use tvm::*; fn main() -> anyhow::Result<()> { let ctx = Context::cpu(0); println!("{}", concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")); + let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")) .context("Failed to open cat.png")?; From 778f3ba55a7af99c53adb55abe54e0cc8512349f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 2 Oct 2020 23:30:59 -0700 Subject: [PATCH 47/50] Try and fix tests/scripts/task_rust.sh --- tests/scripts/task_rust.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index d7b9a5b74406..27395a1eb0a5 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -20,10 +20,13 @@ set -e set -u export TVM_HOME="$(git rev-parse --show-toplevel)" - +echo "Using TVM_HOME=$TVM_HOME" export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}" -export PYTHONPATH="$TVM_HOME/python" +echo "Using LD_LIBRARY_PATH=$LD_LIBRARY_PATH" +export PYTHONPATH="$TVM_HOME/python:${PYTHONPATH}" +echo "Using PYTHONPATH=$PYTHONPATH" export RUST_DIR="$TVM_HOME/rust" +echo "Using RUST_DIR=$RUST_DIR" export LLVM_CONFIG_DEFAULT=`which llvm-config-10` From 2266dddaaade412cfe5c96f54ef59158881ebc3a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 5 Oct 2020 12:47:44 -0700 Subject: [PATCH 48/50] Disable ResNet for now --- tests/scripts/task_rust.sh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index 27395a1eb0a5..18361feb03ee 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -110,6 +110,8 @@ cargo run --bin array cargo run --bin string cd - -cd examples/resnet -cargo build +# TODO(@jroesch): we need to renable MxNet in ci-cpu image +# https://github.com/apache/incubator-tvm/pull/6563 +# cd examples/resnet +# cargo build cd - From d93d13439db31bd042cfd10d0b2942977edc7722 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 5 Oct 2020 17:04:56 -0700 Subject: [PATCH 49/50] Turn off building of Rust docs until we update CI --- tests/scripts/task_python_docs.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 71bb92250a00..1d6953fc736a 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -68,10 +68,12 @@ npm install npm run typedoc cd .. +# TODO(@jroesch): add Rust to CI container +# see: https://github.com/apache/incubator-tvm/issues/6628 # Rust doc -cd rust -cargo doc --workspace --no-deps -cd .. +# cd rust +# cargo doc --workspace --no-deps +# cd .. # Prepare the doc dir rm -rf _docs From 422e970cd749b14cd69801920b72bedb9003c248 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 5 Oct 2020 21:54:24 -0700 Subject: [PATCH 50/50] Actually disable --- tests/scripts/task_python_docs.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 1d6953fc736a..98dac93ac98f 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -82,7 +82,8 @@ rm -f _docs/.buildinfo mkdir -p _docs/api mv docs/doxygen/html _docs/api/doxygen mv jvm/core/target/site/apidocs _docs/api/javadoc -mv rust/target/doc _docs/api/rust +# See above TODO +# mv rust/target/doc _docs/api/rust mv web/dist/docs _docs/api/typedoc echo "Start creating the docs tarball.."