From 0b8c3ede608df533b0fd67c1b16e760c0e16c340 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Grabarz?= Date: Mon, 17 May 2021 23:28:06 +0000 Subject: [PATCH] Add a method `iter_combinations` on query to iterate over combinations of query results (#1763) Related to [discussion on discord](https://discord.com/channels/691052431525675048/742569353878437978/824731187724681289) With const generics, it is now possible to write generic iterator over multiple entities at once. This enables patterns of query iterations like ```rust for [e1, e2, e3] in query.iter_combinations() { // do something with relation of all three entities } ``` The compiler is able to infer the correct iterator for given size of array, so either of those work ```rust for [e1, e2] in query.iter_combinations() { ... } for [e1, e2, e3] in query.iter_combinations() { ... } ``` This feature can be very useful for systems like collision detection. When you ask for permutations of size K of N entities: - if K == N, you get one result of all entities - if K < N, you get all possible subsets of N with size K, without repetition - if K > N, the result set is empty (no permutation of size K exist) Co-authored-by: Carter Anderson --- Cargo.toml | 4 + crates/bevy_ecs/src/query/fetch.rs | 27 +++ crates/bevy_ecs/src/query/iter.rs | 347 +++++++++++++++++++++++++++- crates/bevy_ecs/src/query/mod.rs | 129 +++++++++++ crates/bevy_ecs/src/query/state.rs | 56 ++++- crates/bevy_ecs/src/system/query.rs | 81 ++++++- examples/ecs/iter_combinations.rs | 158 +++++++++++++ 7 files changed, 793 insertions(+), 9 deletions(-) create mode 100644 examples/ecs/iter_combinations.rs diff --git a/Cargo.toml b/Cargo.toml index 95141021a8077c..898ab0bbc3b8fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -267,6 +267,10 @@ path = "examples/ecs/fixed_timestep.rs" name = "hierarchy" path = "examples/ecs/hierarchy.rs" +[[example]] +name = "iter_combinations" +path = "examples/ecs/iter_combinations.rs" + [[example]] name = "parallel_query" path = "examples/ecs/parallel_query.rs" diff --git a/crates/bevy_ecs/src/query/fetch.rs b/crates/bevy_ecs/src/query/fetch.rs index 480bf3e1361371..af858cb634cc28 100644 --- a/crates/bevy_ecs/src/query/fetch.rs +++ b/crates/bevy_ecs/src/query/fetch.rs @@ -279,6 +279,18 @@ pub struct ReadFetch { sparse_set: *const ComponentSparseSet, } +impl Clone for ReadFetch { + fn clone(&self) -> Self { + Self { + storage_type: self.storage_type, + table_components: self.table_components, + entity_table_rows: self.entity_table_rows, + entities: self.entities, + sparse_set: self.sparse_set, + } + } +} + /// SAFETY: access is read only unsafe impl ReadOnlyFetch for ReadFetch {} @@ -382,6 +394,21 @@ pub struct WriteFetch { change_tick: u32, } +impl Clone for WriteFetch { + fn clone(&self) -> Self { + Self { + storage_type: self.storage_type, + table_components: self.table_components, + table_ticks: self.table_ticks, + entities: self.entities, + entity_table_rows: self.entity_table_rows, + sparse_set: self.sparse_set, + last_change_tick: self.last_change_tick, + change_tick: self.change_tick, + } + } +} + /// The [`FetchState`] of `&mut T`. pub struct WriteState { component_id: ComponentId, diff --git a/crates/bevy_ecs/src/query/iter.rs b/crates/bevy_ecs/src/query/iter.rs index 524e4ff0f32e11..876f29d9fe0e7c 100644 --- a/crates/bevy_ecs/src/query/iter.rs +++ b/crates/bevy_ecs/src/query/iter.rs @@ -1,9 +1,10 @@ use crate::{ archetype::{ArchetypeId, Archetypes}, - query::{Fetch, FilterFetch, QueryState, WorldQuery}, + query::{Fetch, FilterFetch, QueryState, ReadOnlyFetch, WorldQuery}, storage::{TableId, Tables}, world::World, }; +use std::mem::MaybeUninit; /// An [`Iterator`] over query results of a [`Query`](crate::system::Query). /// @@ -21,15 +22,20 @@ where archetype_id_iter: std::slice::Iter<'s, ArchetypeId>, fetch: Q::Fetch, filter: F::Fetch, - is_dense: bool, current_len: usize, current_index: usize, + is_dense: bool, } impl<'w, 's, Q: WorldQuery, F: WorldQuery> QueryIter<'w, 's, Q, F> where F::Fetch: FilterFetch, { + /// # Safety + /// This does not check for mutable query correctness. To be safe, make sure mutable queries + /// have unique access to the components they query. + /// This does not validate that `world.id()` matches `query_state.world_id`. Calling this on a `world` + /// with a mismatched WorldId is unsound. pub(crate) unsafe fn new( world: &'w World, query_state: &'s QueryState, @@ -48,14 +54,15 @@ where last_change_tick, change_tick, ); + QueryIter { - is_dense: fetch.is_dense() && filter.is_dense(), world, query_state, - fetch, - filter, tables: &world.storages().tables, archetypes: &world.archetypes, + is_dense: fetch.is_dense() && filter.is_dense(), + fetch, + filter, table_id_iter: query_state.matched_table_ids.iter(), archetype_id_iter: query_state.matched_archetype_ids.iter(), current_len: 0, @@ -70,7 +77,9 @@ where { type Item = >::Item; - #[inline] + // NOTE: If you are changing QueryIter code, also update QueryIterationCursor code, when relevant. + // QueryIter doesn't use QueryIterationCursor for performance reasons. See #1763 for context. + #[inline(always)] fn next(&mut self) -> Option { unsafe { if self.is_dense { @@ -143,6 +152,186 @@ where } } +pub struct QueryCombinationIter<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> +where + F::Fetch: FilterFetch, +{ + tables: &'w Tables, + archetypes: &'w Archetypes, + query_state: &'s QueryState, + world: &'w World, + cursors: [QueryIterationCursor<'s, Q, F>; K], +} + +impl<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> QueryCombinationIter<'w, 's, Q, F, K> +where + F::Fetch: FilterFetch, +{ + /// # Safety + /// This does not check for mutable query correctness. To be safe, make sure mutable queries + /// have unique access to the components they query. + /// This does not validate that `world.id()` matches `query_state.world_id`. Calling this on a `world` + /// with a mismatched WorldId is unsound. + pub(crate) unsafe fn new( + world: &'w World, + query_state: &'s QueryState, + last_change_tick: u32, + change_tick: u32, + ) -> Self { + // Initialize array with cursors. + // There is no FromIterator on arrays, so instead initialize it manually with MaybeUninit + + // TODO: use MaybeUninit::uninit_array if it stabilizes + let mut cursors: [MaybeUninit>; K] = + MaybeUninit::uninit().assume_init(); + for (i, cursor) in cursors.iter_mut().enumerate() { + match i { + 0 => cursor.as_mut_ptr().write(QueryIterationCursor::init( + world, + query_state, + last_change_tick, + change_tick, + )), + _ => cursor.as_mut_ptr().write(QueryIterationCursor::init_empty( + world, + query_state, + last_change_tick, + change_tick, + )), + } + } + + // TODO: use MaybeUninit::array_assume_init if it stabilizes + let cursors: [QueryIterationCursor<'s, Q, F>; K] = + (&cursors as *const _ as *const [QueryIterationCursor<'s, Q, F>; K]).read(); + + QueryCombinationIter { + world, + query_state, + tables: &world.storages().tables, + archetypes: &world.archetypes, + cursors, + } + } + + /// Safety: + /// The lifetime here is not restrictive enough for Fetch with &mut access, + /// as calling `fetch_next_aliased_unchecked` multiple times can produce multiple + /// references to the same component, leading to unique reference aliasing. + ///. + /// It is always safe for shared access. + unsafe fn fetch_next_aliased_unchecked<'a>( + &mut self, + ) -> Option<[>::Item; K]> + where + Q::Fetch: Clone, + F::Fetch: Clone, + { + if K == 0 { + return None; + } + + // first, iterate from last to first until next item is found + 'outer: for i in (0..K).rev() { + match self.cursors[i].next(&self.tables, &self.archetypes, &self.query_state) { + Some(_) => { + // walk forward up to last element, propagating cursor state forward + for j in (i + 1)..K { + self.cursors[j] = self.cursors[j - 1].clone(); + match self.cursors[j].next( + &self.tables, + &self.archetypes, + &self.query_state, + ) { + Some(_) => {} + None if i > 0 => continue 'outer, + None => return None, + } + } + break; + } + None if i > 0 => continue, + None => return None, + } + } + + // TODO: use MaybeUninit::uninit_array if it stabilizes + let mut values: [MaybeUninit<>::Item>; K] = + MaybeUninit::uninit().assume_init(); + + for (value, cursor) in values.iter_mut().zip(&mut self.cursors) { + value.as_mut_ptr().write(cursor.peek_last().unwrap()); + } + + // TODO: use MaybeUninit::array_assume_init if it stabilizes + let values: [>::Item; K] = + (&values as *const _ as *const [>::Item; K]).read(); + + Some(values) + } + + /// Get next combination of queried components + #[inline] + pub fn fetch_next(&mut self) -> Option<[>::Item; K]> + where + Q::Fetch: Clone, + F::Fetch: Clone, + { + // safety: we are limiting the returned reference to self, + // making sure this method cannot be called multiple times without getting rid + // of any previously returned unique references first, thus preventing aliasing. + unsafe { self.fetch_next_aliased_unchecked() } + } +} + +// Iterator type is intentionally implemented only for read-only access. +// Doing so for mutable references would be unsound, because calling `next` +// multiple times would allow multiple owned references to the same data to exist. +impl<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> Iterator + for QueryCombinationIter<'w, 's, Q, F, K> +where + Q::Fetch: Clone + ReadOnlyFetch, + F::Fetch: Clone + FilterFetch + ReadOnlyFetch, +{ + type Item = [>::Item; K]; + + #[inline] + fn next(&mut self) -> Option { + // Safety: it is safe to alias for ReadOnlyFetch + unsafe { QueryCombinationIter::fetch_next_aliased_unchecked(self) } + } + + // NOTE: For unfiltered Queries this should actually return a exact size hint, + // to fulfil the ExactSizeIterator invariant, but this isn't practical without specialization. + // For more information see Issue #1686. + fn size_hint(&self) -> (usize, Option) { + if K == 0 { + return (0, Some(0)); + } + + let max_size: usize = self + .query_state + .matched_archetypes + .ones() + .map(|index| self.world.archetypes[ArchetypeId::new(index)].len()) + .sum(); + + if max_size < K { + return (0, Some(0)); + } + + // n! / k!(n-k)! = (n*n-1*...*n-k+1) / k! + let max_combinations = (0..K) + .try_fold(1usize, |n, i| n.checked_mul(max_size - i)) + .map(|n| { + let k_factorial: usize = (1..=K).product(); + n / k_factorial + }); + + (0, max_combinations) + } +} + // NOTE: We can cheaply implement this for unfiltered Queries because we have: // (1) pre-computed archetype matches // (2) each archetype pre-computes length @@ -157,3 +346,149 @@ impl<'w, 's, Q: WorldQuery> ExactSizeIterator for QueryIter<'w, 's, Q, ()> { .sum() } } + +struct QueryIterationCursor<'s, Q: WorldQuery, F: WorldQuery> { + table_id_iter: std::slice::Iter<'s, TableId>, + archetype_id_iter: std::slice::Iter<'s, ArchetypeId>, + fetch: Q::Fetch, + filter: F::Fetch, + current_len: usize, + current_index: usize, + is_dense: bool, +} + +impl<'s, Q: WorldQuery, F: WorldQuery> Clone for QueryIterationCursor<'s, Q, F> +where + Q::Fetch: Clone, + F::Fetch: Clone, +{ + fn clone(&self) -> Self { + Self { + table_id_iter: self.table_id_iter.clone(), + archetype_id_iter: self.archetype_id_iter.clone(), + fetch: self.fetch.clone(), + filter: self.filter.clone(), + current_len: self.current_len, + current_index: self.current_index, + is_dense: self.is_dense, + } + } +} + +impl<'s, Q: WorldQuery, F: WorldQuery> QueryIterationCursor<'s, Q, F> +where + F::Fetch: FilterFetch, +{ + unsafe fn init_empty( + world: &World, + query_state: &'s QueryState, + last_change_tick: u32, + change_tick: u32, + ) -> Self { + QueryIterationCursor { + table_id_iter: [].iter(), + archetype_id_iter: [].iter(), + ..Self::init(world, query_state, last_change_tick, change_tick) + } + } + + unsafe fn init( + world: &World, + query_state: &'s QueryState, + last_change_tick: u32, + change_tick: u32, + ) -> Self { + let fetch = ::init( + world, + &query_state.fetch_state, + last_change_tick, + change_tick, + ); + let filter = ::init( + world, + &query_state.filter_state, + last_change_tick, + change_tick, + ); + QueryIterationCursor { + is_dense: fetch.is_dense() && filter.is_dense(), + fetch, + filter, + table_id_iter: query_state.matched_table_ids.iter(), + archetype_id_iter: query_state.matched_archetype_ids.iter(), + current_len: 0, + current_index: 0, + } + } + + /// retrieve item returned from most recent `next` call again. + #[inline] + unsafe fn peek_last<'w>(&mut self) -> Option<>::Item> { + if self.current_index > 0 { + if self.is_dense { + Some(self.fetch.table_fetch(self.current_index - 1)) + } else { + Some(self.fetch.archetype_fetch(self.current_index - 1)) + } + } else { + None + } + } + + // NOTE: If you are changing QueryIterationCursor code, also update QueryIter code, when relevant. + // QueryIter doesn't use QueryIterationCursor for performance reasons. See #1763 for context. + #[inline(always)] + unsafe fn next<'w>( + &mut self, + tables: &'w Tables, + archetypes: &'w Archetypes, + query_state: &'s QueryState, + ) -> Option<>::Item> { + if self.is_dense { + loop { + if self.current_index == self.current_len { + let table_id = self.table_id_iter.next()?; + let table = &tables[*table_id]; + self.fetch.set_table(&query_state.fetch_state, table); + self.filter.set_table(&query_state.filter_state, table); + self.current_len = table.len(); + self.current_index = 0; + continue; + } + + if !self.filter.table_filter_fetch(self.current_index) { + self.current_index += 1; + continue; + } + + let item = self.fetch.table_fetch(self.current_index); + + self.current_index += 1; + return Some(item); + } + } else { + loop { + if self.current_index == self.current_len { + let archetype_id = self.archetype_id_iter.next()?; + let archetype = &archetypes[*archetype_id]; + self.fetch + .set_archetype(&query_state.fetch_state, archetype, tables); + self.filter + .set_archetype(&query_state.filter_state, archetype, tables); + self.current_len = archetype.len(); + self.current_index = 0; + continue; + } + + if !self.filter.archetype_filter_fetch(self.current_index) { + self.current_index += 1; + continue; + } + + let item = self.fetch.archetype_fetch(self.current_index); + self.current_index += 1; + return Some(item); + } + } + } +} diff --git a/crates/bevy_ecs/src/query/mod.rs b/crates/bevy_ecs/src/query/mod.rs index 95b4d7b6e8912a..99e140846073c5 100644 --- a/crates/bevy_ecs/src/query/mod.rs +++ b/crates/bevy_ecs/src/query/mod.rs @@ -37,6 +37,135 @@ mod tests { assert_eq!(values, vec![&B(3)]); } + #[test] + fn query_iter_combinations() { + let mut world = World::new(); + + world.spawn().insert_bundle((A(1), B(1))); + world.spawn().insert_bundle((A(2),)); + world.spawn().insert_bundle((A(3),)); + world.spawn().insert_bundle((A(4),)); + + let mut a_query = world.query::<&A>(); + assert_eq!(a_query.iter_combinations::<0>(&world).count(), 0); + assert_eq!( + a_query.iter_combinations::<0>(&world).size_hint(), + (0, Some(0)) + ); + assert_eq!(a_query.iter_combinations::<1>(&world).count(), 4); + assert_eq!( + a_query.iter_combinations::<1>(&world).size_hint(), + (0, Some(4)) + ); + assert_eq!(a_query.iter_combinations::<2>(&world).count(), 6); + assert_eq!( + a_query.iter_combinations::<2>(&world).size_hint(), + (0, Some(6)) + ); + assert_eq!(a_query.iter_combinations::<3>(&world).count(), 4); + assert_eq!( + a_query.iter_combinations::<3>(&world).size_hint(), + (0, Some(4)) + ); + assert_eq!(a_query.iter_combinations::<4>(&world).count(), 1); + assert_eq!( + a_query.iter_combinations::<4>(&world).size_hint(), + (0, Some(1)) + ); + assert_eq!(a_query.iter_combinations::<5>(&world).count(), 0); + assert_eq!( + a_query.iter_combinations::<5>(&world).size_hint(), + (0, Some(0)) + ); + assert_eq!(a_query.iter_combinations::<1024>(&world).count(), 0); + assert_eq!( + a_query.iter_combinations::<1024>(&world).size_hint(), + (0, Some(0)) + ); + + let values: Vec<[&A; 2]> = world.query::<&A>().iter_combinations(&world).collect(); + assert_eq!( + values, + vec![ + [&A(1), &A(2)], + [&A(1), &A(3)], + [&A(1), &A(4)], + [&A(2), &A(3)], + [&A(2), &A(4)], + [&A(3), &A(4)], + ] + ); + let size = a_query.iter_combinations::<3>(&world).size_hint(); + assert_eq!(size.1, Some(4)); + let values: Vec<[&A; 3]> = a_query.iter_combinations(&world).collect(); + assert_eq!( + values, + vec![ + [&A(1), &A(2), &A(3)], + [&A(1), &A(2), &A(4)], + [&A(1), &A(3), &A(4)], + [&A(2), &A(3), &A(4)], + ] + ); + + let mut query = world.query::<&mut A>(); + let mut combinations = query.iter_combinations_mut(&mut world); + while let Some([mut a, mut b, mut c]) = combinations.fetch_next() { + a.0 += 10; + b.0 += 100; + c.0 += 1000; + } + + let values: Vec<[&A; 3]> = a_query.iter_combinations(&world).collect(); + assert_eq!( + values, + vec![ + [&A(31), &A(212), &A(1203)], + [&A(31), &A(212), &A(3004)], + [&A(31), &A(1203), &A(3004)], + [&A(212), &A(1203), &A(3004)] + ] + ); + + let mut b_query = world.query::<&B>(); + assert_eq!( + b_query.iter_combinations::<2>(&world).size_hint(), + (0, Some(0)) + ); + let values: Vec<[&B; 2]> = b_query.iter_combinations(&world).collect(); + assert_eq!(values, Vec::<[&B; 2]>::new()); + } + + #[test] + fn query_iter_combinations_sparse() { + let mut world = World::new(); + world + .register_component(ComponentDescriptor::new::(StorageType::SparseSet)) + .unwrap(); + + world.spawn_batch((1..=4).map(|i| (A(i),))); + + let mut query = world.query::<&mut A>(); + let mut combinations = query.iter_combinations_mut(&mut world); + while let Some([mut a, mut b, mut c]) = combinations.fetch_next() { + a.0 += 10; + b.0 += 100; + c.0 += 1000; + } + + let mut query = world.query::<&A>(); + let values: Vec<[&A; 3]> = query.iter_combinations(&world).collect(); + assert_eq!( + values, + vec![ + [&A(31), &A(212), &A(1203)], + [&A(31), &A(212), &A(3004)], + [&A(31), &A(1203), &A(3004)], + [&A(212), &A(1203), &A(3004)] + ] + ); + } + #[test] fn multi_storage_query() { let mut world = World::new(); diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index f4c0042a199c1b..8e4feccc355213 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -3,8 +3,8 @@ use crate::{ component::ComponentId, entity::Entity, query::{ - Access, Fetch, FetchState, FilterFetch, FilteredAccess, QueryIter, ReadOnlyFetch, - WorldQuery, + Access, Fetch, FetchState, FilterFetch, FilteredAccess, QueryCombinationIter, QueryIter, + ReadOnlyFetch, WorldQuery, }, storage::TableId, world::{World, WorldId}, @@ -205,6 +205,27 @@ where unsafe { self.iter_unchecked(world) } } + #[inline] + pub fn iter_combinations<'w, 's, const K: usize>( + &'s mut self, + world: &'w World, + ) -> QueryCombinationIter<'w, 's, Q, F, K> + where + Q::Fetch: ReadOnlyFetch, + { + // SAFE: query is read only + unsafe { self.iter_combinations_unchecked(world) } + } + + #[inline] + pub fn iter_combinations_mut<'w, 's, const K: usize>( + &'s mut self, + world: &'w mut World, + ) -> QueryCombinationIter<'w, 's, Q, F, K> { + // SAFE: query has unique world access + unsafe { self.iter_combinations_unchecked(world) } + } + /// # Safety /// /// This does not check for mutable query correctness. To be safe, make sure mutable queries @@ -222,6 +243,22 @@ where /// /// This does not check for mutable query correctness. To be safe, make sure mutable queries /// have unique access to the components they query. + #[inline] + pub unsafe fn iter_combinations_unchecked<'w, 's, const K: usize>( + &'s mut self, + world: &'w World, + ) -> QueryCombinationIter<'w, 's, Q, F, K> { + self.validate_world_and_update_archetypes(world); + self.iter_combinations_unchecked_manual( + world, + world.last_change_tick(), + world.read_change_tick(), + ) + } + + /// # Safety + /// This does not check for mutable query correctness. To be safe, make sure mutable queries + /// have unique access to the components they query. /// This does not validate that `world.id()` matches `self.world_id`. Calling this on a `world` /// with a mismatched WorldId is unsound. #[inline] @@ -234,6 +271,21 @@ where QueryIter::new(world, self, last_change_tick, change_tick) } + /// # Safety + /// This does not check for mutable query correctness. To be safe, make sure mutable queries + /// have unique access to the components they query. + /// This does not validate that `world.id()` matches `self.world_id`. Calling this on a `world` + /// with a mismatched WorldId is unsound. + #[inline] + pub(crate) unsafe fn iter_combinations_unchecked_manual<'w, 's, const K: usize>( + &'s self, + world: &'w World, + last_change_tick: u32, + change_tick: u32, + ) -> QueryCombinationIter<'w, 's, Q, F, K> { + QueryCombinationIter::new(world, self, last_change_tick, change_tick) + } + #[inline] pub fn for_each<'w>( &mut self, diff --git a/crates/bevy_ecs/src/system/query.rs b/crates/bevy_ecs/src/system/query.rs index a2f2e94be9a5ac..908651125ac2bc 100644 --- a/crates/bevy_ecs/src/system/query.rs +++ b/crates/bevy_ecs/src/system/query.rs @@ -2,7 +2,8 @@ use crate::{ component::Component, entity::Entity, query::{ - Fetch, FilterFetch, QueryEntityError, QueryIter, QueryState, ReadOnlyFetch, WorldQuery, + Fetch, FilterFetch, QueryCombinationIter, QueryEntityError, QueryIter, QueryState, + ReadOnlyFetch, WorldQuery, }, world::{Mut, World}, }; @@ -157,6 +158,29 @@ where } } + /// Returns an [`Iterator`] over all possible combinations of `K` query results without repetition. + /// This can only be called for read-only queries + /// + /// For permutations of size K of query returning N results, you will get: + /// - if K == N: one permutation of all query results + /// - if K < N: all possible K-sized combinations of query results, without repetition + /// - if K > N: empty set (no K-sized combinations exist) + #[inline] + pub fn iter_combinations(&self) -> QueryCombinationIter<'_, '_, Q, F, K> + where + Q::Fetch: ReadOnlyFetch, + { + // SAFE: system runs without conflicts with other systems. + // same-system queries have runtime borrow checks when they conflict + unsafe { + self.state.iter_combinations_unchecked_manual( + self.world, + self.last_change_tick, + self.change_tick, + ) + } + } + /// Returns an [`Iterator`] over the query results. #[inline] pub fn iter_mut(&mut self) -> QueryIter<'_, '_, Q, F> { @@ -168,6 +192,42 @@ where } } + /// Iterates over all possible combinations of `K` query results without repetition. + /// + /// The returned value is not an `Iterator`, because that would lead to aliasing of mutable references. + /// In order to iterate it, use `fetch_next` method with `while let Some(..)` loop pattern. + /// + /// ``` + /// # struct A; + /// # use bevy_ecs::prelude::*; + /// # fn some_system(mut query: Query<&mut A>) { + /// // iterate using `fetch_next` in while loop + /// let mut combinations = query.iter_combinations_mut(); + /// while let Some([mut a, mut b]) = combinations.fetch_next() { + /// // mutably access components data + /// } + /// # } + /// ``` + /// + /// There is no `for_each` method, because it cannot be safely implemented + /// due to a [compiler bug](https://github.com/rust-lang/rust/issues/62529). + /// + /// For immutable access see [`Query::iter_combinations`]. + #[inline] + pub fn iter_combinations_mut( + &mut self, + ) -> QueryCombinationIter<'_, '_, Q, F, K> { + // SAFE: system runs without conflicts with other systems. + // same-system queries have runtime borrow checks when they conflict + unsafe { + self.state.iter_combinations_unchecked_manual( + self.world, + self.last_change_tick, + self.change_tick, + ) + } + } + /// Returns an [`Iterator`] over the query results. /// /// # Safety @@ -182,6 +242,25 @@ where .iter_unchecked_manual(self.world, self.last_change_tick, self.change_tick) } + /// Iterates over all possible combinations of `K` query results without repetition. + /// See [`Query::iter_combinations`]. + /// + /// # Safety + /// This allows aliased mutability. You must make sure this call does not result in multiple + /// mutable references to the same component + #[inline] + pub unsafe fn iter_combinations_unsafe( + &self, + ) -> QueryCombinationIter<'_, '_, Q, F, K> { + // SEMI-SAFE: system runs without conflicts with other systems. + // same-system queries have runtime borrow checks when they conflict + self.state.iter_combinations_unchecked_manual( + self.world, + self.last_change_tick, + self.change_tick, + ) + } + /// Runs `f` on each query result. This is faster than the equivalent iter() method, but cannot /// be chained like a normal [`Iterator`]. /// diff --git a/examples/ecs/iter_combinations.rs b/examples/ecs/iter_combinations.rs new file mode 100644 index 00000000000000..97ed7b3280f487 --- /dev/null +++ b/examples/ecs/iter_combinations.rs @@ -0,0 +1,158 @@ +use bevy::{core::FixedTimestep, prelude::*}; +use rand::{thread_rng, Rng}; + +#[derive(Debug, Hash, PartialEq, Eq, Clone, StageLabel)] +struct FixedUpdateStage; + +const DELTA_TIME: f64 = 0.01; + +fn main() { + App::build() + .insert_resource(Msaa { samples: 4 }) + .add_plugins(DefaultPlugins) + .add_startup_system(generate_bodies.system()) + .add_stage_after( + CoreStage::Update, + FixedUpdateStage, + SystemStage::parallel() + .with_run_criteria(FixedTimestep::step(DELTA_TIME)) + .with_system(interact_bodies.system()) + .with_system(integrate.system()), + ) + .run(); +} + +const GRAVITY_CONSTANT: f32 = 0.001; +const SOFTENING: f32 = 0.01; +const NUM_BODIES: usize = 100; + +#[derive(Default)] +struct Mass(f32); +#[derive(Default)] +struct Acceleration(Vec3); +#[derive(Default)] +struct LastPos(Vec3); + +#[derive(Bundle, Default)] +struct BodyBundle { + #[bundle] + pbr: PbrBundle, + mass: Mass, + last_pos: LastPos, + acceleration: Acceleration, +} + +fn generate_bodies( + mut commands: Commands, + mut meshes: ResMut>, + mut materials: ResMut>, +) { + let mesh = meshes.add(Mesh::from(shape::Icosphere { + radius: 1.0, + subdivisions: 3, + })); + + let pos_range = 1.0..15.0; + let color_range = 0.5..1.0; + let vel_range = -0.5..0.5; + + let mut rng = thread_rng(); + for _ in 0..NUM_BODIES { + let mass_value_cube_root: f32 = rng.gen_range(0.5..4.0); + let mass_value: f32 = mass_value_cube_root * mass_value_cube_root * mass_value_cube_root; + + let position = Vec3::new( + rng.gen_range(-1.0..1.0), + rng.gen_range(-1.0..1.0), + rng.gen_range(-1.0..1.0), + ) + .normalize() + * rng.gen_range(pos_range.clone()); + + commands.spawn_bundle(BodyBundle { + pbr: PbrBundle { + transform: Transform { + translation: position, + scale: Vec3::splat(mass_value_cube_root * 0.1), + ..Default::default() + }, + mesh: mesh.clone(), + material: materials.add( + Color::rgb_linear( + rng.gen_range(color_range.clone()), + rng.gen_range(color_range.clone()), + rng.gen_range(color_range.clone()), + ) + .into(), + ), + ..Default::default() + }, + mass: Mass(mass_value), + acceleration: Acceleration(Vec3::ZERO), + last_pos: LastPos( + position + - Vec3::new( + rng.gen_range(vel_range.clone()), + rng.gen_range(vel_range.clone()), + rng.gen_range(vel_range.clone()), + ) * DELTA_TIME as f32, + ), + }); + } + + // add bigger "star" body in the center + commands + .spawn_bundle(BodyBundle { + pbr: PbrBundle { + transform: Transform { + scale: Vec3::splat(0.5), + ..Default::default() + }, + mesh: meshes.add(Mesh::from(shape::Icosphere { + radius: 1.0, + subdivisions: 5, + })), + material: materials.add((Color::ORANGE_RED * 10.0).into()), + ..Default::default() + }, + mass: Mass(1000.0), + ..Default::default() + }) + .insert(PointLight { + color: Color::ORANGE_RED, + ..Default::default() + }); + commands.spawn_bundle(PerspectiveCameraBundle { + transform: Transform::from_xyz(0.0, 10.5, -20.0).looking_at(Vec3::ZERO, Vec3::Y), + ..Default::default() + }); +} + +fn interact_bodies(mut query: Query<(&Mass, &GlobalTransform, &mut Acceleration)>) { + let mut iter = query.iter_combinations_mut(); + while let Some([(Mass(m1), transform1, mut acc1), (Mass(m2), transform2, mut acc2)]) = + iter.fetch_next() + { + let delta = transform2.translation - transform1.translation; + let distance_sq: f32 = delta.length_squared(); + + let f = GRAVITY_CONSTANT / (distance_sq * (distance_sq + SOFTENING).sqrt()); + let force_unit_mass = delta * f; + acc1.0 += force_unit_mass * *m2; + acc2.0 -= force_unit_mass * *m1; + } +} + +fn integrate(mut query: Query<(&mut Acceleration, &mut Transform, &mut LastPos)>) { + let dt_sq = (DELTA_TIME * DELTA_TIME) as f32; + for (mut acceleration, mut transform, mut last_pos) in query.iter_mut() { + // verlet integration + // x(t+dt) = 2x(t) - x(t-dt) + a(t)dt^2 + O(dt^4) + + let new_pos = + transform.translation + transform.translation - last_pos.0 + acceleration.0 * dt_sq; + acceleration.0 = Vec3::ZERO; + last_pos.0 = transform.translation; + transform.translation = new_pos; + } +}