Skip to content

Commit

Permalink
sync: implement try_recv for mpsc channels (#4113)
Browse files Browse the repository at this point in the history
  • Loading branch information
Darksonn authored Sep 18, 2021
1 parent 8e92f05 commit ddd33f2
Show file tree
Hide file tree
Showing 7 changed files with 349 additions and 8 deletions.
42 changes: 41 additions & 1 deletion tokio/src/sync/mpsc/bounded.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError};
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::{SendError, TrySendError};
use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError};

cfg_time! {
use crate::sync::mpsc::error::SendTimeoutError;
Expand Down Expand Up @@ -187,6 +187,46 @@ impl<T> Receiver<T> {
poll_fn(|cx| self.chan.recv(cx)).await
}

/// Try to receive the next value for this receiver.
///
/// This method returns the [`Empty`] error if the channel is currently
/// empty, but there are still outstanding [senders] or [permits].
///
/// This method returns the [`Disconnected`] error if the channel is
/// currently empty, and there are no outstanding [senders] or [permits].
///
/// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty
/// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected
/// [senders]: crate::sync::mpsc::Sender
/// [permits]: crate::sync::mpsc::Permit
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
/// use tokio::sync::mpsc::error::TryRecvError;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(100);
///
/// tx.send("hello").await.unwrap();
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
///
/// tx.send("hello").await.unwrap();
/// // Drop the last sender, closing the channel.
/// drop(tx);
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
/// }
/// ```
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.chan.try_recv()
}

/// Blocking receive to call outside of asynchronous contexts.
///
/// This method returns `None` if the channel has been closed and there are
Expand Down
48 changes: 48 additions & 0 deletions tokio/src/sync/mpsc/chan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::future::AtomicWaker;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Arc;
use crate::park::thread::CachedParkThread;
use crate::park::Park;
use crate::sync::mpsc::error::TryRecvError;
use crate::sync::mpsc::list;
use crate::sync::notify::Notify;

Expand Down Expand Up @@ -263,6 +266,51 @@ impl<T, S: Semaphore> Rx<T, S> {
}
})
}

/// Try to receive the next value.
pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
use super::list::TryPopResult;

self.inner.rx_fields.with_mut(|rx_fields_ptr| {
let rx_fields = unsafe { &mut *rx_fields_ptr };

macro_rules! try_recv {
() => {
match rx_fields.list.try_pop(&self.inner.tx) {
TryPopResult::Ok(value) => {
self.inner.semaphore.add_permit();
return Ok(value);
}
TryPopResult::Closed => return Err(TryRecvError::Disconnected),
TryPopResult::Empty => return Err(TryRecvError::Empty),
TryPopResult::Busy => {} // fall through
}
};
}

try_recv!();

// If a previous `poll_recv` call has set a waker, we wake it here.
// This allows us to put our own CachedParkThread waker in the
// AtomicWaker slot instead.
//
// This is not a spurious wakeup to `poll_recv` since we just got a
// Busy from `try_pop`, which only happens if there are messages in
// the queue.
self.inner.rx_waker.wake();

// Park the thread until the problematic send has completed.
let mut park = CachedParkThread::new();
let waker = park.unpark().into_waker();
loop {
self.inner.rx_waker.register_by_ref(&waker);
// It is possible that the problematic send has now completed,
// so we have to check for messages again.
try_recv!();
park.park().expect("park failed");
}
})
}
}

impl<T, S: Semaphore> Drop for Rx<T, S> {
Expand Down
24 changes: 24 additions & 0 deletions tokio/src/sync/mpsc/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,30 @@ impl<T> From<SendError<T>> for TrySendError<T> {
}
}

// ===== TryRecvError =====

/// Error returned by `try_recv`.
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub enum TryRecvError {
/// This **channel** is currently empty, but the **Sender**(s) have not yet
/// disconnected, so data may yet become available.
Empty,
/// The **channel**'s sending half has become disconnected, and there will
/// never be any more data received on it.
Disconnected,
}

impl fmt::Display for TryRecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
TryRecvError::Empty => "receiving on an empty channel".fmt(fmt),
TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt),
}
}
}

impl Error for TryRecvError {}

// ===== RecvError =====

/// Error returned by `Receiver`.
Expand Down
42 changes: 37 additions & 5 deletions tokio/src/sync/mpsc/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,35 @@ pub(crate) struct Tx<T> {
/// Tail in the `Block` mpmc list.
block_tail: AtomicPtr<Block<T>>,

/// Position to push the next message. This reference a block and offset
/// Position to push the next message. This references a block and offset
/// into the block.
tail_position: AtomicUsize,
}

