diff --git a/tests/shape.rs b/tests/shape.rs index da44ea2..f3a7640 100644 --- a/tests/shape.rs +++ b/tests/shape.rs @@ -2,9 +2,15 @@ use thriller_core::{Dimension, Layout, Shape}; #[test] fn test_strides() { - let layout = Layout::RowMajor; - let shape = Shape::new(&[2, 3, 4], layout); - let strides = shape.get_strides(); + let layout_0 = Layout::RowMajor; + let shape_0 = Shape::new(&[2, 3, 4], layout_0); + let strides_0 = shape_0.get_strides(); - assert_eq!(strides.slice(), &[12, 4, 1]); + assert_eq!(strides_0.slice(), &[12, 4, 1]); + + let layout_1 = Layout::ColumnMajor; + let shape_1 = Shape::new(&[2, 3, 4], layout_1); + let strides_1 = shape_1.get_strides(); + + assert_eq!(strides_1.slice(), &[1, 2, 6]); } diff --git a/thriller-core/src/shape.rs b/thriller-core/src/shape.rs index 5c05b2b..be06ec6 100644 --- a/thriller-core/src/shape.rs +++ b/thriller-core/src/shape.rs @@ -118,8 +118,24 @@ impl Dimension for Dim { strides } + /// Returns the strides for a Fortran layout array with the given shape. fn fortran_strides(&self) -> Self { - todo!() + // Compute fortan array strides + // Shape (a, b, c) => Give strides (1, a, a * b) + let mut strides = Self::zeros(self.ndim); + // For empty arrays, use all zero strides + if self.slice().iter().all(|&d| d != 0) { + let mut it = strides.slice_mut().iter_mut(); + if let Some(rs) = it.next() { + *rs = 1; + } + let mut cum_prod = 1; + for (rs, dim) in it.zip(self.slice().iter()) { + cum_prod *= dim; + *rs = cum_prod; + } + } + strides } }