Skip to content

Commit

Permalink
feat: clonable handler
Browse files Browse the repository at this point in the history
  • Loading branch information
Totodore committed Jun 24, 2024
1 parent c7a94b5 commit 302e2a2
Show file tree
Hide file tree
Showing 17 changed files with 113 additions and 146 deletions.
2 changes: 1 addition & 1 deletion e2e/engineioxide/engineioxide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl EngineIoHandler for MyHandler {
println!("socket disconnect {}: {:?}", socket.id, reason);
}

fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<Self::Data>>) {
fn on_message(&self, msg: Str, socket: Arc<Socket<Self::Data>>) {
println!("Ping pong message {:?}", msg);
socket.emit(msg).ok();
}
Expand Down
2 changes: 1 addition & 1 deletion engineioxide/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ mod tests {
println!("socket disconnect {} {:?}", socket.id, reason);
}

fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<Self::Data>>) {
fn on_message(&self, msg: Str, socket: Arc<Socket<Self::Data>>) {
println!("Ping pong message {:?}", msg);
socket.emit(msg).ok();
}
Expand Down
2 changes: 1 addition & 1 deletion engineioxide/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub trait EngineIoHandler: std::fmt::Debug + Send + Sync + 'static {
fn on_disconnect(&self, socket: Arc<Socket<Self::Data>>, reason: DisconnectReason);

/// Called when a message is received from the client.
fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<Self::Data>>);
fn on_message(&self, msg: Str, socket: Arc<Socket<Self::Data>>);

/// Called when a binary message is received from the client.
fn on_binary(&self, data: Bytes, socket: Arc<Socket<Self::Data>>);
Expand Down
48 changes: 22 additions & 26 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use engineioxide::Str;
use futures_util::{FutureExt, TryFutureExt};

use engineioxide::sid::Sid;
use matchit::Router;
use matchit::{Match, Router};
use tokio::sync::oneshot;

