diff --git a/src/dimension/conversion.rs b/src/dimension/conversion.rs index 6b53a4eef..a37b93330 100644 --- a/src/dimension/conversion.rs +++ b/src/dimension/conversion.rs @@ -12,7 +12,7 @@ use num_traits::Zero; use std::ops::{Index, IndexMut}; use alloc::vec::Vec; -use crate::{Dim, Dimension, Ix, Ix1, IxDyn, IxDynImpl}; +use crate::{Dim, Dimension, Ix, Ix1, IxDyn, IxDynImpl, Ixs}; /// $m: macro callback /// $m is called with $arg and then the indices corresponding to the size argument @@ -41,11 +41,13 @@ macro_rules! index_item { /// Argument conversion a dimension. pub trait IntoDimension { type Dim: Dimension; + type Strides: IntoStrides; fn into_dimension(self) -> Self::Dim; } impl IntoDimension for Ix { type Dim = Ix1; + type Strides = Ixs; #[inline(always)] fn into_dimension(self) -> Ix1 { Ix1(self) @@ -57,6 +59,7 @@ where D: Dimension, { type Dim = D; + type Strides = D; #[inline(always)] fn into_dimension(self) -> Self { self @@ -65,6 +68,7 @@ where impl IntoDimension for IxDynImpl { type Dim = IxDyn; + type Strides = IxDyn; #[inline(always)] fn into_dimension(self) -> Self::Dim { Dim::new(self) @@ -73,6 +77,7 @@ impl IntoDimension for IxDynImpl { impl IntoDimension for Vec { type Dim = IxDyn; + type Strides = Vec; #[inline(always)] fn into_dimension(self) -> Self::Dim { Dim::new(IxDynImpl::from(self)) @@ -127,6 +132,7 @@ macro_rules! tuple_to_array { impl IntoDimension for [Ix; $n] { type Dim = Dim<[Ix; $n]>; + type Strides = [Ixs; $n]; #[inline(always)] fn into_dimension(self) -> Self::Dim { Dim::new(self) @@ -135,6 +141,7 @@ macro_rules! tuple_to_array { impl IntoDimension for index!(tuple_type [Ix] $n) { type Dim = Dim<[Ix; $n]>; + type Strides = index!(tuple_type [Ixs] $n); #[inline(always)] fn into_dimension(self) -> Self::Dim { Dim::new(index!(array_expr [self] $n)) @@ -171,3 +178,95 @@ macro_rules! tuple_to_array { } index_item!(tuple_to_array [] 7); + +/// Argument conversion strides. +pub trait IntoStrides { + type Dim: Dimension; + fn into_strides(self) -> Self::Dim; +} + +impl IntoStrides for Ixs { + type Dim = Ix1; + #[inline(always)] + fn into_strides(self) -> Ix1 { + Ix1(self as Ix) + } +} + +impl IntoStrides for D +where + D: Dimension, +{ + type Dim = D; + #[inline(always)] + fn into_strides(self) -> D { + self + } +} + +impl IntoStrides for Vec { + type Dim = IxDyn; + #[inline(always)] + fn into_strides(self) -> IxDyn { + let v: Vec = self.into_iter().map(|x| x as Ix).collect(); + Dim::new(IxDynImpl::from(v)) + } +} + +impl<'a> IntoStrides for &'a [Ixs] { + type Dim = IxDyn; + #[inline(always)] + fn into_strides(self) -> IxDyn { + let v: Vec = self.iter().map(|x| *x as Ix).collect(); + Dim::new(IxDynImpl::from(v)) + } +} + +macro_rules! index_item_ixs { + ($m:ident $arg:tt 0) => (); + ($m:ident $arg:tt 1) => ($m!($arg 0);); + ($m:ident $arg:tt 2) => ($m!($arg 0 1);); + ($m:ident $arg:tt 3) => ($m!($arg 0 1 2);); + ($m:ident $arg:tt 4) => ($m!($arg 0 1 2 3);); + ($m:ident $arg:tt 5) => ($m!($arg 0 1 2 3 4);); + ($m:ident $arg:tt 6) => ($m!($arg 0 1 2 3 4 5);); + ($m:ident $arg:tt 7) => ($m!($arg 0 1 2 3 4 5 6);); +} + +macro_rules! array_expr_ixs { + ([$self_:expr] $($index:tt)*) => ( + [$($self_[$index] as Ix, )*] + ) +} + +macro_rules! tuple_expr_ixs { + ([$self_:expr] $($index:tt)*) => ( + [$($self_.$index as Ix, )*] + ) +} + +macro_rules! tuple_to_strides { + ([] $($n:tt)*) => { + $( + impl IntoStrides for [Ixs; $n] { + type Dim = Dim<[Ix; $n]>; + #[inline(always)] + fn into_strides(self) -> Dim<[Ix; $n]> { + let self_: [Ix; $n] = index!(array_expr_ixs [self] $n); + Dim::new(self_) + } + } + + impl IntoStrides for index!(tuple_type [Ixs] $n) { + type Dim = Dim<[Ix; $n]>; + #[inline(always)] + fn into_strides(self) -> Dim<[Ix; $n]> { + Dim::new(index!(tuple_expr_ixs [self] $n)) + } + } + + )* + } +} + +index_item_ixs!(tuple_to_strides [] 7); diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 3b14ea221..f2d675a4f 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -12,7 +12,7 @@ use num_integer::div_floor; pub use self::axes::{axes_of, Axes, AxisDescription}; pub use self::axis::Axis; -pub use self::conversion::IntoDimension; +pub use self::conversion::{IntoDimension, IntoStrides}; pub use self::dim::*; pub use self::dimension_trait::Dimension; pub use self::dynindeximpl::IxDynImpl; @@ -402,7 +402,7 @@ fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) { /// memory of the array. The result is always <= 0. pub fn offset_from_ptr_to_memory(dim: &D, strides: &D) -> isize { let offset = izip!(dim.slice(), strides.slice()).fold(0, |_offset, (d, s)| { - if (*s as isize) < 0 { + if (*s as isize) < 0 && *d != 0 { _offset + *s as isize * (*d as isize - 1) } else { _offset diff --git a/src/dimension/ndindex.rs b/src/dimension/ndindex.rs index d9bac1d94..8f1d0ba00 100644 --- a/src/dimension/ndindex.rs +++ b/src/dimension/ndindex.rs @@ -2,9 +2,7 @@ use std::fmt::Debug; use super::{stride_offset, stride_offset_checked}; use crate::itertools::zip; -use crate::{ - Dim, Dimension, IntoDimension, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, -}; +use crate::{Dim, Dimension, IntoDimension, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs}; /// Tuple or fixed size arrays that can be used to index an array. /// @@ -199,6 +197,7 @@ ndindex_with_array! { impl<'a> IntoDimension for &'a [Ix] { type Dim = IxDyn; + type Strides = &'a [Ixs]; fn into_dimension(self) -> Self::Dim { Dim(IxDynImpl::from(self)) } diff --git a/src/impl_views/constructors.rs b/src/impl_views/constructors.rs index c6e5f9988..f807adb8b 100644 --- a/src/impl_views/constructors.rs +++ b/src/impl_views/constructors.rs @@ -13,6 +13,7 @@ use crate::error::ShapeError; use crate::extension::nonnull::nonnull_debug_checked_from_ptr; use crate::imp_prelude::*; use crate::{is_aligned, StrideShape}; +use crate::dimension::offset_from_ptr_to_memory; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> @@ -55,7 +56,7 @@ where let dim = shape.dim; dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?; let strides = shape.strides.strides_for_dim(&dim); - unsafe { Ok(Self::new_(xs.as_ptr(), dim, strides)) } + unsafe { Ok(Self::new_(xs.as_ptr().offset(-offset_from_ptr_to_memory(&dim, &strides)), dim, strides)) } } /// Create an `ArrayView` from shape information and a raw pointer to @@ -152,7 +153,7 @@ where let dim = shape.dim; dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?; let strides = shape.strides.strides_for_dim(&dim); - unsafe { Ok(Self::new_(xs.as_mut_ptr(), dim, strides)) } + unsafe { Ok(Self::new_(xs.as_mut_ptr().offset(-offset_from_ptr_to_memory(&dim, &strides)), dim, strides)) } } /// Create an `ArrayViewMut` from shape information and a diff --git a/src/lib.rs b/src/lib.rs index 11064d32d..98657e3ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,7 +134,7 @@ use std::marker::PhantomData; use alloc::sync::Arc; pub use crate::dimension::dim::*; -pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, RemoveAxis}; +pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, IntoStrides, RemoveAxis}; pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; diff --git a/src/shape_builder.rs b/src/shape_builder.rs index 6fc99d0b2..f9d85b857 100644 --- a/src/shape_builder.rs +++ b/src/shape_builder.rs @@ -1,4 +1,4 @@ -use crate::dimension::IntoDimension; +use crate::dimension::{IntoDimension, IntoStrides}; use crate::Dimension; /// A contiguous array shape of n dimensions. @@ -111,7 +111,7 @@ where T: IntoDimension, { type Dim = T::Dim; - type Strides = T; + type Strides = T::Strides; fn into_shape(self) -> Shape { Shape { dim: self.into_dimension(), @@ -124,8 +124,8 @@ where fn set_f(self, is_f: bool) -> Shape { self.into_shape().set_f(is_f) } - fn strides(self, st: T) -> StrideShape { - self.into_shape().strides(st.into_dimension()) + fn strides(self, st: Self::Strides) -> StrideShape { + self.into_shape().strides(st.into_strides()) } } diff --git a/tests/array.rs b/tests/array.rs index 6581e572e..7fdf36999 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1896,6 +1896,50 @@ fn test_shape() { assert_eq!(a.strides(), &[6, 3, 1]); assert_eq!(b.strides(), &[1, 1, 2]); assert_eq!(c.strides(), &[1, 3, 1]); + + // negative strides + let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13].to_vec(); + let a= Array::from_shape_vec((2, 3, 2).strides((-1, -4, 2)),s.clone()).unwrap(); + assert_eq!(a, arr3(&[[[9, 11], [5, 7], [1, 3]], [[8, 10], [4, 6], [0, 2]]])); + assert_eq!(a.shape(), &[2, 3, 2]); + assert_eq!(a.strides(), &[-1, -4, 2]); + + // () + let b=Array::from_shape_vec(().strides(()),s.clone()).unwrap(); + assert_eq!(b,arr0(0)); + + let v = vec![5]; + let mut c = ArrayView2::::from_shape((1, 1).strides((-10, -1)), v.as_slice()).unwrap(); + assert_eq!(c, arr2(&[[5]])); + c.slice_collapse(s![1..1, ..]); + assert_eq!(c.shape(), &[0, 1]); + + // discontinuous + let d = Array3::from_shape_vec((3, 2, 2).strides((-8, -4, -2)), (0..24).collect()).unwrap(); + assert_eq!(d, arr3(&[[[22, 20], [18, 16]], [[14, 12], [10, 8]], [[6, 4], [2, 0]]])); + + // empty + let empty: [f32; 0] = []; + let e = Array::from_shape_vec(0.strides(-2), empty.to_vec()).unwrap(); + assert_eq!(e, arr1(&[])); + + let a = [1., 2., 3., 4., 5., 6.]; + let d = (2, 1, 1); + let s = (-2, 2, 1); + let b = ArrayView::from_shape(d.strides(s), &a).unwrap(); + assert_eq!(b, arr3(&[[[3.0]], [[1.0]]])); + + let d = (1, 2, 1); + let s = (2, -2, -1); + let b = Array::from_shape_vec(d.strides(s), a.to_vec()).unwrap(); + assert_eq!(b, arr3(&[[[3.0], [1.0]]])); + + let a: [f32; 0] = []; + // [[]] shape=[4, 0], strides=[0, 1] + let d = (4, 0); + let s = (0, -1); + let b = ArrayView::from_shape(d.strides(s), &a).unwrap(); + assert_eq!(b, arr2(&[[],[],[],[]])); } #[test] diff --git a/tests/raw_views.rs b/tests/raw_views.rs index bb39547e8..304121774 100644 --- a/tests/raw_views.rs +++ b/tests/raw_views.rs @@ -89,7 +89,7 @@ fn raw_view_negative_strides() { fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> { let ptr: *const u16 = data.as_ptr(); unsafe { - let raw_view = RawArrayView::from_shape_ptr(1.strides((-1isize) as usize), ptr); + let raw_view = RawArrayView::from_shape_ptr(1.strides(-1), ptr); raw_view.deref_into_view() } }