Skip to content

Commit

Permalink
fix(maitake): handle spurious WaitCell polls (#453)
Browse files Browse the repository at this point in the history
This branch changes the `Wait` future for `maitake::sync::WaitCell` to
handle spurious polls correctly. Currently, a `wait_cell::Wait` future
assumes that if it's ever polled a second time, that means its waker was
woken. However, there might be other reasons that a stack of futures
containing a `Wait` is polled again, and the `Wait` future will
incorrectly complete immediately in that case.

This branch fixes this by replacing the `bool` field in `Wait` that's
set on first poll with an additional bit stored in the `WaitCell`'s
state field. This is set when the cell is actually woken, and only
unset by the `Wait` future when it's polled. If the `WOKEN` bit was
set, the `Wait` future completes, and if it was unset, the future
re-registers itself. This way, the `Wait` future only completes if it
was *woken by the waitcell*, rather than on *any* poll if the task was
woken by something else.

Fixes #449
  • Loading branch information
hawkw authored Jul 23, 2023
1 parent 20c3ec2 commit 46b1131
Showing 1 changed file with 35 additions and 24 deletions.
59 changes: 35 additions & 24 deletions maitake/src/sync/wait_cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ pub enum RegisterError {
pub struct Wait<'a> {
/// The [`WaitCell`] being waited on.
cell: &'a WaitCell,

/// Whether we have already polled once
registered: bool,
}

#[derive(Eq, PartialEq, Copy, Clone)]
Expand Down Expand Up @@ -125,16 +122,12 @@ impl WaitCell {
Err(actual) if test_dbg!(actual.is(State::CLOSED)) => {
return Err(RegisterError::Closed);
}
Err(actual) if test_dbg!(actual.is(State::WAKING)) => {
Err(actual)
if test_dbg!(actual.is(State::WAKING)) || test_dbg!(actual.is(State::WOKEN)) =>
{
return Err(RegisterError::Waking);
}

Err(actual) => {
debug_assert!(
actual == State::REGISTERING || actual == State::REGISTERING | State::WAKING
);
return Err(RegisterError::Registering);
}
Err(_) => return Err(RegisterError::Registering),
Ok(_) => {}
}

Expand Down Expand Up @@ -192,10 +185,7 @@ impl WaitCell {
/// **Note**: The calling task's [`Waker`] is not registered until AFTER the
/// first time the returned [`Wait`] future is polled.
pub fn wait(&self) -> Wait<'_> {
Wait {
cell: self,
registered: false,
}
Wait { cell: self }
}

/// Wake the [`Waker`] stored in this cell.
Expand Down Expand Up @@ -242,7 +232,7 @@ impl WaitCell {
// TODO(eliza): could probably be made a public API...
pub(crate) fn take_waker(&self, close: bool) -> Option<Waker> {
trace!(wait_cell = ?fmt::ptr(self), ?close, "notifying");
let mut bits = State::WAKING;
let mut bits = State::WAKING | State::WOKEN;
if close {
bits.0 |= State::CLOSED.0;
}
Expand Down Expand Up @@ -314,17 +304,15 @@ impl Drop for WaitCell {
impl Future for Wait<'_> {
type Output = Result<(), super::Closed>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.registered {
// We made it to "once", and got polled again, we must be ready!
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Try to take the cell's `WOKEN` bit to see if we were previously
// waiting and then received a notification.
if test_dbg!(self.cell.fetch_and(!State::WOKEN, AcqRel)).is(State::WOKEN) {
return Poll::Ready(Ok(()));
}

match test_dbg!(self.cell.register_wait(cx.waker())) {
Ok(_) => {
self.registered = true;
Poll::Pending
}
Ok(_) => Poll::Pending,
Err(RegisterError::Registering) => {
// Cell was busy parking some other task, all we can do is try again later
cx.waker().wake_by_ref();
Expand All @@ -347,6 +335,7 @@ impl State {
const REGISTERING: Self = Self(0b01);
const WAKING: Self = Self(0b10);
const CLOSED: Self = Self(0b100);
const WOKEN: Self = Self(0b1000);

fn is(self, Self(state): Self) -> bool {
self.0 & state == state
Expand All @@ -373,7 +362,7 @@ impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut has_states = false;

fmt_bits!(self, f, has_states, REGISTERING, WAKING, CLOSED);
fmt_bits!(self, f, has_states, REGISTERING, WAKING, CLOSED, WOKEN);

if !has_states {
if *self == Self::WAITING {
Expand All @@ -395,9 +384,12 @@ mod tests {
use crate::scheduler::Scheduler;
use alloc::sync::Arc;

use tokio_test::{assert_pending, assert_ready_ok, task};

#[test]
fn wait_smoke() {
static COMPLETED: AtomicUsize = AtomicUsize::new(0);
let _trace = crate::util::test::trace_init();

let sched = Scheduler::new();
let wait = Arc::new(WaitCell::new());
Expand All @@ -417,6 +409,25 @@ mod tests {
assert_eq!(tick.completed, 1);
assert_eq!(COMPLETED.load(Ordering::Relaxed), 1);
}

/// Reproduces https://github.com/hawkw/mycelium/issues/449
#[test]
fn wait_spurious_poll() {
let _trace = crate::util::test::trace_init();

let cell = Arc::new(WaitCell::new());
let mut task = task::spawn({
let cell = cell.clone();
async move { cell.wait().await }
});

assert_pending!(task.poll(), "first poll should be pending");
assert_pending!(task.poll(), "second poll should be pending");

cell.wake();

assert_ready_ok!(task.poll(), "should have been woken");
}
}

#[cfg(test)]
Expand Down

0 comments on commit 46b1131

Please sign in to comment.