Skip to content

Commit

Permalink
Merge branch 'detect-keepalive-timeout-when-not-receiving'
Browse files Browse the repository at this point in the history
  • Loading branch information
inetic committed Jan 12, 2024
2 parents ea56974 + a5b7d9f commit 58747dd
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 39 deletions.
16 changes: 14 additions & 2 deletions lib/src/network/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
repository::RepositoryId,
};
use serde::{Deserialize, Serialize};
use std::io::Write;
use std::{fmt, io::Write};

#[derive(Clone, PartialEq, Serialize, Deserialize, Debug)]
pub(crate) enum Request {
Expand Down Expand Up @@ -119,7 +119,7 @@ impl Header {
}
}

#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
pub(crate) struct Message {
pub tag: Type,
pub channel: MessageChannelId,
Expand Down Expand Up @@ -147,6 +147,18 @@ impl Message {
}
}

impl fmt::Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Message {{ tag: {:?}, channel: {:?}, content-hash: {:?} }}",
self.tag,
self.channel,
self.content.hash()
)
}
}

impl From<(Header, Vec<u8>)> for Message {
fn from(hdr_and_content: (Header, Vec<u8>)) -> Message {
let hdr = hdr_and_content.0;
Expand Down
9 changes: 5 additions & 4 deletions lib/src/network/message_broker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ use tokio::{
};
use tracing::{instrument::Instrument, Span};

/// Maintains one or more connections to a peer, listening on all of them at the same time. Note
/// that at the present all the connections are UDP/QUIC based and so dropping some of them would
/// make sense. However, in the future we may also have other transports (e.g. TCP, Bluetooth) and
/// thus keeping all may make sence because even if one is dropped, the others may still function.
/// Maintains one or more connections to a single peer, listening on all of them at the same time.
/// Note that at the present all the connections are UDP/QUIC based and so dropping some of them
/// would make sense. However, in the future we may also have other transports (e.g. TCP,
/// Bluetooth) and thus keeping all may make sence because even if one is dropped, the others may
/// still function.
///
/// Once a message is received, it is determined whether it is a request or a response. Based on
/// that it either goes to the ClientStream or ServerStream for processing by the Client and Server
Expand Down
123 changes: 90 additions & 33 deletions lib/src/network/message_dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@ use crate::{
use async_trait::async_trait;
use deadlock::BlockingMutex;
use futures_util::{ready, stream::SelectAll, Sink, SinkExt, Stream, StreamExt};
use scoped_task::ScopedJoinHandle;
use std::{
future::Future,
pin::Pin,
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll, Waker},
};
use tokio::{
runtime, select,
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
sync::Mutex as AsyncMutex,
time::Duration,
sync::{Mutex as AsyncMutex, Notify, Semaphore},
time::{self, Duration},
};

// Time after which if no message is received, the connection is dropped.
Expand All @@ -33,7 +37,9 @@ const KEEP_ALIVE_RECV_INTERVAL: Duration = Duration::from_secs(60);
const KEEP_ALIVE_SEND_INTERVAL: Duration = Duration::from_secs(30);

/// Reads/writes messages from/to the underlying TCP or QUIC streams and dispatches them to
/// individual streams/sinks based on their ids.
/// individual streams/sinks based on their channel ids (in the MessageDispatcher's and
/// MessageBroker's contexts, there is a one-to-one relationship between the channel id and a
/// repository id).
#[derive(Clone)]
pub(super) struct MessageDispatcher {
recv: Arc<RecvState>,
Expand Down Expand Up @@ -80,12 +86,12 @@ impl MessageDispatcher {
}

pub async fn close(&self) {
self.recv.reader.close();
self.recv.multi_stream.close();
self.send.close().await;
}

pub fn is_closed(&self) -> bool {
self.recv.reader.is_empty() || self.send.is_empty()
self.recv.multi_stream.is_empty() || self.send.is_empty()
}
}

Expand All @@ -95,7 +101,7 @@ impl Drop for MessageDispatcher {
return;
}

self.recv.reader.close();
self.recv.multi_stream.close();

let send = self.send.clone();

Expand Down Expand Up @@ -138,12 +144,12 @@ impl ContentStream {

let mut lock = arc.lock().await;

if let Some((transport, data)) = lock.0.take() {
if let Some((transport, data)) = lock.parked_message.take() {
self.last_transport_id = Some(transport);
return Ok(data);
}

let (transport, data) = match self.state.recv(&mut lock.1).await {
let (transport, data) = match self.state.recv_on_queue(&mut lock.receiver).await {
Some((transport, data)) => (transport, data),
None => return Err(ContentStreamError::ChannelClosed),
};
Expand All @@ -153,7 +159,7 @@ impl ContentStream {
Ok(data)
} else {
self.last_transport_id = Some(transport);
lock.0 = Some((transport, data));
lock.parked_message = Some((transport, data));
Err(ContentStreamError::TransportChanged)
}
} else {
Expand Down Expand Up @@ -247,7 +253,7 @@ pub(super) struct LiveConnectionInfoSet {
impl LiveConnectionInfoSet {
/// Returns the current infos.
pub fn iter(&self) -> impl Iterator<Item = ConnectionInfo> {
let recv = self.recv.reader.connection_infos();
let recv = self.recv.multi_stream.connection_infos();
let send = self.send.connection_infos();

IntoIntersection::new(recv, send)
Expand All @@ -261,28 +267,30 @@ struct ChannelQueue {
tx: ChannelQueueSender,
}

type ChannelQueueReceiver = (
Option<(PermitId, Vec<u8>)>,
UnboundedReceiver<(PermitId, Vec<u8>)>,
);
struct ChannelQueueReceiver {
parked_message: Option<(PermitId, Vec<u8>)>,
receiver: UnboundedReceiver<(PermitId, Vec<u8>)>,
}

type ChannelQueueSender = UnboundedSender<(PermitId, Vec<u8>)>;

struct RecvState {
reader: Arc<MultiStream>,
multi_stream: Arc<MultiStream>,
queues: Arc<BlockingMutex<HashMap<MessageChannelId, ChannelQueue>>>,
single_sorter: Semaphore,
}

impl RecvState {
fn new() -> Self {
Self {
reader: Arc::new(MultiStream::new()),
multi_stream: Arc::new(MultiStream::new()),
queues: Arc::new(BlockingMutex::new(HashMap::default())),
single_sorter: Semaphore::new(1),
}
}

fn add(&self, stream: PermittedStream) {
self.reader.add(stream);
self.multi_stream.add(stream);
}

fn add_channel(&self, channel_id: MessageChannelId) {
Expand All @@ -292,14 +300,17 @@ impl RecvState {
let (tx, rx) = mpsc::unbounded_channel();
entry.insert(ChannelQueue {
reference_count: 1,
rx: Arc::new(AsyncMutex::new((None, rx))),
rx: Arc::new(AsyncMutex::new(ChannelQueueReceiver {
parked_message: None,
receiver: rx,
})),
tx,
});
}
}
}

async fn recv(
async fn recv_on_queue(
&self,
queue_rx: &mut UnboundedReceiver<(PermitId, Vec<u8>)>,
) -> Option<(PermitId, Vec<u8>)> {
Expand All @@ -310,7 +321,11 @@ impl RecvState {
}

async fn sort_incoming_messages_into_queues(&self) {
while let Some((transport, message)) = self.reader.recv().await {
// The `recv_on_queue` function may be called multiple times (once per channel), but this
// function must be called at most once, otherwise messages could get reordered.
let _permit = self.single_sorter.acquire().await;

while let Some((transport, message)) = self.multi_stream.recv().await {
if let Some(queue) = self.queues.lock().unwrap().get_mut(&message.channel) {
queue.tx.send((transport, message.content)).unwrap_or(());
}
Expand Down Expand Up @@ -411,37 +426,56 @@ impl Sink<Message> for PermittedSink {

// Stream that reads `Message`s from multiple underlying raw (byte) streams concurrently.
struct MultiStream {
inner: BlockingMutex<MultiStreamInner>,
explicitly_closed: AtomicBool,
rx: AsyncMutex<mpsc::Receiver<(PermitId, Message)>>,
inner: Arc<BlockingMutex<MultiStreamInner>>,
stream_added: Arc<Notify>,
_runner: ScopedJoinHandle<()>,
}

impl MultiStream {
fn new() -> Self {
const MAX_QUEUED_MESSAGES: usize = 32;

let inner = Arc::new(BlockingMutex::new(MultiStreamInner {
streams: SelectAll::new(),
waker: None,
}));

let stream_added = Arc::new(Notify::new());
let (tx, rx) = mpsc::channel(MAX_QUEUED_MESSAGES);

let _runner =
scoped_task::spawn(multi_stream_runner(inner.clone(), tx, stream_added.clone()));

Self {
inner: BlockingMutex::new(MultiStreamInner {
streams: SelectAll::new(),
waker: None,
}),
explicitly_closed: AtomicBool::new(false),
rx: AsyncMutex::new(rx),
inner,
stream_added,
_runner,
}
}

fn add(&self, stream: PermittedStream) {
let mut inner = self.inner.lock().unwrap();
inner.streams.push(stream);
inner.wake();
self.stream_added.notify_one();
}

// Receive next message from this stream. Equivalent to
//
// ```ignore
// async fn recv(&self) -> Option<Message>;
// ```
fn recv(&self) -> Recv {
Recv { inner: &self.inner }
// Receive next message from this stream.
async fn recv(&self) -> Option<(PermitId, Message)> {
if self.explicitly_closed.load(Ordering::Relaxed) {
return None;
}
self.rx.lock().await.recv().await
}

// Closes this stream. Any subsequent `recv` will immediately return `None` unless new
// streams are added first.
fn close(&self) {
self.explicitly_closed.store(true, Ordering::Relaxed);
let mut inner = self.inner.lock().unwrap();
inner.streams.clear();
inner.wake();
Expand Down Expand Up @@ -475,6 +509,29 @@ impl MultiStreamInner {
}
}

// We need this runner because we want to detect when a peer has disconnected even if the user has
// not called `MultiStream::recv`.
async fn multi_stream_runner(
inner: Arc<BlockingMutex<MultiStreamInner>>,
tx: mpsc::Sender<(PermitId, Message)>,
stream_added: Arc<Notify>,
) {
// Wait for at least one stream to be added.
stream_added.notified().await;

while let Some((permit_id, message)) = (Recv { inner: &inner }).await {
// Close the connection if the sender is sending too many messages that we're not handling
// in a reasonable time. Note that if we don't have some of the repositories that the peer
// has, then they'll send some small number of messages from their Barrier code. That's
// fine because that number does not exceed MAX_QUEUED_MESSAGES and so the above `tx.send`
// won't block for long.
match time::timeout(KEEP_ALIVE_RECV_INTERVAL, tx.send((permit_id, message))).await {
Ok(Ok(())) => (),
Err(_) | Ok(Err(_)) => break,
}
}
}

// Future returned from [`MultiStream::recv`].
struct Recv<'a> {
inner: &'a BlockingMutex<MultiStreamInner>,
Expand Down

0 comments on commit 58747dd

Please sign in to comment.