Skip to content

Commit

Permalink
feat: ns key in map as Str (#342)
Browse files Browse the repository at this point in the history
  • Loading branch information
Totodore committed Jun 26, 2024
1 parent 04f7ab3 commit 8aa3f51
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
15 changes: 13 additions & 2 deletions engineioxide/src/str.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::borrow::Cow;
use std::borrow::{Borrow, Cow};

use bytes::Bytes;

Expand Down Expand Up @@ -28,7 +28,13 @@ impl Str {
Str(Bytes::copy_from_slice(data.as_bytes()))
}
}

/// This custom Hash implementation as a [`str`] is made to match with the [`Borrow`]
/// implementation as [`str`]. Otherwise [`str`] and [`Str`] won't have the same hash.
impl std::hash::Hash for Str {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
str::hash(self.as_str(), state);
}
}
impl std::ops::Deref for Str {
type Target = str;
fn deref(&self) -> &Self::Target {
Expand All @@ -40,6 +46,11 @@ impl std::fmt::Display for Str {
write!(f, "{}", self.as_str())
}
}
impl Borrow<str> for Str {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl From<&'static str> for Str {
fn from(s: &'static str) -> Self {
Str(Bytes::from_static(s.as_bytes()))
Expand Down
33 changes: 21 additions & 12 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::{ProtocolVersion, SocketIo};

pub struct Client<A: Adapter> {
pub(crate) config: SocketIoConfig,
ns: RwLock<HashMap<Cow<'static, str>, Arc<Namespace<A>>>>,
nsps: RwLock<HashMap<Str, Arc<Namespace<A>>>>,
router: RwLock<Router<NamespaceCtr<A>>>,

#[cfg(feature = "state")]
Expand All @@ -45,7 +45,7 @@ impl<A: Adapter> Client<A> {

Self {
config,
ns: RwLock::new(HashMap::new()),
nsps: RwLock::new(HashMap::new()),
router: RwLock::new(Router::new()),
#[cfg(feature = "state")]
state,
Expand Down Expand Up @@ -75,9 +75,9 @@ impl<A: Adapter> Client<A> {
if let Some(ns) = self.get_ns(&ns_path) {
tokio::spawn(connect(ns, esocket.clone()));
} else if let Ok(Match { value: ns_ctr, .. }) = self.router.read().unwrap().at(&ns_path) {
let path: Cow<'static, str> = Cow::Owned(ns_path.clone().into());
let ns = ns_ctr.get_new_ns(ns_path); //TODO: check memory leak here
self.ns.write().unwrap().insert(path, ns.clone());
let path = Str::copy_from_slice(&ns_path);
let ns = ns_ctr.get_new_ns(path.clone());
self.nsps.write().unwrap().insert(path, ns.clone());
tokio::spawn(connect(ns, esocket.clone()));
} else if protocol == ProtocolVersion::V4 && ns_path == "/" {
#[cfg(feature = "tracing")]
Expand Down Expand Up @@ -130,8 +130,9 @@ impl<A: Adapter> Client<A> {
{
#[cfg(feature = "tracing")]
tracing::debug!("adding namespace {}", path);
let ns = Namespace::new(Str::from(&path), callback);
self.ns.write().unwrap().insert(path, ns);
let path = Str::from(path);
let ns = Namespace::new(path.clone(), callback);
self.nsps.write().unwrap().insert(path, ns);
}

pub fn add_dyn_ns<C, T>(&self, path: String, callback: C) -> Result<(), matchit::InsertError>
Expand All @@ -155,22 +156,22 @@ impl<A: Adapter> Client<A> {

#[cfg(feature = "tracing")]
tracing::debug!("deleting namespace {}", path);
if let Some(ns) = self.ns.write().unwrap().remove(path) {
if let Some(ns) = self.nsps.write().unwrap().remove(path) {
ns.close(DisconnectReason::ServerNSDisconnect)
.now_or_never();
}
}

pub fn get_ns(&self, path: &str) -> Option<Arc<Namespace<A>>> {
self.ns.read().unwrap().get(path).cloned()
self.nsps.read().unwrap().get(path).cloned()
}

/// Closes all engine.io connections and all clients
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
pub(crate) async fn close(&self) {
#[cfg(feature = "tracing")]
tracing::debug!("closing all namespaces");
let ns = { std::mem::take(&mut *self.ns.write().unwrap()) };
let ns = { std::mem::take(&mut *self.nsps.write().unwrap()) };
futures_util::future::join_all(
ns.values()
.map(|ns| ns.close(DisconnectReason::ClosingServer)),
Expand Down Expand Up @@ -232,7 +233,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
#[cfg(feature = "tracing")]
tracing::debug!("eio socket disconnected");
let socks: Vec<_> = self
.ns
.nsps
.read()
.unwrap()
.values()
Expand Down Expand Up @@ -324,7 +325,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
impl<A: Adapter> std::fmt::Debug for Client<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut f = f.debug_struct("Client");
f.field("config", &self.config).field("ns", &self.ns);
f.field("config", &self.config).field("nsps", &self.nsps);
#[cfg(feature = "state")]
let f = f.field("state", &self.state);
f.finish()
Expand Down Expand Up @@ -425,6 +426,14 @@ mod test {
Arc::new(client)
}

#[tokio::test]
async fn get_ns() {
let client = create_client();
let ns = Namespace::new(Str::from("/"), || {});
client.nsps.write().unwrap().insert(Str::from("/"), ns);
assert!(matches!(client.get_ns("/"), Some(_)));
}

#[tokio::test]
async fn io_should_always_be_set() {
let client = create_client();
Expand Down

0 comments on commit 8aa3f51

Please sign in to comment.