Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize iter::ArrayChunks::fold for TrustedRandomAccess iterators #103446

Merged
merged 5 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions library/core/benches/iter.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::borrow::Borrow;
use core::iter::*;
use core::mem;
use core::num::Wrapping;
Expand Down Expand Up @@ -403,13 +404,31 @@ fn bench_trusted_random_access_adapters(b: &mut Bencher) {

/// Exercises the iter::Copied specialization for slice::Iter
#[bench]
fn bench_copied_array_chunks(b: &mut Bencher) {
fn bench_copied_chunks(b: &mut Bencher) {
let v = vec![1u8; 1024];

b.iter(|| {
let mut iter = black_box(&v).iter().copied();
let mut acc = Wrapping(0);
// This uses a while-let loop to side-step the TRA specialization in ArrayChunks
while let Ok(chunk) = iter.next_chunk::<{ mem::size_of::<u64>() }>() {
let d = u64::from_ne_bytes(chunk);
acc += Wrapping(d.rotate_left(7).wrapping_add(1));
}
acc
})
}

/// Exercises the TrustedRandomAccess specialization in ArrayChunks
#[bench]
fn bench_trusted_random_access_chunks(b: &mut Bencher) {
let v = vec![1u8; 1024];

b.iter(|| {
black_box(&v)
.iter()
.copied()
// this shows that we're not relying on the slice::Iter specialization in Copied
.map(|b| *b.borrow())
.array_chunks::<{ mem::size_of::<u64>() }>()
.map(|ary| {
let d = u64::from_ne_bytes(ary);
Expand Down
1 change: 1 addition & 0 deletions library/core/benches/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#![feature(test)]
#![feature(trusted_random_access)]
#![feature(iter_array_chunks)]
#![feature(iter_next_chunk)]

extern crate test;

Expand Down
75 changes: 52 additions & 23 deletions library/core/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -865,24 +865,6 @@ where
return Ok(Try::from_output(unsafe { mem::zeroed() }));
}

struct Guard<'a, T, const N: usize> {
array_mut: &'a mut [MaybeUninit<T>; N],
initialized: usize,
}

impl<T, const N: usize> Drop for Guard<'_, T, N> {
fn drop(&mut self) {
debug_assert!(self.initialized <= N);

// SAFETY: this slice will contain only initialized objects.
unsafe {
crate::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
&mut self.array_mut.get_unchecked_mut(..self.initialized),
));
}
}
}

let mut array = MaybeUninit::uninit_array::<N>();
let mut guard = Guard { array_mut: &mut array, initialized: 0 };

Expand All @@ -896,13 +878,11 @@ where
ControlFlow::Continue(elem) => elem,
};

// SAFETY: `guard.initialized` starts at 0, is increased by one in the
// loop and the loop is aborted once it reaches N (which is
// `array.len()`).
// SAFETY: `guard.initialized` starts at 0, which means push can be called
// at most N times, which this loop does.
unsafe {
guard.array_mut.get_unchecked_mut(guard.initialized).write(item);
guard.push_unchecked(item);
}
guard.initialized += 1;
}
None => {
let alive = 0..guard.initialized;
Expand All @@ -920,6 +900,55 @@ where
Ok(Try::from_output(output))
}

/// Panic guard for incremental initialization of arrays.
///
/// Disarm the guard with `mem::forget` once the array has been initialized.
///
/// # Safety
///
/// All write accesses to this structure are unsafe and must maintain a correct
/// count of `initialized` elements.
///
/// To minimize indirection fields are still pub but callers should at least use
/// `push_unchecked` to signal that something unsafe is going on.
pub(crate) struct Guard<'a, T, const N: usize> {
/// The array to be initialized.
pub array_mut: &'a mut [MaybeUninit<T>; N],
/// The number of items that have been initialized so far.
pub initialized: usize,
}

