Skip to content

Commit

Permalink
feat: clonable handler
Browse files Browse the repository at this point in the history
  • Loading branch information
Totodore committed Jun 24, 2024
1 parent c7a94b5 commit 9df042f
Show file tree
Hide file tree
Showing 25 changed files with 138 additions and 316 deletions.
2 changes: 1 addition & 1 deletion e2e/engineioxide/engineioxide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl EngineIoHandler for MyHandler {
println!("socket disconnect {}: {:?}", socket.id, reason);
}

fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<Self::Data>>) {
fn on_message(&self, msg: Str, socket: Arc<Socket<Self::Data>>) {
println!("Ping pong message {:?}", msg);
socket.emit(msg).ok();
}
Expand Down
15 changes: 9 additions & 6 deletions engineioxide/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
You can still use engineioxide as a standalone crate to talk with an engine.io client.

### Supported Protocols
You can enable support for other engine.io protocol implementations through feature flags.
The latest protocol version (v4) is enabled by default.
You can enable support for other engine.io protocol implementations through feature flags.
The latest protocol version (v4) is enabled by default.

To add support for the `v3` protocol version, adjust your dependency configuration accordingly:

Expand All @@ -14,7 +14,7 @@ To add support for the `v3` protocol version, adjust your dependency configurati
engineioxide = { version = "0.3.0", features = ["v3"] }
```

## Feature flags :
## Feature flags :
* `v3`: Enable the engine.io v3 protocol
* `tracing`: Enable tracing logs with the `tracing` crate

Expand Down Expand Up @@ -42,15 +42,18 @@ struct SocketState {
impl EngineIoHandler for MyHandler {
type Data = SocketState;

fn on_connect(self: Arc<Self>, socket: Arc<Socket<SocketState>>) {
fn on_connect(self: Arc<Self>, socket: Arc<Socket<SocketState>>) {
let cnt = self.user_cnt.fetch_add(1, Ordering::Relaxed) + 1;
socket.emit(cnt.to_string()).ok();
}
fn on_disconnect(&self, socket: Arc<Socket<SocketState>>, reason: DisconnectReason) {
fn on_disconnect(&self,
socket: Arc<Socket<SocketState>>,
reason: DisconnectReason
) {
let cnt = self.user_cnt.fetch_sub(1, Ordering::Relaxed) - 1;
socket.emit(cnt.to_string()).ok();
}
fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<SocketState>>) {
fn on_message(&self, msg: Str, socket: Arc<Socket<SocketState>>) {
*socket.data.id.lock().unwrap() = msg.into(); // bind a provided user id to a socket
}
fn on_binary(&self, data: Bytes, socket: Arc<Socket<SocketState>>) { }
Expand Down
4 changes: 2 additions & 2 deletions engineioxide/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
//! type Data = ();
//! fn on_connect(self: Arc<Self>, socket: Arc<Socket<()>>) { }
//! fn on_disconnect(&self, socket: Arc<Socket<()>>, reason: DisconnectReason) { }
//! fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<()>>) { }
//! fn on_message(&self, msg: Str, socket: Arc<Socket<()>>) { }
//! fn on_binary(&self, data: Bytes, socket: Arc<Socket<()>>) { }
//! }
//!
Expand Down Expand Up @@ -150,7 +150,7 @@ impl EngineIoConfigBuilder {
/// println!("socket disconnect {}", socket.id);
/// }
///
/// fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<()>>) {
/// fn on_message(&self, msg: Str, socket: Arc<Socket<()>>) {
/// println!("Ping pong message {:?}", msg);
/// socket.emit(msg).unwrap();
/// }
Expand Down
2 changes: 1 addition & 1 deletion engineioxide/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ mod tests {
println!("socket disconnect {} {:?}", socket.id, reason);
}

fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<Self::Data>>) {
fn on_message(&self, msg: Str, socket: Arc<Socket<Self::Data>>) {
println!("Ping pong message {:?}", msg);
socket.emit(msg).ok();
}
Expand Down
4 changes: 2 additions & 2 deletions engineioxide/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
//! let cnt = self.user_cnt.fetch_sub(1, Ordering::Relaxed) - 1;
//! socket.emit(cnt.to_string()).ok();
//! }
//! fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<SocketState>>) {
//! fn on_message(&self, msg: Str, socket: Arc<Socket<SocketState>>) {
//! *socket.data.id.lock().unwrap() = msg.into(); // bind a provided user id to a socket
//! }
//! fn on_binary(&self, data: Bytes, socket: Arc<Socket<SocketState>>) { }
Expand Down Expand Up @@ -60,7 +60,7 @@ pub trait EngineIoHandler: std::fmt::Debug + Send + Sync + 'static {
fn on_disconnect(&self, socket: Arc<Socket<Self::Data>>, reason: DisconnectReason);

/// Called when a message is received from the client.
fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<Self::Data>>);
fn on_message(&self, msg: Str, socket: Arc<Socket<Self::Data>>);

/// Called when a binary message is received from the client.
fn on_binary(&self, data: Bytes, socket: Arc<Socket<Self::Data>>);
Expand Down
2 changes: 1 addition & 1 deletion engineioxide/src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
//! type Data = ();
//! fn on_connect(self: Arc<Self>, socket: Arc<Socket<()>>) { }
//! fn on_disconnect(&self, socket: Arc<Socket<()>>, reason: DisconnectReason) { }
//! fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<()>>) { }
//! fn on_message(&self, msg: Str, socket: Arc<Socket<()>>) { }
//! fn on_binary(&self, data: Bytes, socket: Arc<Socket<()>>) { }
//! }
//! // Create a new engineio layer
Expand Down
2 changes: 1 addition & 1 deletion engineioxide/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
//! type Data = ();
//! fn on_connect(self: Arc<Self>, socket: Arc<Socket<()>>) { }
//! fn on_disconnect(&self, socket: Arc<Socket<()>>, reason: DisconnectReason) { }
//! fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<()>>) { }
//! fn on_message(&self, msg: Str, socket: Arc<Socket<()>>) { }
//! fn on_binary(&self, data: Bytes, socket: Arc<Socket<()>>) { }
//! }
//!
Expand Down
2 changes: 1 addition & 1 deletion engineioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
//! fn on_disconnect(&self, socket: Arc<Socket<SocketState>>, reason: DisconnectReason) {
//! let cnt = self.user_cnt.fetch_sub(1, Ordering::Relaxed) - 1;
//! }
//! fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<SocketState>>) {
//! fn on_message(&self, msg: Str, socket: Arc<Socket<SocketState>>) {
//! *socket.data.id.lock().unwrap() = msg.into(); // bind a provided user id to a socket
//! }
//! fn on_binary(&self, data: Bytes, socket: Arc<Socket<SocketState>>) { }
Expand Down
2 changes: 1 addition & 1 deletion engineioxide/tests/disconnect_reason.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl EngineIoHandler for MyHandler {
self.disconnect_tx.try_send(reason).unwrap();
}

fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<Socket<()>>) {
fn on_message(&self, msg: Str, socket: Arc<Socket<()>>) {
println!("Ping pong message {:?}", msg);
socket.emit(msg).ok();
}
Expand Down
5 changes: 0 additions & 5 deletions socketioxide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,3 @@ harness = false
name = "extensions"
path = "benches/extensions.rs"
harness = false

[[bench]]
name = "ns_routing"
path = "benches/ns_routing.rs"
harness = false
100 changes: 0 additions & 100 deletions socketioxide/benches/ns_routing.rs

This file was deleted.

52 changes: 24 additions & 28 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use engineioxide::Str;
use futures_util::{FutureExt, TryFutureExt};

use engineioxide::sid::Sid;
use matchit::Router;
use matchit::{Match, Router};
use tokio::sync::oneshot;

use crate::adapter::Adapter;
Expand Down Expand Up @@ -54,35 +54,31 @@ impl<A: Adapter> Client<A> {

/// Called when a socket connects to a new namespace
fn sock_connect(
self: Arc<Self>,
&self,
auth: Option<String>,
ns_path: Str,
esocket: Arc<engineioxide::Socket<SocketData<A>>>,
esocket: &Arc<engineioxide::Socket<SocketData<A>>>,
) {
#[cfg(feature = "tracing")]
tracing::debug!("auth: {:?}", auth);
let protocol: ProtocolVersion = esocket.protocol.into();
let esocket_clone = esocket.clone();
let connect = move |ns: Arc<Namespace<A>>| async move {
if ns
.connect(esocket_clone.id, esocket_clone.clone(), auth)
.await
.is_ok()
{
// cancel the connect timeout task for v5
if let Some(tx) = esocket_clone.data.connect_recv_tx.lock().unwrap().take() {
tx.send(()).ok();
let connect =
move |ns: Arc<Namespace<A>>, esocket: Arc<engineioxide::Socket<SocketData<A>>>| async move {
if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() {
// cancel the connect timeout task for v5
if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() {
tx.send(()).ok();
}
}
}
};
};

if let Some(ns) = self.get_ns(&ns_path) {
tokio::spawn(connect(ns));
} else if let Ok(res) = self.router.read().unwrap().at(&ns_path) {
tokio::spawn(connect(ns, esocket.clone()));
} else if let Ok(Match { value: ns_ctr, .. }) = self.router.read().unwrap().at(&ns_path) {
let path: Cow<'static, str> = Cow::Owned(ns_path.clone().into());
let ns = res.value.get_new_ns(ns_path); //TODO: check memory leak here
let ns = ns_ctr.get_new_ns(ns_path); //TODO: check memory leak here
self.ns.write().unwrap().insert(path, ns.clone());
tokio::spawn(connect(ns));
tokio::spawn(connect(ns, esocket.clone()));
} else if protocol == ProtocolVersion::V4 && ns_path == "/" {
#[cfg(feature = "tracing")]
tracing::error!(
Expand Down Expand Up @@ -129,7 +125,7 @@ impl<A: Adapter> Client<A> {
/// Adds a new namespace handler
pub fn add_ns<C, T>(&self, path: Cow<'static, str>, callback: C)
where
C: ConnectHandler<A, T> + Clone,
C: ConnectHandler<A, T>,
T: Send + Sync + 'static,
{
#[cfg(feature = "tracing")]
Expand All @@ -140,7 +136,7 @@ impl<A: Adapter> Client<A> {

pub fn add_dyn_ns<C, T>(&self, path: String, callback: C) -> Result<(), matchit::InsertError>
where
C: ConnectHandler<A, T> + Clone,
C: ConnectHandler<A, T>,
T: Send + Sync + 'static,
{
#[cfg(feature = "tracing")]
Expand Down Expand Up @@ -176,8 +172,8 @@ impl<A: Adapter> Client<A> {
tracing::debug!("closing all namespaces");
let ns = { std::mem::take(&mut *self.ns.write().unwrap()) };
futures_util::future::join_all(
ns.iter()
.map(|(_, ns)| ns.close(DisconnectReason::ClosingServer)),
ns.values()
.map(|ns| ns.close(DisconnectReason::ClosingServer)),
)
.await;
#[cfg(feature = "tracing")]
Expand Down Expand Up @@ -225,7 +221,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
ProtocolVersion::V4 => {
#[cfg(feature = "tracing")]
tracing::debug!("connecting to default namespace for v4");
self.sock_connect(None, Str::from("/"), socket);
self.sock_connect(None, Str::from("/"), &socket);
}
ProtocolVersion::V5 => self.spawn_connect_timeout_task(socket),
}
Expand All @@ -239,8 +235,8 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
.ns
.read()
.unwrap()
.iter()
.filter_map(|(_, ns)| ns.get_socket(socket.id).ok())
.values()
.filter_map(|ns| ns.get_socket(socket.id).ok())
.collect();

let _res: Result<Vec<_>, _> = socks
Expand All @@ -259,7 +255,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
}
}

fn on_message(self: &Arc<Self>, msg: Str, socket: Arc<EIoSocket<SocketData<A>>>) {
fn on_message(&self, msg: Str, socket: Arc<EIoSocket<SocketData<A>>>) {
#[cfg(feature = "tracing")]
tracing::debug!("Received message: {:?}", msg);
let packet = match Packet::try_from(msg) {
Expand All @@ -276,7 +272,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {

let res: Result<(), Error> = match packet.inner {
PacketData::Connect(auth) => {
self.clone().sock_connect(auth, packet.ns, socket.clone());
self.sock_connect(auth, packet.ns, &socket);
Ok(())
}
PacketData::BinaryEvent(_, _, _) | PacketData::BinaryAck(_, _) => {
Expand Down
2 changes: 2 additions & 0 deletions socketioxide/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use engineioxide::{sid::Sid, socket::DisconnectReason as EIoDisconnectReason};
use std::fmt::{Debug, Display};
use tokio::{sync::mpsc::error::TrySendError, time::error::Elapsed};

pub use matchit::InsertError as NsInsertError;

/// Error type for socketio
#[derive(thiserror::Error, Debug)]
pub enum Error {
Expand Down
3 changes: 1 addition & 2 deletions socketioxide/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
//! * [`HttpExtension`]: extracts an http extension of the given type coming from the request.
//! (Similar to axum's [`extract::Extension`](https://docs.rs/axum/latest/axum/struct.Extension.html)
//! * [`MaybeHttpExtension`]: extracts an http extension of the given type if it exists or [`None`] otherwise.
//! * [`NsParam`]: extracts and deserialize the namespace path parameters. Works only for the [`ConnectHandler`] and [`ConnectMiddleware`].
//!
//! ### You can also implement your own Extractor with the [`FromConnectParts`], [`FromMessageParts`] and [`FromDisconnectParts`] traits
//! When implementing these traits, if you clone the [`Arc<Socket>`](crate::socket::Socket) make sure that it is dropped at least when the socket is disconnected.
Expand Down Expand Up @@ -59,7 +58,7 @@
//!
//! impl<A: Adapter> FromConnectParts<A> for UserId {
//! type Error = Infallible;
//! fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<String>, _: &NsParamBuff<'_>) -> Result<Self, Self::Error> {
//! fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<String>) -> Result<Self, Self::Error> {
//! // In a real app it would be better to parse the query params with a crate like `url`
//! let uri = &s.req_parts().uri;
//! let uid = uri
Expand Down
Loading

0 comments on commit 9df042f

Please sign in to comment.