diff --git a/crates/bevy_ecs/src/query/fetch.rs b/crates/bevy_ecs/src/query/fetch.rs index 91f423be28725..6a07ae009e21f 100644 --- a/crates/bevy_ecs/src/query/fetch.rs +++ b/crates/bevy_ecs/src/query/fetch.rs @@ -238,6 +238,18 @@ pub struct ReadFetch { sparse_set: *const ComponentSparseSet, } +impl Clone for ReadFetch { + fn clone(&self) -> Self { + Self { + storage_type: self.storage_type, + table_components: self.table_components, + entity_table_rows: self.entity_table_rows, + entities: self.entities, + sparse_set: self.sparse_set, + } + } +} + /// SAFE: access is read only unsafe impl ReadOnlyFetch for ReadFetch {} @@ -340,6 +352,21 @@ pub struct WriteFetch { change_tick: u32, } +impl Clone for WriteFetch { + fn clone(&self) -> Self { + Self { + storage_type: self.storage_type, + table_components: self.table_components, + table_ticks: self.table_ticks, + entities: self.entities, + entity_table_rows: self.entity_table_rows, + sparse_set: self.sparse_set, + last_change_tick: self.last_change_tick, + change_tick: self.change_tick, + } + } +} + pub struct WriteState { component_id: ComponentId, storage_type: StorageType, diff --git a/crates/bevy_ecs/src/query/iter.rs b/crates/bevy_ecs/src/query/iter.rs index 4878d349b4352..104d0cad55c84 100644 --- a/crates/bevy_ecs/src/query/iter.rs +++ b/crates/bevy_ecs/src/query/iter.rs @@ -4,6 +4,7 @@ use crate::{ storage::{TableId, Tables}, world::World, }; +use std::mem::MaybeUninit; pub struct QueryIter<'w, 's, Q: WorldQuery, F: WorldQuery> where @@ -13,13 +14,7 @@ where archetypes: &'w Archetypes, query_state: &'s QueryState, world: &'w World, - table_id_iter: std::slice::Iter<'s, TableId>, - archetype_id_iter: std::slice::Iter<'s, ArchetypeId>, - fetch: Q::Fetch, - filter: F::Fetch, - is_dense: bool, - current_len: usize, - current_index: usize, + cursor: QueryIterationCursor<'s, Q, F>, } impl<'w, 's, Q: WorldQuery, F: WorldQuery> QueryIter<'w, 's, Q, F> @@ -32,30 +27,12 @@ where last_change_tick: u32, change_tick: u32, ) -> Self { - let fetch = ::init( - world, - &query_state.fetch_state, - last_change_tick, - change_tick, - ); - let filter = ::init( - world, - &query_state.filter_state, - last_change_tick, - change_tick, - ); QueryIter { - is_dense: fetch.is_dense() && filter.is_dense(), world, query_state, - fetch, - filter, tables: &world.storages().tables, archetypes: &world.archetypes, - table_id_iter: query_state.matched_table_ids.iter(), - archetype_id_iter: query_state.matched_archetype_ids.iter(), - current_len: 0, - current_index: 0, + cursor: QueryIterationCursor::init(world, query_state, last_change_tick, change_tick), } } } @@ -69,58 +46,133 @@ where #[inline] fn next(&mut self) -> Option { unsafe { - if self.is_dense { - loop { - if self.current_index == self.current_len { - let table_id = self.table_id_iter.next()?; - let table = &self.tables[*table_id]; - self.fetch.set_table(&self.query_state.fetch_state, table); - self.filter.set_table(&self.query_state.filter_state, table); - self.current_len = table.len(); - self.current_index = 0; - continue; - } + self.cursor + .next(&self.tables, &self.archetypes, &self.query_state) + } + } - if !self.filter.table_filter_fetch(self.current_index) { - self.current_index += 1; - continue; - } + // NOTE: For unfiltered Queries this should actually return a exact size hint, + // to fulfil the ExactSizeIterator invariant, but this isn't practical without specialization. + // For more information see Issue #1686. + fn size_hint(&self) -> (usize, Option) { + let max_size = self + .query_state + .matched_archetypes + .ones() + .map(|index| self.world.archetypes[ArchetypeId::new(index)].len()) + .sum(); + + (0, Some(max_size)) + } +} - let item = self.fetch.table_fetch(self.current_index); +pub struct QueryPermutationIter<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> +where + F::Fetch: FilterFetch, +{ + tables: &'w Tables, + archetypes: &'w Archetypes, + query_state: &'s QueryState, + world: &'w World, + cursors: [QueryIterationCursor<'s, Q, F>; K], +} - self.current_index += 1; - return Some(item); - } - } else { - loop { - if self.current_index == self.current_len { - let archetype_id = self.archetype_id_iter.next()?; - let archetype = &self.archetypes[*archetype_id]; - self.fetch.set_archetype( - &self.query_state.fetch_state, - archetype, - self.tables, - ); - self.filter.set_archetype( - &self.query_state.filter_state, - archetype, - self.tables, - ); - self.current_len = archetype.len(); - self.current_index = 0; - continue; - } +impl<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> QueryPermutationIter<'w, 's, Q, F, K> +where + F::Fetch: FilterFetch, +{ + pub(crate) unsafe fn new( + world: &'w World, + query_state: &'s QueryState, + last_change_tick: u32, + change_tick: u32, + ) -> Self { + // Initialize array with cursors. + // There is no FromIterator on arrays, so instead initialize it manually with MaybeUninit - if !self.filter.archetype_filter_fetch(self.current_index) { - self.current_index += 1; - continue; - } + // MaybeUninit::uninit_array is unstable + let mut cursors: [MaybeUninit>; K] = + MaybeUninit::uninit().assume_init(); + for (i, cursor) in cursors.iter_mut().enumerate() { + match i { + 0 => cursor.as_mut_ptr().write(QueryIterationCursor::init( + world, + query_state, + last_change_tick, + change_tick, + )), + _ => cursor.as_mut_ptr().write(QueryIterationCursor::init_empty( + world, + query_state, + last_change_tick, + change_tick, + )), + } + } - let item = self.fetch.archetype_fetch(self.current_index); - self.current_index += 1; - return Some(item); + // MaybeUninit::array_assume_init is unstable + let cursors: [QueryIterationCursor<'s, Q, F>; K] = + (&cursors as *const _ as *const [QueryIterationCursor<'s, Q, F>; K]).read(); + + QueryPermutationIter { + world, + query_state, + tables: &world.storages().tables, + archetypes: &world.archetypes, + cursors, + } + } +} + +impl<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> Iterator + for QueryPermutationIter<'w, 's, Q, F, K> +where + F::Fetch: FilterFetch, + Q::Fetch: Clone, + F::Fetch: Clone, +{ + type Item = [>::Item; K]; + + #[inline] + fn next(&mut self) -> Option { + unsafe { + // first, iterate from last to first until next item is found + 'outer: for i in (0..K).rev() { + match self.cursors[i].next(&self.tables, &self.archetypes, &self.query_state) { + Some(_) => { + // walk forward up to last element, propagating cursor state forward + for j in (i + 1)..K { + self.cursors[j] = self.cursors[j - 1].clone(); + match self.cursors[j].next( + &self.tables, + &self.archetypes, + &self.query_state, + ) { + Some(_) => {} + None if i > 0 => continue 'outer, + None => return None, + } + } + break; + } + None if i > 0 => continue, + None => return None, } } + + // MaybeUninit::uninit_array is unstable + let mut values: [MaybeUninit<>::Item>; K] = + MaybeUninit::uninit().assume_init(); + + for (value, cursor) in values.iter_mut().zip(&mut self.cursors) { + value.as_mut_ptr().write(cursor.peek_last().unwrap()); + } + + // MaybeUninit::array_assume_init is unstable + let values: [>::Item; K] = + (&values as *const _ as *const [>::Item; K]).read(); + + Some(values) } } @@ -128,14 +180,19 @@ where // to fulfil the ExactSizeIterator invariant, but this isn't practical without specialization. // For more information see Issue #1686. fn size_hint(&self) -> (usize, Option) { - let max_size = self + let max_size: usize = self .query_state .matched_archetypes .ones() .map(|index| self.world.archetypes[ArchetypeId::new(index)].len()) .sum(); - (0, Some(max_size)) + // n! / k!(n-k)! = (n*n-1*...*n-k+1) / k! + let k_factorial: usize = (1..=K).product(); + let max_permutations = + (0..K).fold(1, |n, i| n * (max_size.saturating_sub(i))) / k_factorial; + + (0, Some(max_permutations)) } } @@ -153,3 +210,143 @@ impl<'w, 's, Q: WorldQuery> ExactSizeIterator for QueryIter<'w, 's, Q, ()> { .sum() } } + +struct QueryIterationCursor<'s, Q: WorldQuery, F: WorldQuery> { + table_id_iter: std::slice::Iter<'s, TableId>, + archetype_id_iter: std::slice::Iter<'s, ArchetypeId>, + fetch: Q::Fetch, + filter: F::Fetch, + current_len: usize, + current_index: usize, + is_dense: bool, +} + +impl<'s, Q: WorldQuery, F: WorldQuery> Clone for QueryIterationCursor<'s, Q, F> +where + Q::Fetch: Clone, + F::Fetch: Clone, +{ + fn clone(&self) -> Self { + Self { + table_id_iter: self.table_id_iter.clone(), + archetype_id_iter: self.archetype_id_iter.clone(), + fetch: self.fetch.clone(), + filter: self.filter.clone(), + current_len: self.current_len, + current_index: self.current_index, + is_dense: self.is_dense, + } + } +} + +impl<'s, Q: WorldQuery, F: WorldQuery> QueryIterationCursor<'s, Q, F> +where + F::Fetch: FilterFetch, +{ + unsafe fn init_empty( + world: &World, + query_state: &'s QueryState, + last_change_tick: u32, + change_tick: u32, + ) -> Self { + QueryIterationCursor { + table_id_iter: [].iter(), + archetype_id_iter: [].iter(), + ..Self::init(world, query_state, last_change_tick, change_tick) + } + } + + unsafe fn init( + world: &World, + query_state: &'s QueryState, + last_change_tick: u32, + change_tick: u32, + ) -> Self { + let fetch = ::init( + world, + &query_state.fetch_state, + last_change_tick, + change_tick, + ); + let filter = ::init( + world, + &query_state.filter_state, + last_change_tick, + change_tick, + ); + QueryIterationCursor { + is_dense: fetch.is_dense() && filter.is_dense(), + fetch, + filter, + table_id_iter: query_state.matched_table_ids.iter(), + archetype_id_iter: query_state.matched_archetype_ids.iter(), + current_len: 0, + current_index: 0, + } + } + + /// retreive last returned item again + #[inline] + unsafe fn peek_last<'w>(&mut self) -> Option<>::Item> { + if self.current_index > 0 { + Some(self.fetch.table_fetch(self.current_index - 1)) + } else { + None + } + } + + #[inline] + unsafe fn next<'w>( + &mut self, + tables: &'w Tables, + archetypes: &'w Archetypes, + query_state: &'s QueryState, + ) -> Option<>::Item> { + if self.is_dense { + loop { + if self.current_index == self.current_len { + let table_id = self.table_id_iter.next()?; + let table = &tables[*table_id]; + self.fetch.set_table(&query_state.fetch_state, table); + self.filter.set_table(&query_state.filter_state, table); + self.current_len = table.len(); + self.current_index = 0; + continue; + } + + if !self.filter.table_filter_fetch(self.current_index) { + self.current_index += 1; + continue; + } + + let item = self.fetch.table_fetch(self.current_index); + + self.current_index += 1; + return Some(item); + } + } else { + loop { + if self.current_index == self.current_len { + let archetype_id = self.archetype_id_iter.next()?; + let archetype = &archetypes[*archetype_id]; + self.fetch + .set_archetype(&query_state.fetch_state, archetype, tables); + self.filter + .set_archetype(&query_state.filter_state, archetype, tables); + self.current_len = archetype.len(); + self.current_index = 0; + continue; + } + + if !self.filter.archetype_filter_fetch(self.current_index) { + self.current_index += 1; + continue; + } + + let item = self.fetch.archetype_fetch(self.current_index); + self.current_index += 1; + return Some(item); + } + } + } +} diff --git a/crates/bevy_ecs/src/query/mod.rs b/crates/bevy_ecs/src/query/mod.rs index 95b4d7b6e8912..6a5d20b5a8c06 100644 --- a/crates/bevy_ecs/src/query/mod.rs +++ b/crates/bevy_ecs/src/query/mod.rs @@ -37,6 +37,64 @@ mod tests { assert_eq!(values, vec![&B(3)]); } + #[test] + fn query_k_iter() { + let mut world = World::new(); + world.spawn().insert_bundle((A(1), B(1))); + world.spawn().insert_bundle((A(2),)); + world.spawn().insert_bundle((A(3),)); + world.spawn().insert_bundle((A(4),)); + + let size = world.query::<&A>().k_iter::<2>(&world).size_hint(); + assert_eq!(size.1, Some(6)); + let values: Vec<[&A; 2]> = world.query::<&A>().k_iter(&world).collect(); + assert_eq!( + values, + vec![ + [&A(1), &A(2)], + [&A(1), &A(3)], + [&A(1), &A(4)], + [&A(2), &A(3)], + [&A(2), &A(4)], + [&A(3), &A(4)], + ] + ); + let size = world.query::<&A>().k_iter::<3>(&world).size_hint(); + assert_eq!(size.1, Some(4)); + let values: Vec<[&A; 3]> = world.query::<&A>().k_iter(&world).collect(); + assert_eq!( + values, + vec![ + [&A(1), &A(2), &A(3)], + [&A(1), &A(2), &A(4)], + [&A(1), &A(3), &A(4)], + [&A(2), &A(3), &A(4)], + ] + ); + + for [mut a, mut b, mut c] in world.query::<&mut A>().k_iter_mut(&mut world) { + a.0 += 10; + b.0 += 100; + c.0 += 1000; + } + + let values: Vec<[&A; 3]> = world.query::<&A>().k_iter(&world).collect(); + assert_eq!( + values, + vec![ + [&A(31), &A(212), &A(1203)], + [&A(31), &A(212), &A(3004)], + [&A(31), &A(1203), &A(3004)], + [&A(212), &A(1203), &A(3004)], + ] + ); + + let size = world.query::<&B>().k_iter::<2>(&world).size_hint(); + assert_eq!(size.1, Some(0)); + let values: Vec<[&B; 2]> = world.query::<&B>().k_iter(&world).collect(); + assert_eq!(values, Vec::<[&B; 2]>::new()); + } + #[test] fn multi_storage_query() { let mut world = World::new(); diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index 07295430105cd..7eae29adbf819 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -3,8 +3,8 @@ use crate::{ component::ComponentId, entity::Entity, query::{ - Access, Fetch, FetchState, FilterFetch, FilteredAccess, QueryIter, ReadOnlyFetch, - WorldQuery, + Access, Fetch, FetchState, FilterFetch, FilteredAccess, QueryIter, QueryPermutationIter, + ReadOnlyFetch, WorldQuery, }, storage::TableId, world::{World, WorldId}, @@ -190,6 +190,27 @@ where unsafe { self.iter_unchecked(world) } } + #[inline] + pub fn k_iter<'w, 's, const K: usize>( + &'s mut self, + world: &'w World, + ) -> QueryPermutationIter<'w, 's, Q, F, K> + where + Q::Fetch: ReadOnlyFetch, + { + // SAFE: query is read only + unsafe { self.k_iter_unchecked(world) } + } + + #[inline] + pub fn k_iter_mut<'w, 's, const K: usize>( + &'s mut self, + world: &'w mut World, + ) -> QueryPermutationIter<'w, 's, Q, F, K> { + // SAFE: query has unique world access + unsafe { self.k_iter_unchecked(world) } + } + /// # Safety /// This does not check for mutable query correctness. To be safe, make sure mutable queries /// have unique access to the components they query. @@ -202,6 +223,18 @@ where self.iter_unchecked_manual(world, world.last_change_tick(), world.read_change_tick()) } + /// # Safety + /// This does not check for mutable query correctness. To be safe, make sure mutable queries + /// have unique access to the components they query. + #[inline] + pub unsafe fn k_iter_unchecked<'w, 's, const K: usize>( + &'s mut self, + world: &'w World, + ) -> QueryPermutationIter<'w, 's, Q, F, K> { + self.validate_world_and_update_archetypes(world); + self.k_iter_unchecked_manual(world, world.last_change_tick(), world.read_change_tick()) + } + /// # Safety /// This does not check for mutable query correctness. To be safe, make sure mutable queries /// have unique access to the components they query. @@ -217,6 +250,21 @@ where QueryIter::new(world, self, last_change_tick, change_tick) } + /// # Safety + /// This does not check for mutable query correctness. To be safe, make sure mutable queries + /// have unique access to the components they query. + /// This does not validate that `world.id()` matches `self.world_id`. Calling this on a `world` + /// with a mismatched WorldId is unsafe. + #[inline] + pub(crate) unsafe fn k_iter_unchecked_manual<'w, 's, const K: usize>( + &'s self, + world: &'w World, + last_change_tick: u32, + change_tick: u32, + ) -> QueryPermutationIter<'w, 's, Q, F, K> { + QueryPermutationIter::new(world, self, last_change_tick, change_tick) + } + #[inline] pub fn for_each<'w>( &mut self, diff --git a/crates/bevy_ecs/src/system/query.rs b/crates/bevy_ecs/src/system/query.rs index 11d4ae2d45534..ebe4d5d9ff7f5 100644 --- a/crates/bevy_ecs/src/system/query.rs +++ b/crates/bevy_ecs/src/system/query.rs @@ -2,7 +2,8 @@ use crate::{ component::Component, entity::Entity, query::{ - Fetch, FilterFetch, QueryEntityError, QueryIter, QueryState, ReadOnlyFetch, WorldQuery, + Fetch, FilterFetch, QueryEntityError, QueryIter, QueryPermutationIter, QueryState, + ReadOnlyFetch, WorldQuery, }, world::{Mut, World}, }; @@ -57,6 +58,32 @@ where } } + /// Iterates over all possible combinations of `K` query results without repetition. + /// This can only be called for read-only queries + /// + /// When you ask for permutations of size K of query returning N results, you will get: + /// - if K == N: one result of all results + /// - if K < N: all possible subsets of N with size K, without repetition + /// - if K > N: empty set (no permutation of size K exist) + /// + /// ``` + /// + /// + /// + /// ``` + #[inline] + pub fn k_iter(&self) -> QueryPermutationIter<'_, '_, Q, F, K> + where + Q::Fetch: ReadOnlyFetch, + { + // SAFE: system runs without conflicts with other systems. + // same-system queries have runtime borrow checks when they conflict + unsafe { + self.state + .k_iter_unchecked_manual(self.world, self.last_change_tick, self.change_tick) + } + } + /// Iterates over the query results #[inline] pub fn iter_mut(&mut self) -> QueryIter<'_, '_, Q, F> { @@ -68,6 +95,18 @@ where } } + /// Iterates over all possible combinations of `K` query results without repetition. + /// See [`Query::k_iter`]. + #[inline] + pub fn k_iter_mut(&mut self) -> QueryPermutationIter<'_, '_, Q, F, K> { + // SAFE: system runs without conflicts with other systems. + // same-system queries have runtime borrow checks when they conflict + unsafe { + self.state + .k_iter_unchecked_manual(self.world, self.last_change_tick, self.change_tick) + } + } + /// Iterates over the query results /// /// # Safety @@ -81,6 +120,20 @@ where .iter_unchecked_manual(self.world, self.last_change_tick, self.change_tick) } + /// Iterates over all possible combinations of `K` query results without repetition. + /// See [`Query::k_iter`]. + /// + /// # Safety + /// This allows aliased mutability. You must make sure this call does not result in multiple + /// mutable references to the same component + #[inline] + pub unsafe fn k_iter_unsafe(&self) -> QueryPermutationIter<'_, '_, Q, F, K> { + // SEMI-SAFE: system runs without conflicts with other systems. + // same-system queries have runtime borrow checks when they conflict + self.state + .k_iter_unchecked_manual(self.world, self.last_change_tick, self.change_tick) + } + /// Runs `f` on each query result. This is faster than the equivalent iter() method, but cannot /// be chained like a normal iterator. This can only be called for read-only queries #[inline]