Skip to content

Commit

Permalink
connection count scale test (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
neonphog authored May 18, 2024
1 parent 8f76229 commit 0bcb451
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 3 deletions.
4 changes: 4 additions & 0 deletions rust/sbd-bench/examples/c_count_scale.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#[tokio::main(flavor = "multi_thread")]
async fn main() {
sbd_bench::c_count_scale(usize::MAX).await;
}
166 changes: 166 additions & 0 deletions rust/sbd-bench/src/c_count_scale.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
use super::*;
use std::sync::Mutex;

pub struct Fail(Arc<tokio::sync::Semaphore>, bool);

impl Clone for Fail {
fn clone(&self) -> Self {
Self(self.0.clone(), false)
}
}

impl Drop for Fail {
fn drop(&mut self) {
if self.1 {
self.0.close();
}
}
}

impl Default for Fail {
fn default() -> Self {
Self(Arc::new(tokio::sync::Semaphore::new(0)), false)
}
}

impl Fail {
fn set_fail_on_drop(&mut self, fail_on_drop: bool) {
self.1 = fail_on_drop;
}

async fn fail(&self) {
let _ = self.0.acquire().await;
}
}

struct Stats {
addr: std::net::SocketAddr,
last_ip: u32,
client_count: usize,
messages_sent: usize,
}

impl std::fmt::Debug for Stats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Stats")
.field("client_count", &self.client_count)
.field("messages_sent", &self.messages_sent)
.finish()
}
}

async fn create_client(this: &Mutex<Stats>) -> (SbdClient, MsgRecv) {
let crypto = DefaultCrypto::default();

let (ip, addr) = {
let mut lock = this.lock().unwrap();
let ip = lock.last_ip;
lock.last_ip += 1;
lock.client_count += 1;
(ip, lock.addr)
};

let ip = ip.to_be_bytes();
let ip = std::net::Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]);
let ip = format!("{ip}");

let c = SbdClientConfig {
allow_plain_text: true,
headers: vec![("test-ip".to_string(), ip)],
..Default::default()
};

let url = format!("ws://{}", addr);

SbdClient::connect_config(&url, &crypto, c).await.unwrap()
}

pub async fn c_count_scale(max: usize) -> ! {
let config = Arc::new(Config {
bind: vec!["127.0.0.1:0".to_string(), "[::1]:0".to_string()],
trusted_ip_header: Some("test-ip".to_string()),
..Default::default()
});

let server = SbdServer::new(config).await.unwrap();

let addr = *server.bind_addrs().first().unwrap();

let stats = Arc::new(Mutex::new(Stats {
addr,
last_ip: u32::from_be_bytes([1, 1, 1, 1]),
client_count: 0,
messages_sent: 0,
}));

let fail = Fail::default();

for _ in 0..16 {
let stats = stats.clone();
let mut fail = fail.clone();
tokio::task::spawn(async move {
loop {
fail.set_fail_on_drop(true);

let client_count = stats.lock().unwrap().client_count;
if client_count > max {
fail.set_fail_on_drop(false);

return;
}

let (c_a, mut r_a) = create_client(&stats).await;
let (c_b, mut r_b) = create_client(&stats).await;

let mut fail = fail.clone();
let stats = stats.clone();
tokio::task::spawn(async move {
fail.set_fail_on_drop(true);

loop {
c_a.send(c_b.pub_key(), b"hello").await.unwrap();
c_b.send(c_a.pub_key(), b"world").await.unwrap();
let m = r_b.recv().await?;
assert_eq!(b"hello", m.message());
let m = r_a.recv().await?;
assert_eq!(b"world", m.message());

stats.lock().unwrap().messages_sent += 2;

tokio::time::sleep(std::time::Duration::from_secs(2))
.await;
}

#[allow(unreachable_code)]
Some(())
});
}
});
}

tokio::task::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(3)).await;

println!("{:?}", *stats.lock().unwrap());
}
});

fail.fail().await;

panic!("test failed")
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test(flavor = "multi_thread")]
async fn c_count_scale_test() {
let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
c_count_scale(300),
)
.await;
}
}
4 changes: 4 additions & 0 deletions rust/sbd-bench/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async fn raw_connect(
max_message_size: 1000,
allow_plain_text: true,
danger_disable_certificate_check: false,
headers: Vec::new(),
})
.connect()
.await
Expand All @@ -33,3 +34,6 @@ pub use thru::*;

mod c_turnover;
pub use c_turnover::*;

mod c_count_scale;
pub use c_count_scale::*;
7 changes: 6 additions & 1 deletion rust/sbd-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl MsgRecv {
}

/// Configuration for connecting an SbdClient.
#[derive(Clone, Copy)]
#[derive(Clone)]
pub struct SbdClientConfig {
/// Outgoing message buffer size.
pub out_buffer_size: usize,
Expand All @@ -202,6 +202,9 @@ pub struct SbdClientConfig {
/// scheme. WARNING: this is a dangerous configuration and should not
/// be used outside of testing (i.e. self-signed tls certificates).
pub danger_disable_certificate_check: bool,

/// Set any custom http headers to send with the websocket connect.
pub headers: Vec<(String, String)>,
}

impl Default for SbdClientConfig {
Expand All @@ -210,6 +213,7 @@ impl Default for SbdClientConfig {
out_buffer_size: MAX_MSG_SIZE * 8,
allow_plain_text: false,
danger_disable_certificate_check: false,
headers: Vec::new(),
}
}
}
Expand Down Expand Up @@ -259,6 +263,7 @@ impl SbdClient {
allow_plain_text: config.allow_plain_text,
danger_disable_certificate_check: config
.danger_disable_certificate_check,
headers: config.headers,
}
.connect()
.await?;
Expand Down
18 changes: 16 additions & 2 deletions rust/sbd-client/src/raw_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ pub struct WsRawConnect {
/// Setting this to `true` allows `ws://` scheme.
pub allow_plain_text: bool,

#[allow(unused_variables)]
/// Setting this to `true` disables certificate verification on `wss://`
/// scheme. WARNING: this is a dangerous configuration and should not
/// be used outside of testing (i.e. self-signed tls certificates).
pub danger_disable_certificate_check: bool,

/// Set any custom http headers to send with the websocket connect.
pub headers: Vec<(String, String)>,
}

impl WsRawConnect {
Expand All @@ -29,9 +31,21 @@ impl WsRawConnect {
max_message_size,
allow_plain_text,
danger_disable_certificate_check,
headers,
} = self;

let request = tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(full_url).map_err(Error::other)?;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
let mut request = IntoClientRequest::into_client_request(full_url)
.map_err(Error::other)?;

for (k, v) in headers {
use tokio_tungstenite::tungstenite::http::header::*;
let k =
HeaderName::from_bytes(k.as_bytes()).map_err(Error::other)?;
let v =
HeaderValue::from_bytes(v.as_bytes()).map_err(Error::other)?;
request.headers_mut().insert(k, v);
}

let scheme_ws = request.uri().scheme_str() == Some("ws");
let scheme_wss = request.uri().scheme_str() == Some("wss");
Expand Down
1 change: 1 addition & 0 deletions rust/sbd-server/tests/rate_limit_enforced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async fn get_client(
max_message_size: 100,
allow_plain_text: true,
danger_disable_certificate_check: false,
headers: Vec::new(),
})
.connect()
.await
Expand Down

0 comments on commit 0bcb451

Please sign in to comment.