diff --git a/rust/sbd-bench/examples/c_count_scale.rs b/rust/sbd-bench/examples/c_count_scale.rs new file mode 100644 index 0000000..9a4ea58 --- /dev/null +++ b/rust/sbd-bench/examples/c_count_scale.rs @@ -0,0 +1,4 @@ +#[tokio::main(flavor = "multi_thread")] +async fn main() { + sbd_bench::c_count_scale(usize::MAX).await; +} diff --git a/rust/sbd-bench/src/c_count_scale.rs b/rust/sbd-bench/src/c_count_scale.rs new file mode 100644 index 0000000..98776fb --- /dev/null +++ b/rust/sbd-bench/src/c_count_scale.rs @@ -0,0 +1,166 @@ +use super::*; +use std::sync::Mutex; + +pub struct Fail(Arc, bool); + +impl Clone for Fail { + fn clone(&self) -> Self { + Self(self.0.clone(), false) + } +} + +impl Drop for Fail { + fn drop(&mut self) { + if self.1 { + self.0.close(); + } + } +} + +impl Default for Fail { + fn default() -> Self { + Self(Arc::new(tokio::sync::Semaphore::new(0)), false) + } +} + +impl Fail { + fn set_fail_on_drop(&mut self, fail_on_drop: bool) { + self.1 = fail_on_drop; + } + + async fn fail(&self) { + let _ = self.0.acquire().await; + } +} + +struct Stats { + addr: std::net::SocketAddr, + last_ip: u32, + client_count: usize, + messages_sent: usize, +} + +impl std::fmt::Debug for Stats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Stats") + .field("client_count", &self.client_count) + .field("messages_sent", &self.messages_sent) + .finish() + } +} + +async fn create_client(this: &Mutex) -> (SbdClient, MsgRecv) { + let crypto = DefaultCrypto::default(); + + let (ip, addr) = { + let mut lock = this.lock().unwrap(); + let ip = lock.last_ip; + lock.last_ip += 1; + lock.client_count += 1; + (ip, lock.addr) + }; + + let ip = ip.to_be_bytes(); + let ip = std::net::Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]); + let ip = format!("{ip}"); + + let c = SbdClientConfig { + allow_plain_text: true, + headers: vec![("test-ip".to_string(), ip)], + ..Default::default() + }; + + let url = format!("ws://{}", addr); + + SbdClient::connect_config(&url, &crypto, c).await.unwrap() +} + +pub async fn c_count_scale(max: usize) -> ! { + let config = Arc::new(Config { + bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()], + trusted_ip_header: Some("test-ip".to_string()), + ..Default::default() + }); + + let server = SbdServer::new(config).await.unwrap(); + + let addr = *server.bind_addrs().first().unwrap(); + + let stats = Arc::new(Mutex::new(Stats { + addr, + last_ip: u32::from_be_bytes([1, 1, 1, 1]), + client_count: 0, + messages_sent: 0, + })); + + let fail = Fail::default(); + + for _ in 0..16 { + let stats = stats.clone(); + let mut fail = fail.clone(); + tokio::task::spawn(async move { + loop { + fail.set_fail_on_drop(true); + + let client_count = stats.lock().unwrap().client_count; + if client_count > max { + fail.set_fail_on_drop(false); + + return; + } + + let (c_a, mut r_a) = create_client(&stats).await; + let (c_b, mut r_b) = create_client(&stats).await; + + let mut fail = fail.clone(); + let stats = stats.clone(); + tokio::task::spawn(async move { + fail.set_fail_on_drop(true); + + loop { + c_a.send(c_b.pub_key(), b"hello").await.unwrap(); + c_b.send(c_a.pub_key(), b"world").await.unwrap(); + let m = r_b.recv().await?; + assert_eq!(b"hello", m.message()); + let m = r_a.recv().await?; + assert_eq!(b"world", m.message()); + + stats.lock().unwrap().messages_sent += 2; + + tokio::time::sleep(std::time::Duration::from_secs(2)) + .await; + } + + #[allow(unreachable_code)] + Some(()) + }); + } + }); + } + + tokio::task::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + + println!("{:?}", *stats.lock().unwrap()); + } + }); + + fail.fail().await; + + panic!("test failed") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn c_count_scale_test() { + let _ = tokio::time::timeout( + std::time::Duration::from_secs(5), + c_count_scale(300), + ) + .await; + } +} diff --git a/rust/sbd-bench/src/lib.rs b/rust/sbd-bench/src/lib.rs index 56152a6..cf21fd0 100644 --- a/rust/sbd-bench/src/lib.rs +++ b/rust/sbd-bench/src/lib.rs @@ -18,6 +18,7 @@ async fn raw_connect( max_message_size: 1000, allow_plain_text: true, danger_disable_certificate_check: false, + headers: Vec::new(), }) .connect() .await @@ -33,3 +34,6 @@ pub use thru::*; mod c_turnover; pub use c_turnover::*; + +mod c_count_scale; +pub use c_count_scale::*; diff --git a/rust/sbd-client/src/lib.rs b/rust/sbd-client/src/lib.rs index 43c74e5..fd8d50d 100644 --- a/rust/sbd-client/src/lib.rs +++ b/rust/sbd-client/src/lib.rs @@ -190,7 +190,7 @@ impl MsgRecv { } /// Configuration for connecting an SbdClient. -#[derive(Clone, Copy)] +#[derive(Clone)] pub struct SbdClientConfig { /// Outgoing message buffer size. pub out_buffer_size: usize, @@ -202,6 +202,9 @@ pub struct SbdClientConfig { /// scheme. WARNING: this is a dangerous configuration and should not /// be used outside of testing (i.e. self-signed tls certificates). pub danger_disable_certificate_check: bool, + + /// Set any custom http headers to send with the websocket connect. + pub headers: Vec<(String, String)>, } impl Default for SbdClientConfig { @@ -210,6 +213,7 @@ impl Default for SbdClientConfig { out_buffer_size: MAX_MSG_SIZE * 8, allow_plain_text: false, danger_disable_certificate_check: false, + headers: Vec::new(), } } } @@ -259,6 +263,7 @@ impl SbdClient { allow_plain_text: config.allow_plain_text, danger_disable_certificate_check: config .danger_disable_certificate_check, + headers: config.headers, } .connect() .await?; diff --git a/rust/sbd-client/src/raw_client.rs b/rust/sbd-client/src/raw_client.rs index d9f5581..17db367 100644 --- a/rust/sbd-client/src/raw_client.rs +++ b/rust/sbd-client/src/raw_client.rs @@ -14,11 +14,13 @@ pub struct WsRawConnect { /// Setting this to `true` allows `ws://` scheme. pub allow_plain_text: bool, - #[allow(unused_variables)] /// Setting this to `true` disables certificate verification on `wss://` /// scheme. WARNING: this is a dangerous configuration and should not /// be used outside of testing (i.e. self-signed tls certificates). pub danger_disable_certificate_check: bool, + + /// Set any custom http headers to send with the websocket connect. + pub headers: Vec<(String, String)>, } impl WsRawConnect { @@ -29,9 +31,21 @@ impl WsRawConnect { max_message_size, allow_plain_text, danger_disable_certificate_check, + headers, } = self; - let request = tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(full_url).map_err(Error::other)?; + use tokio_tungstenite::tungstenite::client::IntoClientRequest; + let mut request = IntoClientRequest::into_client_request(full_url) + .map_err(Error::other)?; + + for (k, v) in headers { + use tokio_tungstenite::tungstenite::http::header::*; + let k = + HeaderName::from_bytes(k.as_bytes()).map_err(Error::other)?; + let v = + HeaderValue::from_bytes(v.as_bytes()).map_err(Error::other)?; + request.headers_mut().insert(k, v); + } let scheme_ws = request.uri().scheme_str() == Some("ws"); let scheme_wss = request.uri().scheme_str() == Some("wss"); diff --git a/rust/sbd-server/tests/rate_limit_enforced.rs b/rust/sbd-server/tests/rate_limit_enforced.rs index 7b08531..2e1f3b8 100644 --- a/rust/sbd-server/tests/rate_limit_enforced.rs +++ b/rust/sbd-server/tests/rate_limit_enforced.rs @@ -18,6 +18,7 @@ async fn get_client( max_message_size: 100, allow_plain_text: true, danger_disable_certificate_check: false, + headers: Vec::new(), }) .connect() .await