diff --git a/tentacle/src/service.rs b/tentacle/src/service.rs index 43587a68..20f5b3d6 100644 --- a/tentacle/src/service.rs +++ b/tentacle/src/service.rs @@ -215,9 +215,12 @@ where let _ignore = inner .handle_sender - .send(ServiceEventAndError::Event(ServiceEvent::ListenStarted { - address: listen_address.clone(), - })) + .send( + ServiceEvent::ListenStarted { + address: listen_address.clone(), + } + .into(), + ) .await; #[cfg(feature = "upnp")] if let Some(client) = inner.igd_client.as_mut() { @@ -270,8 +273,16 @@ where while let Some(s) = self.recv.next().await { match s { - ServiceEventAndError::Event(e) => { - self.handle.handle_event(&mut self.service_context, e).await + ServiceEventAndError::Event { + event, + wait_response, + } => { + self.handle + .handle_event(&mut self.service_context, event) + .await; + if let Some(tx) = wait_response { + let _ignore = tx.send(()); + } } ServiceEventAndError::Error(e) => { self.handle.handle_error(&mut self.service_context, e).await @@ -705,6 +716,18 @@ where self.future_task_sender.clone(), ); + // session open event must be notified first, and then the protocol is enabled + let (tx, rx) = futures::channel::oneshot::channel(); + let _ignore = self + .handle_sender + .send( + Into::::into(ServiceEvent::SessionOpen { session_context }) + .wait_response(tx), + ) + .await; + // Don't care about it's drop or response + let _ignore = rx.await; + if ty.is_outbound() { match target { TargetProtocol::All => { @@ -726,13 +749,6 @@ where } crate::runtime::spawn(session.for_each(|_| future::ready(()))); - - let _ignore = self - .handle_sender - .send(ServiceEventAndError::Event(ServiceEvent::SessionOpen { - session_context, - })) - .await; } /// Close the specified session, clean up the handle @@ -757,9 +773,12 @@ where // Service handle processing flow let _ignore = self .handle_sender - .send(ServiceEventAndError::Event(ServiceEvent::SessionClose { - session_context: session_control.inner, - })) + .send( + ServiceEvent::SessionClose { + session_context: session_control.inner, + } + .into(), + ) .await; } } @@ -857,12 +876,13 @@ where { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error( + .send( ServiceError::ProtocolHandleError { proto_id: *proto_id, error: ProtocolHandleErrorKind::AbnormallyClosed(None), - }, - )) + } + .into(), + ) .await; error = true; } @@ -879,12 +899,13 @@ where error = true; let _ignore = self .handle_sender - .send(ServiceEventAndError::Error( + .send( ServiceError::ProtocolHandleError { proto_id: *proto_id, error: ProtocolHandleErrorKind::AbnormallyClosed(Some(*session_id)), - }, - )) + } + .into(), + ) .await; } } @@ -921,10 +942,13 @@ where self.dial_protocols.remove(&address); let _ignore = self .handle_sender - .send(ServiceEventAndError::Error(ServiceError::DialerError { - address, - error: DialerErrorKind::HandshakeError(error), - })) + .send( + ServiceError::DialerError { + address, + error: DialerErrorKind::HandshakeError(error), + } + .into(), + ) .await; } } @@ -935,12 +959,13 @@ where if let Some(session_control) = self.sessions.get(&id) { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error( + .send( ServiceError::ProtocolSelectError { proto_name, session_context: Arc::clone(&session_control.inner), - }, - )) + } + .into(), + ) .await; } } @@ -951,11 +976,14 @@ where } => { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error(ServiceError::ProtocolError { - id, - proto_id, - error, - })) + .send( + ServiceError::ProtocolError { + id, + proto_id, + error, + } + .into(), + ) .await; } SessionEvent::DialError { address, error } => { @@ -963,20 +991,26 @@ where self.dial_protocols.remove(&address); let _ignore = self .handle_sender - .send(ServiceEventAndError::Error(ServiceError::DialerError { - address, - error: DialerErrorKind::TransportError(error), - })) + .send( + ServiceError::DialerError { + address, + error: DialerErrorKind::TransportError(error), + } + .into(), + ) .await; } #[cfg(not(target_family = "wasm"))] SessionEvent::ListenError { address, error } => { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error(ServiceError::ListenError { - address: address.clone(), - error: ListenErrorKind::TransportError(error), - })) + .send( + ServiceError::ListenError { + address: address.clone(), + error: ListenErrorKind::TransportError(error), + } + .into(), + ) .await; if self.listens.remove(&address) { #[cfg(feature = "upnp")] @@ -986,9 +1020,7 @@ where let _ignore = self .handle_sender - .send(ServiceEventAndError::Event(ServiceEvent::ListenClose { - address, - })) + .send(ServiceEvent::ListenClose { address }.into()) .await; } else { // try start listen error @@ -999,9 +1031,12 @@ where if let Some(session_control) = self.sessions.get(&id) { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error(ServiceError::SessionTimeout { - session_context: Arc::clone(&session_control.inner), - })) + .send( + ServiceError::SessionTimeout { + session_context: Arc::clone(&session_control.inner), + } + .into(), + ) .await; } } @@ -1009,10 +1044,13 @@ where if let Some(session_control) = self.sessions.get(&id) { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error(ServiceError::MuxerError { - session_context: Arc::clone(&session_control.inner), - error, - })) + .send( + ServiceError::MuxerError { + session_context: Arc::clone(&session_control.inner), + error, + } + .into(), + ) .await; } } @@ -1023,9 +1061,12 @@ where } => { let _ignore = self .handle_sender - .send(ServiceEventAndError::Event(ServiceEvent::ListenStarted { - address: listen_address.clone(), - })) + .send( + ServiceEvent::ListenStarted { + address: listen_address.clone(), + } + .into(), + ) .await; self.listens.insert(listen_address.clone()); self.state.decrease(); @@ -1039,9 +1080,7 @@ where SessionEvent::ProtocolHandleError { error, proto_id } => { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error( - ServiceError::ProtocolHandleError { error, proto_id }, - )) + .send(ServiceError::ProtocolHandleError { error, proto_id }.into()) .await; // if handle panic, close service self.handle_service_task(ServiceTask::Shutdown(false), Priority::High) @@ -1051,9 +1090,12 @@ where if let Some(session) = self.sessions.get(&id) { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error(ServiceError::SessionBlocked { - session_context: session.inner.clone(), - })) + .send( + ServiceError::SessionBlocked { + session_context: session.inner.clone(), + } + .into(), + ) .await; } } @@ -1077,10 +1119,13 @@ where if let Err(e) = self.dial_inner(address.clone(), target) { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error(ServiceError::DialerError { - address, - error: DialerErrorKind::TransportError(e), - })) + .send( + ServiceError::DialerError { + address, + error: DialerErrorKind::TransportError(e), + } + .into(), + ) .await; } } @@ -1090,10 +1135,13 @@ where if let Err(e) = self.listen_inner(address.clone()) { let _ignore = self .handle_sender - .send(ServiceEventAndError::Error(ServiceError::ListenError { - address, - error: ListenErrorKind::TransportError(e), - })) + .send( + ServiceError::ListenError { + address, + error: ListenErrorKind::TransportError(e), + } + .into(), + ) .await; } } @@ -1176,9 +1224,7 @@ where let mut events = futures::stream::iter( self.listens .drain() - .map(|address| { - ServiceEventAndError::Event(ServiceEvent::ListenClose { address }) - }) + .map(|address| ServiceEvent::ListenClose { address }.into()) .collect::>(), ) .map(Ok); diff --git a/tentacle/src/service/event.rs b/tentacle/src/service/event.rs index ab40381d..8ff6b6f8 100644 --- a/tentacle/src/service/event.rs +++ b/tentacle/src/service/event.rs @@ -13,13 +13,44 @@ use bytes::Bytes; #[derive(Debug)] pub(crate) enum ServiceEventAndError { - Event(ServiceEvent), + Event { + event: ServiceEvent, + wait_response: Option>, + }, Error(ServiceError), Update { listen_addrs: Vec, }, } +impl ServiceEventAndError { + pub fn wait_response(self, tx: futures::channel::oneshot::Sender<()>) -> Self { + if let ServiceEventAndError::Event { event, .. } = self { + ServiceEventAndError::Event { + event, + wait_response: Some(tx), + } + } else { + self + } + } +} + +impl From for ServiceEventAndError { + fn from(event: ServiceEvent) -> Self { + ServiceEventAndError::Event { + event, + wait_response: None, + } + } +} + +impl From for ServiceEventAndError { + fn from(event: ServiceError) -> Self { + ServiceEventAndError::Error(event) + } +} + /// Error generated by the Service #[derive(Debug)] pub enum ServiceError {