diff --git a/zenoh/src/api/querier.rs b/zenoh/src/api/querier.rs index 6f45124669..12bde1935c 100644 --- a/zenoh/src/api/querier.rs +++ b/zenoh/src/api/querier.rs @@ -219,3 +219,10 @@ impl Drop for Querier<'_> { } } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum QueriersKind { + Querier, + #[allow(dead_code)] + LivelinessQuerier, +} diff --git a/zenoh/src/api/session.rs b/zenoh/src/api/session.rs index 7819eb8e9d..3ff788eb1b 100644 --- a/zenoh/src/api/session.rs +++ b/zenoh/src/api/session.rs @@ -67,6 +67,7 @@ use zenoh_result::ZResult; use zenoh_shm::api::client_storage::ShmClientStorage; use zenoh_task::TaskController; +use super::querier::QueriersKind; #[cfg(feature = "unstable")] use crate::api::selector::ZenohParameters; #[cfg(feature = "unstable")] @@ -133,7 +134,6 @@ pub(crate) struct SessionState { pub(crate) remote_subscribers: HashMap>, pub(crate) publishers: HashMap, pub(crate) queriers: HashMap, - #[cfg(feature = "unstable")] pub(crate) liveliness_queriers: HashMap, #[cfg(feature = "unstable")] pub(crate) remote_tokens: HashMap>, @@ -171,7 +171,6 @@ impl SessionState { remote_subscribers: HashMap::new(), publishers: HashMap::new(), queriers: HashMap::new(), - #[cfg(feature = "unstable")] liveliness_queriers: HashMap::new(), #[cfg(feature = "unstable")] remote_tokens: HashMap::new(), @@ -311,6 +310,69 @@ impl SessionState { SubscriberKind::LivelinessSubscriber => &mut self.liveliness_subscribers, } } + + pub(crate) fn queriers(&self, kind: QueriersKind) -> &HashMap { + match kind { + QueriersKind::Querier => &self.queriers, + QueriersKind::LivelinessQuerier => &self.liveliness_queriers, + } + } + + pub(crate) fn queriers_mut(&mut self, kind: QueriersKind) -> &mut HashMap { + match kind { + QueriersKind::Querier => &mut self.queriers, + QueriersKind::LivelinessQuerier => &mut self.liveliness_queriers, + } + } + + fn register_querier<'a>( + &mut self, + id: EntityId, + key_expr: &'a KeyExpr, + destination: Locality, + kind: QueriersKind, + ) -> Option> { + let mut querier_state = QuerierState { + id, + remote_id: id, + key_expr: key_expr.clone().into_owned(), + destination, + }; + let aggregated_queriers: &[OwnedKeyExpr] = match kind { + QueriersKind::Querier => self.aggregated_queriers.as_slice(), + QueriersKind::LivelinessQuerier => &[] as &[OwnedKeyExpr; 0], + }; + + let declared_querier = (destination != Locality::SessionLocal) + .then( + || match aggregated_queriers.iter().find(|s| s.includes(key_expr)) { + Some(join_querier) => { + if let Some(joined_querier) = self.queriers(kind).values().find(|q| { + q.destination != Locality::SessionLocal + && join_querier.includes(&q.key_expr) + }) { + querier_state.remote_id = joined_querier.remote_id; + None + } else { + Some(join_querier.clone().into()) + } + } + None => { + if let Some(twin_querier) = self.queriers(kind).values().find(|p| { + p.destination != Locality::SessionLocal && &p.key_expr == key_expr + }) { + querier_state.remote_id = twin_querier.remote_id; + None + } else { + Some(key_expr.clone()) + } + } + }, + ) + .flatten(); + self.queriers_mut(kind).insert(id, querier_state); + declared_querier + } } impl fmt::Debug for SessionState { @@ -1325,63 +1387,29 @@ impl SessionInner { } } - pub(crate) fn declare_querier_inner( + fn _declare_querier_inner( &self, - key_expr: KeyExpr, + key_expr: &KeyExpr, destination: Locality, + kind: QueriersKind, ) -> ZResult { let mut state = zwrite!(self.state); - tracing::trace!("declare_querier({:?})", key_expr); let id = self.runtime.next_id(); - - let mut querier_state = QuerierState { - id, - remote_id: id, - key_expr: key_expr.clone().into_owned(), - destination, - }; - - let declared_querier = (destination != Locality::SessionLocal) - .then(|| { - match state - .aggregated_queriers - .iter() - .find(|s| s.includes(&key_expr)) - { - Some(join_querier) => { - if let Some(joined_querier) = state.queriers.values().find(|q| { - q.destination != Locality::SessionLocal - && join_querier.includes(&q.key_expr) - }) { - querier_state.remote_id = joined_querier.remote_id; - None - } else { - Some(join_querier.clone().into()) - } - } - None => { - if let Some(twin_querier) = state.queriers.values().find(|p| { - p.destination != Locality::SessionLocal && p.key_expr == key_expr - }) { - querier_state.remote_id = twin_querier.remote_id; - None - } else { - Some(key_expr.clone()) - } - } - } - }) - .flatten(); - - state.queriers.insert(id, querier_state); - + let declared_querier = + state.register_querier(id, key_expr, destination, QueriersKind::Querier); if let Some(res) = declared_querier { let primitives = state.primitives()?; drop(state); + let interest_options = match kind { + QueriersKind::Querier => InterestOptions::KEYEXPRS + InterestOptions::QUERYABLES, + QueriersKind::LivelinessQuerier => { + InterestOptions::KEYEXPRS + InterestOptions::TOKENS + } + }; primitives.send_interest(Interest { id, mode: InterestMode::CurrentFuture, - options: InterestOptions::KEYEXPRS + InterestOptions::QUERYABLES, + options: interest_options, wire_expr: Some(res.to_wire(self).to_owned()), ext_qos: network::ext::QoSType::DEFAULT, ext_tstamp: None, @@ -1391,17 +1419,27 @@ impl SessionInner { Ok(id) } - pub(crate) fn undeclare_querier_inner(&self, pid: Id) -> ZResult<()> { + pub(crate) fn declare_querier_inner( + &self, + key_expr: KeyExpr, + destination: Locality, + ) -> ZResult { + tracing::trace!("declare_querier({:?})", key_expr); + self._declare_querier_inner(&key_expr, destination, QueriersKind::Querier) + } + + fn _undeclare_querier_inner(&self, pid: Id, kind: QueriersKind) -> ZResult<()> { let mut state = zwrite!(self.state); let Ok(primitives) = state.primitives() else { return Ok(()); }; - if let Some(querier_state) = state.queriers.remove(&pid) { + let queriers = state.queriers_mut(kind); + if let Some(querier_state) = queriers.remove(&pid) { trace!("undeclare_querier({:?})", querier_state); if querier_state.destination != Locality::SessionLocal { // Note: there might be several queriers on the same KeyExpr. // Before calling forget_queriers(key_expr), check if this was the last one. - if !state.queriers.values().any(|p| { + if !queriers.values().any(|p| { p.destination != Locality::SessionLocal && p.remote_id == querier_state.remote_id }) { @@ -1423,6 +1461,10 @@ impl SessionInner { } } + pub(crate) fn undeclare_querier_inner(&self, pid: Id) -> ZResult<()> { + self._undeclare_querier_inner(pid, QueriersKind::Querier) + } + pub(crate) fn declare_subscriber_inner( self: &Arc, key_expr: &KeyExpr, @@ -1832,73 +1874,17 @@ impl SessionInner { #[cfg(feature = "unstable")] pub(crate) fn declare_liveliness_querier_inner(&self, key_expr: &KeyExpr) -> ZResult { - let mut state = zwrite!(self.state); trace!("declare_liveliness_querier({:?})", key_expr); - let id = self.runtime.next_id(); - - let mut querier_state = QuerierState { - id, - remote_id: id, - key_expr: key_expr.clone().into_owned(), - destination: Locality::default(), - }; - - let primitives = state.primitives()?; - let declared_querier = - if let Some(twin_querier) = state.queriers.values().find(|p| &p.key_expr == key_expr) { - querier_state.remote_id = twin_querier.remote_id; - None - } else { - Some(key_expr.clone()) - }; - state.liveliness_queriers.insert(id, querier_state); - drop(state); - - if let Some(res) = declared_querier { - primitives.send_interest(Interest { - id, - mode: InterestMode::CurrentFuture, - options: InterestOptions::KEYEXPRS + InterestOptions::TOKENS, - wire_expr: Some(res.to_wire(self).to_owned()), - ext_qos: declare::ext::QoSType::DECLARE, - ext_tstamp: None, - ext_nodeid: declare::ext::NodeIdType::DEFAULT, - }); - } - - Ok(id) + self._declare_querier_inner( + key_expr, + Locality::default(), + QueriersKind::LivelinessQuerier, + ) } #[cfg(feature = "unstable")] pub(crate) fn undeclare_liveliness_querier_inner(&self, pid: Id) -> ZResult<()> { - let mut state = zwrite!(self.state); - let Ok(primitives) = state.primitives() else { - return Ok(()); - }; - if let Some(querier_state) = state.liveliness_queriers.remove(&pid) { - trace!("undeclare_liveliness_querier({:?})", querier_state); - // Note: there might be several queriers on the same KeyExpr. - // Before calling forget_queriers(key_expr), check if this was the last one. - if !state - .liveliness_queriers - .values() - .any(|p| p.remote_id == querier_state.remote_id) - { - drop(state); - primitives.send_interest(Interest { - id: querier_state.remote_id, - mode: InterestMode::Final, - options: InterestOptions::empty(), - wire_expr: None, - ext_qos: declare::ext::QoSType::DEFAULT, - ext_tstamp: None, - ext_nodeid: declare::ext::NodeIdType::DEFAULT, - }); - } - Ok(()) - } else { - Err(zerror!("Unable to find liveliness querier").into()) - } + self._undeclare_querier_inner(pid, QueriersKind::LivelinessQuerier) } #[zenoh_macros::unstable]