diff --git a/crates/bevy_asset/src/processor/mod.rs b/crates/bevy_asset/src/processor/mod.rs index 625a484330078..9becd3e110b78 100644 --- a/crates/bevy_asset/src/processor/mod.rs +++ b/crates/bevy_asset/src/processor/mod.rs @@ -166,7 +166,7 @@ impl AssetProcessor { let start_time = std::time::Instant::now(); debug!("Processing Assets"); IoTaskPool::get().scope(|scope| { - scope.spawn(async move { + scope.spawn_async(async move { self.initialize().await.unwrap(); for source in self.sources().iter_processed() { self.process_assets_internal(scope, source, PathBuf::from("")) @@ -316,7 +316,7 @@ impl AssetProcessor { error!("AddFolder event cannot be handled in single threaded mode (or WASM) yet."); #[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] IoTaskPool::get().scope(|scope| { - scope.spawn(async move { + scope.spawn_async(async move { self.process_assets_internal(scope, source, path) .await .unwrap(); @@ -445,7 +445,7 @@ impl AssetProcessor { } else { // Files without extensions are skipped let processor = self.clone(); - scope.spawn(async move { + scope.spawn_async(async move { processor.process_asset(source, path).await; }); } @@ -461,7 +461,7 @@ impl AssetProcessor { for path in check_reprocess_queue.drain(..) { let processor = self.clone(); let source = self.get_source(path.source()).unwrap(); - scope.spawn(async move { + scope.spawn_async(async move { processor.process_asset(source, path.into()).await; }); } diff --git a/crates/bevy_asset/src/server/loaders.rs b/crates/bevy_asset/src/server/loaders.rs index 65f21d6b9b52f..a7b6ec6d00066 100644 --- a/crates/bevy_asset/src/server/loaders.rs +++ b/crates/bevy_asset/src/server/loaders.rs @@ -79,7 +79,7 @@ impl AssetLoaders { MaybeAssetLoader::Ready(_) => unreachable!(), MaybeAssetLoader::Pending { sender, .. } => { IoTaskPool::get() - .spawn(async move { + .spawn_async(async move { let _ = sender.broadcast(loader).await; }) .detach(); diff --git a/crates/bevy_asset/src/server/mod.rs b/crates/bevy_asset/src/server/mod.rs index cc0825d3aaa57..f2cba427398d5 100644 --- a/crates/bevy_asset/src/server/mod.rs +++ b/crates/bevy_asset/src/server/mod.rs @@ -297,7 +297,7 @@ impl AssetServer { let owned_handle = Some(handle.clone().untyped()); let server = self.clone(); IoTaskPool::get() - .spawn(async move { + .spawn_async(async move { if let Err(err) = server.load_internal(owned_handle, path, false, None).await { error!("{}", err); } @@ -367,7 +367,7 @@ impl AssetServer { let server = self.clone(); IoTaskPool::get() - .spawn(async move { + .spawn_async(async move { let path_clone = path.clone(); match server.load_untyped_async(path).await { Ok(handle) => server.send_asset_event(InternalAssetEvent::Loaded { @@ -552,7 +552,7 @@ impl AssetServer { let server = self.clone(); let path = path.into().into_owned(); IoTaskPool::get() - .spawn(async move { + .spawn_async(async move { let mut reloaded = false; let requests = server @@ -691,7 +691,7 @@ impl AssetServer { let path = path.into_owned(); let server = self.clone(); IoTaskPool::get() - .spawn(async move { + .spawn_async(async move { let Ok(source) = server.get_source(path.source()) else { error!( "Failed to load {path}. AssetSource {:?} does not exist", diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index 3eb76d3d3df75..0c821aa9bbba0 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -1131,7 +1131,7 @@ impl QueryState { ) { // NOTE: If you are changing query iteration code, remember to update the following places, where relevant: // QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual - bevy_tasks::ComputeTaskPool::get().scope(|scope| { + let _: Vec<()> = bevy_tasks::ComputeTaskPool::get().scope(|scope| { if D::IS_DENSE && F::IS_DENSE { // SAFETY: We only access table data that has been registered in `self.archetype_component_access`. let tables = unsafe { &world.storages().tables }; @@ -1145,7 +1145,7 @@ impl QueryState { while offset < table.entity_count() { let mut func = func.clone(); let len = batch_size.min(table.entity_count() - offset); - scope.spawn(async move { + scope.spawn(move || { #[cfg(feature = "trace")] let _span = self.par_iter_span.enter(); let table = &world @@ -1172,7 +1172,7 @@ impl QueryState { while offset < archetype.len() { let mut func = func.clone(); let len = batch_size.min(archetype.len() - offset); - scope.spawn(async move { + scope.spawn(move || { #[cfg(feature = "trace")] let _span = self.par_iter_span.enter(); let archetype = diff --git a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs index 82a3e207e995f..04b86cc4bc4e1 100644 --- a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs +++ b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs @@ -610,7 +610,7 @@ impl ExecutorState { #[cfg(feature = "trace")] let system_span = system_meta.system_task_span.clone(); - let task = async move { + let mut task = move || { let res = std::panic::catch_unwind(AssertUnwindSafe(|| { #[cfg(feature = "trace")] let _span = system_span.enter(); @@ -629,6 +629,14 @@ impl ExecutorState { if system_meta.is_send { context.scope.spawn(task); } else { + let task = async move { task() }; + #[cfg(feature = "trace")] + let task = task.instrument( + self.system_task_metadata[system_index] + .system_task_span + .clone(), + ); + self.local_thread_running = true; context.scope.spawn_on_external(task); } diff --git a/crates/bevy_ecs/src/schedule/mod.rs b/crates/bevy_ecs/src/schedule/mod.rs index 8ac9a9d47bec5..3220cb578a90d 100644 --- a/crates/bevy_ecs/src/schedule/mod.rs +++ b/crates/bevy_ecs/src/schedule/mod.rs @@ -1139,6 +1139,7 @@ mod tests { /// verify the [`SimpleExecutor`] supports stepping #[test] + #[allow(deprecated)] fn simple_executor() { assert_executor_supports_stepping!(ExecutorKind::Simple); } diff --git a/crates/bevy_gltf/src/loader.rs b/crates/bevy_gltf/src/loader.rs index c6200a6339a06..9fdc559caa43e 100644 --- a/crates/bevy_gltf/src/loader.rs +++ b/crates/bevy_gltf/src/loader.rs @@ -354,7 +354,7 @@ async fn load_gltf<'a, 'b, 'c>( let parent_path = load_context.path().parent().unwrap(); let linear_textures = &linear_textures; let buffer_data = &buffer_data; - scope.spawn(async move { + scope.spawn_async(async move { load_image( gltf_texture, buffer_data, diff --git a/crates/bevy_render/src/pipelined_rendering.rs b/crates/bevy_render/src/pipelined_rendering.rs index 0688c90095f95..82ed350cce339 100644 --- a/crates/bevy_render/src/pipelined_rendering.rs +++ b/crates/bevy_render/src/pipelined_rendering.rs @@ -149,7 +149,7 @@ impl Plugin for PipelinedRenderingPlugin { // run a scope here to allow main world to use this thread while it's waiting for the render app let sent_app = compute_task_pool .scope(|s| { - s.spawn(async { app_to_render_receiver.recv().await }); + s.spawn_async(async { app_to_render_receiver.recv().await }); }) .pop(); let Some(Ok(mut render_app)) = sent_app else { @@ -182,7 +182,7 @@ fn update_rendering(app_world: &mut World, _sub_app: &mut App) { // while we wait for the render world to be received. let mut render_app = ComputeTaskPool::get() .scope_with_executor(true, Some(&*main_thread_executor.0), |s| { - s.spawn(async { render_channels.recv().await }); + s.spawn_async(async { render_channels.recv().await }); }) .pop() .unwrap(); diff --git a/crates/bevy_render/src/render_resource/pipeline_cache.rs b/crates/bevy_render/src/render_resource/pipeline_cache.rs index 2e2f2eeafa78e..b5dfbeed1452a 100644 --- a/crates/bevy_render/src/render_resource/pipeline_cache.rs +++ b/crates/bevy_render/src/render_resource/pipeline_cache.rs @@ -955,7 +955,9 @@ fn create_pipeline_task( sync: bool, ) -> CachedPipelineState { if !sync { - return CachedPipelineState::Creating(bevy_tasks::AsyncComputeTaskPool::get().spawn(task)); + return CachedPipelineState::Creating( + bevy_tasks::AsyncComputeTaskPool::get().spawn_async(task), + ); } match futures_lite::future::block_on(task) { diff --git a/crates/bevy_render/src/renderer/mod.rs b/crates/bevy_render/src/renderer/mod.rs index fa4377a405c47..2cffd33df5d0b 100644 --- a/crates/bevy_render/src/renderer/mod.rs +++ b/crates/bevy_render/src/renderer/mod.rs @@ -405,7 +405,7 @@ impl<'w> RenderContext<'w> { command_buffers .push((i, command_buffer_generation_task(render_device))); } else { - task_pool.spawn(async move { + task_pool.spawn_async(async move { (i, command_buffer_generation_task(render_device)) }); } diff --git a/crates/bevy_render/src/view/window/screenshot.rs b/crates/bevy_render/src/view/window/screenshot.rs index c13a60f88524c..2cce7f05edee4 100644 --- a/crates/bevy_render/src/view/window/screenshot.rs +++ b/crates/bevy_render/src/view/window/screenshot.rs @@ -367,7 +367,7 @@ pub(crate) fn collect_screenshots(world: &mut World) { )); }; - AsyncComputeTaskPool::get().spawn(finish).detach(); + AsyncComputeTaskPool::get().spawn_async(finish).detach(); } } } diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index 3fb0e3c297acd..345ab52c2566a 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -14,10 +14,11 @@ multi-threaded = [] [dependencies] futures-lite = "2.0.1" async-executor = "1.7.2" -async-channel = "2.2.0" async-io = { version = "2.0.0", optional = true } async-task = "4.2.0" concurrent-queue = "2.0.0" +rayon-core = "1.0" +parking = "2.2" [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen-futures = "0.4" diff --git a/crates/bevy_tasks/src/iter/mod.rs b/crates/bevy_tasks/src/iter/mod.rs index 6887fa05440a1..f6fd8b3eab105 100644 --- a/crates/bevy_tasks/src/iter/mod.rs +++ b/crates/bevy_tasks/src/iter/mod.rs @@ -36,7 +36,7 @@ where fn count(mut self, pool: &TaskPool) -> usize { pool.scope(|s| { while let Some(batch) = self.next_batch() { - s.spawn(async move { batch.count() }); + s.spawn_async(async move { batch.count() }); } }) .iter() @@ -108,7 +108,7 @@ where pool.scope(|s| { while let Some(batch) = self.next_batch() { let newf = f.clone(); - s.spawn(async move { + s.spawn_async(async move { batch.for_each(newf); }); } @@ -197,7 +197,7 @@ where { pool.scope(|s| { while let Some(batch) = self.next_batch() { - s.spawn(async move { batch.collect::>() }); + s.spawn_async(async move { batch.collect::>() }); } }) .into_iter() @@ -219,7 +219,7 @@ where pool.scope(|s| { while let Some(batch) = self.next_batch() { let newf = f.clone(); - s.spawn(async move { batch.partition::, F>(newf) }); + s.spawn_async(async move { batch.partition::, F>(newf) }); } }) .into_iter() @@ -246,7 +246,7 @@ where while let Some(batch) = self.next_batch() { let newf = f.clone(); let newi = init.clone(); - s.spawn(async move { batch.fold(newi, newf) }); + s.spawn_async(async move { batch.fold(newi, newf) }); } }) } @@ -263,7 +263,7 @@ where pool.scope(|s| { while let Some(mut batch) = self.next_batch() { let newf = f.clone(); - s.spawn(async move { batch.all(newf) }); + s.spawn_async(async move { batch.all(newf) }); } }) .into_iter() @@ -282,7 +282,7 @@ where pool.scope(|s| { while let Some(mut batch) = self.next_batch() { let newf = f.clone(); - s.spawn(async move { batch.any(newf) }); + s.spawn_async(async move { batch.any(newf) }); } }) .into_iter() @@ -302,7 +302,7 @@ where let poses = pool.scope(|s| { while let Some(batch) = self.next_batch() { let mut newf = f.clone(); - s.spawn(async move { + s.spawn_async(async move { let mut len = 0; let mut pos = None; for item in batch { @@ -334,7 +334,7 @@ where { pool.scope(|s| { while let Some(batch) = self.next_batch() { - s.spawn(async move { batch.max() }); + s.spawn_async(async move { batch.max() }); } }) .into_iter() @@ -351,7 +351,7 @@ where { pool.scope(|s| { while let Some(batch) = self.next_batch() { - s.spawn(async move { batch.min() }); + s.spawn_async(async move { batch.min() }); } }) .into_iter() @@ -371,7 +371,7 @@ where pool.scope(|s| { while let Some(batch) = self.next_batch() { let newf = f.clone(); - s.spawn(async move { batch.max_by_key(newf) }); + s.spawn_async(async move { batch.max_by_key(newf) }); } }) .into_iter() @@ -391,7 +391,7 @@ where pool.scope(|s| { while let Some(batch) = self.next_batch() { let newf = f.clone(); - s.spawn(async move { batch.max_by(newf) }); + s.spawn_async(async move { batch.max_by(newf) }); } }) .into_iter() @@ -411,7 +411,7 @@ where pool.scope(|s| { while let Some(batch) = self.next_batch() { let newf = f.clone(); - s.spawn(async move { batch.min_by_key(newf) }); + s.spawn_async(async move { batch.min_by_key(newf) }); } }) .into_iter() @@ -431,7 +431,7 @@ where pool.scope(|s| { while let Some(batch) = self.next_batch() { let newf = f.clone(); - s.spawn(async move { batch.min_by(newf) }); + s.spawn_async(async move { batch.min_by(newf) }); } }) .into_iter() @@ -484,7 +484,7 @@ where { pool.scope(|s| { while let Some(batch) = self.next_batch() { - s.spawn(async move { batch.sum() }); + s.spawn_async(async move { batch.sum() }); } }) .into_iter() @@ -501,7 +501,7 @@ where { pool.scope(|s| { while let Some(batch) = self.next_batch() { - s.spawn(async move { batch.product() }); + s.spawn_async(async move { batch.product() }); } }) .into_iter() diff --git a/crates/bevy_tasks/src/lib.rs b/crates/bevy_tasks/src/lib.rs index 60b162dbed596..1611f67f81283 100644 --- a/crates/bevy_tasks/src/lib.rs +++ b/crates/bevy_tasks/src/lib.rs @@ -1,5 +1,9 @@ #![doc = include_str!("../README.md")] +use rayon_core::Yield; +use std::future::Future; +use std::task::{Context, Poll}; + mod slice; pub use slice::{ParallelSlice, ParallelSliceMut}; @@ -28,8 +32,6 @@ pub use thread_executor::{ThreadExecutor, ThreadExecutorTicker}; #[cfg(feature = "async-io")] pub use async_io::block_on; -#[cfg(not(feature = "async-io"))] -pub use futures_lite::future::block_on; pub use futures_lite::future::poll_once; mod iter; @@ -61,3 +63,71 @@ pub fn available_parallelism() -> usize { .map(NonZeroUsize::get) .unwrap_or(1) } + +/// Blocks the current thread on a future. +/// +/// # Examples +/// +/// ``` +/// use futures_lite::future; +/// +/// let val = future::block_on(async { +/// 1 + 2 +/// }); +/// +/// assert_eq!(val, 3); +/// ``` +#[cfg(not(feature = "async-io"))] +pub fn block_on(future: impl Future) -> T { + use core::cell::RefCell; + use core::task::Waker; + + use parking::Parker; + + // Pin the future on the stack. + futures_lite::pin!(future); + + // Creates a parker and an associated waker that unparks it. + fn parker_and_waker() -> (Parker, Waker) { + let parker = Parker::new(); + let unparker = parker.unparker(); + let waker = Waker::from(unparker); + (parker, waker) + } + + thread_local! { + // Cached parker and waker for efficiency. + static CACHE: RefCell<(Parker, Waker)> = RefCell::new(parker_and_waker()); + } + + CACHE.with(|cache| { + // Try grabbing the cached parker and waker. + let tmp_cached; + let tmp_fresh; + let (parker, waker) = match cache.try_borrow_mut() { + Ok(cache) => { + // Use the cached parker and waker. + tmp_cached = cache; + &*tmp_cached + } + Err(_) => { + // Looks like this is a recursive `block_on()` call. + // Create a fresh parker and waker. + tmp_fresh = parker_and_waker(); + &tmp_fresh + } + }; + + let cx = &mut Context::from_waker(waker); + // Keep polling until the future is ready. + loop { + match future.as_mut().poll(cx) { + Poll::Ready(output) => return output, + Poll::Pending => match rayon_core::yield_now() { + Some(Yield::Executed) => continue, + Some(Yield::Idle) | None => parker.park(), + }, + } + } + }) +} diff --git a/crates/bevy_tasks/src/slice.rs b/crates/bevy_tasks/src/slice.rs index 8410478322ee0..623925c5666d8 100644 --- a/crates/bevy_tasks/src/slice.rs +++ b/crates/bevy_tasks/src/slice.rs @@ -39,7 +39,7 @@ pub trait ParallelSlice: AsRef<[T]> { let f = &f; task_pool.scope(|scope| { for chunk in slice.chunks(chunk_size) { - scope.spawn(async move { f(chunk) }); + scope.spawn_async(async move { f(chunk) }); } }) } @@ -136,7 +136,7 @@ pub trait ParallelSliceMut: AsMut<[T]> { let f = &f; task_pool.scope(|scope| { for chunk in slice.chunks_mut(chunk_size) { - scope.spawn(async move { f(chunk) }); + scope.spawn_async(async move { f(chunk) }); } }) } diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index 551bb06311fd2..c5117487e55e0 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -1,15 +1,9 @@ -use std::{ - future::Future, - marker::PhantomData, - mem, - panic::AssertUnwindSafe, - sync::Arc, - thread::{self, JoinHandle}, -}; +use std::{future::Future, marker::PhantomData, mem, panic::AssertUnwindSafe, sync::Arc}; -use async_task::FallibleTask; +use async_task::{FallibleTask, Runnable}; use concurrent_queue::ConcurrentQueue; use futures_lite::FutureExt; +use rayon_core::{ThreadPool, ThreadPoolBuilder}; use crate::{ block_on, @@ -31,17 +25,8 @@ impl Drop for CallOnDrop { #[derive(Default)] #[must_use] pub struct TaskPoolBuilder { - /// If set, we'll set up the thread pool to use at most `num_threads` threads. - /// Otherwise use the logical core count of the system - num_threads: Option, - /// If set, we'll use the given stack size rather than the system default - stack_size: Option, - /// Allows customizing the name of the threads - helpful for debugging. If set, threads will - /// be named (), i.e. "MyThreadPool (2)" - thread_name: Option, - + thread_pool_builder: ThreadPoolBuilder, on_thread_spawn: Option>, - on_thread_destroy: Option>, } impl TaskPoolBuilder { @@ -53,20 +38,22 @@ impl TaskPoolBuilder { /// Override the number of threads created for the pool. If unset, we default to the number /// of logical cores of the system pub fn num_threads(mut self, num_threads: usize) -> Self { - self.num_threads = Some(num_threads); + self.thread_pool_builder = self.thread_pool_builder.num_threads(num_threads); self } /// Override the stack size of the threads created for the pool pub fn stack_size(mut self, stack_size: usize) -> Self { - self.stack_size = Some(stack_size); + self.thread_pool_builder = self.thread_pool_builder.stack_size(stack_size); self } /// Override the name of the threads created for the pool. If set, threads will /// be named ` ()`, i.e. `MyThreadPool (2)` pub fn thread_name(mut self, thread_name: String) -> Self { - self.thread_name = Some(thread_name); + self.thread_pool_builder = self + .thread_pool_builder + .thread_name(move |idx| format!("{thread_name} ({idx})")); self } @@ -84,7 +71,7 @@ impl TaskPoolBuilder { /// This is called on the thread itself and has access to all thread-local storage. /// This will block thread termination until the callback completes. pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self { - self.on_thread_destroy = Some(Arc::new(f)); + self.thread_pool_builder = self.thread_pool_builder.exit_handler(move |_| f()); self } @@ -106,16 +93,7 @@ impl TaskPoolBuilder { /// will still execute a task, even if it is dropped. #[derive(Debug)] pub struct TaskPool { - /// The executor for the pool - /// - /// This has to be separate from TaskPoolInner because we have to create an `Arc` to - /// pass into the worker threads, and we must create the worker threads before we can create - /// the `Vec>` contained within `TaskPoolInner` - executor: Arc>, - - /// Inner state of the pool - threads: Vec>, - shutdown_tx: async_channel::Sender<()>, + thread_pool: ThreadPool, } impl TaskPool { @@ -135,72 +113,27 @@ impl TaskPool { } fn new_internal(builder: TaskPoolBuilder) -> Self { - let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>(); - - let executor = Arc::new(async_executor::Executor::new()); - - let num_threads = builder - .num_threads - .unwrap_or_else(crate::available_parallelism); - - let threads = (0..num_threads) - .map(|i| { - let ex = Arc::clone(&executor); - let shutdown_rx = shutdown_rx.clone(); - - let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() { - format!("{thread_name} ({i})") - } else { - format!("TaskPool ({i})") - }; - let mut thread_builder = thread::Builder::new().name(thread_name); - - if let Some(stack_size) = builder.stack_size { - thread_builder = thread_builder.stack_size(stack_size); - } - - let on_thread_spawn = builder.on_thread_spawn.clone(); - let on_thread_destroy = builder.on_thread_destroy.clone(); - - thread_builder - .spawn(move || { - TaskPool::LOCAL_EXECUTOR.with(|local_executor| { - if let Some(on_thread_spawn) = on_thread_spawn { - on_thread_spawn(); - drop(on_thread_spawn); - } - let _destructor = CallOnDrop(on_thread_destroy); - loop { - let res = std::panic::catch_unwind(|| { - let tick_forever = async move { - loop { - local_executor.tick().await; - } - }; - block_on(ex.run(tick_forever.or(shutdown_rx.recv()))) - }); - if let Ok(value) = res { - // Use unwrap_err because we expect a Closed error - value.unwrap_err(); - break; - } - } - }); - }) - .expect("Failed to spawn thread.") - }) - .collect(); - Self { - executor, - threads, - shutdown_tx, + thread_pool: builder + .thread_pool_builder + .spawn_handler(move |thread| { + let on_thread_spawn = builder.on_thread_spawn.clone(); + std::thread::spawn(move || { + if let Some(on_thread_spawn) = on_thread_spawn { + on_thread_spawn(); + } + thread.run() + }); + Ok(()) + }) + .build() + .expect("Failed to spawn thread pool."), } } /// Return the number of threads owned by the task pool pub fn thread_num(&self) -> usize { - self.threads.len() + self.thread_pool.current_num_threads() } /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback, @@ -352,95 +285,94 @@ impl TaskPool { // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety. // Any usages of the references passed into `Scope` must be accessed through // the transmuted reference for the rest of this function. - let executor: &async_executor::Executor = &self.executor; - // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) }; - // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let external_executor: &'env ThreadExecutor<'env> = - unsafe { mem::transmute(external_executor) }; - // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let scope_executor: &'env ThreadExecutor<'env> = unsafe { mem::transmute(scope_executor) }; - let spawned: ConcurrentQueue>>> = - ConcurrentQueue::unbounded(); - // shadow the variable so that the owned value cannot be used for the rest of the function - // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let spawned: &'env ConcurrentQueue< - FallibleTask>>, - > = unsafe { mem::transmute(&spawned) }; - - let scope = Scope { - executor, - external_executor, - scope_executor, - spawned, - scope: PhantomData, - env: PhantomData, - }; - - // shadow the variable so that the owned value cannot be used for the rest of the function - // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let scope: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) }; - - f(scope); - - if spawned.is_empty() { - Vec::new() - } else { - block_on(async move { - let get_results = async { - let mut results = Vec::with_capacity(spawned.len()); - while let Ok(task) = spawned.pop() { - if let Some(res) = task.await { - match res { - Ok(res) => results.push(res), - Err(payload) => std::panic::resume_unwind(payload), + self.thread_pool.in_place_scope(|scope| { + // SAFETY: All tasks must complete in this function so we can change the lifetime + let thread_pool: &rayon_core::Scope<'env> = unsafe { mem::transmute(scope) }; + // SAFETY: As above, all futures must complete in this function so we can change the lifetime + let external_executor: &'env ThreadExecutor<'env> = + unsafe { mem::transmute(external_executor) }; + // SAFETY: As above, all futures must complete in this function so we can change the lifetime + let scope_executor: &'env ThreadExecutor<'env> = + unsafe { mem::transmute(scope_executor) }; + let spawned: ConcurrentQueue>>> = + ConcurrentQueue::unbounded(); + // shadow the variable so that the owned value cannot be used for the rest of the function + // SAFETY: As above, all futures must complete in this function so we can change the lifetime + let spawned: &'env ConcurrentQueue< + FallibleTask>>, + > = unsafe { mem::transmute(&spawned) }; + + let scope = Scope { + thread_pool, + external_executor, + scope_executor, + spawned, + scope: PhantomData, + env: PhantomData, + }; + + // shadow the variable so that the owned value cannot be used for the rest of the function + // SAFETY: As above, all futures must complete in this function so we can change the lifetime + let scope: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) }; + + f(scope); + + if spawned.is_empty() { + Vec::new() + } else { + block_on(async move { + let get_results = async { + let mut results = Vec::with_capacity(spawned.len()); + while let Ok(task) = spawned.pop() { + if let Some(res) = task.await { + match res { + Ok(res) => results.push(res), + Err(payload) => std::panic::resume_unwind(payload), + } + } else { + panic!("Failed to catch panic!"); } - } else { - panic!("Failed to catch panic!"); } - } - results - }; - - let tick_task_pool_executor = tick_task_pool_executor || self.threads.is_empty(); - - // we get this from a thread local so we should always be on the scope executors thread. - // note: it is possible `scope_executor` and `external_executor` is the same executor, - // in that case, we should only tick one of them, otherwise, it may cause deadlock. - let scope_ticker = scope_executor.ticker().unwrap(); - let external_ticker = if !external_executor.is_same(scope_executor) { - external_executor.ticker() - } else { - None - }; - - match (external_ticker, tick_task_pool_executor) { - (Some(external_ticker), true) => { - Self::execute_global_external_scope( - executor, - external_ticker, - scope_ticker, - get_results, - ) - .await - } - (Some(external_ticker), false) => { - Self::execute_external_scope(external_ticker, scope_ticker, get_results) + results + }; + + let tick_task_pool_executor = + tick_task_pool_executor || self.thread_pool.current_num_threads() == 0; + + // we get this from a thread local so we should always be on the scope executors thread. + // note: it is possible `scope_executor` and `external_executor` is the same executor, + // in that case, we should only tick one of them, otherwise, it may cause deadlock. + let scope_ticker = scope_executor.ticker().unwrap(); + let external_ticker = if !external_executor.is_same(scope_executor) { + external_executor.ticker() + } else { + None + }; + + match (external_ticker, tick_task_pool_executor) { + (Some(external_ticker), true) => { + Self::execute_global_external_scope( + external_ticker, + scope_ticker, + get_results, + ) .await + } + (Some(external_ticker), false) => { + Self::execute_external_scope(external_ticker, scope_ticker, get_results) + .await + } + // either external_executor is none or it is same as scope_executor + (None, true) => Self::execute_global_scope(scope_ticker, get_results).await, + (None, false) => Self::execute_scope(scope_ticker, get_results).await, } - // either external_executor is none or it is same as scope_executor - (None, true) => { - Self::execute_global_scope(executor, scope_ticker, get_results).await - } - (None, false) => Self::execute_scope(scope_ticker, get_results).await, - } - }) - } + }) + } + }) } #[inline] async fn execute_global_external_scope<'scope, 'ticker, T>( - executor: &'scope async_executor::Executor<'scope>, external_ticker: ThreadExecutorTicker<'scope, 'ticker>, scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, get_results: impl Future>, @@ -456,10 +388,7 @@ impl TaskPool { }; // we don't care if it errors. If a scoped task errors it will propagate // to get_results - let _result = AssertUnwindSafe(executor.run(tick_forever)) - .catch_unwind() - .await - .is_ok(); + let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok(); } }; execute_forever.or(get_results).await @@ -486,7 +415,6 @@ impl TaskPool { #[inline] async fn execute_global_scope<'scope, 'ticker, T>( - executor: &'scope async_executor::Executor<'scope>, scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, get_results: impl Future>, ) -> Vec { @@ -497,10 +425,7 @@ impl TaskPool { scope_ticker.tick().await; } }; - let _result = AssertUnwindSafe(executor.run(tick_forever)) - .catch_unwind() - .await - .is_ok(); + let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok(); } }; execute_forever.or(get_results).await @@ -524,6 +449,10 @@ impl TaskPool { execute_forever.or(get_results).await } + pub fn spawn(&self, f: impl FnOnce() + Send + 'static) { + self.thread_pool.spawn(f); + } + /// Spawns a static future onto the thread pool. The returned [`Task`] is a /// future that can be polled for the result. It can also be canceled and /// "detached", allowing the task to continue running even if dropped. In @@ -532,11 +461,17 @@ impl TaskPool { /// /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should /// be used instead. - pub fn spawn(&self, future: impl Future + Send + 'static) -> Task + pub fn spawn_async(&self, future: impl Future + Send + 'static) -> Task where T: Send + 'static, { - Task::new(self.executor.spawn(future)) + let (runnable, task) = async_task::spawn(future, |runnable: Runnable| { + rayon_core::spawn(move || { + runnable.run(); + }) + }); + runnable.schedule(); + Task::new(task) } /// Spawns a static future on the thread-local async executor for the @@ -550,7 +485,7 @@ impl TaskPool { /// /// Users should generally prefer to use [`TaskPool::spawn`] instead, /// unless the provided future is not `Send`. - pub fn spawn_local(&self, future: impl Future + 'static) -> Task + pub fn spawn_async_local(&self, future: impl Future + 'static) -> Task where T: 'static, { @@ -582,26 +517,12 @@ impl Default for TaskPool { } } -impl Drop for TaskPool { - fn drop(&mut self) { - self.shutdown_tx.close(); - - let panicking = thread::panicking(); - for join_handle in self.threads.drain(..) { - let res = join_handle.join(); - if !panicking { - res.expect("Task thread panicked while executing."); - } - } - } -} - /// A [`TaskPool`] scope for running one or more non-`'static` futures. /// /// For more information, see [`TaskPool::scope`]. #[derive(Debug)] pub struct Scope<'scope, 'env: 'scope, T> { - executor: &'scope async_executor::Executor<'scope>, + thread_pool: &'scope rayon_core::Scope<'scope>, external_executor: &'scope ThreadExecutor<'scope>, scope_executor: &'scope ThreadExecutor<'scope>, spawned: &'scope ConcurrentQueue>>>, @@ -611,6 +532,10 @@ pub struct Scope<'scope, 'env: 'scope, T> { } impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { + pub fn spawn(&self, f: impl FnOnce() + Send + 'scope) { + self.thread_pool.spawn(|_| f()); + } + /// Spawns a scoped future onto the thread pool. The scope *must* outlive /// the provided future. The results of the future will be returned as a part of /// [`TaskPool::scope`]'s return value. @@ -619,14 +544,19 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { /// instead. /// /// For more information, see [`TaskPool::scope`]. - pub fn spawn + 'scope + Send>(&self, f: Fut) { - let task = self - .executor - .spawn(AssertUnwindSafe(f).catch_unwind()) - .fallible(); + pub fn spawn_async + 'scope + Send>(&self, f: Fut) { + let thread_pool = self.thread_pool; + let (runnable, task) = unsafe { + async_task::spawn_unchecked(AssertUnwindSafe(f).catch_unwind(), |runnable: Runnable| { + thread_pool.spawn(|_| { + runnable.run(); + }) + }) + }; + runnable.schedule(); // ConcurrentQueue only errors when closed or full, but we never // close and use an unbounded queue, so it is safe to unwrap - self.spawned.push(task).unwrap(); + self.spawned.push(task.fallible()).unwrap(); } /// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive @@ -697,7 +627,7 @@ mod tests { let outputs = pool.scope(|scope| { for _ in 0..100 { let count_clone = count.clone(); - scope.spawn(async move { + scope.spawn_async(async move { if *foo != 42 { panic!("not 42!?!?") } else { @@ -781,7 +711,7 @@ mod tests { for i in 0..100 { if i % 2 == 0 { let count_clone = non_local_count.clone(); - scope.spawn(async move { + scope.spawn_async(async move { if *foo != 42 { panic!("not 42!?!?") } else { @@ -827,7 +757,7 @@ mod tests { thread::spawn(move || { inner_pool.scope(|scope| { let inner_count_clone = count_clone.clone(); - scope.spawn(async move { + scope.spawn_async(async move { inner_count_clone.fetch_add(1, Ordering::Release); }); let spawner = thread::current().id(); @@ -903,7 +833,7 @@ mod tests { inner_pool.scope(|scope| { let spawner = thread::current().id(); let inner_count_clone = count_clone.clone(); - scope.spawn(async move { + scope.spawn_async(async move { inner_count_clone.fetch_add(1, Ordering::Release); // spawning on the scope from another thread runs the futures on the scope's thread @@ -932,9 +862,9 @@ mod tests { let count = Arc::new(AtomicI32::new(0)); pool.scope(|scope| { - scope.spawn(async { + scope.spawn_async(async { pool.scope(|scope| { - scope.spawn(async { + scope.spawn_async(async { count.fetch_add(1, Ordering::Relaxed); }); });