Skip to content

Commit

Permalink
Add Sniffer::wait_for_message_type_with_remove..
Browse files Browse the repository at this point in the history
fn
  • Loading branch information
jbesraa committed Jan 9, 2025
1 parent aa3f30a commit 937609e
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
48 changes: 48 additions & 0 deletions roles/tests-integration/lib/sniffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,38 @@ impl Sniffer {
}
}

pub async fn wait_for_message_type_with_remove(
&self,
message_direction: MessageDirection,
message_type: u8,
) -> bool {
let now = std::time::Instant::now();
loop {
let has_message_type = match message_direction {
MessageDirection::ToDownstream => self
.messages_from_upstream
.has_message_type_with_remove(message_type),
MessageDirection::ToUpstream => self
.messages_from_downstream
.has_message_type_with_remove(message_type),
};

// ready to unblock test runtime
if has_message_type {
return true;
}

// 10 min timeout
// only for worst case, ideally should never be triggered
if now.elapsed().as_secs() > 10 * 60 {
panic!("Timeout waiting for message type");
}

// sleep to reduce async lock contention
sleep(Duration::from_secs(1)).await;
}
}

pub async fn includes_message_type(
&self,
message_direction: MessageDirection,
Expand Down Expand Up @@ -672,6 +704,22 @@ impl MessagesAggregator {
has_message
}

fn has_message_type_with_remove(&self, message_type: u8) -> bool {
self.messages
.safe_lock(|messages| {
let mut cloned_messages = messages.clone();
for (pos, (t, _)) in cloned_messages.iter().enumerate() {
if *t == message_type {
let drained = cloned_messages.drain(pos + 1..).collect();
*messages = drained;
return true;
}
}
false
})
.unwrap()
}

// The aggregator queues messages in FIFO order, so this function returns the oldest message in
// the queue.
//
Expand Down
39 changes: 37 additions & 2 deletions roles/tests-integration/tests/sniffer_integration.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_ERROR;
use const_sv2::{
MESSAGE_TYPE_SETUP_CONNECTION_ERROR, MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS,
MESSAGE_TYPE_SET_NEW_PREV_HASH,
};
use integration_tests_sv2::*;
use roles_logic_sv2::{
common_messages_sv2::SetupConnectionError,
Expand All @@ -10,7 +13,6 @@ use std::convert::TryInto;
#[tokio::test]
async fn test_sniffer_interrupter() {
let (_tp, tp_addr) = start_template_provider(None).await;
use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS;
let message =
PoolMessages::Common(CommonMessages::SetupConnectionError(SetupConnectionError {
flags: 0,
Expand All @@ -33,3 +35,36 @@ async fn test_sniffer_interrupter() {
assert_common_message!(&sniffer.next_message_from_downstream(), SetupConnection);
assert_common_message!(&sniffer.next_message_from_upstream(), SetupConnectionError);
}

#[tokio::test]
async fn test_sniffer_wait_for_message_type_with_remove() {
let (_tp, tp_addr) = start_template_provider(None).await;
let (sniffer, sniffer_addr) = start_sniffer("".to_string(), tp_addr, false, None).await;
let _ = start_pool(Some(sniffer_addr)).await;
assert!(
sniffer
.wait_for_message_type_with_remove(
MessageDirection::ToDownstream,
MESSAGE_TYPE_SET_NEW_PREV_HASH,
)
.await
);
assert_eq!(
sniffer
.includes_message_type(
MessageDirection::ToDownstream,
MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS
)
.await,
false
);
assert_eq!(
sniffer
.includes_message_type(
MessageDirection::ToDownstream,
MESSAGE_TYPE_SET_NEW_PREV_HASH
)
.await,
false
);
}

0 comments on commit 937609e

Please sign in to comment.