diff --git a/src/libstd/sync/mpsc/mpsc_queue.rs b/src/libstd/sync/mpsc/mpsc_queue.rs index 296773d20f614..4c75ab4c1af7b 100644 --- a/src/libstd/sync/mpsc/mpsc_queue.rs +++ b/src/libstd/sync/mpsc/mpsc_queue.rs @@ -114,6 +114,19 @@ impl Queue { if self.head.load(Ordering::Acquire) == tail {Empty} else {Inconsistent} } } + + pub fn can_pop(&self) -> bool { + unsafe { + let tail = *self.tail.get(); + let next = (*tail).next.load(Ordering::Acquire); + + if !next.is_null() { + true + } else { + self.head.load(Ordering::Acquire) != tail + } + } + } } impl Drop for Queue { diff --git a/src/libstd/sync/mpsc/shared.rs b/src/libstd/sync/mpsc/shared.rs index f9e0290416432..bf1d45c2addc4 100644 --- a/src/libstd/sync/mpsc/shared.rs +++ b/src/libstd/sync/mpsc/shared.rs @@ -20,13 +20,11 @@ pub use self::Failure::*; -use core::cmp; use core::intrinsics::abort; use core::isize; +use core::usize; -use cell::UnsafeCell; -use ptr; -use sync::atomic::{AtomicUsize, AtomicIsize, AtomicBool, Ordering}; +use sync::atomic::{AtomicUsize, AtomicIsize, Ordering}; use sync::mpsc::blocking::{self, SignalToken}; use sync::mpsc::mpsc_queue as mpsc; use sync::mpsc::select::StartResult::*; @@ -35,26 +33,97 @@ use sync::{Mutex, MutexGuard}; use thread; use time::Instant; -const DISCONNECTED: isize = isize::MIN; -const FUDGE: isize = 1024; +const DISCONNECTED: usize = usize::MAX; const MAX_REFCOUNT: usize = (isize::MAX) as usize; -#[cfg(test)] -const MAX_STEALS: isize = 5; -#[cfg(not(test))] -const MAX_STEALS: isize = 1 << 20; + +struct CellDisconnected; + +struct SignalTokenCell { + // Atomic holder of 0, DISCONNECTED or SignalToken + token: AtomicUsize, +} + +impl Drop for SignalTokenCell { + fn drop(&mut self) { + self.take_token(); + } +} + +impl SignalTokenCell { + fn new() -> SignalTokenCell { + SignalTokenCell { + token: AtomicUsize::new(0) + } + } + + fn load_is_disconnected(&self) -> bool { + self.token.load(Ordering::Relaxed) == DISCONNECTED + } + + /// Do not overwrite DISCONNECTED or another token + fn store_if_empty(&self, token: SignalToken) { + let ptr = unsafe { token.cast_to_usize() }; + if self.token.compare_and_swap(0, ptr, Ordering::SeqCst) != 0 { + unsafe { SignalToken::cast_from_usize(ptr); } + } + } + + /// Store token unless it is disconnected overwriting another token if any + fn store_unless_disconnected(&self, token: SignalToken) -> Result<(), CellDisconnected> { + let ptr = unsafe { token.cast_to_usize() }; + let mut curr = self.token.load(Ordering::Relaxed); + loop { + if curr == DISCONNECTED { + unsafe { SignalToken::cast_from_usize(ptr); } + return Err(CellDisconnected); + } + let prev = self.token.compare_and_swap(curr, ptr, Ordering::SeqCst); + if prev == curr { + if prev != 0 { + unsafe { SignalToken::cast_from_usize(prev); } + } + return Ok(()); + } + curr = prev; + } + } + + fn store_disconnected(&self) -> Option { + let ptr = self.token.swap(DISCONNECTED, Ordering::SeqCst); + if ptr != 0 && ptr != DISCONNECTED { + Some(unsafe { SignalToken::cast_from_usize(ptr) }) + } else { + None + } + } + + fn take_token(&self) -> Option { + let mut curr = self.token.load(Ordering::SeqCst); + loop { + if curr == 0 || curr == DISCONNECTED { + return None; + } + + let prev = self.token.compare_and_swap(curr, 0, Ordering::SeqCst); + if prev == curr { + return Some(unsafe { SignalToken::cast_from_usize(curr) }) + } + + curr = prev; + } + } +} + pub struct Packet { queue: mpsc::Queue, - cnt: AtomicIsize, // How many items are on this channel - steals: UnsafeCell, // How many times has a port received without blocking? - to_wake: AtomicUsize, // SignalToken for wake up + to_wake: SignalTokenCell, // SignalToken for wake up // The number of channels which are currently using this packet. channels: AtomicUsize, // See the discussion in Port::drop and the channel send methods for what // these are used for - port_dropped: AtomicBool, sender_drain: AtomicIsize, // this lock protects various portions of this implementation during @@ -73,11 +142,8 @@ impl Packet { pub fn new() -> Packet { Packet { queue: mpsc::Queue::new(), - cnt: AtomicIsize::new(0), - steals: UnsafeCell::new(0), - to_wake: AtomicUsize::new(0), + to_wake: SignalTokenCell::new(), channels: AtomicUsize::new(2), - port_dropped: AtomicBool::new(false), sender_drain: AtomicIsize::new(0), select_lock: Mutex::new(()), } @@ -101,30 +167,9 @@ impl Packet { token: Option, guard: MutexGuard<()>) { token.map(|token| { - assert_eq!(self.cnt.load(Ordering::SeqCst), 0); - assert_eq!(self.to_wake.load(Ordering::SeqCst), 0); - self.to_wake.store(unsafe { token.cast_to_usize() }, Ordering::SeqCst); - self.cnt.store(-1, Ordering::SeqCst); - - // This store is a little sketchy. What's happening here is that - // we're transferring a blocker from a oneshot or stream channel to - // this shared channel. In doing so, we never spuriously wake them - // up and rather only wake them up at the appropriate time. This - // implementation of shared channels assumes that any blocking - // recv() will undo the increment of steals performed in try_recv() - // once the recv is complete. This thread that we're inheriting, - // however, is not in the middle of recv. Hence, the first time we - // wake them up, they're going to wake up from their old port, move - // on to the upgraded port, and then call the block recv() function. - // - // When calling this function, they'll find there's data immediately - // available, counting it as a steal. This in fact wasn't a steal - // because we appropriately blocked them waiting for data. - // - // To offset this bad increment, we initially set the steal count to - // -1. You'll find some special code in abort_selection() as well to - // ensure that this -1 steal count doesn't escape too far. - unsafe { *self.steals.get() = -1; } + // To not overwrite signal token + // installed after receiver timed out and started again. + self.to_wake.store_if_empty(token); }); // When the shared packet is constructed, we grabbed this lock. The @@ -135,86 +180,58 @@ impl Packet { drop(guard); } + fn drain_queue_after_disconnected(&self) { + assert!(self.to_wake.load_is_disconnected()); + + // In this case, we have possibly failed to send our data, and + // we need to consider re-popping the data in order to fully + // destroy it. We must arbitrate among the multiple senders, + // however, because the queues that we're using are + // single-consumer queues. In order to do this, all exiting + // pushers will use an atomic count in order to count those + // flowing through. Pushers who see 0 are required to drain as + // much as possible, and then can only exit when they are the + // only pusher (otherwise they must try again). + if self.sender_drain.fetch_add(1, Ordering::SeqCst) == 0 { + loop { + // drain the queue, for info on the thread yield see the + // discussion in try_recv + loop { + match self.queue.pop() { + mpsc::Data(..) => {} + mpsc::Empty => break, + mpsc::Inconsistent => thread::yield_now(), + } + } + // maybe we're done, if we're not the last ones + // here, then we need to go try again. + if self.sender_drain.fetch_sub(1, Ordering::SeqCst) == 1 { + break + } + } + + // At this point, there may still be data on the queue, + // but only if the count hasn't been incremented and + // some other sender hasn't finished pushing data just + // yet. That sender in question will drain its own data. + } + } + pub fn send(&self, t: T) -> Result<(), T> { - // See Port::drop for what's going on - if self.port_dropped.load(Ordering::SeqCst) { return Err(t) } - - // Note that the multiple sender case is a little trickier - // semantically than the single sender case. The logic for - // incrementing is "add and if disconnected store disconnected". - // This could end up leading some senders to believe that there - // wasn't a disconnect if in fact there was a disconnect. This means - // that while one thread is attempting to re-store the disconnected - // states, other threads could walk through merrily incrementing - // this very-negative disconnected count. To prevent senders from - // spuriously attempting to send when the channels is actually - // disconnected, the count has a ranged check here. - // - // This is also done for another reason. Remember that the return - // value of this function is: - // - // `true` == the data *may* be received, this essentially has no - // meaning - // `false` == the data will *never* be received, this has a lot of - // meaning - // - // In the SPSC case, we have a check of 'queue.is_empty()' to see - // whether the data was actually received, but this same condition - // means nothing in a multi-producer context. As a result, this - // preflight check serves as the definitive "this will never be - // received". Once we get beyond this check, we have permanently - // entered the realm of "this may be received" - if self.cnt.load(Ordering::SeqCst) < DISCONNECTED + FUDGE { - return Err(t) + if self.to_wake.load_is_disconnected() { + return Err(t); } self.queue.push(t); - match self.cnt.fetch_add(1, Ordering::SeqCst) { - -1 => { - self.take_to_wake().signal(); - } - // In this case, we have possibly failed to send our data, and - // we need to consider re-popping the data in order to fully - // destroy it. We must arbitrate among the multiple senders, - // however, because the queues that we're using are - // single-consumer queues. In order to do this, all exiting - // pushers will use an atomic count in order to count those - // flowing through. Pushers who see 0 are required to drain as - // much as possible, and then can only exit when they are the - // only pusher (otherwise they must try again). - n if n < DISCONNECTED + FUDGE => { - // see the comment in 'try' for a shared channel for why this - // window of "not disconnected" is ok. - self.cnt.store(DISCONNECTED, Ordering::SeqCst); - - if self.sender_drain.fetch_add(1, Ordering::SeqCst) == 0 { - loop { - // drain the queue, for info on the thread yield see the - // discussion in try_recv - loop { - match self.queue.pop() { - mpsc::Data(..) => {} - mpsc::Empty => break, - mpsc::Inconsistent => thread::yield_now(), - } - } - // maybe we're done, if we're not the last ones - // here, then we need to go try again. - if self.sender_drain.fetch_sub(1, Ordering::SeqCst) == 1 { - break - } - } - - // At this point, there may still be data on the queue, - // but only if the count hasn't been incremented and - // some other sender hasn't finished pushing data just - // yet. That sender in question will drain its own data. - } - } + if let Some(token) = self.to_wake.take_token() { + token.signal(); + } - // Can't make any assumptions about this case like in the SPSC case. - _ => {} + // Disconnected means receiver has beed dropped just now. + // So it does not do recv, but it can still do drain under the same lock. + if self.to_wake.load_is_disconnected() { + self.drain_queue_after_disconnected(); } Ok(()) @@ -229,50 +246,38 @@ impl Packet { } let (wait_token, signal_token) = blocking::tokens(); - if self.decrement(signal_token) == Installed { - if let Some(deadline) = deadline { - let timed_out = !wait_token.wait_max_until(deadline); - if timed_out { - self.abort_selection(false); - } - } else { - wait_token.wait(); - } - } + + // Ignore disconnected, because disconnected is checked in next try_recv + drop(self.to_wake.store_unless_disconnected(signal_token)); match self.try_recv() { - data @ Ok(..) => unsafe { *self.steals.get() -= 1; data }, - data => data, + Err(Empty) => {} + data => return data, } - } - - // Essentially the exact same thing as the stream decrement function. - // Returns true if blocking should proceed. - fn decrement(&self, token: SignalToken) -> StartResult { - unsafe { - assert_eq!(self.to_wake.load(Ordering::SeqCst), 0); - let ptr = token.cast_to_usize(); - self.to_wake.store(ptr, Ordering::SeqCst); - - let steals = ptr::replace(self.steals.get(), 0); - - match self.cnt.fetch_sub(1 + steals, Ordering::SeqCst) { - DISCONNECTED => { self.cnt.store(DISCONNECTED, Ordering::SeqCst); } - // If we factor in our steals and notice that the channel has no - // data, we successfully sleep - n => { - assert!(n >= 0); - if n - steals <= 0 { return Installed } - } - } - self.to_wake.store(0, Ordering::SeqCst); - drop(SignalToken::cast_from_usize(ptr)); - Abort + match deadline { + Some(deadline) => { + wait_token.wait_max_until(deadline); + }, + None => wait_token.wait(), } + + // Release memory + self.to_wake.take_token(); + + self.try_recv() } pub fn try_recv(&self) -> Result { + // Disconnected flag must be loaded before queue pop to properly handle + // race like this: + // + // recv: queue.pop -> empty + // send: queue.push + // send: drop_chan + // recv: check disconnected flag + let disconnected = self.to_wake.load_is_disconnected(); + let ret = match self.queue.pop() { mpsc::Data(t) => Some(t), mpsc::Empty => None, @@ -304,39 +309,12 @@ impl Packet { } }; match ret { - // See the discussion in the stream implementation for why we - // might decrement steals. - Some(data) => unsafe { - if *self.steals.get() > MAX_STEALS { - match self.cnt.swap(0, Ordering::SeqCst) { - DISCONNECTED => { - self.cnt.store(DISCONNECTED, Ordering::SeqCst); - } - n => { - let m = cmp::min(n, *self.steals.get()); - *self.steals.get() -= m; - self.bump(n - m); - } - } - assert!(*self.steals.get() >= 0); - } - *self.steals.get() += 1; - Ok(data) - }, - - // See the discussion in the stream implementation for why we try - // again. + Some(data) => Ok(data), None => { - match self.cnt.load(Ordering::SeqCst) { - n if n != DISCONNECTED => Err(Empty), - _ => { - match self.queue.pop() { - mpsc::Data(t) => Ok(t), - mpsc::Empty => Err(Disconnected), - // with no senders, an inconsistency is impossible. - mpsc::Inconsistent => unreachable!(), - } - } + if disconnected { + Err(Disconnected) + } else { + Err(Empty) } } } @@ -365,39 +343,18 @@ impl Packet { n => panic!("bad number of channels left {}", n), } - match self.cnt.swap(DISCONNECTED, Ordering::SeqCst) { - -1 => { self.take_to_wake().signal(); } - DISCONNECTED => {} - n => { assert!(n >= 0); } + if let Some(signal) = self.to_wake.store_disconnected() { + signal.signal(); } } // See the long discussion inside of stream.rs for why the queue is drained, // and why it is done in this fashion. pub fn drop_port(&self) { - self.port_dropped.store(true, Ordering::SeqCst); - let mut steals = unsafe { *self.steals.get() }; - while { - let cnt = self.cnt.compare_and_swap(steals, DISCONNECTED, Ordering::SeqCst); - cnt != DISCONNECTED && cnt != steals - } { - // See the discussion in 'try_recv' for why we yield - // control of this thread. - loop { - match self.queue.pop() { - mpsc::Data(..) => { steals += 1; } - mpsc::Empty | mpsc::Inconsistent => break, - } - } - } - } + self.to_wake.store_disconnected(); - // Consumes ownership of the 'to_wake' field. - fn take_to_wake(&self) -> SignalToken { - let ptr = self.to_wake.load(Ordering::SeqCst); - self.to_wake.store(0, Ordering::SeqCst); - assert!(ptr != 0); - unsafe { SignalToken::cast_from_usize(ptr) } + // Must drain under lock, because sender may also drain in `send`. + self.drain_queue_after_disconnected(); } //////////////////////////////////////////////////////////////////////////// @@ -410,19 +367,7 @@ impl Packet { // This is different than the stream version because there's no need to peek // at the queue, we can just look at the local count. pub fn can_recv(&self) -> bool { - let cnt = self.cnt.load(Ordering::SeqCst); - cnt == DISCONNECTED || cnt - unsafe { *self.steals.get() } > 0 - } - - // increment the count on the channel (used for selection) - fn bump(&self, amt: isize) -> isize { - match self.cnt.fetch_add(amt, Ordering::SeqCst) { - DISCONNECTED => { - self.cnt.store(DISCONNECTED, Ordering::SeqCst); - DISCONNECTED - } - n => n - } + self.queue.can_pop() } // Inserts the signal token for selection on this port, returning true if @@ -431,21 +376,18 @@ impl Packet { // The code here is the same as in stream.rs, except that it doesn't need to // peek at the channel to see if an upgrade is pending. pub fn start_selection(&self, token: SignalToken) -> StartResult { - match self.decrement(token) { - Installed => Installed, - Abort => { - let prev = self.bump(1); - assert!(prev == DISCONNECTED || prev >= 0); - Abort + if self.can_recv() { + StartResult::Abort + } else { + match self.to_wake.store_unless_disconnected(token) { + Ok(()) => Installed, + Err(CellDisconnected) => StartResult::Abort, } } } // Cancels a previous thread waiting on this port, returning whether there's // data on the port. - // - // This is similar to the stream implementation (hence fewer comments), but - // uses a different value for the "steals" variable. pub fn abort_selection(&self, _was_upgrade: bool) -> bool { // Before we do anything else, we bounce on this lock. The reason for // doing this is to ensure that any upgrade-in-progress is gone and @@ -456,40 +398,9 @@ impl Packet { let _guard = self.select_lock.lock().unwrap(); } - // Like the stream implementation, we want to make sure that the count - // on the channel goes non-negative. We don't know how negative the - // stream currently is, so instead of using a steal value of 1, we load - // the channel count and figure out what we should do to make it - // positive. - let steals = { - let cnt = self.cnt.load(Ordering::SeqCst); - if cnt < 0 && cnt != DISCONNECTED {-cnt} else {0} - }; - let prev = self.bump(steals + 1); + self.to_wake.take_token(); - if prev == DISCONNECTED { - assert_eq!(self.to_wake.load(Ordering::SeqCst), 0); - true - } else { - let cur = prev + steals + 1; - assert!(cur >= 0); - if prev < 0 { - drop(self.take_to_wake()); - } else { - while self.to_wake.load(Ordering::SeqCst) != 0 { - thread::yield_now(); - } - } - unsafe { - // if the number of steals is -1, it was the pre-emptive -1 steal - // count from when we inherited a blocker. This is fine because - // we're just going to overwrite it with a real value. - let old = self.steals.get(); - assert!(*old == 0 || *old == -1); - *old = steals; - prev >= 0 - } - } + self.can_recv() } } @@ -499,8 +410,6 @@ impl Drop for Packet { // disconnection, but also a proper fence before the read of // `to_wake`, so this assert cannot be removed with also removing // the `to_wake` assert. - assert_eq!(self.cnt.load(Ordering::SeqCst), DISCONNECTED); - assert_eq!(self.to_wake.load(Ordering::SeqCst), 0); assert_eq!(self.channels.load(Ordering::SeqCst), 0); } } diff --git a/src/test/run-pass/mpsc_stress.rs b/src/test/run-pass/mpsc_stress.rs new file mode 100644 index 0000000000000..b70e749e47fd4 --- /dev/null +++ b/src/test/run-pass/mpsc_stress.rs @@ -0,0 +1,163 @@ +// Copyright 2017 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// compile-flags:--test +// ignore-emscripten + +use std::sync::mpsc::channel; +use std::sync::mpsc::TryRecvError; +use std::sync::mpsc::RecvError; +use std::sync::mpsc::RecvTimeoutError; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; + +use std::thread; +use std::time::Duration; + + +/// Simple thread synchronization utility +struct Barrier { + // Not using mutex/condvar for precision + shared: Arc, + count: usize, +} + +impl Barrier { + fn new(count: usize) -> Vec { + let shared = Arc::new(AtomicUsize::new(0)); + (0..count).map(|_| Barrier { shared: shared.clone(), count: count }).collect() + } + + fn new2() -> (Barrier, Barrier) { + let mut v = Barrier::new(2); + (v.pop().unwrap(), v.pop().unwrap()) + } + + /// Returns when `count` threads enter `wait` + fn wait(self) { + self.shared.fetch_add(1, Ordering::SeqCst); + while self.shared.load(Ordering::SeqCst) != self.count { + } + } +} + + +fn shared_close_sender_does_not_lose_messages_iter() { + let (tb, rb) = Barrier::new2(); + + let (tx, rx) = channel(); + tx.clone(); // convert to shared + + thread::spawn(move || { + tb.wait(); + tx.send(17).expect("send"); + drop(tx); + }); + + let i = rx.into_iter(); + rb.wait(); + // Make sure it doesn't return disconnected before returning an element + assert_eq!(vec![17], i.collect::>()); +} + +#[test] +fn shared_close_sender_does_not_lose_messages() { + for _ in 0..10000 { + shared_close_sender_does_not_lose_messages_iter(); + } +} + + +// https://github.com/rust-lang/rust/issues/39364 +fn concurrent_recv_timeout_and_upgrade_iter() { + // 1 us + let sleep = Duration::new(0, 1_000); + + let (a, b) = Barrier::new2(); + let (tx, rx) = channel(); + let th = thread::spawn(move || { + a.wait(); + loop { + match rx.recv_timeout(sleep) { + Ok(_) => { + break; + }, + Err(_) => {}, + } + } + }); + b.wait(); + thread::sleep(sleep); + tx.clone().send(()).expect("send"); + th.join().unwrap(); +} + +#[test] +fn concurrent_recv_timeout_and_upgrade() { + for _ in 0..10000 { + concurrent_recv_timeout_and_upgrade_iter(); + } +} + + +fn concurrent_writes_iter() { + const THREADS: usize = 4; + const PER_THR: usize = 100; + + let mut bs = Barrier::new(THREADS + 1); + let (tx, rx) = channel(); + + let mut threads = Vec::new(); + for j in 0..THREADS { + let tx = tx.clone(); + let b = bs.pop().unwrap(); + threads.push(thread::spawn(move || { + b.wait(); + for i in 0..PER_THR { + tx.send(j * 1000 + i).expect("send"); + } + })); + } + + let b = bs.pop().unwrap(); + b.wait(); + + let mut v: Vec<_> = rx.iter().take(THREADS * PER_THR).collect(); + v.sort(); + + for j in 0..THREADS { + for i in 0..PER_THR { + assert_eq!(j * 1000 + i, v[j * PER_THR + i]); + } + } + + for t in threads { + t.join().unwrap(); + } + + let one_us = Duration::new(0, 1000); + + assert_eq!(TryRecvError::Empty, rx.try_recv().unwrap_err()); + assert_eq!(RecvTimeoutError::Timeout, rx.recv_timeout(one_us).unwrap_err()); + + drop(tx); + + assert_eq!(RecvError, rx.recv().unwrap_err()); + assert_eq!(RecvTimeoutError::Disconnected, rx.recv_timeout(one_us).unwrap_err()); + assert_eq!(TryRecvError::Disconnected, rx.try_recv().unwrap_err()); +} + +#[test] +fn concurrent_writes() { + for _ in 0..100 { + concurrent_writes_iter(); + } +}