Skip to content
This repository has been archived by the owner on Jan 22, 2025. It is now read-only.

prunes turbine QUIC connections #33663

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,7 @@ impl Validator {
.expect("Operator must spin up node with valid QUIC TVU address")
.ip(),
turbine_quic_endpoint_sender,
bank_forks.clone(),
)
.unwrap();

Expand Down
158 changes: 146 additions & 12 deletions turbine/src/quic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ use {
rcgen::RcgenError,
rustls::{Certificate, PrivateKey},
solana_quic_client::nonblocking::quic_client::SkipServerVerification,
solana_runtime::bank_forks::BankForks,
solana_sdk::{pubkey::Pubkey, signature::Keypair},
solana_streamer::{
quic::SkipClientVerification, tls_certificates::new_self_signed_tls_certificate,
},
std::{
cmp::Reverse,
collections::{hash_map::Entry, HashMap},
io::Error as IoError,
net::{IpAddr, SocketAddr, UdpSocket},
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock,
},
},
thiserror::Error,
tokio::{
Expand All @@ -32,6 +37,7 @@ use {

const CLIENT_CHANNEL_BUFFER: usize = 1 << 14;
const ROUTER_CHANNEL_BUFFER: usize = 64;
const CONNECTION_CACHE_CAPACITY: usize = 3072;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this number chosen because it's roughly the size of mainnet currently? It seems reasonable, but I wanted to be sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, basically.
I think it would be too slow to have to reconnect before sending the shred, so we effectively need to cache all the connections.

const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = 1280;
const ALPN_TURBINE_PROTOCOL_ID: &[u8] = b"solana-turbine";
const CONNECT_SERVER_NAME: &str = "solana-turbine";
Expand All @@ -40,11 +46,13 @@ const CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN: VarInt = VarInt::from_u32(1);
const CONNECTION_CLOSE_ERROR_CODE_DROPPED: VarInt = VarInt::from_u32(2);
const CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY: VarInt = VarInt::from_u32(3);
const CONNECTION_CLOSE_ERROR_CODE_REPLACED: VarInt = VarInt::from_u32(4);
const CONNECTION_CLOSE_ERROR_CODE_PRUNED: VarInt = VarInt::from_u32(5);

const CONNECTION_CLOSE_REASON_SHUTDOWN: &[u8] = b"SHUTDOWN";
const CONNECTION_CLOSE_REASON_DROPPED: &[u8] = b"DROPPED";
const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY";
const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED";
const CONNECTION_CLOSE_REASON_PRUNED: &[u8] = b"PRUNED";

pub type AsyncTryJoinHandle = TryJoin<JoinHandle<()>, JoinHandle<()>>;

Expand Down Expand Up @@ -75,6 +83,7 @@ pub fn new_quic_endpoint(
socket: UdpSocket,
address: IpAddr,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
) -> Result<
(
Endpoint,
Expand All @@ -98,19 +107,24 @@ pub fn new_quic_endpoint(
)?
};
endpoint.set_default_client_config(client_config);
let prune_cache_pending = Arc::<AtomicBool>::default();
let cache = Arc::<Mutex<HashMap<Pubkey, Connection>>>::default();
let router = Arc::<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER);
let server_task = runtime.spawn(run_server(
endpoint.clone(),
sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
let client_task = runtime.spawn(run_client(
endpoint.clone(),
client_receiver,
sender,
bank_forks,
prune_cache_pending,
router,
cache,
));
Expand Down Expand Up @@ -163,6 +177,8 @@ fn new_transport_config() -> TransportConfig {
async fn run_server(
endpoint: Endpoint,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -171,6 +187,8 @@ async fn run_server(
endpoint.clone(),
connecting,
sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand All @@ -181,6 +199,8 @@ async fn run_client(
endpoint: Endpoint,
mut receiver: AsyncReceiver<(SocketAddr, Bytes)>,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -203,6 +223,8 @@ async fn run_client(
remote_address,
sender.clone(),
receiver,
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand Down Expand Up @@ -234,10 +256,22 @@ async fn handle_connecting_error(
endpoint: Endpoint,
connecting: Connecting,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
if let Err(err) = handle_connecting(endpoint, connecting, sender, router, cache).await {
if let Err(err) = handle_connecting(
endpoint,
connecting,
sender,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await
{
error!("handle_connecting: {err:?}");
}
}
Expand All @@ -246,6 +280,8 @@ async fn handle_connecting(
endpoint: Endpoint,
connecting: Connecting,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -264,24 +300,37 @@ async fn handle_connecting(
connection,
sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await;
Ok(())
}

#[allow(clippy::too_many_arguments)]
async fn handle_connection(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
receiver: AsyncReceiver<Bytes>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
cache_connection(remote_pubkey, connection.clone(), &cache).await;
cache_connection(
remote_pubkey,
connection.clone(),
bank_forks,
prune_cache_pending,
router.clone(),
cache.clone(),
)
.await;
let send_datagram_task = tokio::task::spawn(send_datagram_task(connection.clone(), receiver));
let read_datagram_task = tokio::task::spawn(read_datagram_task(
endpoint,
Expand Down Expand Up @@ -351,11 +400,22 @@ async fn make_connection_task(
remote_address: SocketAddr,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
receiver: AsyncReceiver<Bytes>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
if let Err(err) =
make_connection(endpoint, remote_address, sender, receiver, router, cache).await
if let Err(err) = make_connection(
endpoint,
remote_address,
sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await
{
error!("make_connection: {remote_address}, {err:?}");
}
Expand All @@ -366,6 +426,8 @@ async fn make_connection(
remote_address: SocketAddr,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
receiver: AsyncReceiver<Bytes>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -379,6 +441,8 @@ async fn make_connection(
connection,
sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
Expand All @@ -402,15 +466,32 @@ fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
async fn cache_connection(
remote_pubkey: Pubkey,
connection: Connection,
cache: &Mutex<HashMap<Pubkey, Connection>>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
let Some(old) = cache.lock().await.insert(remote_pubkey, connection) else {
return;
let (old, should_prune_cache) = {
let mut cache = cache.lock().await;
(
cache.insert(remote_pubkey, connection),
cache.len() >= CONNECTION_CACHE_CAPACITY.saturating_mul(2),
)
};
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
if let Some(old) = old {
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
}
if should_prune_cache && !prune_cache_pending.swap(true, Ordering::Relaxed) {
tokio::task::spawn(prune_connection_cache(
bank_forks,
prune_cache_pending,
router,
cache,
));
}
}

async fn drop_connection(
Expand All @@ -429,6 +510,50 @@ async fn drop_connection(
}
}

async fn prune_connection_cache(
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
debug_assert!(prune_cache_pending.load(Ordering::Relaxed));
let staked_nodes = {
let root_bank = bank_forks.read().unwrap().root_bank();
root_bank.staked_nodes()
};
{
let mut cache = cache.lock().await;
if cache.len() < CONNECTION_CACHE_CAPACITY.saturating_mul(2) {
prune_cache_pending.store(false, Ordering::Relaxed);
return;
}
let mut connections: Vec<_> = cache
.drain()
.filter(|(_, connection)| connection.close_reason().is_none())
.map(|entry @ (pubkey, _)| {
let stake = staked_nodes.get(&pubkey).copied().unwrap_or_default();
(stake, entry)
})
.collect();
connections
.select_nth_unstable_by_key(CONNECTION_CACHE_CAPACITY, |&(stake, _)| Reverse(stake));
for (_, (_, connection)) in &connections[CONNECTION_CACHE_CAPACITY..] {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_PRUNED,
CONNECTION_CLOSE_REASON_PRUNED,
);
}
cache.extend(
connections
.into_iter()
.take(CONNECTION_CACHE_CAPACITY)
.map(|(_, entry)| entry),
);
prune_cache_pending.store(false, Ordering::Relaxed);
}
router.write().await.retain(|_, sender| !sender.is_closed());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been trying to see when the receiver would be closed in order to prune these senders properly, and I want to make sure that this will work as intended.

The receiver side is plumbed down into send_datagram_task, which loops on receiver.recv(), so it would only error when it receives data and then tries to call connection.send_datagram, at which point it gets dropped. This would be bad because it won't get closed immediately.

On the other side, read_datagram_task will fail immediately on the next read after close, which will make both tasks terminate immediately since they're joined with try_join. This means that the receiver is dropped correctly, so that you can prune the senders on this side.

Do I have that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good question and might have spotted a bug in this code.
I think I need a patch like below to stop waiting on receiver if the connection is already closed.
I will do this in a separate pr because it is orthogonal to the cache pruning logic here.

https://github.com/solana-labs/solana/blob/c98c24bd6/turbine/src/quic_endpoint.rs#L343

diff --git a/turbine/src/quic_endpoint.rs b/turbine/src/quic_endpoint.rs
index 0f362fd1a3..a911aef5ad 100644
--- a/turbine/src/quic_endpoint.rs
+++ b/turbine/src/quic_endpoint.rs
@@ -389,10 +389,19 @@ async fn send_datagram_task(
     connection: Connection,
     mut receiver: AsyncReceiver<Bytes>,
 ) -> Result<(), Error> {
-    while let Some(bytes) = receiver.recv().await {
-        connection.send_datagram(bytes)?;
+    loop {
+        tokio::select! {
+            biased;
+
+            bytes = receiver.recv() => {
+                match bytes {
+                    None => return Ok(()),
+                    Some(bytes) => connection.send_datagram(bytes)?,
+                }
+            }
+            err = connection.closed() => return Err(Error::from(err)),
+        }
     }
-    Ok(())
 }
 
 async fn make_connection_task(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I need a patch like below to stop waiting on receiver if the connection is already closed.

Yeah that makes sense to me, it'll make the code clearer about the error condition. The implicit cancellation of this task through the error in the other task with try_join was a little tricky to understand.

Feel free to do it in a follow-up PR!

}

impl<T> From<crossbeam_channel::SendError<T>> for Error {
fn from(_: crossbeam_channel::SendError<T>) -> Self {
Error::ChannelSendError
Expand All @@ -440,6 +565,8 @@ mod tests {
use {
super::*,
itertools::{izip, multiunzip},
solana_ledger::genesis_utils::{create_genesis_config, GenesisConfigInfo},
solana_runtime::bank::Bank,
solana_sdk::signature::Signer,
std::{iter::repeat_with, net::Ipv4Addr, time::Duration},
};
Expand Down Expand Up @@ -467,6 +594,12 @@ mod tests {
repeat_with(crossbeam_channel::unbounded::<(Pubkey, SocketAddr, Bytes)>)
.take(NUM_ENDPOINTS)
.unzip();
let bank_forks = {
let GenesisConfigInfo { genesis_config, .. } =
create_genesis_config(/*mint_lamports:*/ 100_000);
let bank = Bank::new_for_tests(&genesis_config);
Arc::new(RwLock::new(BankForks::new(bank)))
};
let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(keypairs.iter().zip(sockets).zip(senders).map(
|((keypair, socket), sender)| {
Expand All @@ -476,6 +609,7 @@ mod tests {
socket,
IpAddr::V4(Ipv4Addr::LOCALHOST),
sender,
bank_forks.clone(),
)
.unwrap()
},
Expand Down