From c902a13a76a553fa286ab1e9b629b8ba8a1deab2 Mon Sep 17 00:00:00 2001 From: Stepan Koltsov Date: Sun, 7 Jan 2024 19:07:06 +0000 Subject: [PATCH] QueryParIter::map_collect and similar operations --- crates/bevy_ecs/src/query/iter.rs | 48 -------- crates/bevy_ecs/src/query/par_iter.rs | 162 +++++++++++++++++++++++++- crates/bevy_ecs/src/query/state.rs | 18 ++- 3 files changed, 168 insertions(+), 60 deletions(-) diff --git a/crates/bevy_ecs/src/query/iter.rs b/crates/bevy_ecs/src/query/iter.rs index e79371c0d7a6ea..33198c71a120d1 100644 --- a/crates/bevy_ecs/src/query/iter.rs +++ b/crates/bevy_ecs/src/query/iter.rs @@ -40,54 +40,6 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> { } } - /// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment - /// from an table. - /// - /// # Safety - /// - all `rows` must be in `[0, table.entity_count)`. - /// - `table` must match D and F - /// - Both `D::IS_DENSE` and `F::IS_DENSE` must be true. - #[inline] - #[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))] - pub(super) unsafe fn for_each_in_table_range( - &mut self, - func: &mut Func, - table: &'w Table, - rows: Range, - ) where - Func: FnMut(D::Item<'w>), - { - // SAFETY: Caller assures that D::IS_DENSE and F::IS_DENSE are true, that table matches D and F - // and all indicies in rows are in range. - unsafe { - self.fold_over_table_range((), &mut |_, item| func(item), table, rows); - } - } - - /// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment - /// from an archetype. - /// - /// # Safety - /// - all `indices` must be in `[0, archetype.len())`. - /// - `archetype` must match D and F - /// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false. - #[inline] - #[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))] - pub(super) unsafe fn for_each_in_archetype_range( - &mut self, - func: &mut Func, - archetype: &'w Archetype, - rows: Range, - ) where - Func: FnMut(D::Item<'w>), - { - // SAFETY: Caller assures that either D::IS_DENSE or F::IS_DENSE are false, that archetype matches D and F - // and all indices in rows are in range. - unsafe { - self.fold_over_archetype_range((), &mut |_, item| func(item), archetype, rows); - } - } - /// Executes the equivalent of [`Iterator::fold`] over a contiguous segment /// from an table. /// diff --git a/crates/bevy_ecs/src/query/par_iter.rs b/crates/bevy_ecs/src/query/par_iter.rs index 620e9175c91363..4721fe59d9c5e4 100644 --- a/crates/bevy_ecs/src/query/par_iter.rs +++ b/crates/bevy_ecs/src/query/par_iter.rs @@ -109,6 +109,114 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool #[inline] pub fn for_each) + Send + Sync + Clone>(self, func: FN) { + self.fold_impl(move |(), x| func(x)); + } + + /// Run `func` on each query result in parallel, collecting the result. + /// + /// This function output is deterministic. The function may be a bit expensive because + /// it allocates a `Vec` for each batch. + /// + /// # Panics + /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. + /// + /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + pub fn map_collect(self, func: FN) -> C + where + R: Send + 'static, + C: FromIterator, + FN: Fn(QueryItem<'w, D>) -> R + Send + Sync + Clone, + { + self.filter_map_collect(move |x| Some(func(x))) + } + + /// Run `func` on each query result in parallel, collecting the result. + /// + /// This function output is deterministic. The function may be a bit expensive because + /// it allocates a `Vec` for each batch. + /// + /// # Panics + /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. + /// + /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + pub fn filter_map_collect(self, func: FN) -> C + where + R: Send + 'static, + C: FromIterator, + FN: Fn(QueryItem<'w, D>) -> Option + Send + Sync + Clone, + { + self.flat_map_collect(func) + } + + /// Run `func` on each query result in parallel, collecting the result. + /// + /// This function output is deterministic. The function may be a bit expensive because + /// it allocates a `Vec` for each batch. + /// + /// # Panics + /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. + /// + /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + pub fn flat_map_collect(self, func: FN) -> C + where + R: Send + 'static, + I: IntoIterator, + C: FromIterator, + FN: Fn(QueryItem<'w, D>) -> I + Send + Sync + Clone, + { + let vecs = self.fold_impl::, _>(move |mut acc, x| { + acc.extend(func(x)); + acc + }); + + // Compute total length. Because collect will likely want to reserve capacity, + // it is cheaper to do sum the lengths here than collect which would have to + // resize multiple times and over-allocate. + let mut len = Some(0usize); + for vec in &vecs { + len = len.and_then(|len| len.checked_add(vec.len())); + } + + // Override the size hint. + struct IterWithSizeHint { + iter: I, + rem: usize, + } + + impl Iterator for IterWithSizeHint { + type Item = I::Item; + + fn next(&mut self) -> Option { + let next = self.iter.next()?; + self.rem -= 1; + Some(next) + } + + fn size_hint(&self) -> (usize, Option) { + (self.rem, Some(self.rem)) + } + } + + match len { + Some(len) => IterWithSizeHint { + iter: vecs.into_iter().flatten(), + rem: len, + } + .collect(), + None => vecs.into_iter().flatten().collect(), + } + } + + /// Common implementation of `for_each` and `filter_map_collect`. + #[inline] + fn fold_impl(self, func: FN) -> Vec + where + B: Default + Send + 'static, + FN: Fn(B, QueryItem<'w, D>) -> B + Send + Sync + Clone, + { #[cfg(any(target = "wasm32", not(feature = "multi-threaded")))] { // SAFETY: @@ -118,9 +226,11 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { // Query or a World, which ensures that multiple aliasing QueryParIters cannot exist // at the same time. unsafe { - self.state + let one = self + .state .iter_unchecked_manual(self.world, self.last_run, self.this_run) - .for_each(func); + .fold(B::default(), func); + vec![one] } } #[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))] @@ -129,22 +239,24 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { if thread_count <= 1 { // SAFETY: See the safety comment above. unsafe { - self.state + let one = self + .state .iter_unchecked_manual(self.world, self.last_run, self.this_run) - .for_each(func); + .fold(B::default(), func); + vec![one] } } else { // Need a batch size of at least 1. let batch_size = self.get_batch_size(thread_count).max(1); // SAFETY: See the safety comment above. unsafe { - self.state.par_for_each_unchecked_manual( + self.state.par_fold_unchecked_manual( self.world, batch_size, func, self.last_run, self.this_run, - ); + ) } } } @@ -188,3 +300,41 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { ) } } + +#[cfg(test)] +mod tests { + use crate as bevy_ecs; + use crate::entity::Entity; + use crate::query::With; + use crate::world::World; + use bevy_ecs_macros::Component; + use bevy_tasks::{ComputeTaskPool, TaskPool}; + + #[test] + fn test_map_collect() { + ComputeTaskPool::get_or_init(TaskPool::default); + + #[derive(Component)] + struct ComponentA(usize); + #[derive(Component)] + struct ComponentB(usize); + + let mut world = World::default(); + + for i in 0..100 { + if i % 2 == 0 { + world.spawn(ComponentA(i)); + } else { + world.spawn((ComponentA(i), ComponentB(i))); + } + } + + let mut state = world.query_filtered::>(); + let entities: Vec = state.iter(&mut world).collect(); + + let mut state = world.query_filtered::>(); + let par_entities: Vec = state.par_iter(&world).map_collect(|x| x); + + assert_eq!(par_entities, entities); + } +} diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index e459baf7a1a754..d53a54b92338fd 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -1101,9 +1101,10 @@ impl QueryState { /// /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool #[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))] - pub(crate) unsafe fn par_for_each_unchecked_manual< + pub(crate) unsafe fn par_fold_unchecked_manual< 'w, - FN: Fn(D::Item<'w>) + Send + Sync + Clone, + B: Default + Send + 'static, + FN: Fn(B, D::Item<'w>) -> B + Send + Sync + Clone, >( &self, world: UnsafeWorldCell<'w>, @@ -1111,7 +1112,7 @@ impl QueryState { func: FN, last_run: Tick, this_run: Tick, - ) { + ) -> Vec { // NOTE: If you are changing query iteration code, remember to update the following places, where relevant: // QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual bevy_tasks::ComputeTaskPool::get().scope(|scope| { @@ -1138,7 +1139,7 @@ impl QueryState { .debug_checked_unwrap(); let batch = offset..offset + len; self.iter_unchecked_manual(world, last_run, this_run) - .for_each_in_table_range(&mut func, table, batch); + .fold_over_table_range(B::default(), &mut func, table, batch) }); offset += batch_size; } @@ -1162,13 +1163,18 @@ impl QueryState { world.archetypes().get(*archetype_id).debug_checked_unwrap(); let batch = offset..offset + len; self.iter_unchecked_manual(world, last_run, this_run) - .for_each_in_archetype_range(&mut func, archetype, batch); + .fold_over_archetype_range( + B::default(), + &mut func, + archetype, + batch, + ) }); offset += batch_size; } } } - }); + }) } /// Returns a single immutable query result when there is exactly one entity matching