diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 09112d01..8e5d4212 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -220,7 +220,9 @@ jobs: run: rustup show - uses: actions/checkout@v2 - name: run tests - run: cargo test -p maitake --no-default-features + # don't run doctests with `no-default-features`, as some of them + # require liballoc. + run: cargo test -p maitake --no-default-features --tests --lib # run loom tests maitake_loom: diff --git a/Cargo.lock b/Cargo.lock index 4813961e..1f53f99b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -418,6 +418,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "itoa" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112c678d4050afce233f4f2852bb2eb519230b3cf12f33585275537d7e41578d" + [[package]] name = "json" version = "0.12.4" @@ -476,6 +482,8 @@ dependencies = [ "generator", "pin-utils", "scoped-tls", + "serde", + "serde_json", "tracing 0.1.34", "tracing-subscriber 0.3.11", ] @@ -493,6 +501,7 @@ dependencies = [ "pin-project", "tracing 0.1.34", "tracing 0.2.0", + "tracing-subscriber 0.3.0", "tracing-subscriber 0.3.11", ] @@ -893,12 +902,49 @@ dependencies = [ "wait-timeout", ] +[[package]] +name = "ryu" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3f6f92acf49d1b98f7a81226834412ada05458b7364277387724a237f062695" + [[package]] name = "scoped-tls" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2" +[[package]] +name = "serde" +version = "1.0.137" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61ea8d54c77f8315140a05f4c7237403bf38b72704d031543aa1d16abbf517d1" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.137" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f26faba0c3959972377d3b2d306ee9f71faee9714294e41bb777f83f88578be" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b7ce2b32a1aed03c558dc61a5cd328f15aff2dbc17daad8fb8af04d2100e15c" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "sharded-slab" version = "0.1.4" diff --git a/bin/loom b/bin/loom index eb6ea345..0867cf85 100755 --- a/bin/loom +++ b/bin/loom @@ -3,7 +3,7 @@ set -x RUSTFLAGS="--cfg loom ${RUSTFLAGS}" \ -LOOM_LOG="${LOOM_LOG:-info}" \ +LOOM_LOG="${LOOM_LOG:-debug}" \ LOOM_LOCATION=true \ cargo test \ --profile loom \ diff --git a/bitfield/src/bitfield.rs b/bitfield/src/bitfield.rs index db3af151..bbbc6776 100644 --- a/bitfield/src/bitfield.rs +++ b/bitfield/src/bitfield.rs @@ -234,11 +234,11 @@ macro_rules! bitfield { #[automatically_derived] impl core::fmt::Debug for $Name { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - if f.alternate() { - f.debug_tuple(stringify!($Name)).field(&format_args!("{}", self)).finish() - } else { - f.debug_tuple(stringify!($Name)).field(&format_args!("{:#b}", self.0)).finish() - } + let mut dbg = f.debug_struct(stringify!($Name)); + $( + dbg.field(stringify!($Field), &self.get(Self::$Field)); + )+ + dbg.finish() } } diff --git a/maitake/Cargo.toml b/maitake/Cargo.toml index 7bf8055a..27137df3 100644 --- a/maitake/Cargo.toml +++ b/maitake/Cargo.toml @@ -48,8 +48,11 @@ git = "https://github.com/tokio-rs/tracing" [dev-dependencies] futures-util = "0.3" +[target.'cfg(not(loom))'.dev-dependencies] +tracing-subscriber = { git = "https://github.com/tokio-rs/tracing", features = ["ansi", "fmt"] } + [target.'cfg(loom)'.dev-dependencies] -loom = { version = "0.5.5", features = ["futures"] } +loom = { version = "0.5.5", features = ["futures", "checkpoint"] } tracing_01 = { package = "tracing", version = "0.1", default_features = false } tracing_subscriber_03 = { package = "tracing-subscriber", version = "0.3.11", features = ["fmt"] } diff --git a/maitake/src/lib.rs b/maitake/src/lib.rs index 9b04defa..d750bc67 100644 --- a/maitake/src/lib.rs +++ b/maitake/src/lib.rs @@ -1,6 +1,6 @@ #![cfg_attr(docsrs, doc = include_str!("../README.md"))] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg, doc_cfg_hide))] -#![cfg_attr(docsrs, doc(cfg_hide(docsrs)))] +#![cfg_attr(docsrs, doc(cfg_hide(docsrs, loom)))] #![cfg_attr(not(test), no_std)] #[cfg(feature = "alloc")] extern crate alloc; diff --git a/maitake/src/loom.rs b/maitake/src/loom.rs index b82aa9e4..05b2d99f 100644 --- a/maitake/src/loom.rs +++ b/maitake/src/loom.rs @@ -3,9 +3,42 @@ pub(crate) use self::inner::*; #[cfg(loom)] mod inner { + #![allow(dead_code)] + #[cfg(feature = "alloc")] - pub use loom::alloc; - pub use loom::{cell, future, hint, model, sync, thread}; + pub(crate) use loom::alloc; + pub(crate) use loom::{cell, future, hint, model, thread}; + + pub(crate) mod sync { + pub(crate) use loom::sync::*; + + pub(crate) mod spin { + pub(crate) use loom::sync::MutexGuard; + + /// Mock version of mycelium's spinlock, but using + /// `loom::sync::Mutex`. The API is slightly different, since the + /// mycelium mutex does not support poisoning. + #[derive(Debug)] + pub(crate) struct Mutex(loom::sync::Mutex); + + impl Mutex { + #[track_caller] + pub(crate) fn new(t: T) -> Self { + Self(loom::sync::Mutex::new(t)) + } + + #[track_caller] + pub fn try_lock(&self) -> Option> { + self.0.try_lock().ok() + } + + #[track_caller] + pub fn lock(&self) -> MutexGuard<'_, T> { + self.0.lock().expect("loom mutex will never poison") + } + } + } + } } #[cfg(not(loom))] @@ -15,6 +48,8 @@ mod inner { #[cfg(feature = "alloc")] pub use alloc::sync::*; pub use core::sync::*; + + pub use mycelium_util::sync::spin; } pub(crate) use core::sync::atomic; diff --git a/maitake/src/scheduler/tests.rs b/maitake/src/scheduler/tests.rs index 0ebbb3bb..5804e69d 100644 --- a/maitake/src/scheduler/tests.rs +++ b/maitake/src/scheduler/tests.rs @@ -49,6 +49,8 @@ mod alloc { static SCHEDULER: Lazy = Lazy::new(StaticScheduler::new); static IT_WORKED: AtomicBool = AtomicBool::new(false); + crate::util::trace_init(); + SCHEDULER.spawn(async { Yield::once().await; IT_WORKED.store(true, Ordering::Release); @@ -69,6 +71,7 @@ mod alloc { const TASKS: usize = 10; + crate::util::trace_init(); for _ in 0..TASKS { SCHEDULER.spawn(async { Yield::once().await; @@ -89,6 +92,7 @@ mod alloc { static SCHEDULER: Lazy = Lazy::new(StaticScheduler::new); static COMPLETED: AtomicUsize = AtomicUsize::new(0); + crate::util::trace_init(); let chan = Chan::new(1); SCHEDULER.spawn({ @@ -114,6 +118,7 @@ mod alloc { static SCHEDULER: Lazy = Lazy::new(StaticScheduler::new); static COMPLETED: AtomicUsize = AtomicUsize::new(0); + crate::util::trace_init(); let chan = Chan::new(1); SCHEDULER.spawn({ @@ -142,6 +147,8 @@ mod alloc { const TASKS: usize = 10; + crate::util::trace_init(); + for i in 0..TASKS { SCHEDULER.spawn(async { Yield::new(i).await; @@ -196,6 +203,8 @@ mod custom_storage { static SCHEDULER: StaticScheduler = unsafe { StaticScheduler::new_with_static_stub(&STUB) }; static IT_WORKED: AtomicBool = AtomicBool::new(false); + crate::util::trace_init(); + MyBoxTask::spawn(&SCHEDULER, async { Yield::once().await; IT_WORKED.store(true, Ordering::Release); @@ -217,6 +226,8 @@ mod custom_storage { const TASKS: usize = 10; + crate::util::trace_init(); + for _ in 0..TASKS { MyBoxTask::spawn(&SCHEDULER, async { Yield::once().await; @@ -238,6 +249,7 @@ mod custom_storage { static SCHEDULER: StaticScheduler = unsafe { StaticScheduler::new_with_static_stub(&STUB) }; static COMPLETED: AtomicUsize = AtomicUsize::new(0); + crate::util::trace_init(); let chan = Chan::new(1); MyBoxTask::spawn(&SCHEDULER, { @@ -264,6 +276,7 @@ mod custom_storage { static SCHEDULER: StaticScheduler = unsafe { StaticScheduler::new_with_static_stub(&STUB) }; static COMPLETED: AtomicUsize = AtomicUsize::new(0); + crate::util::trace_init(); let chan = Chan::new(1); MyBoxTask::spawn(&SCHEDULER, { @@ -291,6 +304,7 @@ mod custom_storage { static SCHEDULER: StaticScheduler = unsafe { StaticScheduler::new_with_static_stub(&STUB) }; static COMPLETED: AtomicUsize = AtomicUsize::new(0); + crate::util::trace_init(); const TASKS: usize = 10; for i in 0..TASKS { diff --git a/maitake/src/util.rs b/maitake/src/util.rs index 9b54cee8..02a31dd1 100644 --- a/maitake/src/util.rs +++ b/maitake/src/util.rs @@ -18,7 +18,12 @@ macro_rules! test_dbg { ($e:expr) => { match $e { e => { - crate::util::tracing::debug!("{} = {:?}", stringify!($e), &e); + crate::util::tracing::debug!( + location = %core::panic::Location::caller(), + "{} = {:?}", + stringify!($e), + &e + ); e } } @@ -33,7 +38,10 @@ macro_rules! test_trace { #[cfg(test)] macro_rules! test_trace { ($($args:tt)+) => { - crate::util::tracing::debug!($($args)+); + crate::util::tracing::debug!( + location = %core::panic::Location::caller(), + $($args)+ + ); }; } @@ -87,3 +95,12 @@ pub(crate) unsafe fn non_null(ptr: *mut T) -> NonNull { unsafe fn non_null(ptr: *mut T) -> NonNull { NonNull::new_unchecked(ptr) } + +#[cfg(all(test, not(loom)))] +pub(crate) fn trace_init() { + use tracing_subscriber::filter::LevelFilter; + let _ = tracing_subscriber::fmt() + .with_max_level(LevelFilter::TRACE) + .with_test_writer() + .try_init(); +} diff --git a/maitake/src/wait.rs b/maitake/src/wait.rs index b7309e52..186b7f08 100644 --- a/maitake/src/wait.rs +++ b/maitake/src/wait.rs @@ -1,16 +1,20 @@ //! Waiter cells and queues to allow tasks to wait for notifications. //! //! This module implements two types of structure for waiting: a [`WaitCell`], -//! which stores a *single* waiting task, and a wait *queue*, which +//! which stores a *single* waiting task, and a [`WaitQueue`], which //! stores a queue of waiting tasks. pub(crate) mod cell; -pub use cell::WaitCell; +pub mod queue; + +pub use self::cell::WaitCell; +#[doc(inline)] +pub use self::queue::WaitQueue; use core::task::Poll; -/// An error indicating that a [`WaitCell`] or queue was closed while attempting -/// register a waiter. -#[derive(Clone, Debug, PartialEq, Eq)] +/// An error indicating that a [`WaitCell`] or [`WaitQueue`] was closed while +/// attempting register a waiter. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub struct Closed(()); pub type WaitResult = Result<(), Closed>; diff --git a/maitake/src/wait/queue.rs b/maitake/src/wait/queue.rs new file mode 100644 index 00000000..9dc1e21d --- /dev/null +++ b/maitake/src/wait/queue.rs @@ -0,0 +1,927 @@ +use crate::{ + loom::{ + cell::UnsafeCell, + sync::{ + atomic::{AtomicUsize, Ordering::*}, + spin::Mutex, + }, + }, + util, + wait::{self, WaitResult}, +}; +use cordyceps::{ + list::{self, List}, + Linked, +}; +use core::{ + future::Future, + marker::PhantomPinned, + mem, + pin::Pin, + ptr::NonNull, + task::{Context, Poll, Waker}, +}; +use mycelium_bitfield::{bitfield, FromBits}; +#[cfg(test)] +use mycelium_util::fmt; +use mycelium_util::sync::CachePadded; +use pin_project::{pin_project, pinned_drop}; + +#[cfg(test)] +mod tests; + +/// A queue of [`Waker`]s implemented using an [intrusive doubly-linked +/// list][ilist]. +/// +/// A `WaitQueue` allows any number of tasks to [wait] asynchronously and be +/// woken when some event occurs, either [individually][wake] in first-in, +/// first-out order, or [all at once][wake_all]. This makes it a vital building +/// block of runtime services (such as timers or I/O resources), where it may be +/// used to wake a set of tasks when a timer completes or when a resource +/// becomes available. It can be equally useful for implementing higher-level +/// synchronization primitives: for example, a `WaitQueue` plus an +/// [`UnsafeCell`] is essentially an entire implementation of a fair +/// asynchronous mutex. Finally, a `WaitQueue` can be a useful synchronization +/// primitive on its own: sometimes, you just need to have a bunch of tasks wait +/// for something and then wake them all up. +/// +/// # Examples +/// +/// Waking a single task at a time by calling [`wake`][wake]: +/// +/// ``` +/// use std::sync::Arc; +/// use maitake::{scheduler::Scheduler, wait::WaitQueue}; +/// +/// const TASKS: usize = 10; +/// +/// // In order to spawn tasks, we need a `Scheduler` instance. +/// let scheduler = Scheduler::new(); +/// +/// // Construct a new `WaitQueue`. +/// let q = Arc::new(WaitQueue::new()); +/// +/// // Spawn some tasks that will wait on the queue. +/// for _ in 0..TASKS { +/// let q = q.clone(); +/// scheduler.spawn(async move { +/// // Wait to be woken by the queue. +/// q.wait().await.expect("queue is not closed"); +/// }); +/// } +/// +/// // Tick the scheduler once. +/// let tick = scheduler.tick(); +/// +/// // No tasks should complete on this tick, as they are all waiting +/// // to be woken by the queue. +/// assert_eq!(tick.completed, 0, "no tasks have been woken"); +/// +/// let mut completed = 0; +/// for i in 1..=TASKS { +/// // Wake the next task from the queue. +/// q.wake(); +/// +/// // Tick the scheduler. +/// let tick = scheduler.tick(); +/// +/// // A single task should have completed on this tick. +/// completed += tick.completed; +/// assert_eq!(completed, i); +/// } +/// +/// assert_eq!(completed, TASKS, "all tasks should have completed"); +/// ``` +/// +/// Waking all tasks using [`wake_all`][wake_all]: +/// +/// ``` +/// use std::sync::Arc; +/// use maitake::{scheduler::Scheduler, wait::WaitQueue}; +/// +/// const TASKS: usize = 10; +/// +/// // In order to spawn tasks, we need a `Scheduler` instance. +/// let scheduler = Scheduler::new(); +/// +/// // Construct a new `WaitQueue`. +/// let q = Arc::new(WaitQueue::new()); +/// +/// // Spawn some tasks that will wait on the queue. +/// for _ in 0..TASKS { +/// let q = q.clone(); +/// scheduler.spawn(async move { +/// // Wait to be woken by the queue. +/// q.wait().await.expect("queue is not closed"); +/// }); +/// } +/// +/// // Tick the scheduler once. +/// let tick = scheduler.tick(); +/// +/// // No tasks should complete on this tick, as they are all waiting +/// // to be woken by the queue. +/// assert_eq!(tick.completed, 0, "no tasks have been woken"); +/// +/// // Wake all tasks waiting for the queue. +/// q.wake_all(); +/// +/// // Tick the scheduler again to run the woken tasks. +/// let tick = scheduler.tick(); +/// +/// // All tasks have now completed, since they were woken by the +/// // queue. +/// assert_eq!(tick.completed, TASKS, "all tasks should have completed"); +/// ``` +/// +/// # Implementation Notes +/// +/// The *[intrusive]* aspect of this list is important, as it means that it does +/// not allocate memory. Instead, nodes in the linked list are stored in the +/// futures of tasks trying to wait for capacity. This means that it is not +/// necessary to allocate any heap memory for each task waiting to be woken. +/// +/// However, the intrusive linked list introduces one new danger: because +/// futures can be *cancelled*, and the linked list nodes live within the +/// futures trying to wait on the queue, we *must* ensure that the node +/// is unlinked from the list before dropping a cancelled future. Failure to do +/// so would result in the list containing dangling pointers. Therefore, we must +/// use a *doubly-linked* list, so that nodes can edit both the previous and +/// next node when they have to remove themselves. This is kind of a bummer, as +/// it means we can't use something nice like this [intrusive queue by Dmitry +/// Vyukov][2], and there are not really practical designs for lock-free +/// doubly-linked lists that don't rely on some kind of deferred reclamation +/// scheme such as hazard pointers or QSBR. +/// +/// Instead, we just stick a [`Mutex`] around the linked list, which must be +/// acquired to pop nodes from it, or for nodes to remove themselves when +/// futures are cancelled. This is a bit sad, but the critical sections for this +/// mutex are short enough that we still get pretty good performance despite it. +/// +/// [`Waker`]: core::task::Waker +/// [wait]: WaitQueue::wait +/// [wake]: WaitQueue::wake +/// [wake_all]: WaitQueue::wake_all +/// [`UnsafeCell`]: core::cell::UnsafeCell +/// [ilist]: cordyceps::List +/// [intrusive]: https://fuchsia.dev/fuchsia-src/development/languages/c-cpp/fbl_containers_guide/introduction +/// [2]: https://www.1024cores.net/home/lock-free-algorithms/queues/intrusive-mpsc-node-based-queue +#[derive(Debug)] +pub struct WaitQueue { + /// The wait queue's state variable. + state: CachePadded, + + /// The linked list of waiters. + /// + /// # Safety + /// + /// This is protected by a mutex; the mutex *must* be acquired when + /// manipulating the linked list, OR when manipulating waiter nodes that may + /// be linked into the list. If a node is known to not be linked, it is safe + /// to modify that node (such as by waking the stored [`Waker`]) without + /// holding the lock; otherwise, it may be modified through the list, so the + /// lock must be held when modifying the + /// node. + /// + /// A spinlock (from `mycelium_util`) is used here, in order to support + /// `no_std` platforms; when running `loom` tests, a `loom` mutex is used + /// instead to simulate the spinlock, because loom doesn't play nice with + /// real spinlocks. + queue: Mutex>, +} + +/// Future returned from [`WaitQueue::wait()`]. +/// +/// This future is fused, so once it has completed, any future calls to poll +/// will immediately return [`Poll::Ready`]. +#[derive(Debug)] +#[pin_project(PinnedDrop)] +#[must_use = "futures do nothing unless `.await`ed or `poll`ed"] +pub struct Wait<'a> { + /// The [`WaitQueue`] being waited on from. + queue: &'a WaitQueue, + + /// Entry in the wait queue linked list. + #[pin] + waiter: Waiter, +} + +/// A waiter node which may be linked into a wait queue. +#[derive(Debug)] +#[repr(C)] +#[pin_project] +struct Waiter { + /// The intrusive linked list node. + /// + /// This *must* be the first field in the struct in order for the `Linked` + /// implementation to be sound. + #[pin] + node: UnsafeCell, + + /// The future's state. + state: WaitStateBits, +} + +#[derive(Debug)] +#[repr(C)] +struct Node { + /// Intrusive linked list pointers. + /// + /// # Safety + /// + /// This *must* be the first field in the struct in order for the `Linked` + /// impl to be sound. + links: list::Links, + + /// The node's waker + waker: Wakeup, + + // This type is !Unpin due to the heuristic from: + // + _pin: PhantomPinned, +} + +bitfield! { + #[derive(Eq, PartialEq)] + struct QueueState { + /// The queue's state. + const STATE: State; + + /// The number of times [`WaitQueue::wake_all`] has been called. + const WAKE_ALLS = ..; + } +} + +bitfield! { + #[derive(Eq, PartialEq)] + struct WaitStateBits { + /// The waiter's state. + const STATE: WaitState; + + /// The number of times [`WaitQueue::wake_all`] has been called. + const WAKE_ALLS = ..; + } +} + +/// The state of a [`Waiter`] node in a [`WaitQueue`]. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[repr(u8)] +enum WaitState { + /// The waiter has not yet been enqueued. + /// + /// The number of times [`WaitQueue::wake_all`] has been called is stored + /// when the node is created, in order to determine whether it was woken by + /// a stored wakeup when enqueueing. + /// + /// When in this state, the node is **not** part of the linked list, and + /// can be dropped without removing it from the list. + Start, + + /// The waiter is waiting. + /// + /// When in this state, the node **is** part of the linked list. If the + /// node is dropped in this state, it **must** be removed from the list + /// before dropping it. Failure to ensure this will result in dangling + /// pointers in the linked list! + Waiting, + + /// The waiter has been woken. + /// + /// When in this state, the node is **not** part of the linked list, and + /// can be dropped without removing it from the list. + Woken, +} + +/// The queue's current state. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[repr(u8)] +enum State { + /// No waiters are queued, and there is no pending notification. + /// Waiting while the queue is in this state will enqueue the waiter; + /// notifying while in this state will store a pending notification in the + /// queue, transitioning to [`State::Woken`]. + Empty = 0b00, + + /// There are one or more waiters in the queue. Waiting while + /// the queue is in this state will not transition the state. Waking while + /// in this state will wake the first waiter in the queue; if this empties + /// the queue, then the queue will transition to [`State::Empty`]. + Waiting = 0b01, + + /// The queue has a stored notification. Waiting while the queue + /// is in this state will consume the pending notification *without* + /// enqueueing the waiter and transition the queue to [`State::Empty`]. + /// Waking while in this state will leave the queue in this state. + Woken = 0b10, + + /// The queue is closed. Waiting while in this state will return + /// [`Closed`] without transitioning the queue's state. + /// + /// *Note*: This *must* correspond to all state bits being set, as it's set + /// via a [`fetch_or`]. + /// + /// [`Closed`]: crate::wait::Closed + /// [`fetch_or`]: core::sync::atomic::AtomicUsize::fetch_or + Closed = 0b11, +} + +#[derive(Clone, Debug)] +enum Wakeup { + Empty, + Waiting(Waker), + One, + All, + Closed, +} + +// === impl WaitQueue === + +impl WaitQueue { + /// Returns a new `WaitQueue`. + #[must_use] + #[cfg(not(loom))] + pub const fn new() -> Self { + Self { + state: CachePadded::new(AtomicUsize::new(State::Empty.into_usize())), + queue: Mutex::new(List::new()), + } + } + + /// Returns a new `WaitQueue`. + #[must_use] + #[cfg(loom)] + pub fn new() -> Self { + Self { + state: CachePadded::new(AtomicUsize::new(State::Empty.into_usize())), + queue: Mutex::new(List::new()), + } + } + + /// Wake the next task in the queue. + /// + /// If the queue is empty, a wakeup is stored in the `WaitQueue`, and the + /// next call to [`wait`] will complete immediately. + /// + /// [`wait`]: WaitQueue::wait + #[inline] + pub fn wake(&self) { + // snapshot the queue's current state. + let mut state = self.load(); + + // check if any tasks are currently waiting on this queue. if there are + // no waiting tasks, store the wakeup to be consumed by the next call to + // `wait`. + loop { + match state.get(QueueState::STATE) { + // if the queue is closed, bail. + State::Closed => return, + // if there are waiting tasks, break out of the loop and wake one. + State::Waiting => break, + _ => {} + } + + let next = state.with_state(State::Woken); + // advance the state to `Woken`, and return (if we did so + // successfully) + match self.compare_exchange(state, next) { + Ok(_) => return, + Err(actual) => state = actual, + } + } + + // okay, there are tasks waiting on the queue; we must acquire the lock + // on the linked list and wake the next task from the queue. + let mut queue = self.queue.lock(); + test_trace!("wake: -> locked"); + + // the queue's state may have changed while we were waiting to acquire + // the lock, so we need to acquire a new snapshot. + state = self.load(); + + if let Some(waker) = self.wake_locked(&mut *queue, state) { + drop(queue); + waker.wake(); + } + } + + /// Wake *all* tasks currently in the queue. + pub fn wake_all(&self) { + let mut queue = self.queue.lock(); + let state = self.load(); + + match state.get(QueueState::STATE) { + // if the queue is closed, bail. + State::Closed => return, + + // if there are no waiters in the queue, increment the number of + // `wake_all` calls and return. + State::Woken | State::Empty => { + self.state.fetch_add(QueueState::ONE_WAKE_ALL, SeqCst); + return; + } + State::Waiting => {} + } + + // okay, we actually have to wake some stuff. + + // TODO(eliza): wake outside the lock using an array, a la + // https://github.com/tokio-rs/tokio/blob/4941fbf7c43566a8f491c64af5a4cd627c99e5a6/tokio/src/sync/batch_semaphore.rs#L277-L303 + while let Some(node) = queue.pop_back() { + let waker = Waiter::wake(node, &mut queue, Wakeup::All); + waker.wake() + } + + // now that the queue has been drained, transition to the empty state, + // and increment the wake_all count. + let next_state = QueueState::new() + .with_state(State::Empty) + .with(QueueState::WAKE_ALLS, state.get(QueueState::WAKE_ALLS) + 1); + self.compare_exchange(state, next_state) + .expect("state should not have transitioned while locked"); + } + + /// Close the queue, indicating that it may no longer be used. + /// + /// Once a queue is closed, all [`wait`] calls (current or future) will + /// return an error. + /// + /// This method is generally used when implementing higher-level + /// synchronization primitives or resources: when an event makes a resource + /// permanently unavailable, the queue can be closed. + pub fn close(&self) { + let state = self.state.fetch_or(State::Closed.into_usize(), SeqCst); + let state = test_dbg!(QueueState::from_bits(state)); + if state.get(QueueState::STATE) != State::Waiting { + return; + } + + let mut queue = self.queue.lock(); + + // TODO(eliza): wake outside the lock using an array, a la + // https://github.com/tokio-rs/tokio/blob/4941fbf7c43566a8f491c64af5a4cd627c99e5a6/tokio/src/sync/batch_semaphore.rs#L277-L303 + while let Some(node) = queue.pop_back() { + let waker = Waiter::wake(node, &mut queue, Wakeup::Closed); + waker.wake() + } + } + + /// Wait to be woken up by this queue. + /// + /// This returns a [`Wait`] future that will complete when the task is + /// woken by a call to [`wake`] or [`wake_all`], or when the `WaitQueue` is + /// dropped. + /// + /// [`wake`]: Self::wake + /// [`wake_all`]: Self::wake_all + pub fn wait(&self) -> Wait<'_> { + Wait { + queue: self, + waiter: self.waiter(), + } + } + + /// Returns a [`Waiter`] entry in this queue. + /// + /// This is factored out into a separate function because it's used by both + /// [`WaitQueue::wait`] and [`WaitQueue::wait_owned`]. + fn waiter(&self) -> Waiter { + // how many times has `wake_all` been called when this waiter is created? + let current_wake_alls = test_dbg!(self.load().get(QueueState::WAKE_ALLS)); + let state = WaitStateBits::new() + .with(WaitStateBits::WAKE_ALLS, current_wake_alls) + .with(WaitStateBits::STATE, WaitState::Start); + Waiter { + state, + node: UnsafeCell::new(Node { + links: list::Links::new(), + waker: Wakeup::Empty, + _pin: PhantomPinned, + }), + } + } + + #[cfg_attr(test, track_caller)] + fn load(&self) -> QueueState { + #[allow(clippy::let_and_return)] + let state = QueueState::from_bits(self.state.load(SeqCst)); + test_trace!("state.load() = {state:?}"); + state + } + + #[cfg_attr(test, track_caller)] + fn store(&self, state: QueueState) { + test_trace!("state.store({state:?}"); + self.state.store(state.0, SeqCst); + } + + #[cfg_attr(test, track_caller)] + fn compare_exchange( + &self, + current: QueueState, + new: QueueState, + ) -> Result { + #[allow(clippy::let_and_return)] + let res = self + .state + .compare_exchange(current.0, new.0, SeqCst, SeqCst) + .map(QueueState::from_bits) + .map_err(QueueState::from_bits); + test_trace!("state.compare_exchange({current:?}, {new:?}) = {res:?}"); + res + } + + #[cold] + #[inline(never)] + fn wake_locked(&self, queue: &mut List, curr: QueueState) -> Option { + let state = curr.get(QueueState::STATE); + + // is the queue still in the `Waiting` state? it is possible that we + // transitioned to a different state while locking the queue. + if test_dbg!(state) != State::Waiting { + // if there are no longer any queued tasks, try to store the + // wakeup in the queue and bail. + if let Err(actual) = self.compare_exchange(curr, curr.with_state(State::Woken)) { + debug_assert!(actual.get(QueueState::STATE) != State::Waiting); + self.store(actual.with_state(State::Woken)); + } + + return None; + } + + // otherwise, we have to dequeue a task and wake it. + let node = queue + .pop_back() + .expect("if we are in the Waiting state, there must be waiters in the queue"); + let waker = Waiter::wake(node, queue, Wakeup::One); + + // if we took the final waiter currently in the queue, transition to the + // `Empty` state. + if test_dbg!(queue.is_empty()) { + self.store(curr.with_state(State::Empty)); + } + + Some(waker) + } +} + +// === impl Waiter === + +impl Waiter { + /// Wake the task that owns this `Waiter`. + /// + /// # Safety + /// + /// This is only safe to call while the list is locked. The `list` + /// parameter ensures this method is only called while holding the lock, so + /// this can be safe. + /// + /// Of course, that must be the *same* list that this waiter is a member of, + /// and currently, there is no way to ensure that... + #[inline(always)] + #[cfg_attr(loom, track_caller)] + fn wake(this: NonNull, list: &mut List, wakeup: Wakeup) -> Waker { + Waiter::with_node(this, list, |node| { + let waker = test_dbg!(mem::replace(&mut node.waker, wakeup)); + match waker { + Wakeup::Waiting(waker) => waker, + _ => unreachable!("tried to wake a waiter in the {:?} state!", waker), + } + }) + } + + /// # Safety + /// + /// This is only safe to call while the list is locked. The dummy `_list` + /// parameter ensures this method is only called while holding the lock, so + /// this can be safe. + /// + /// Of course, that must be the *same* list that this waiter is a member of, + /// and currently, there is no way to ensure that... + #[inline(always)] + #[cfg_attr(loom, track_caller)] + fn with_node( + mut this: NonNull, + _list: &mut List, + f: impl FnOnce(&mut Node) -> T, + ) -> T { + unsafe { + // safety: this is only called while holding the lock on the queue, + // so it's safe to mutate the waiter. + this.as_mut().node.with_mut(|node| f(&mut *node)) + } + } + + fn poll_wait( + mut self: Pin<&mut Self>, + queue: &WaitQueue, + cx: &mut Context<'_>, + ) -> Poll { + test_trace!(ptr = ?fmt::ptr(self.as_mut()), "Waiter::poll_wait"); + let mut this = self.as_mut().project(); + + match test_dbg!(this.state.get(WaitStateBits::STATE)) { + WaitState::Start => { + let mut queue_state = queue.load(); + + // can we consume a pending wakeup? + if queue + .compare_exchange( + queue_state.with_state(State::Woken), + queue_state.with_state(State::Empty), + ) + .is_ok() + { + this.state.set(WaitStateBits::STATE, WaitState::Woken); + return Poll::Ready(Ok(())); + } + + // okay, no pending wakeups. try to wait... + test_trace!("poll_wait: locking..."); + let mut waiters = queue.queue.lock(); + test_trace!("poll_wait: -> locked"); + queue_state = queue.load(); + + // the whole queue was woken while we were trying to acquire + // the lock! + if queue_state.get(QueueState::WAKE_ALLS) + != this.state.get(WaitStateBits::WAKE_ALLS) + { + this.state.set(WaitStateBits::STATE, WaitState::Woken); + return Poll::Ready(Ok(())); + } + + // transition the queue to the waiting state + 'to_waiting: loop { + match test_dbg!(queue_state.get(QueueState::STATE)) { + // the queue is `Empty`, transition to `Waiting` + State::Empty => { + match queue.compare_exchange( + queue_state, + queue_state.with_state(State::Waiting), + ) { + Ok(_) => break 'to_waiting, + Err(actual) => queue_state = actual, + } + } + // the queue is already `Waiting` + State::Waiting => break 'to_waiting, + // the queue was woken, consume the wakeup. + State::Woken => { + match queue + .compare_exchange(queue_state, queue_state.with_state(State::Empty)) + { + Ok(_) => { + this.state.set(WaitStateBits::STATE, WaitState::Woken); + return Poll::Ready(Ok(())); + } + Err(actual) => queue_state = actual, + } + } + State::Closed => return wait::closed(), + } + } + + // enqueue the node + this.state.set(WaitStateBits::STATE, WaitState::Waiting); + this.node.as_mut().with_mut(|node| { + unsafe { + // safety: we may mutate the node because we are + // holding the lock. + (*node).waker = Wakeup::Waiting(cx.waker().clone()); + } + }); + let ptr = unsafe { NonNull::from(Pin::into_inner_unchecked(self)) }; + waiters.push_front(ptr); + + Poll::Pending + } + WaitState::Waiting => { + let mut _waiters = queue.queue.lock(); + this.node.with_mut(|node| unsafe { + // safety: we may mutate the node because we are + // holding the lock. + let node = &mut *node; + match node.waker { + Wakeup::Waiting(ref mut waker) => { + if !waker.will_wake(cx.waker()) { + *waker = cx.waker().clone(); + } + Poll::Pending + } + Wakeup::All | Wakeup::One => { + this.state.set(WaitStateBits::STATE, WaitState::Woken); + Poll::Ready(Ok(())) + } + Wakeup::Closed => { + this.state.set(WaitStateBits::STATE, WaitState::Woken); + wait::closed() + } + Wakeup::Empty => unreachable!(), + } + }) + } + WaitState::Woken => Poll::Ready(Ok(())), + } + } + + /// Release this `Waiter` from the queue. + /// + /// This is called from the `drop` implementation for the [`Wait`] and + /// [`WaitOwned`] futures. + fn release(mut self: Pin<&mut Self>, queue: &WaitQueue) { + let state = *(self.as_mut().project().state); + let ptr = NonNull::from(unsafe { Pin::into_inner_unchecked(self) }); + test_trace!(self = ?fmt::ptr(ptr), ?state, ?queue, "Waiter::release"); + + // if we're not enqueued, we don't have to do anything else. + if state.get(WaitStateBits::STATE) != WaitState::Waiting { + return; + } + + let mut waiters = queue.queue.lock(); + let state = queue.load(); + + // remove the node + unsafe { + // safety: we have the lock on the queue, so this is safe. + waiters.remove(ptr); + }; + + // if we removed the last waiter from the queue, transition the state to + // `Empty`. + if test_dbg!(waiters.is_empty()) && state.get(QueueState::STATE) == State::Waiting { + queue.store(state.with_state(State::Empty)); + } + + // if the node has an unconsumed wakeup, it must be assigned to the next + // node in the queue. + if Waiter::with_node(ptr, &mut waiters, |node| matches!(&node.waker, Wakeup::One)) { + if let Some(waker) = queue.wake_locked(&mut waiters, state) { + drop(waiters); + waker.wake() + } + } + } +} + +unsafe impl Linked> for Waiter { + type Handle = NonNull; + + fn into_ptr(r: Self::Handle) -> NonNull { + r + } + + unsafe fn from_ptr(ptr: NonNull) -> Self::Handle { + ptr + } + + unsafe fn links(ptr: NonNull) -> NonNull> { + (*ptr.as_ptr()) + .node + .with_mut(|node| util::non_null(node).cast::>()) + } +} + +// === impl Wait === + +impl Future for Wait<'_> { + type Output = WaitResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.waiter.poll_wait(this.queue, cx) + } +} + +#[pinned_drop] +impl PinnedDrop for Wait<'_> { + fn drop(mut self: Pin<&mut Self>) { + let this = self.project(); + this.waiter.release(this.queue); + } +} + +// === impl QueueState === + +impl QueueState { + const ONE_WAKE_ALL: usize = Self::WAKE_ALLS.first_bit(); + + fn with_state(self, state: State) -> Self { + self.with(Self::STATE, state) + } +} + +impl FromBits for State { + const BITS: u32 = 2; + type Error = core::convert::Infallible; + + fn try_from_bits(bits: usize) -> Result { + Ok(match bits as u8 { + bits if bits == Self::Empty as u8 => Self::Empty, + bits if bits == Self::Waiting as u8 => Self::Waiting, + bits if bits == Self::Woken as u8 => Self::Woken, + bits if bits == Self::Closed as u8 => Self::Closed, + _ => unsafe { + mycelium_util::unreachable_unchecked!( + "all potential 2-bit patterns should be covered!" + ) + }, + }) + } + + fn into_bits(self) -> usize { + self.into_usize() + } +} + +impl State { + const fn into_usize(self) -> usize { + self as u8 as usize + } +} + +// === impl WaitState === + +impl FromBits for WaitState { + const BITS: u32 = 2; + type Error = &'static str; + + fn try_from_bits(bits: usize) -> Result { + match bits as u8 { + bits if bits == Self::Start as u8 => Ok(Self::Start), + bits if bits == Self::Waiting as u8 => Ok(Self::Waiting), + bits if bits == Self::Woken as u8 => Ok(Self::Woken), + _ => Err("invalid `WaitState`; expected one of Start, Waiting, or Woken"), + } + } + + fn into_bits(self) -> usize { + self as u8 as usize + } +} + +// === impl WaitOwned === + +feature! { + #![feature = "alloc"] + + use alloc::sync::Arc; + + /// Future returned from [`WaitQueue::wait_owned()`]. + /// + /// This is identical to the [`Wait`] future, except that it takes an + /// [`Arc`] reference to the [`WaitQueue`], allowing the returned future to + /// live for the `'static` lifetime. + /// + /// This future is fused, so once it has completed, any future calls to poll + /// will immediately return [`Poll::Ready`]. + #[derive(Debug)] + #[pin_project(PinnedDrop)] + pub struct WaitOwned { + /// The `WaitQueue` being waited on. + queue: Arc, + + /// Entry in the wait queue. + #[pin] + waiter: Waiter, + } + + impl WaitQueue { + /// Wait to be woken up by this queue, returning a future that's valid + /// for the `'static` lifetime. + /// + /// This returns a [`WaitOwned`] future that will complete when the task is + /// woken by a call to [`wake`] or [`wake_all`], or when the `WaitQueue` is + /// dropped. + /// + /// This is identical to the [`wait`] method, except that it takes a + /// [`Arc`] reference to the [`WaitQueue`], allowing the returned future to + /// live for the `'static` lifetime. + /// + /// [`wake`]: Self::wake + /// [`wake_all`]: Self::wake_all + /// [`wait`]: Self::wait + pub fn wait_owned(self: &Arc) -> WaitOwned { + let waiter = self.waiter(); + let queue = self.clone(); + WaitOwned { queue, waiter } + } + } + + impl Future for WaitOwned { + type Output = WaitResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.waiter.poll_wait(&*this.queue, cx) + } + } + + #[pinned_drop] + impl PinnedDrop for WaitOwned { + fn drop(mut self: Pin<&mut Self>) { + let this = self.project(); + this.waiter.release(&*this.queue); + } + } +} diff --git a/maitake/src/wait/queue/tests.rs b/maitake/src/wait/queue/tests.rs new file mode 100644 index 00000000..5db6ebf3 --- /dev/null +++ b/maitake/src/wait/queue/tests.rs @@ -0,0 +1,298 @@ +use super::*; + +#[cfg(all(not(loom), feature = "alloc"))] +mod alloc { + use super::*; + use crate::loom::sync::Arc; + use crate::scheduler::Scheduler; + use core::sync::atomic::{AtomicUsize, Ordering}; + + #[test] + fn wake_all() { + crate::util::trace_init(); + static COMPLETED: AtomicUsize = AtomicUsize::new(0); + + let scheduler = Scheduler::new(); + let q = Arc::new(WaitQueue::new()); + + const TASKS: usize = 10; + + for _ in 0..TASKS { + let q = q.clone(); + scheduler.spawn(async move { + q.wait().await.unwrap(); + COMPLETED.fetch_add(1, Ordering::SeqCst); + }); + } + + let tick = scheduler.tick(); + + assert_eq!(tick.completed, 0); + assert_eq!(COMPLETED.load(Ordering::SeqCst), 0); + assert!(!tick.has_remaining); + + q.wake_all(); + + let tick = scheduler.tick(); + + assert_eq!(tick.completed, TASKS); + assert_eq!(COMPLETED.load(Ordering::SeqCst), TASKS); + assert!(!tick.has_remaining); + } + + #[test] + fn close() { + crate::util::trace_init(); + static COMPLETED: AtomicUsize = AtomicUsize::new(0); + + let scheduler = Scheduler::new(); + let q = Arc::new(WaitQueue::new()); + + const TASKS: usize = 10; + + for _ in 0..TASKS { + let wait = q.wait_owned(); + scheduler.spawn(async move { + wait.await.expect_err("dropping the queue must close it"); + COMPLETED.fetch_add(1, Ordering::SeqCst); + }); + } + + let tick = scheduler.tick(); + + assert_eq!(tick.completed, 0); + assert_eq!(COMPLETED.load(Ordering::SeqCst), 0); + assert!(!tick.has_remaining); + + q.close(); + + let tick = scheduler.tick(); + + assert_eq!(tick.completed, TASKS); + assert_eq!(COMPLETED.load(Ordering::SeqCst), TASKS); + assert!(!tick.has_remaining); + } + + #[test] + fn wake_one() { + crate::util::trace_init(); + static COMPLETED: AtomicUsize = AtomicUsize::new(0); + + let scheduler = Scheduler::new(); + let q = Arc::new(WaitQueue::new()); + + const TASKS: usize = 10; + + for _ in 0..TASKS { + let q = q.clone(); + scheduler.spawn(async move { + q.wait().await.unwrap(); + COMPLETED.fetch_add(1, Ordering::SeqCst); + q.wake(); + }); + } + + let tick = scheduler.tick(); + + assert_eq!(tick.completed, 0); + assert_eq!(COMPLETED.load(Ordering::SeqCst), 0); + assert!(!tick.has_remaining); + + q.wake(); + + let tick = scheduler.tick(); + + assert_eq!(tick.completed, TASKS); + assert_eq!(COMPLETED.load(Ordering::SeqCst), TASKS); + assert!(!tick.has_remaining); + } +} + +#[cfg(loom)] +mod loom { + use super::*; + use crate::loom::{self, future, sync::Arc, thread}; + + #[test] + fn wake_one() { + loom::model(|| { + let q = Arc::new(WaitQueue::new()); + let thread = thread::spawn({ + let q = q.clone(); + move || { + future::block_on(async { + q.wait().await.expect("queue must not be closed"); + }); + } + }); + + q.wake(); + thread.join().unwrap(); + }); + } + + #[test] + fn wake_all_sequential() { + loom::model(|| { + let q = Arc::new(WaitQueue::new()); + let wait1 = q.wait(); + let wait2 = q.wait(); + + let thread = thread::spawn({ + let q = q.clone(); + move || { + q.wake_all(); + } + }); + + future::block_on(async { + wait1.await.unwrap(); + wait2.await.unwrap(); + }); + + thread.join().unwrap(); + }); + } + + #[test] + fn wake_all_concurrent() { + use alloc::sync::Arc; + + loom::model(|| { + let q = Arc::new(WaitQueue::new()); + let wait1 = q.wait_owned(); + let wait2 = q.wait_owned(); + + let thread1 = + thread::spawn(move || future::block_on(wait1).expect("wait1 must not fail")); + let thread2 = + thread::spawn(move || future::block_on(wait2).expect("wait2 must not fail")); + + q.wake_all(); + + thread1.join().unwrap(); + thread2.join().unwrap(); + }); + } + + #[test] + fn wake_close() { + use alloc::sync::Arc; + + loom::model(|| { + let q = Arc::new(WaitQueue::new()); + let wait1 = q.wait_owned(); + let wait2 = q.wait_owned(); + + let thread1 = + thread::spawn(move || future::block_on(wait1).expect_err("wait1 must be canceled")); + let thread2 = + thread::spawn(move || future::block_on(wait2).expect_err("wait2 must be canceled")); + + q.close(); + + thread1.join().unwrap(); + thread2.join().unwrap(); + }); + } + + #[test] + fn wake_one_many() { + loom::model(|| { + let q = Arc::new(WaitQueue::new()); + + fn thread(q: &Arc) -> thread::JoinHandle<()> { + let q = q.clone(); + thread::spawn(move || { + future::block_on(async { + q.wait().await.expect("queue must not be closed"); + q.wake(); + }) + }) + } + + q.wake(); + + let thread1 = thread(&q); + let thread2 = thread(&q); + + thread1.join().unwrap(); + thread2.join().unwrap(); + + future::block_on(async { + q.wait().await.expect("queue must not be closed"); + }); + }); + } + + #[test] + fn wake_mixed() { + loom::model(|| { + let q = Arc::new(WaitQueue::new()); + + let thread1 = thread::spawn({ + let q = q.clone(); + move || { + q.wake_all(); + } + }); + + let thread2 = thread::spawn({ + let q = q.clone(); + move || { + q.wake(); + } + }); + + let thread3 = thread::spawn(move || { + future::block_on(q.wait()).unwrap(); + }); + + thread1.join().unwrap(); + thread2.join().unwrap(); + thread3.join().unwrap(); + }); + } + + #[test] + fn drop_wait_future() { + use futures_util::future::poll_fn; + use std::future::Future; + use std::task::Poll; + + loom::model(|| { + let q = Arc::new(WaitQueue::new()); + + let thread1 = thread::spawn({ + let q = q.clone(); + move || { + let mut wait = Box::pin(q.wait()); + + future::block_on(poll_fn(|cx| { + if wait.as_mut().poll(cx).is_ready() { + q.wake(); + } + Poll::Ready(()) + })); + } + }); + + let thread2 = thread::spawn({ + let q = q.clone(); + move || { + future::block_on(async { + q.wait().await.unwrap(); + // Trigger second notification + q.wake(); + q.wait().await.unwrap(); + }); + } + }); + + q.wake(); + + thread1.join().unwrap(); + thread2.join().unwrap(); + }); + } +}