Skip to content

Commit

Permalink
fix(maitake): handle spurious WaitCell polls
Browse files Browse the repository at this point in the history
This branch changes the `Wait` future for `maitake::sync::WaitCell` to
handle spurious polls correctly. Currently, a `wait_cell::Wait` future
assumes that if it's ever polled a second time, that means its waker was
woken. However, there might be other reasons that a stack of futures
containing a `Wait` is polled again, and the `Wait` future will
incorrectly complete immediately in that case.

This branch fixes this by replacing the `bool` field in `Wait` that's
set on first poll with an "event count" stored in the remaining
`WaitCell` state bits. Now, when a `Wait` is created, it loads the
current event count, and calls to `wake()` and `close()` increment the
event count. The `Wait` future then checks if the event count has gone
up when it's polled, rather than just checking if it's ever been polled
before. This allows the `Wait` future to determine if it is being polled
because the `WaitCell` woke it up, or if it's being polled because some
other future decided to poll it. This *also* has the side benefit of
fixing racy scenarios where the `WaitCell` is woken between when the
`Wait` future is created and when it's polled for the first time.

Fixes #449
  • Loading branch information
hawkw committed Jul 22, 2023
1 parent a517eaa commit 659b917
Showing 1 changed file with 57 additions and 13 deletions.
70 changes: 57 additions & 13 deletions maitake/src/sync/wait_cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ pub struct Wait<'a> {
/// The [`WaitCell`] being waited on.
cell: &'a WaitCell,

/// Whether we have already polled once
registered: bool,
/// Initial event count
gen: usize,
}

#[derive(Eq, PartialEq, Copy, Clone)]
Expand Down Expand Up @@ -192,10 +192,8 @@ impl WaitCell {
/// **Note**: The calling task's [`Waker`] is not registered until AFTER the
/// first time the returned [`Wait`] future is polled.
pub fn wait(&self) -> Wait<'_> {
Wait {
cell: self,
registered: false,
}
let gen = self.current_state().gen();
Wait { cell: self, gen }
}

/// Wake the [`Waker`] stored in this cell.
Expand All @@ -205,6 +203,7 @@ impl WaitCell {
/// - `true` if a waiting task was woken.
/// - `false` if no task was woken (no [`Waker`] was stored in the cell)
pub fn wake(&self) -> bool {
self.fetch_add(State::GEN_ONE, Release);
if let Some(waker) = self.take_waker(false) {
waker.wake();
true
Expand All @@ -222,6 +221,7 @@ impl WaitCell {
/// [`wait`]: Self::wait
/// [`register_wait`]: Self::register_wait
pub fn close(&self) -> bool {
self.fetch_add(State::GEN_ONE, Release);
if let Some(waker) = self.take_waker(true) {
waker.wake();
true
Expand Down Expand Up @@ -275,6 +275,11 @@ impl WaitCell {
.map_err(State)
}

#[inline(always)]
fn fetch_add(&self, u: usize, order: Ordering) -> State {
State(self.lock.fetch_add(u, order))
}

#[inline(always)]
fn fetch_and(&self, State(state): State, order: Ordering) -> State {
State(self.lock.fetch_and(state, order))
Expand Down Expand Up @@ -314,17 +319,14 @@ impl Drop for WaitCell {
impl Future for Wait<'_> {
type Output = Result<(), super::Closed>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.registered {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if test_dbg!(self.cell.current_state().gen()) != test_dbg!(self.gen) {
// We made it to "once", and got polled again, we must be ready!
return Poll::Ready(Ok(()));
}

match test_dbg!(self.cell.register_wait(cx.waker())) {
Ok(_) => {
self.registered = true;
Poll::Pending
}
Ok(_) => Poll::Pending,
Err(RegisterError::Registering) => {
// Cell was busy parking some other task, all we can do is try again later
cx.waker().wake_by_ref();
Expand All @@ -347,10 +349,17 @@ impl State {
const REGISTERING: Self = Self(0b01);
const WAKING: Self = Self(0b10);
const CLOSED: Self = Self(0b100);
const GEN_SHIFT: usize = 3;
const GEN_ONE: usize = 1 << Self::GEN_SHIFT;
const GEN_MASK: usize = usize::MAX << Self::GEN_SHIFT;

fn is(self, Self(state): Self) -> bool {
self.0 & state == state
}

fn gen(self) -> usize {
self.0 & Self::GEN_MASK
}
}

impl ops::BitOr for State {
Expand Down Expand Up @@ -395,9 +404,12 @@ mod tests {
use crate::scheduler::Scheduler;
use alloc::sync::Arc;

use tokio_test::{assert_pending, assert_ready_ok, task};

#[test]
fn wait_smoke() {
static COMPLETED: AtomicUsize = AtomicUsize::new(0);
let _trace = crate::util::test::trace_init();

let sched = Scheduler::new();
let wait = Arc::new(WaitCell::new());
Expand All @@ -421,7 +433,7 @@ mod tests {
/// Reproduces https://github.com/hawkw/mycelium/issues/449
#[test]
fn wait_spurious_poll() {
use tokio_test::{assert_pending, assert_ready_ok, task};
let _trace = crate::util::test::trace_init();

let cell = Arc::new(WaitCell::new());
let mut task = task::spawn({
Expand All @@ -436,6 +448,38 @@ mod tests {

assert_ready_ok!(task.poll(), "should have been woken");
}

/// Tests behavior when a `Wait` future is created and the `WaitCell` is
/// woken *between* the call to `wait()` and the first time the `Wait` future
/// is polled.
#[test]
fn wake_before_poll() {
let _trace = crate::util::test::trace_init();

let mut task = task::spawn(async move {
let cell = WaitCell::new();
let wait = cell.wait();
cell.wake();
wait.await
});

assert_ready_ok!(task.poll(), "should have been woken");
}

/// Like `wake_before_poll` but with `close()` rather than `wait()`.
#[test]
fn close_before_poll() {
let _trace = crate::util::test::trace_init();

let mut task = task::spawn(async move {
let cell = WaitCell::new();
let wait = cell.wait();
cell.wake();
wait.await
});

assert_ready_ok!(task.poll(), "should have been woken");
}
}

#[cfg(test)]
Expand Down

0 comments on commit 659b917

Please sign in to comment.