Skip to content

Commit

Permalink
Prevent race condition in ResponseTracker by using entry which holds …
Browse files Browse the repository at this point in the history
…the lock

This commit resolves the race condition in the ResponseTracer when creating a new
token with an existing correlation id and completing the very same correlation id
at the same time. The solution is to use the entry method of Dashmap which keeps the
lock as long as entry is not dropped.

This fixes #1531.
  • Loading branch information
tillrohrmann committed May 19, 2024
1 parent dbf2f3e commit e489a92
Showing 1 changed file with 79 additions and 24 deletions.
103 changes: 79 additions & 24 deletions crates/network/src/rpc_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

use std::sync::{Arc, Weak};

use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use futures::stream::BoxStream;
use futures::StreamExt;
Expand Down Expand Up @@ -134,31 +135,25 @@ where

/// Returns None if an in-flight request holds the same correlation_id.
pub fn new_token(&self, correlation_id: T::CorrelationId) -> Option<RpcToken<T>> {
let (sender, receiver) = oneshot::channel();
let existing = self
.inner
.in_flight
.insert(correlation_id.clone(), RpcTokenSender { sender });

if existing.is_some() {
// in this extraordinary case, we put the old token back even that it wouldn't really
// guarantee correctness since the response might have arrived by now, but we do it
// anyway as a best hope.
self.inner
.in_flight
.entry(correlation_id.clone())
.and_modify(|val| *val = existing.unwrap());
warn!(
"correlation id {:?} was already in-flight when this rpc was issued, this is an indicator that the correlation_id is not unique across RPC calls",
correlation_id
);
return None;
match self.inner.in_flight.entry(correlation_id.clone()) {
Entry::Occupied(_) => {
warn!(
"correlation id {:?} was already in-flight when this rpc was issued, this is an indicator that the correlation_id is not unique across RPC calls",
correlation_id
);
None
}
Entry::Vacant(entry) => {
let (sender, receiver) = oneshot::channel();
entry.insert(RpcTokenSender { sender });

Some(RpcToken {
correlation_id,
router: Arc::downgrade(&self.inner),
receiver: Some(receiver),
})
}
}
Some(RpcToken {
correlation_id,
router: Arc::downgrade(&self.inner),
receiver: Some(receiver),
})
}

/// Returns None if an in-flight request holds the same correlation_id.
Expand Down Expand Up @@ -308,8 +303,10 @@ where
#[cfg(test)]
mod test {
use super::*;
use futures::future::join_all;
use restate_node_protocol::common::TargetName;
use restate_types::GenerationalNodeId;
use tokio::sync::Barrier;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TestCorrelationId(u64);
Expand Down Expand Up @@ -420,4 +417,62 @@ mod test {
assert_eq!(GenerationalNodeId::new(1, 1), from);
assert_eq!("a very real message", msg.text);
}

#[tokio::test(flavor = "multi_thread")]
async fn concurrent_response_tracker_modifications() {
let num_responses = 10000;
let response_tracker = ResponseTracker::default();

let rpc_tokens: Vec<RpcToken<TestResponse>> = (0..num_responses)
.map(|idx| {
response_tracker
.new_token(TestCorrelationId(idx))
.expect("first time created")
})
.collect();

let barrier = Arc::new(Barrier::new((2 * num_responses) as usize));

for idx in 0..num_responses {
let response_tracker_handle_message = response_tracker.clone();
let barrier_handle_message = Arc::clone(&barrier);

tokio::spawn(async move {
barrier_handle_message.wait().await;
response_tracker_handle_message.handle_message(MessageEnvelope::new(
GenerationalNodeId::new(0, 0),
0,
TestResponse {
text: format!("{}", idx),
correlation_id: TestCorrelationId(idx),
},
));
});

let response_tracker_new_token = response_tracker.clone();
let barrier_new_token = Arc::clone(&barrier);

tokio::spawn(async move {
barrier_new_token.wait().await;
response_tracker_new_token.new_token(TestCorrelationId(idx));
});
}

let results = join_all(rpc_tokens.into_iter().map(|rpc_token| async {
rpc_token
.recv()
.await
.expect("should complete successfully")
}))
.await;

for result in results {
let (_, response) = result.split();

assert_eq!(
response.text.parse::<u64>().expect("valid u64"),
response.correlation_id.0
);
}
}
}

0 comments on commit e489a92

Please sign in to comment.