Skip to content

Commit

Permalink
QueryParIter::map_collect and similar operations
Browse files Browse the repository at this point in the history
  • Loading branch information
stepancheg committed Jan 7, 2024
1 parent 101037d commit c902a13
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 60 deletions.
48 changes: 0 additions & 48 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,54 +40,6 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
}
}

/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
/// from an table.
///
/// # Safety
/// - all `rows` must be in `[0, table.entity_count)`.
/// - `table` must match D and F
/// - Both `D::IS_DENSE` and `F::IS_DENSE` must be true.
#[inline]
#[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))]
pub(super) unsafe fn for_each_in_table_range<Func>(
&mut self,
func: &mut Func,
table: &'w Table,
rows: Range<usize>,
) where
Func: FnMut(D::Item<'w>),
{
// SAFETY: Caller assures that D::IS_DENSE and F::IS_DENSE are true, that table matches D and F
// and all indicies in rows are in range.
unsafe {
self.fold_over_table_range((), &mut |_, item| func(item), table, rows);
}
}

/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
/// from an archetype.
///
/// # Safety
/// - all `indices` must be in `[0, archetype.len())`.
/// - `archetype` must match D and F
/// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false.
#[inline]
#[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))]
pub(super) unsafe fn for_each_in_archetype_range<Func>(
&mut self,
func: &mut Func,
archetype: &'w Archetype,
rows: Range<usize>,
) where
Func: FnMut(D::Item<'w>),
{
// SAFETY: Caller assures that either D::IS_DENSE or F::IS_DENSE are false, that archetype matches D and F
// and all indices in rows are in range.
unsafe {
self.fold_over_archetype_range((), &mut |_, item| func(item), archetype, rows);
}
}

/// Executes the equivalent of [`Iterator::fold`] over a contiguous segment
/// from an table.
///
Expand Down
162 changes: 156 additions & 6 deletions crates/bevy_ecs/src/query/par_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,114 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[inline]
pub fn for_each<FN: Fn(QueryItem<'w, D>) + Send + Sync + Clone>(self, func: FN) {
self.fold_impl(move |(), x| func(x));
}

/// Run `func` on each query result in parallel, collecting the result.
///
/// This function output is deterministic. The function may be a bit expensive because
/// it allocates a `Vec` for each batch.
///
/// # Panics
/// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
/// initialized and run from the ECS scheduler, this should never panic.
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
pub fn map_collect<R, C, FN>(self, func: FN) -> C
where
R: Send + 'static,
C: FromIterator<R>,
FN: Fn(QueryItem<'w, D>) -> R + Send + Sync + Clone,
{
self.filter_map_collect(move |x| Some(func(x)))
}

/// Run `func` on each query result in parallel, collecting the result.
///
/// This function output is deterministic. The function may be a bit expensive because
/// it allocates a `Vec` for each batch.
///
/// # Panics
/// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
/// initialized and run from the ECS scheduler, this should never panic.
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
pub fn filter_map_collect<R, C, FN>(self, func: FN) -> C
where
R: Send + 'static,
C: FromIterator<R>,
FN: Fn(QueryItem<'w, D>) -> Option<R> + Send + Sync + Clone,
{
self.flat_map_collect(func)
}

/// Run `func` on each query result in parallel, collecting the result.
///
/// This function output is deterministic. The function may be a bit expensive because
/// it allocates a `Vec` for each batch.
///
/// # Panics
/// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
/// initialized and run from the ECS scheduler, this should never panic.
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
pub fn flat_map_collect<R, I, C, FN>(self, func: FN) -> C
where
R: Send + 'static,
I: IntoIterator<Item = R>,
C: FromIterator<R>,
FN: Fn(QueryItem<'w, D>) -> I + Send + Sync + Clone,
{
let vecs = self.fold_impl::<Vec<R>, _>(move |mut acc, x| {
acc.extend(func(x));
acc
});

// Compute total length. Because collect will likely want to reserve capacity,
// it is cheaper to do sum the lengths here than collect which would have to
// resize multiple times and over-allocate.
let mut len = Some(0usize);
for vec in &vecs {
len = len.and_then(|len| len.checked_add(vec.len()));
}

// Override the size hint.
struct IterWithSizeHint<I: Iterator> {
iter: I,
rem: usize,
}

impl<I: Iterator> Iterator for IterWithSizeHint<I> {
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
let next = self.iter.next()?;
self.rem -= 1;
Some(next)
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.rem, Some(self.rem))
}
}

match len {
Some(len) => IterWithSizeHint {
iter: vecs.into_iter().flatten(),
rem: len,
}
.collect(),
None => vecs.into_iter().flatten().collect(),
}
}

/// Common implementation of `for_each` and `filter_map_collect`.
#[inline]
fn fold_impl<B, FN>(self, func: FN) -> Vec<B>
where
B: Default + Send + 'static,
FN: Fn(B, QueryItem<'w, D>) -> B + Send + Sync + Clone,
{
#[cfg(any(target = "wasm32", not(feature = "multi-threaded")))]
{
// SAFETY:
Expand All @@ -118,9 +226,11 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
// Query or a World, which ensures that multiple aliasing QueryParIters cannot exist
// at the same time.
unsafe {
self.state
let one = self
.state
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
.for_each(func);
.fold(B::default(), func);
vec![one]
}
}
#[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))]
Expand All @@ -129,22 +239,24 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
if thread_count <= 1 {
// SAFETY: See the safety comment above.
unsafe {
self.state
let one = self
.state
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
.for_each(func);
.fold(B::default(), func);
vec![one]
}
} else {
// Need a batch size of at least 1.
let batch_size = self.get_batch_size(thread_count).max(1);
// SAFETY: See the safety comment above.
unsafe {
self.state.par_for_each_unchecked_manual(
self.state.par_fold_unchecked_manual(
self.world,
batch_size,
func,
self.last_run,
self.this_run,
);
)
}
}
}
Expand Down Expand Up @@ -188,3 +300,41 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
)
}
}

#[cfg(test)]
mod tests {
use crate as bevy_ecs;
use crate::entity::Entity;
use crate::query::With;
use crate::world::World;
use bevy_ecs_macros::Component;
use bevy_tasks::{ComputeTaskPool, TaskPool};

#[test]
fn test_map_collect() {
ComputeTaskPool::get_or_init(TaskPool::default);

#[derive(Component)]
struct ComponentA(usize);
#[derive(Component)]
struct ComponentB(usize);

let mut world = World::default();

for i in 0..100 {
if i % 2 == 0 {
world.spawn(ComponentA(i));
} else {
world.spawn((ComponentA(i), ComponentB(i)));
}
}

let mut state = world.query_filtered::<Entity, With<ComponentA>>();
let entities: Vec<Entity> = state.iter(&mut world).collect();

let mut state = world.query_filtered::<Entity, With<ComponentA>>();
let par_entities: Vec<Entity> = state.par_iter(&world).map_collect(|x| x);

assert_eq!(par_entities, entities);
}
}
18 changes: 12 additions & 6 deletions crates/bevy_ecs/src/query/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1101,17 +1101,18 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))]
pub(crate) unsafe fn par_for_each_unchecked_manual<
pub(crate) unsafe fn par_fold_unchecked_manual<
'w,
FN: Fn(D::Item<'w>) + Send + Sync + Clone,
B: Default + Send + 'static,
FN: Fn(B, D::Item<'w>) -> B + Send + Sync + Clone,
>(
&self,
world: UnsafeWorldCell<'w>,
batch_size: usize,
func: FN,
last_run: Tick,
this_run: Tick,
) {
) -> Vec<B> {
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
bevy_tasks::ComputeTaskPool::get().scope(|scope| {
Expand All @@ -1138,7 +1139,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
.debug_checked_unwrap();
let batch = offset..offset + len;
self.iter_unchecked_manual(world, last_run, this_run)
.for_each_in_table_range(&mut func, table, batch);
.fold_over_table_range(B::default(), &mut func, table, batch)
});
offset += batch_size;
}
Expand All @@ -1162,13 +1163,18 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
world.archetypes().get(*archetype_id).debug_checked_unwrap();
let batch = offset..offset + len;
self.iter_unchecked_manual(world, last_run, this_run)
.for_each_in_archetype_range(&mut func, archetype, batch);
.fold_over_archetype_range(
B::default(),
&mut func,
archetype,
batch,
)
});
offset += batch_size;
}
}
}
});
})
}

/// Returns a single immutable query result when there is exactly one entity matching
Expand Down

0 comments on commit c902a13

Please sign in to comment.