impl<T, const N: usize> Guard<'_, T, N> {
/// Adds an item to the array and updates the initialized item counter.
///
/// # Safety
///
/// No more than N elements must be initialized.
#[inline]
pub unsafe fn push_unchecked(&mut self, item: T) {
// SAFETY: If `initialized` was correct before and the caller does not
// invoke this method more than N times then writes will be in-bounds
// and slots will not be initialized more than once.
unsafe {
self.array_mut.get_unchecked_mut(self.initialized).write(item);
self.initialized = self.initialized.unchecked_add(1);
}
}
}

impl<T, const N: usize> Drop for Guard<'_, T, N> {
fn drop(&mut self) {
debug_assert!(self.initialized <= N);

// SAFETY: this slice will contain only initialized objects.
the8472 marked this conversation as resolved.
Show resolved Hide resolved
unsafe {
crate::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
&mut self.array_mut.get_unchecked_mut(..self.initialized),
));
}
}
}

/// Returns the next chunk of `N` items from the iterator or errors with an
/// iterator over the remainder. Used for `Iterator::next_chunk`.
#[inline]
Expand Down
75 changes: 72 additions & 3 deletions library/core/src/iter/adapters/array_chunks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::array;
use crate::iter::{ByRefSized, FusedIterator, Iterator};
use crate::ops::{ControlFlow, Try};
use crate::const_closure::ConstFnMutClosure;
use crate::iter::{ByRefSized, FusedIterator, Iterator, TrustedRandomAccessNoCoerce};
use crate::mem::{self, MaybeUninit};
use crate::ops::{ControlFlow, NeverShortCircuit, Try};

/// An iterator over `N` elements of the iterator at a time.
///
Expand Down Expand Up @@ -82,7 +84,13 @@ where
}
}

impl_fold_via_try_fold! { fold -> try_fold }
fn fold<B, F>(self, init: B, f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
<Self as SpecFold>::fold(self, init, f)
}
}

#[unstable(feature = "iter_array_chunks", reason = "recently added", issue = "100450")]
Expand Down Expand Up @@ -168,3 +176,64 @@ where
self.iter.len() < N
}
}

trait SpecFold: Iterator {
fn fold<B, F>(self, init: B, f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B;
}

impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
where
I: Iterator,
{
#[inline]
default fn fold<B, F>(mut self, init: B, mut f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
let fold = ConstFnMutClosure::new(&mut f, NeverShortCircuit::wrap_mut_2_imp);
self.try_fold(init, fold).0
}
}

impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
where
I: Iterator + TrustedRandomAccessNoCoerce,
{
#[inline]
fn fold<B, F>(mut self, init: B, mut f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
let mut accum = init;
let inner_len = self.iter.size();
let mut i = 0;
// Use a while loop because (0..len).step_by(N) doesn't optimize well.
while inner_len - i >= N {
let mut chunk = MaybeUninit::uninit_array();
let mut guard = array::Guard { array_mut: &mut chunk, initialized: 0 };
while guard.initialized < N {
// SAFETY: The method consumes the iterator and the loop condition ensures that
// all accesses are in bounds and only happen once.
unsafe {
let idx = i + guard.initialized;
guard.push_unchecked(self.iter.__iterator_get_unchecked(idx));
}
}
mem::forget(guard);
// SAFETY: The loop above initialized all elements
let chunk = unsafe { MaybeUninit::array_assume_init(chunk) };
accum = f(accum, chunk);
i += N;
}

// unlike try_fold this method does not need to take care of the remainder
// since `self` will be dropped

accum
}
}
3 changes: 2 additions & 1 deletion library/core/tests/iter/adapters/array_chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ fn test_iterator_array_chunks_fold() {
let result =
(0..10).map(|_| CountDrop::new(&count)).array_chunks::<3>().fold(0, |acc, _item| acc + 1);
assert_eq!(result, 3);
assert_eq!(count.get(), 10);
// fold impls may or may not process the remainder
assert!(count.get() <= 10 && count.get() >= 9);
}

#[test]
Expand Down