/// List queue receive handle
pub(crate) struct Rx<T> {
/// Pointer to the block being processed
/// Pointer to the block being processed.
head: NonNull<Block<T>>,

/// Next slot index to process
/// Next slot index to process.
index: usize,

/// Pointer to the next block pending release
/// Pointer to the next block pending release.
free_head: NonNull<Block<T>>,
}

/// Return value of `Rx::try_pop`.
pub(crate) enum TryPopResult<T> {
/// Successfully popped a value.
Ok(T),
/// The channel is empty.
Empty,
/// The channel is empty and closed.
Closed,
/// The channel is not empty, but the first value is being written.
Busy,
}

pub(crate) fn channel<T>() -> (Tx<T>, Rx<T>) {
// Create the initial block shared between the tx and rx halves.
let initial_block = Box::new(Block::new(0));
Expand Down Expand Up @@ -218,7 +230,7 @@ impl<T> fmt::Debug for Tx<T> {
}

impl<T> Rx<T> {
/// Pops the next value off the queue
/// Pops the next value off the queue.
pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> {
// Advance `head`, if needed
if !self.try_advancing_head() {
Expand All @@ -240,6 +252,26 @@ impl<T> Rx<T> {
}
}

/// Pops the next value off the queue, detecting whether the block
/// is busy or empty on failure.
///
/// This function exists because `Rx::pop` can return `None` even if the
/// channel's queue contains a message that has been completely written.
/// This can happen if the fully delivered message is behind another message
/// that is in the middle of being written to the block, since the channel
/// can't return the messages out of order.
pub(crate) fn try_pop(&mut self, tx: &Tx<T>) -> TryPopResult<T> {
let tail_position = tx.tail_position.load(Acquire);
let result = self.pop(tx);

match result {
Some(block::Read::Value(t)) => TryPopResult::Ok(t),
Some(block::Read::Closed) => TryPopResult::Closed,
None if tail_position == self.index => TryPopResult::Empty,
None => TryPopResult::Busy,
}
}

/// Tries advancing the block pointer to the block referenced by `self.index`.
///
/// Returns `true` if successful, `false` if there is no next block to load.
Expand Down
42 changes: 41 additions & 1 deletion tokio/src/sync/mpsc/unbounded.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::loom::sync::atomic::AtomicUsize;
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::SendError;
use crate::sync::mpsc::error::{SendError, TryRecvError};

use std::fmt;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -129,6 +129,46 @@ impl<T> UnboundedReceiver<T> {
poll_fn(|cx| self.poll_recv(cx)).await
}

/// Try to receive the next value for this receiver.
///
/// This method returns the [`Empty`] error if the channel is currently
/// empty, but there are still outstanding [senders] or [permits].
///
/// This method returns the [`Disconnected`] error if the channel is
/// currently empty, and there are no outstanding [senders] or [permits].
///
/// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty
/// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected
/// [senders]: crate::sync::mpsc::Sender
/// [permits]: crate::sync::mpsc::Permit
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
/// use tokio::sync::mpsc::error::TryRecvError;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::unbounded_channel();
///
/// tx.send("hello").unwrap();
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
///
/// tx.send("hello").unwrap();
/// // Drop the last sender, closing the channel.
/// drop(tx);
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
/// }
/// ```
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.chan.try_recv()
}

/// Blocking receive to call outside of asynchronous contexts.
///
/// # Panics
Expand Down
56 changes: 56 additions & 0 deletions tokio/src/sync/tests/loom_mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,59 @@ fn dropping_unbounded_tx() {
assert!(v.is_none());
});
}

#[test]
fn try_recv() {
loom::model(|| {
use crate::sync::{mpsc, Semaphore};
use loom::sync::{Arc, Mutex};

const PERMITS: usize = 2;
const TASKS: usize = 2;
const CYCLES: usize = 1;

struct Context {
sem: Arc<Semaphore>,
tx: mpsc::Sender<()>,
rx: Mutex<mpsc::Receiver<()>>,
}

fn run(ctx: &Context) {
block_on(async {
let permit = ctx.sem.acquire().await;
assert_ok!(ctx.rx.lock().unwrap().try_recv());
crate::task::yield_now().await;
assert_ok!(ctx.tx.clone().try_send(()));
drop(permit);
});
}

let (tx, rx) = mpsc::channel(PERMITS);
let sem = Arc::new(Semaphore::new(PERMITS));
let ctx = Arc::new(Context {
sem,
tx,
rx: Mutex::new(rx),
});

for _ in 0..PERMITS {
assert_ok!(ctx.tx.clone().try_send(()));
}

let mut ths = Vec::new();

for _ in 0..TASKS {
let ctx = ctx.clone();

ths.push(thread::spawn(move || {
run(&ctx);
}));
}

run(&ctx);

for th in ths {
th.join().unwrap();
}
});
}
Loading

0 comments on commit ddd33f2

Please sign in to comment.