diff --git a/engineioxide/src/str.rs b/engineioxide/src/str.rs index 6c636a30..0105ce91 100644 --- a/engineioxide/src/str.rs +++ b/engineioxide/src/str.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use std::borrow::{Borrow, Cow}; use bytes::Bytes; @@ -28,7 +28,13 @@ impl Str { Str(Bytes::copy_from_slice(data.as_bytes())) } } - +/// This custom Hash implementation as a [`str`] is made to match with the [`Borrow`] +/// implementation as [`str`]. Otherwise [`str`] and [`Str`] won't have the same hash. +impl std::hash::Hash for Str { + fn hash(&self, state: &mut H) { + str::hash(self.as_str(), state); + } +} impl std::ops::Deref for Str { type Target = str; fn deref(&self) -> &Self::Target { @@ -40,6 +46,11 @@ impl std::fmt::Display for Str { write!(f, "{}", self.as_str()) } } +impl Borrow for Str { + fn borrow(&self) -> &str { + self.as_str() + } +} impl From<&'static str> for Str { fn from(s: &'static str) -> Self { Str(Bytes::from_static(s.as_bytes())) diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index 0cd7c475..8c8668c6 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -26,7 +26,7 @@ use crate::{ProtocolVersion, SocketIo}; pub struct Client { pub(crate) config: SocketIoConfig, - ns: RwLock, Arc>>>, + nsps: RwLock>>>, router: RwLock>>, #[cfg(feature = "state")] @@ -45,7 +45,7 @@ impl Client { Self { config, - ns: RwLock::new(HashMap::new()), + nsps: RwLock::new(HashMap::new()), router: RwLock::new(Router::new()), #[cfg(feature = "state")] state, @@ -75,9 +75,9 @@ impl Client { if let Some(ns) = self.get_ns(&ns_path) { tokio::spawn(connect(ns, esocket.clone())); } else if let Ok(Match { value: ns_ctr, .. }) = self.router.read().unwrap().at(&ns_path) { - let path: Cow<'static, str> = Cow::Owned(ns_path.clone().into()); - let ns = ns_ctr.get_new_ns(ns_path); //TODO: check memory leak here - self.ns.write().unwrap().insert(path, ns.clone()); + let path = Str::copy_from_slice(&ns_path); + let ns = ns_ctr.get_new_ns(path.clone()); + self.nsps.write().unwrap().insert(path, ns.clone()); tokio::spawn(connect(ns, esocket.clone())); } else if protocol == ProtocolVersion::V4 && ns_path == "/" { #[cfg(feature = "tracing")] @@ -130,8 +130,9 @@ impl Client { { #[cfg(feature = "tracing")] tracing::debug!("adding namespace {}", path); - let ns = Namespace::new(Str::from(&path), callback); - self.ns.write().unwrap().insert(path, ns); + let path = Str::from(path); + let ns = Namespace::new(path.clone(), callback); + self.nsps.write().unwrap().insert(path, ns); } pub fn add_dyn_ns(&self, path: String, callback: C) -> Result<(), matchit::InsertError> @@ -155,14 +156,14 @@ impl Client { #[cfg(feature = "tracing")] tracing::debug!("deleting namespace {}", path); - if let Some(ns) = self.ns.write().unwrap().remove(path) { + if let Some(ns) = self.nsps.write().unwrap().remove(path) { ns.close(DisconnectReason::ServerNSDisconnect) .now_or_never(); } } pub fn get_ns(&self, path: &str) -> Option>> { - self.ns.read().unwrap().get(path).cloned() + self.nsps.read().unwrap().get(path).cloned() } /// Closes all engine.io connections and all clients @@ -170,7 +171,7 @@ impl Client { pub(crate) async fn close(&self) { #[cfg(feature = "tracing")] tracing::debug!("closing all namespaces"); - let ns = { std::mem::take(&mut *self.ns.write().unwrap()) }; + let ns = { std::mem::take(&mut *self.nsps.write().unwrap()) }; futures_util::future::join_all( ns.values() .map(|ns| ns.close(DisconnectReason::ClosingServer)), @@ -232,7 +233,7 @@ impl EngineIoHandler for Client { #[cfg(feature = "tracing")] tracing::debug!("eio socket disconnected"); let socks: Vec<_> = self - .ns + .nsps .read() .unwrap() .values() @@ -324,7 +325,7 @@ impl EngineIoHandler for Client { impl std::fmt::Debug for Client { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut f = f.debug_struct("Client"); - f.field("config", &self.config).field("ns", &self.ns); + f.field("config", &self.config).field("nsps", &self.nsps); #[cfg(feature = "state")] let f = f.field("state", &self.state); f.finish() @@ -425,6 +426,14 @@ mod test { Arc::new(client) } + #[tokio::test] + async fn get_ns() { + let client = create_client(); + let ns = Namespace::new(Str::from("/"), || {}); + client.nsps.write().unwrap().insert(Str::from("/"), ns); + assert!(matches!(client.get_ns("/"), Some(_))); + } + #[tokio::test] async fn io_should_always_be_set() { let client = create_client();