Skip to content

Commit

Permalink
fix(server): count http calls in connection guard (#1468)
Browse files Browse the repository at this point in the history
* fix(server): count http calls in connection guard

* fix clippy

* add comments

* fix test
  • Loading branch information
niklasad1 authored Oct 7, 2024
1 parent 1bfea40 commit 7ad8079
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 16 deletions.
17 changes: 12 additions & 5 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1146,12 +1146,19 @@ where
RpcServiceCfg::OnlyCalls,
));

Box::pin(
http::call_with_service(request, batch_config, max_request_size, rpc_service, max_response_size)
.map(Ok),
)
Box::pin(async move {
let rp =
http::call_with_service(request, batch_config, max_request_size, rpc_service, max_response_size)
.await;
// NOTE: The `conn guard` must be held until the response is processed
// to respect the `max_connections` limit.
drop(conn);
Ok(rp)
})
} else {
Box::pin(async { http::response::denied() }.map(Ok))
// NOTE: the `conn guard` is dropped when this function which is fine
// because it doesn't rely on any async operations.
Box::pin(async { Ok(http::response::denied()) })
}
}
}
Expand Down
21 changes: 11 additions & 10 deletions server/src/tests/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,16 +569,17 @@ async fn custom_subscription_id_works() {
let addr = server.local_addr().unwrap();
let mut module = RpcModule::new(());
module
.register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, sink, _, _| async {
let sink = sink.accept().await.unwrap();

assert!(matches!(sink.subscription_id(), SubscriptionId::Str(id) if id == "0xdeadbeef"));

loop {
let _ = &sink;
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
}
})
.register_subscription::<(), _, _>(
"subscribe_hello",
"subscribe_hello",
"unsubscribe_hello",
|_, sink, _, _| async {
let sink = sink.accept().await.unwrap();
assert!(matches!(sink.subscription_id(), SubscriptionId::Str(id) if id == "0xdeadbeef"));
// Keep idle until it's unsubscribed.
futures_util::future::pending::<()>().await;
},
)
.unwrap();
let _handle = server.start(module);

Expand Down
70 changes: 69 additions & 1 deletion tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use jsonrpsee::core::server::SubscriptionMessage;
use jsonrpsee::core::{JsonValue, StringError};
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::server::middleware::http::HostFilterLayer;
use jsonrpsee::server::{ServerBuilder, ServerHandle};
use jsonrpsee::server::{ConnectionGuard, ServerBuilder, ServerHandle};
use jsonrpsee::types::error::{ErrorObject, UNKNOWN_ERROR_CODE};
use jsonrpsee::ws_client::WsClientBuilder;
use jsonrpsee::{rpc_params, ResponsePayload, RpcModule};
Expand Down Expand Up @@ -1543,3 +1543,71 @@ async fn server_ws_low_api_works() {
Ok(local_addr)
}
}

#[tokio::test]
async fn http_connection_guard_works() {
use jsonrpsee::{server::ServerBuilder, RpcModule};
use tokio::sync::mpsc;

init_logger();

let (tx, mut rx) = mpsc::channel::<()>(1);

let server_url = {
let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("http://{}", server.local_addr().unwrap());
let mut module = RpcModule::new(tx);

module
.register_async_method("wait_until", |_, wait, _| async move {
wait.send(()).await.unwrap();
wait.closed().await;
true
})
.unwrap();

module
.register_async_method("connection_count", |_, _, ctx| async move {
let conn = ctx.get::<ConnectionGuard>().unwrap();
conn.max_connections() - conn.available_connections()
})
.unwrap();

let handle = server.start(module);

tokio::spawn(handle.stopped());

server_url
};

let waiting_calls: Vec<_> = (0..2)
.map(|_| {
let client = HttpClientBuilder::default().build(&server_url).unwrap();
tokio::spawn(async move {
let _ = client.request::<bool, ArrayParams>("wait_until", rpc_params!()).await;
})
})
.collect();

// Wait until both calls are ACK:ed by the server.
rx.recv().await.unwrap();
rx.recv().await.unwrap();

// Assert that two calls are waiting to be answered and the current one.
{
let client = HttpClientBuilder::default().build(&server_url).unwrap();
let conn_count = client.request::<usize, ArrayParams>("connection_count", rpc_params!()).await.unwrap();
assert_eq!(conn_count, 3);
}

// Complete the waiting calls.
drop(rx);
futures::future::join_all(waiting_calls).await;

// Assert that connection count is back to 1.
{
let client = HttpClientBuilder::default().build(&server_url).unwrap();
let conn_count = client.request::<usize, ArrayParams>("connection_count", rpc_params!()).await.unwrap();
assert_eq!(conn_count, 1);
}
}

0 comments on commit 7ad8079

Please sign in to comment.