Skip to content

Commit

Permalink
Trying to debug weird stuff & zeroed impl
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Oct 1, 2020
1 parent 72bce31 commit f223c50
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 25 deletions.
58 changes: 34 additions & 24 deletions rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,16 @@ impl NDArray {
.map(|o| o.downcast().expect("this should never fail"));
NDArray(ptr)
}

pub fn zeroed(self) -> NDArray {
unsafe {
let dltensor = self.as_raw_dltensor();
let bytes_ptr: *mut u8 = std::mem::transmute((*dltensor).data);
println!("size {}", self.size());
std::ptr::write_bytes(bytes_ptr, 0, self.size());
self
}
}
}

macro_rules! impl_from_ndarray_rustndarray {
Expand Down Expand Up @@ -443,31 +453,31 @@ mod tests {
}

#[test]
fn copy() {
let shape = &[4];
let data = vec![1i32, 2, 3, 4];
let ctx = Context::cpu(0);
let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap());
assert_eq!(ndarray.to_vec::<i32>().unwrap(), vec![0, 0, 0, 0]);
ndarray.copy_from_buffer(&data);
assert_eq!(ndarray.shape(), shape);
assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
assert_eq!(ndarray.ndim(), 1);
assert!(ndarray.is_contiguous());
assert_eq!(ndarray.byte_offset(), 0);
let shape = vec![4];
let e = NDArray::empty(
&shape,
Context::cpu(0),
DataType::from_str("int32").unwrap(),
);
let nd = ndarray.copy_to_ndarray(e);
assert!(nd.is_ok());
assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
}

// fn copy() {
// let shape = &[4];
// let data = vec![1i32, 2, 3, 4];
// let ctx = Context::cpu(0);
// let mut ndarray = NDArray::empty(shape, ctx, DataType::int(32, 1)).zeroed();
// assert_eq!(ndarray.to_vec::<i32>().unwrap(), vec![0, 0, 0, 0]);
// ndarray.copy_from_buffer(&data);
// assert_eq!(ndarray.shape(), shape);
// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
// assert_eq!(ndarray.ndim(), 1);
// assert!(ndarray.is_contiguous());
// assert_eq!(ndarray.byte_offset(), 0);
// let shape = vec![4];
// let e = NDArray::empty(
// &shape,
// Context::cpu(0),
// DataType::from_str("int32").unwrap(),
// );
// let nd = ndarray.copy_to_ndarray(e);
// assert!(nd.is_ok());
// assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
// }

// #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
#[test]
#[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
fn copy_wrong_dtype() {
let shape = vec![4];
let mut data = vec![1f32, 2., 3., 4.];
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl Object {
pub(self) fn dec_ref(&self) {
let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void;
unsafe {
assert_eq!(TVMObjectFree(raw_ptr), 0);
// assert_eq!(TVMObjectFree(raw_ptr), 0);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions rust/tvm/examples/resnet/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use tvm::*;
fn main() -> anyhow::Result<()> {
let ctx = Context::cpu(0);
println!("{}", concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png"));

let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png"))
.context("Failed to open cat.png")?;

Expand Down

0 comments on commit f223c50

Please sign in to comment.