use crate::adapter::Adapter;
Expand Down Expand Up @@ -54,35 +54,31 @@ impl<A: Adapter> Client<A> {

/// Called when a socket connects to a new namespace
fn sock_connect(
self: Arc<Self>,
&self,
auth: Option<String>,
ns_path: Str,
esocket: Arc<engineioxide::Socket<SocketData<A>>>,
) {
#[cfg(feature = "tracing")]
tracing::debug!("auth: {:?}", auth);
let protocol: ProtocolVersion = esocket.protocol.into();
let esocket_clone = esocket.clone();
let connect = move |ns: Arc<Namespace<A>>| async move {
if ns
.connect(esocket_clone.id, esocket_clone.clone(), auth)
.await
.is_ok()
{
// cancel the connect timeout task for v5
if let Some(tx) = esocket_clone.data.connect_recv_tx.lock().unwrap().take() {
tx.send(()).ok();
let connect =
move |ns: Arc<Namespace<A>>, esocket: Arc<engineioxide::Socket<SocketData<A>>>| async move {
if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() {
// cancel the connect timeout task for v5
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
tx.send(()).ok();
}
}
}
};
};

if let Some(ns) = self.get_ns(&ns_path) {
tokio::spawn(connect(ns));
} else if let Ok(res) = self.router.read().unwrap().at(&ns_path) {
tokio::spawn(connect(ns, esocket));
} else if let Ok(Match { value: ns_ctr, .. }) = self.router.read().unwrap().at(&ns_path) {
let path: Cow<'static, str> = Cow::Owned(ns_path.clone().into());
let ns = res.value.get_new_ns(ns_path); //TODO: check memory leak here
let ns = ns_ctr.get_new_ns(ns_path); //TODO: check memory leak here
self.ns.write().unwrap().insert(path, ns.clone());
tokio::spawn(connect(ns));
tokio::spawn(connect(ns, esocket));
} else if protocol == ProtocolVersion::V4 && ns_path == "/" {
#[cfg(feature = "tracing")]
tracing::error!(
Expand Down Expand Up @@ -129,7 +125,7 @@ impl<A: Adapter> Client<A> {
/// Adds a new namespace handler
pub fn add_ns<C, T>(&self, path: Cow<'static, str>, callback: C)
where
C: ConnectHandler<A, T> + Clone,
C: ConnectHandler<A, T>,
T: Send + Sync + 'static,
{
#[cfg(feature = "tracing")]
Expand All @@ -140,7 +136,7 @@ impl<A: Adapter> Client<A> {

pub fn add_dyn_ns<C, T>(&self, path: String, callback: C) -> Result<(), matchit::InsertError>
where
C: ConnectHandler<A, T> + Clone,
C: ConnectHandler<A, T>,
T: Send + Sync + 'static,
{
#[cfg(feature = "tracing")]
Expand Down Expand Up @@ -176,8 +172,8 @@ impl<A: Adapter> Client<A> {
tracing::debug!("closing all namespaces");
let ns = { std::mem::take(&mut *self.ns.write().unwrap()) };
futures_util::future::join_all(
ns.iter()
.map(|(_, ns)| ns.close(DisconnectReason::ClosingServer)),
ns.values()
.map(|ns| ns.close(DisconnectReason::ClosingServer)),
)
.await;
#[cfg(feature = "tracing")]
Expand Down Expand Up @@ -239,8 +235,8 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
.ns
.read()
.unwrap()
.iter()
.filter_map(|(_, ns)| ns.get_socket(socket.id).ok())
.values()
.filter_map(|ns| ns.get_socket(socket.id).ok())
.collect();

let _res: Result<Vec<_>, _> = socks
Expand All @@ -259,7 +255,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
}
}

fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<EIoSocket<SocketData<A>>>) {
fn on_message(&self, msg: Str, socket: Arc<EIoSocket<SocketData<A>>>) {
#[cfg(feature = "tracing")]
tracing::debug!("Received message: {:?}", msg);
let packet = match Packet::try_from(msg) {
Expand All @@ -276,7 +272,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {

let res: Result<(), Error> = match packet.inner {
PacketData::Connect(auth) => {
self.clone().sock_connect(auth, packet.ns, socket.clone());
self.sock_connect(auth, packet.ns, socket.clone());
Ok(())
}
PacketData::BinaryEvent(_, _, _) | PacketData::BinaryAck(_, _) => {
Expand Down
2 changes: 2 additions & 0 deletions socketioxide/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use engineioxide::{sid::Sid, socket::DisconnectReason as EIoDisconnectReason};
use std::fmt::{Debug, Display};
use tokio::{sync::mpsc::error::TrySendError, time::error::Elapsed};

pub use matchit::InsertError as NsInsertError;

/// Error type for socketio
#[derive(thiserror::Error, Debug)]
pub enum Error {
Expand Down
3 changes: 1 addition & 2 deletions socketioxide/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
//! * [`HttpExtension`]: extracts an http extension of the given type coming from the request.
//! (Similar to axum's [`extract::Extension`](https://docs.rs/axum/latest/axum/struct.Extension.html)
//! * [`MaybeHttpExtension`]: extracts an http extension of the given type if it exists or [`None`] otherwise.
//! * [`NsParam`]: extracts and deserialize the namespace path parameters. Works only for the [`ConnectHandler`] and [`ConnectMiddleware`].
//!
//! ### You can also implement your own Extractor with the [`FromConnectParts`], [`FromMessageParts`] and [`FromDisconnectParts`] traits
//! When implementing these traits, if you clone the [`Arc<Socket>`](crate::socket::Socket) make sure that it is dropped at least when the socket is disconnected.
Expand Down Expand Up @@ -59,7 +58,7 @@
//!
//! impl<A: Adapter> FromConnectParts<A> for UserId {
//! type Error = Infallible;
//! fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<String>, _: &NsParamBuff<'_>) -> Result<Self, Self::Error> {
//! fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<String>) -> Result<Self, Self::Error> {
//! // In a real app it would be better to parse the query params with a crate like `url`
//! let uri = &s.req_parts().uri;
//! let uid = uri
Expand Down
31 changes: 28 additions & 3 deletions socketioxide/src/handler/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ pub trait FromConnectParts<A: Adapter>: Sized {
///
/// * See the [`connect`](super::connect) module doc for more details on connect middlewares.
/// * See the [`extract`](crate::extract) module doc for more details on available extractors.
pub trait ConnectMiddleware<A: Adapter, T>: Send + Sync + 'static {
pub trait ConnectMiddleware<A: Adapter, T>: Sized + Clone + Send + Sync + 'static {
/// Call the middleware with the given arguments.
fn call<'a>(
&'a self,
Expand All @@ -177,7 +177,7 @@ pub trait ConnectMiddleware<A: Adapter, T>: Send + Sync + 'static {
///
/// * See the [`connect`](super::connect) module doc for more details on connect handler.
/// * See the [`extract`](crate::extract) module doc for more details on available extractors.
pub trait ConnectHandler<A: Adapter, T>: Send + Sync + 'static {
pub trait ConnectHandler<A: Adapter, T>: Sized + Clone + Send + Sync + 'static {
/// Call the handler with the given arguments.
fn call(&self, s: Arc<Socket<A>>, auth: Option<String>);

Expand Down Expand Up @@ -236,7 +236,6 @@ pub trait ConnectHandler<A: Adapter, T>: Send + Sync + 'static {
/// ```
fn with<M, T1>(self, middleware: M) -> impl ConnectHandler<A, T>
where
Self: Sized,
M: ConnectMiddleware<A, T1> + Send + Sync + 'static,
T: Send + Sync + 'static,
T1: Send + Sync + 'static,
Expand Down Expand Up @@ -344,6 +343,32 @@ where
self.middleware.call(s, auth).await
}
}
impl<A, H, N, T, T1> Clone for LayeredConnectHandler<A, H, N, T, T1>
where
H: Clone,
N: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
middleware: self.middleware.clone(),
phantom: self.phantom,
}
}
}
impl<M, N, T, T1> Clone for ConnectMiddlewareLayer<M, N, T, T1>
where
M: Clone,
N: Clone,
{
fn clone(&self) -> Self {
Self {
middleware: self.middleware.clone(),
next: self.next.clone(),
phantom: self.phantom,
}
}
}

impl<A, M, N, T, T1> ConnectMiddleware<A, T> for ConnectMiddlewareLayer<M, N, T, T1>
where
Expand Down
41 changes: 17 additions & 24 deletions socketioxide/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use engineioxide::{
config::{EngineIoConfig, EngineIoConfigBuilder},
service::NotFoundService,
sid::Sid,
Str, TransportType,
TransportType,
};

use crate::{
Expand Down Expand Up @@ -360,35 +360,28 @@ impl<A: Adapter> SocketIo<A> {
/// });
///
/// ```
///
/// #### Example with dynamic namespace:
/// ```
/// # use socketioxide::{SocketIo, extract::{NsParam, SocketRef}};
/// #[derive(Debug, serde::Deserialize)]
/// struct Params {
/// id: String,
/// user_id: String
/// }
///
/// let (_svc, io) = SocketIo::new_svc();
/// io.ns("/{id}/user/{user_id}", |s: SocketRef, NsParam(params): NsParam<Params>| {
/// println!("new socket with params: {:?}", params);
/// }).unwrap();
///
/// // You can specify any type that implements the `serde::Deserialize` trait.
/// io.ns("/{id}/admin/{role}", |s: SocketRef, NsParam(params): NsParam<(usize, String)>| {
/// println!("new socket with params: {:?}", params);
/// }).unwrap();
/// ```
#[inline]
pub fn ns<C, T>(&self, path: impl Into<Cow<'static, str>>, callback: C)
where
C: ConnectHandler<A, T> + Clone,
C: ConnectHandler<A, T>,
T: Send + Sync + 'static,
{
self.0.add_ns(path.into(), callback)
}

#[inline]
pub fn dyn_ns<C, T>(
&self,
path: impl Into<String>,
callback: C,
) -> Result<(), crate::NsInsertError>
where
C: ConnectHandler<A, T>,
T: Send + Sync + 'static,
{
self.0.add_dyn_ns(path.into(), callback)
}

/// Deletes the namespace with the given path.
///
/// This will disconnect all sockets connected to this
Expand Down Expand Up @@ -433,9 +426,9 @@ impl<A: Adapter> SocketIo<A> {
///
/// ## Example with a dynamic namespace
/// ```
/// # use socketioxide::{SocketIo, extract::{SocketRef, NsParam}};
/// # use socketioxide::{SocketIo, extract::{SocketRef}};
/// let (_, io) = SocketIo::new_svc();
/// io.ns("/{id}/{user_id}", |socket: SocketRef, NsParam(params): NsParam<(String, String)>| {
/// io.ns("/{id}/{user_id}", |socket: SocketRef| {
/// println!("Socket connected on {} namespace with params {:?}", socket.ns(), params);
/// });
///
Expand Down
10 changes: 6 additions & 4 deletions socketioxide/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@
//!
//! Path parameters must be wrapped in curly braces `{}`:
//! ```
//! # use socketioxide::{SocketIo, extract::{NsParam, SocketRef}};
//! # use socketioxide::{SocketIo, extract::SocketRef};
//! #[derive(Debug, serde::Deserialize)]
//! struct Params {
//! id: String,
Expand All @@ -207,12 +207,12 @@
//! let (_svc, io) = SocketIo::new_svc();
//! io.ns("/{id}/user/{user_id}", |s: SocketRef, NsParam(params): NsParam<Params>| {
//! println!("new socket with params: {:?}", params);
//! }).unwrap();
//! });
//!
//! // You can specify any type that implements the `serde::Deserialize` trait.
//! io.ns("/{id}/admin/{role}", |s: SocketRef, NsParam(params): NsParam<(usize, String)>| {
//! println!("new socket with params: {:?}", params);
//! }).unwrap();
//! });
//! ```
//!
//! You can check the [`matchit`] crate for more details on the path parameters format.
Expand Down Expand Up @@ -332,7 +332,9 @@ pub mod service;
pub mod socket;

pub use engineioxide::TransportType;
pub use errors::{AckError, AdapterError, BroadcastError, DisconnectError, SendError, SocketError};
pub use errors::{
AckError, AdapterError, BroadcastError, DisconnectError, NsInsertError, SendError, SocketError,
};
pub use io::{SocketIo, SocketIoBuilder, SocketIoConfig};

mod client;
Expand Down
8 changes: 5 additions & 3 deletions socketioxide/src/ns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ use crate::{
use crate::{client::SocketData, errors::AdapterError};
use engineioxide::{sid::Sid, Str};

/// A [`Namespace`] constructor used for dynamic namespaces
/// A namespace constructor only hold a common handler that will be cloned
/// to the instantiated namespaces.
pub struct NamespaceCtr<A: Adapter> {
handler: BoxedConnectHandler<A>,
}
Expand All @@ -27,7 +30,7 @@ pub struct Namespace<A: Adapter> {
impl<A: Adapter> NamespaceCtr<A> {
pub fn new<C, T>(handler: C) -> Self
where
C: ConnectHandler<A, T> + Clone + Send + Sync + 'static,
C: ConnectHandler<A, T> + Send + Sync + 'static,
T: Send + Sync + 'static,
{
Self {
Expand All @@ -47,7 +50,7 @@ impl<A: Adapter> NamespaceCtr<A> {
impl<A: Adapter> Namespace<A> {
pub fn new<C, T>(path: Str, handler: C) -> Arc<Self>
where
C: ConnectHandler<A, T> + Clone + Send + Sync + 'static,
C: ConnectHandler<A, T> + Send + Sync + 'static,
T: Send + Sync + 'static,
{
Arc::new_cyclic(|ns| Self {
Expand All @@ -70,7 +73,6 @@ impl<A: Adapter> Namespace<A> {
esocket: Arc<engineioxide::Socket<SocketData<A>>>,
auth: Option<String>,
) -> Result<(), ConnectFail> {
// deep-clone to avoid packet memory leak
let socket: Arc<Socket<A>> = Socket::new(sid, self.clone(), esocket.clone()).into();

if let Err(e) = self.handler.call_middleware(socket.clone(), &auth).await {
Expand Down
2 changes: 1 addition & 1 deletion socketioxide/src/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ impl<A: Adapter> BroadcastOperators<A> {
},
}
}
pub(crate) fn from_sock(ns: Arc<Namespace<A>>, sid: Sid, ns_path: Str) -> Self {
pub(crate) fn from_sock(ns: Arc<Namespace<A>>, sid: Sid) -> Self {
Self {
binary: vec![],
timeout: None,
Expand Down
1 change: 1 addition & 0 deletions socketioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ impl<A: Adapter> Socket<A> {
#[cfg(test)]
mod test {
use super::*;
use engineioxide::Str;

#[tokio::test]
async fn send_with_ack_error() {
Expand Down
Loading

0 comments on commit 302e2a2

Please sign in to comment.