diff --git a/Cargo.lock b/Cargo.lock index 332764c36f3..245be67f331 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2725,6 +2725,7 @@ dependencies = [ "libp2p-identify", "libp2p-identity", "libp2p-noise", + "libp2p-protocol-utils", "libp2p-swarm", "libp2p-swarm-test", "libp2p-yamux", @@ -2956,6 +2957,10 @@ dependencies = [ "tracing", ] +[[package]] +name = "libp2p-protocol-utils" +version = "0.1.0" + [[package]] name = "libp2p-quic" version = "0.10.1" @@ -3000,6 +3005,7 @@ dependencies = [ "libp2p-identity", "libp2p-ping", "libp2p-plaintext", + "libp2p-protocol-utils", "libp2p-swarm", "libp2p-swarm-test", "libp2p-yamux", @@ -3060,6 +3066,7 @@ dependencies = [ "libp2p-core", "libp2p-identity", "libp2p-noise", + "libp2p-protocol-utils", "libp2p-swarm", "libp2p-swarm-test", "libp2p-tcp", diff --git a/Cargo.toml b/Cargo.toml index 35439a1a696..5bb6a6bb8c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ members = [ "misc/memory-connection-limits", "misc/metrics", "misc/multistream-select", + "misc/protocol-utils", "misc/quick-protobuf-codec", "misc/quickcheck-ext", "misc/rw-stream-sink", @@ -60,8 +61,8 @@ members = [ "transports/webrtc", "transports/webrtc-websys", "transports/websocket", - "transports/webtransport-websys", "transports/websocket-websys", + "transports/webtransport-websys", "wasm-tests/webtransport-tests", ] resolver = "2" @@ -94,6 +95,7 @@ libp2p-perf = { version = "0.3.0", path = "protocols/perf" } libp2p-ping = { version = "0.44.0", path = "protocols/ping" } libp2p-plaintext = { version = "0.41.0", path = "transports/plaintext" } libp2p-pnet = { version = "0.24.0", path = "transports/pnet" } +libp2p-protocol-utils = { version = "0.1.0", path = "misc/protocol-utils" } libp2p-quic = { version = "0.10.1", path = "transports/quic" } libp2p-relay = { version = "0.17.1", path = "protocols/relay" } libp2p-rendezvous = { version = "0.14.0", path = "protocols/rendezvous" } diff --git a/misc/protocol-utils/CHANGELOG.md b/misc/protocol-utils/CHANGELOG.md new file mode 100644 index 00000000000..ac65462846e --- /dev/null +++ b/misc/protocol-utils/CHANGELOG.md @@ -0,0 +1,4 @@ +## 0.1.0 - unreleased + +- Initial release, offering `InflightProtocolDataQueue`. + See [PR 4834](https://github.com/libp2p/rust-libp2p/pull/4834). diff --git a/misc/protocol-utils/Cargo.toml b/misc/protocol-utils/Cargo.toml new file mode 100644 index 00000000000..2c92dfa5557 --- /dev/null +++ b/misc/protocol-utils/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "libp2p-protocol-utils" +version = "0.1.0" +edition = "2021" +description = "Utilities for implementing protocols for libp2p, via `NetworkBehaviour` and `ConnectionHandler`." +rust-version.workspace = true +license = "MIT" +repository = "https://github.com/libp2p/rust-libp2p" +keywords = [] +categories = ["data-structures"] +publish = false # Temporary until we actually publish it. + +[lints] +workspace = true diff --git a/misc/protocol-utils/src/ipd_queue.rs b/misc/protocol-utils/src/ipd_queue.rs new file mode 100644 index 00000000000..97f6330efa4 --- /dev/null +++ b/misc/protocol-utils/src/ipd_queue.rs @@ -0,0 +1,61 @@ +use std::collections::VecDeque; + +/// Manages associated data of request-response protocols whilst they are in-flight. +/// +/// The [`InflightProtocolDataQueue`] ensures that for each in-flight protocol, there is a corresponding piece of associated data. +/// We process the associated data in a FIFO order based on the incoming responses. +/// In other words, we assume that requests and their responses are either temporally ordered or it doesn't matter, which piece of data is paired with a particular response. +pub struct InflightProtocolDataQueue { + data_of_inflight_requests: VecDeque, + pending_requests: VecDeque, + received_responses: VecDeque, +} + +impl Default for InflightProtocolDataQueue { + fn default() -> Self { + Self { + pending_requests: Default::default(), + received_responses: Default::default(), + data_of_inflight_requests: Default::default(), + } + } +} + +impl InflightProtocolDataQueue { + /// Enqueues a new request along-side with the associated data. + /// + /// The request will be returned again from [`InflightProtocolDataQueue::next_request`]. + pub fn enqueue_request(&mut self, request: Req, data: D) { + self.pending_requests.push_back(request); + self.data_of_inflight_requests.push_back(data); + } + + /// Submits a response to the queue. + /// + /// A pair of response and data will be returned from [`InflightProtocolDataQueue::next_completed`]. + pub fn submit_response(&mut self, res: Res) { + debug_assert!( + self.data_of_inflight_requests.len() > self.received_responses.len(), + "Expect to not provide more responses than requests were started" + ); + self.received_responses.push_back(res); + } + + /// How many protocols are currently in-flight. + pub fn num_inflight(&self) -> usize { + self.data_of_inflight_requests.len() - self.received_responses.len() + } + + pub fn next_completed(&mut self) -> Option<(Res, D)> { + let res = self.received_responses.pop_front()?; + let data = self.data_of_inflight_requests.pop_front()?; + + Some((res, data)) + } + + pub fn next_request(&mut self) -> Option { + let req = self.pending_requests.pop_front()?; + + Some(req) + } +} diff --git a/misc/protocol-utils/src/lib.rs b/misc/protocol-utils/src/lib.rs new file mode 100644 index 00000000000..9d9021ae875 --- /dev/null +++ b/misc/protocol-utils/src/lib.rs @@ -0,0 +1,3 @@ +mod ipd_queue; + +pub use ipd_queue::InflightProtocolDataQueue; diff --git a/protocols/kad/Cargo.toml b/protocols/kad/Cargo.toml index 04101d51026..e889aacbcb4 100644 --- a/protocols/kad/Cargo.toml +++ b/protocols/kad/Cargo.toml @@ -19,6 +19,7 @@ asynchronous-codec = { workspace = true } futures = "0.3.29" libp2p-core = { workspace = true } libp2p-swarm = { workspace = true } +libp2p-protocol-utils = { workspace = true } quick-protobuf = "0.8" quick-protobuf-codec = { workspace = true } libp2p-identity = { workspace = true, features = ["rand"] } diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index adfb076541c..2ba5a543c5c 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -29,16 +29,15 @@ use futures::prelude::*; use futures::stream::SelectAll; use libp2p_core::{upgrade, ConnectedPoint}; use libp2p_identity::PeerId; -use libp2p_swarm::handler::{ - ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, -}; +use libp2p_protocol_utils::InflightProtocolDataQueue; +use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound}; use libp2p_swarm::{ ConnectionHandler, ConnectionHandlerEvent, Stream, StreamUpgradeError, SubstreamProtocol, SupportedProtocols, }; -use std::collections::VecDeque; use std::task::Waker; use std::{error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll}; +use void::Void; const MAX_NUM_SUBSTREAMS: usize = 32; @@ -62,12 +61,13 @@ pub struct Handler { /// List of active outbound substreams with the state they are in. outbound_substreams: SelectAll, - /// Number of outbound streams being upgraded right now. - num_requested_outbound_streams: usize, - /// List of outbound substreams that are waiting to become active next. /// Contains the request we want to send, and the user data if we expect an answer. - pending_messages: VecDeque<(KadRequestMsg, Option)>, + pending_streams: InflightProtocolDataQueue< + (KadRequestMsg, Option), + ProtocolConfig, + Result, StreamUpgradeError>, + >, /// List of active inbound substreams with the state they are in. inbound_substreams: SelectAll, @@ -293,7 +293,7 @@ pub enum HandlerEvent { #[derive(Debug)] pub enum HandlerQueryErr { /// Error while trying to perform the query. - Upgrade(StreamUpgradeError), + Upgrade(StreamUpgradeError), /// Received an answer that doesn't correspond to the request. UnexpectedMessage, /// I/O error in the substream. @@ -329,8 +329,8 @@ impl error::Error for HandlerQueryErr { } } -impl From> for HandlerQueryErr { - fn from(err: StreamUpgradeError) -> Self { +impl From> for HandlerQueryErr { + fn from(err: StreamUpgradeError) -> Self { HandlerQueryErr::Upgrade(err) } } @@ -481,40 +481,12 @@ impl Handler { next_connec_unique_id: UniqueConnecId(0), inbound_substreams: Default::default(), outbound_substreams: Default::default(), - num_requested_outbound_streams: 0, - pending_messages: Default::default(), + pending_streams: Default::default(), protocol_status: None, remote_supported_protocols: Default::default(), } } - fn on_fully_negotiated_outbound( - &mut self, - FullyNegotiatedOutbound { protocol, info: () }: FullyNegotiatedOutbound< - ::OutboundProtocol, - ::OutboundOpenInfo, - >, - ) { - if let Some((msg, query_id)) = self.pending_messages.pop_front() { - self.outbound_substreams - .push(OutboundSubstreamState::PendingSend(protocol, msg, query_id)); - } else { - debug_assert!(false, "Requested outbound stream without message") - } - - self.num_requested_outbound_streams -= 1; - - if self.protocol_status.is_none() { - // Upon the first successfully negotiated substream, we know that the - // remote is configured with the same protocol name and we want - // the behaviour to add this peer to the routing table, if possible. - self.protocol_status = Some(ProtocolStatus { - supported: true, - reported: false, - }); - } - } - fn on_fully_negotiated_inbound( &mut self, FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound< @@ -572,26 +544,6 @@ impl Handler { substream: protocol, }); } - - fn on_dial_upgrade_error( - &mut self, - DialUpgradeError { - info: (), error, .. - }: DialUpgradeError< - ::OutboundOpenInfo, - ::OutboundProtocol, - >, - ) { - // TODO: cache the fact that the remote doesn't support kademlia at all, so that we don't - // continue trying - - if let Some((_, Some(query_id))) = self.pending_messages.pop_front() { - self.outbound_substreams - .push(OutboundSubstreamState::ReportError(error.into(), query_id)); - } - - self.num_requested_outbound_streams -= 1; - } } impl ConnectionHandler for Handler { @@ -626,16 +578,20 @@ impl ConnectionHandler for Handler { } } HandlerIn::FindNodeReq { key, query_id } => { - let msg = KadRequestMsg::FindNode { key }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::FindNode { key }, Some(query_id)), + ); } HandlerIn::FindNodeRes { closer_peers, request_id, } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }), HandlerIn::GetProvidersReq { key, query_id } => { - let msg = KadRequestMsg::GetProviders { key }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::GetProviders { key }, Some(query_id)), + ); } HandlerIn::GetProvidersRes { closer_peers, @@ -649,16 +605,22 @@ impl ConnectionHandler for Handler { }, ), HandlerIn::AddProvider { key, provider } => { - let msg = KadRequestMsg::AddProvider { key, provider }; - self.pending_messages.push_back((msg, None)); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::AddProvider { key, provider }, None), + ); } HandlerIn::GetRecord { key, query_id } => { - let msg = KadRequestMsg::GetValue { key }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::GetValue { key }, Some(query_id)), + ); } HandlerIn::PutRecord { record, query_id } => { - let msg = KadRequestMsg::PutValue { record }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::PutValue { record }, Some(query_id)), + ); } HandlerIn::GetRecordRes { record, @@ -712,44 +674,67 @@ impl ConnectionHandler for Handler { ) -> Poll< ConnectionHandlerEvent, > { - match &mut self.protocol_status { - Some(status) if !status.reported => { - status.reported = true; - let event = if status.supported { - HandlerEvent::ProtocolConfirmed { - endpoint: self.endpoint.clone(), - } - } else { - HandlerEvent::ProtocolNotSupported { - endpoint: self.endpoint.clone(), - } - }; + loop { + match &mut self.protocol_status { + Some(status) if !status.reported => { + status.reported = true; + let event = if status.supported { + HandlerEvent::ProtocolConfirmed { + endpoint: self.endpoint.clone(), + } + } else { + HandlerEvent::ProtocolNotSupported { + endpoint: self.endpoint.clone(), + } + }; - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } + _ => {} } - _ => {} - } - if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) { - return Poll::Ready(event); - } + if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) { + return Poll::Ready(event); + } - if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) { - return Poll::Ready(event); - } + if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) { + return Poll::Ready(event); + } - let num_in_progress_outbound_substreams = - self.outbound_substreams.len() + self.num_requested_outbound_streams; - if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS - && self.num_requested_outbound_streams < self.pending_messages.len() - { - self.num_requested_outbound_streams += 1; - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()), - }); - } + match self.pending_streams.next_completed() { + Some((Ok(stream), (message, query_id))) => { + self.outbound_substreams + .push(OutboundSubstreamState::PendingSend( + stream, message, query_id, + )); + continue; + } + // TODO: Check if the remote doesn't support kademlia and stop trying if it doesn't + Some((Err(error), (_, Some(query_id)))) => { + self.outbound_substreams + .push(OutboundSubstreamState::ReportError(error.into(), query_id)); + continue; + } + Some((Err(error), (message, None))) => { + tracing::debug!(?message, "Failed to establish stream: {error}"); + continue; + } + None => {} + } + + let num_in_progress_outbound_substreams = + self.outbound_substreams.len() + self.pending_streams.num_inflight(); - Poll::Pending + if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS { + if let Some(next) = self.pending_streams.next_request() { + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(next, ()), + }); + } + } + + return Poll::Pending; + } } fn on_connection_event( @@ -762,14 +747,24 @@ impl ConnectionHandler for Handler { >, ) { match event { - ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => { - self.on_fully_negotiated_outbound(fully_negotiated_outbound) + ConnectionEvent::FullyNegotiatedOutbound(ev) => { + self.pending_streams.submit_response(Ok(ev.protocol)); + + if self.protocol_status.is_none() { + // Upon the first successfully negotiated substream, we know that the + // remote is configured with the same protocol name and we want + // the behaviour to add this peer to the routing table, if possible. + self.protocol_status = Some(ProtocolStatus { + supported: true, + reported: false, + }); + } } ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => { self.on_fully_negotiated_inbound(fully_negotiated_inbound) } - ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { - self.on_dial_upgrade_error(dial_upgrade_error) + ConnectionEvent::DialUpgradeError(ev) => { + self.pending_streams.submit_response(Err(ev.error)); } ConnectionEvent::RemoteProtocolsChange(change) => { let dirty = self.remote_supported_protocols.on_protocols_change(change); diff --git a/protocols/kad/src/protocol.rs b/protocols/kad/src/protocol.rs index 7fe2d1130b1..0c990ba19f0 100644 --- a/protocols/kad/src/protocol.rs +++ b/protocols/kad/src/protocol.rs @@ -40,6 +40,7 @@ use std::marker::PhantomData; use std::{convert::TryFrom, time::Duration}; use std::{io, iter}; use tracing::debug; +use void::Void; /// The protocol name used for negotiating with multistream-select. pub(crate) const DEFAULT_PROTO_NAME: StreamProtocol = StreamProtocol::new("/ipfs/kad/1.0.0"); @@ -220,8 +221,8 @@ where C: AsyncRead + AsyncWrite + Unpin, { type Output = KadInStreamSink; - type Future = future::Ready>; - type Error = io::Error; + type Future = future::Ready>; + type Error = Void; fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future { let codec = Codec::new(self.max_packet_size); @@ -235,8 +236,8 @@ where C: AsyncRead + AsyncWrite + Unpin, { type Output = KadOutStreamSink; - type Future = future::Ready>; - type Error = io::Error; + type Future = future::Ready>; + type Error = Void; fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future { let codec = Codec::new(self.max_packet_size); diff --git a/protocols/relay/Cargo.toml b/protocols/relay/Cargo.toml index 2ef9ed41019..ad4a290d51f 100644 --- a/protocols/relay/Cargo.toml +++ b/protocols/relay/Cargo.toml @@ -28,6 +28,7 @@ static_assertions = "1" thiserror = "1.0" tracing = "0.1.37" void = "1" +libp2p-protocol-utils = { workspace = true } [dev-dependencies] libp2p-identity = { workspace = true, features = ["rand"] } diff --git a/protocols/relay/src/priv_client/handler.rs b/protocols/relay/src/priv_client/handler.rs index 1925d6f6ab4..662d63cc742 100644 --- a/protocols/relay/src/priv_client/handler.rs +++ b/protocols/relay/src/priv_client/handler.rs @@ -18,9 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use crate::client::Connection; use crate::priv_client::transport; +use crate::priv_client::transport::ToListenerMsg; use crate::protocol::{self, inbound_stop, outbound_hop}; use crate::{priv_client, proto, HOP_PROTOCOL_NAME, STOP_PROTOCOL_NAME}; +use futures::channel::mpsc::Sender; use futures::channel::{mpsc, oneshot}; use futures::future::FutureExt; use futures_timer::Delay; @@ -28,17 +31,16 @@ use libp2p_core::multiaddr::Protocol; use libp2p_core::upgrade::ReadyUpgrade; use libp2p_core::Multiaddr; use libp2p_identity::PeerId; -use libp2p_swarm::handler::{ - ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, -}; +use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound}; use libp2p_swarm::{ - ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, StreamUpgradeError, + ConnectionHandler, ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol, }; use std::collections::VecDeque; use std::task::{Context, Poll}; use std::time::Duration; use std::{fmt, io}; +use void::Void; /// The maximum number of circuits being denied concurrently. /// @@ -104,8 +106,7 @@ pub struct Handler { >, >, - /// We issue a stream upgrade for each pending request. - pending_requests: VecDeque, + pending_streams: VecDeque>>>, inflight_reserve_requests: futures_bounded::FuturesTupleSet< Result, @@ -133,7 +134,7 @@ impl Handler { remote_peer_id, remote_addr, queued_events: Default::default(), - pending_requests: Default::default(), + pending_streams: Default::default(), inflight_reserve_requests: futures_bounded::FuturesTupleSet::new( STREAM_TIMEOUT, MAX_CONCURRENT_STREAMS_PER_CONNECTION, @@ -154,57 +155,6 @@ impl Handler { } } - fn on_dial_upgrade_error( - &mut self, - DialUpgradeError { error, .. }: DialUpgradeError< - ::OutboundOpenInfo, - ::OutboundProtocol, - >, - ) { - let pending_request = self - .pending_requests - .pop_front() - .expect("got a stream error without a pending request"); - - match pending_request { - PendingRequest::Reserve { mut to_listener } => { - let error = match error { - StreamUpgradeError::Timeout => { - outbound_hop::ReserveError::Io(io::ErrorKind::TimedOut.into()) - } - StreamUpgradeError::Apply(never) => void::unreachable(never), - StreamUpgradeError::NegotiationFailed => { - outbound_hop::ReserveError::Unsupported - } - StreamUpgradeError::Io(e) => outbound_hop::ReserveError::Io(e), - }; - - if let Err(e) = - to_listener.try_send(transport::ToListenerMsg::Reservation(Err(error))) - { - tracing::debug!("Unable to send error to listener: {}", e.into_send_error()) - } - self.reservation.failed(); - } - PendingRequest::Connect { - to_dial: send_back, .. - } => { - let error = match error { - StreamUpgradeError::Timeout => { - outbound_hop::ConnectError::Io(io::ErrorKind::TimedOut.into()) - } - StreamUpgradeError::NegotiationFailed => { - outbound_hop::ConnectError::Unsupported - } - StreamUpgradeError::Io(e) => outbound_hop::ConnectError::Io(e), - StreamUpgradeError::Apply(v) => void::unreachable(v), - }; - - let _ = send_back.send(Err(error)); - } - } - } - fn insert_to_deny_futs(&mut self, circuit: inbound_stop::Circuit) { let src_peer_id = circuit.src_peer_id(); @@ -219,6 +169,62 @@ impl Handler { ) } } + + fn make_new_reservation(&mut self, to_listener: Sender) { + let (sender, receiver) = oneshot::channel(); + + self.pending_streams.push_back(sender); + self.queued_events + .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), + }); + let result = self.inflight_reserve_requests.try_push( + async move { + let stream = receiver + .await + .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))? + .map_err(into_reserve_error)?; + + let reservation = outbound_hop::make_reservation(stream).await?; + + Ok(reservation) + }, + to_listener, + ); + + if result.is_err() { + tracing::warn!("Dropping in-flight reservation request because we are at capacity"); + } + } + + fn establish_new_circuit( + &mut self, + to_dial: oneshot::Sender>, + dst_peer_id: PeerId, + ) { + let (sender, receiver) = oneshot::channel(); + + self.pending_streams.push_back(sender); + self.queued_events + .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), + }); + let result = self.inflight_outbound_connect_requests.try_push( + async move { + let stream = receiver + .await + .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))? + .map_err(into_connect_error)?; + + outbound_hop::open_circuit(stream, dst_peer_id).await + }, + to_dial, + ); + + if result.is_err() { + tracing::warn!("Dropping in-flight connect request because we are at capacity") + } + } } impl ConnectionHandler for Handler { @@ -236,25 +242,13 @@ impl ConnectionHandler for Handler { fn on_behaviour_event(&mut self, event: Self::FromBehaviour) { match event { In::Reserve { to_listener } => { - self.pending_requests - .push_back(PendingRequest::Reserve { to_listener }); - self.queued_events - .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), - }); + self.make_new_reservation(to_listener); } In::EstablishCircuit { - to_dial: send_back, + to_dial, dst_peer_id, } => { - self.pending_requests.push_back(PendingRequest::Connect { - dst_peer_id, - to_dial: send_back, - }); - self.queued_events - .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), - }); + self.establish_new_circuit(to_dial, dst_peer_id); } } } @@ -402,12 +396,8 @@ impl ConnectionHandler for Handler { } if let Poll::Ready(Some(to_listener)) = self.reservation.poll(cx) { - self.pending_requests - .push_back(PendingRequest::Reserve { to_listener }); - - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), - }); + self.make_new_reservation(to_listener); + continue; } // Deny incoming circuit requests. @@ -450,42 +440,16 @@ impl ConnectionHandler for Handler { tracing::warn!("Dropping inbound stream because we are at capacity") } } - ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { - protocol: stream, - .. - }) => { - let pending_request = self.pending_requests.pop_front().expect( - "opened a stream without a pending connection command or a reserve listener", - ); - match pending_request { - PendingRequest::Reserve { to_listener } => { - if self - .inflight_reserve_requests - .try_push(outbound_hop::make_reservation(stream), to_listener) - .is_err() - { - tracing::warn!("Dropping outbound stream because we are at capacity"); - } - } - PendingRequest::Connect { - dst_peer_id, - to_dial: send_back, - } => { - if self - .inflight_outbound_connect_requests - .try_push(outbound_hop::open_circuit(stream, dst_peer_id), send_back) - .is_err() - { - tracing::warn!("Dropping outbound stream because we are at capacity"); - } - } + ConnectionEvent::FullyNegotiatedOutbound(ev) => { + if let Some(next) = self.pending_streams.pop_front() { + let _ = next.send(Ok(ev.protocol)); } } - ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { - void::unreachable(listen_upgrade_error.error) - } - ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { - self.on_dial_upgrade_error(dial_upgrade_error) + ConnectionEvent::ListenUpgradeError(ev) => void::unreachable(ev.error), + ConnectionEvent::DialUpgradeError(ev) => { + if let Some(next) = self.pending_streams.pop_front() { + let _ = next.send(Err(ev.error)); + } } _ => {} } @@ -614,14 +578,24 @@ impl Reservation { } } -pub(crate) enum PendingRequest { - Reserve { - /// A channel into the [`Transport`](priv_client::Transport). - to_listener: mpsc::Sender, - }, - Connect { - dst_peer_id: PeerId, - /// A channel into the future returned by [`Transport::dial`](libp2p_core::Transport::dial). - to_dial: oneshot::Sender>, - }, +fn into_reserve_error(e: StreamUpgradeError) -> outbound_hop::ReserveError { + match e { + StreamUpgradeError::Timeout => { + outbound_hop::ReserveError::Io(io::ErrorKind::TimedOut.into()) + } + StreamUpgradeError::Apply(never) => void::unreachable(never), + StreamUpgradeError::NegotiationFailed => outbound_hop::ReserveError::Unsupported, + StreamUpgradeError::Io(e) => outbound_hop::ReserveError::Io(e), + } +} + +fn into_connect_error(e: StreamUpgradeError) -> outbound_hop::ConnectError { + match e { + StreamUpgradeError::Timeout => { + outbound_hop::ConnectError::Io(io::ErrorKind::TimedOut.into()) + } + StreamUpgradeError::Apply(never) => void::unreachable(never), + StreamUpgradeError::NegotiationFailed => outbound_hop::ConnectError::Unsupported, + StreamUpgradeError::Io(e) => outbound_hop::ConnectError::Io(e), + } } diff --git a/protocols/relay/src/protocol/outbound_hop.rs b/protocols/relay/src/protocol/outbound_hop.rs index e5f9a6a0a52..3ae824be167 100644 --- a/protocols/relay/src/protocol/outbound_hop.rs +++ b/protocols/relay/src/protocol/outbound_hop.rs @@ -47,7 +47,7 @@ pub enum ConnectError { #[error("Remote does not support the `{HOP_PROTOCOL_NAME}` protocol")] Unsupported, #[error("IO error")] - Io(#[source] io::Error), + Io(#[from] io::Error), #[error("Protocol error")] Protocol(#[from] ProtocolViolation), } @@ -61,7 +61,7 @@ pub enum ReserveError { #[error("Remote does not support the `{HOP_PROTOCOL_NAME}` protocol")] Unsupported, #[error("IO error")] - Io(#[source] io::Error), + Io(#[from] io::Error), #[error("Protocol error")] Protocol(#[from] ProtocolViolation), } diff --git a/protocols/request-response/Cargo.toml b/protocols/request-response/Cargo.toml index 1bfd03e1520..cec43cc3ee6 100644 --- a/protocols/request-response/Cargo.toml +++ b/protocols/request-response/Cargo.toml @@ -18,6 +18,7 @@ instant = "0.1.12" libp2p-core = { workspace = true } libp2p-swarm = { workspace = true } libp2p-identity = { workspace = true } +libp2p-protocol-utils = { workspace = true } rand = "0.8" serde = { version = "1.0", optional = true} serde_json = { version = "1.0.108", optional = true } diff --git a/protocols/request-response/src/handler.rs b/protocols/request-response/src/handler.rs index 2d45e0d7dc3..a759022b1cb 100644 --- a/protocols/request-response/src/handler.rs +++ b/protocols/request-response/src/handler.rs @@ -28,10 +28,8 @@ use crate::{InboundRequestId, OutboundRequestId, EMPTY_QUEUE_SHRINK_THRESHOLD}; use futures::channel::mpsc; use futures::{channel::oneshot, prelude::*}; -use libp2p_swarm::handler::{ - ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, - ListenUpgradeError, -}; +use libp2p_protocol_utils::InflightProtocolDataQueue; +use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound}; use libp2p_swarm::{ handler::{ConnectionHandler, ConnectionHandlerEvent, StreamUpgradeError}, SubstreamProtocol, @@ -47,6 +45,7 @@ use std::{ task::{Context, Poll}, time::Duration, }; +use void::Void; /// A connection handler for a request response [`Behaviour`](super::Behaviour) protocol. pub struct Handler @@ -59,10 +58,13 @@ where codec: TCodec, /// Queue of events to emit in `poll()`. pending_events: VecDeque>, - /// Outbound upgrades waiting to be emitted as an `OutboundSubstreamRequest`. - pending_outbound: VecDeque>, - requested_outbound: VecDeque>, + pending_streams: InflightProtocolDataQueue< + (OutboundRequestId, TCodec::Request), + SmallVec<[TCodec::Protocol; 2]>, + Result<(libp2p_swarm::Stream, TCodec::Protocol), StreamUpgradeError>, + >, + /// A channel for receiving inbound requests. inbound_receiver: mpsc::Receiver<( InboundRequestId, @@ -102,8 +104,7 @@ where Self { inbound_protocols, codec, - pending_outbound: VecDeque::new(), - requested_outbound: Default::default(), + pending_streams: InflightProtocolDataQueue::default(), inbound_receiver, inbound_sender, pending_events: VecDeque::new(), @@ -167,92 +168,6 @@ where tracing::warn!("Dropping inbound stream because we are at capacity") } } - - fn on_fully_negotiated_outbound( - &mut self, - FullyNegotiatedOutbound { - protocol: (mut stream, protocol), - info: (), - }: FullyNegotiatedOutbound< - ::OutboundProtocol, - ::OutboundOpenInfo, - >, - ) { - let message = self - .requested_outbound - .pop_front() - .expect("negotiated a stream without a pending message"); - - let mut codec = self.codec.clone(); - let request_id = message.request_id; - - let send = async move { - let write = codec.write_request(&protocol, &mut stream, message.request); - write.await?; - stream.close().await?; - let read = codec.read_response(&protocol, &mut stream); - let response = read.await?; - - Ok(Event::Response { - request_id, - response, - }) - }; - - if self - .worker_streams - .try_push(RequestId::Outbound(request_id), send.boxed()) - .is_err() - { - tracing::warn!("Dropping outbound stream because we are at capacity") - } - } - - fn on_dial_upgrade_error( - &mut self, - DialUpgradeError { error, info: () }: DialUpgradeError< - ::OutboundOpenInfo, - ::OutboundProtocol, - >, - ) { - let message = self - .requested_outbound - .pop_front() - .expect("negotiated a stream without a pending message"); - - match error { - StreamUpgradeError::Timeout => { - self.pending_events - .push_back(Event::OutboundTimeout(message.request_id)); - } - StreamUpgradeError::NegotiationFailed => { - // The remote merely doesn't support the protocol(s) we requested. - // This is no reason to close the connection, which may - // successfully communicate with other protocols already. - // An event is reported to permit user code to react to the fact that - // the remote peer does not support the requested protocol(s). - self.pending_events - .push_back(Event::OutboundUnsupportedProtocols(message.request_id)); - } - StreamUpgradeError::Apply(e) => void::unreachable(e), - StreamUpgradeError::Io(e) => { - tracing::debug!( - "outbound stream for request {} failed: {e}, retrying", - message.request_id - ); - self.requested_outbound.push_back(message); - } - } - } - fn on_listen_upgrade_error( - &mut self, - ListenUpgradeError { error, .. }: ListenUpgradeError< - ::InboundOpenInfo, - ::InboundProtocol, - >, - ) { - void::unreachable(error) - } } /// The events emitted by the [`Handler`]. @@ -382,7 +297,14 @@ where } fn on_behaviour_event(&mut self, request: Self::FromBehaviour) { - self.pending_outbound.push_back(request); + let OutboundMessage { + request_id, + request, + protocols, + } = request; + + self.pending_streams + .enqueue_request(protocols, (request_id, request)); } #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))] @@ -390,74 +312,114 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll, (), Self::ToBehaviour>> { - match self.worker_streams.poll_unpin(cx) { - Poll::Ready((_, Ok(Ok(event)))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); - } - Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - Event::InboundStreamFailed { - request_id: id, - error: e, - }, - )); - } - Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - Event::OutboundStreamFailed { - request_id: id, - error: e, - }, - )); - } - Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - Event::InboundTimeout(id), - )); - } - Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - Event::OutboundTimeout(id), - )); + loop { + match self.worker_streams.poll_unpin(cx) { + Poll::Ready((_, Ok(Ok(event)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } + Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::InboundStreamFailed { + request_id: id, + error: e, + }, + )); + } + Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundStreamFailed { + request_id: id, + error: e, + }, + )); + } + Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::InboundTimeout(id), + )); + } + Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundTimeout(id), + )); + } + Poll::Pending => {} } - Poll::Pending => {} - } - - // Drain pending events that were produced by `worker_streams`. - if let Some(event) = self.pending_events.pop_front() { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); - } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { - self.pending_events.shrink_to_fit(); - } - - // Check for inbound requests. - if let Poll::Ready(Some((id, rq, rs_sender))) = self.inbound_receiver.poll_next_unpin(cx) { - // We received an inbound request. - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request { - request_id: id, - request: rq, - sender: rs_sender, - })); - } + // Drain pending events that were produced by `worker_streams`. + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { + self.pending_events.shrink_to_fit(); + } - // Emit outbound requests. - if let Some(request) = self.pending_outbound.pop_front() { - let protocols = request.protocols.clone(); - self.requested_outbound.push_back(request); + // Check for inbound requests. + if let Poll::Ready(Some((id, rq, rs_sender))) = + self.inbound_receiver.poll_next_unpin(cx) + { + // We received an inbound request. + + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request { + request_id: id, + request: rq, + sender: rs_sender, + })); + } - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(Protocol { protocols }, ()), - }); - } + match self.pending_streams.next_completed() { + Some((Ok((mut stream, protocol)), (request_id, request))) => { + let mut codec = self.codec.clone(); + + let send = async move { + let write = codec.write_request(&protocol, &mut stream, request); + write.await?; + stream.close().await?; + let read = codec.read_response(&protocol, &mut stream); + let response = read.await?; + + Ok(Event::Response { + request_id, + response, + }) + }; + + if self + .worker_streams + .try_push(RequestId::Outbound(request_id), send.boxed()) + .is_err() + { + tracing::warn!("Dropping outbound stream because we are at capacity") + } + continue; + } + Some((Err(StreamUpgradeError::Timeout), (request_id, _))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundTimeout(request_id), + )); + } + Some((Err(StreamUpgradeError::NegotiationFailed), (request_id, _))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundUnsupportedProtocols(request_id), + )); + } + Some((Err(StreamUpgradeError::Io(error)), (request_id, _))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundStreamFailed { request_id, error }, + )); + } + Some((Err(StreamUpgradeError::Apply(void)), _)) => void::unreachable(void), + None => {} + } - debug_assert!(self.pending_outbound.is_empty()); + // Emit outbound requests. + if let Some(protocols) = self.pending_streams.next_request() { + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(Protocol { protocols }, ()), + }); + } - if self.pending_outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { - self.pending_outbound.shrink_to_fit(); + return Poll::Pending; } - - Poll::Pending } fn on_connection_event( @@ -473,15 +435,13 @@ where ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => { self.on_fully_negotiated_inbound(fully_negotiated_inbound) } - ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => { - self.on_fully_negotiated_outbound(fully_negotiated_outbound) - } - ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { - self.on_dial_upgrade_error(dial_upgrade_error) + ConnectionEvent::FullyNegotiatedOutbound(ev) => { + self.pending_streams.submit_response(Ok(ev.protocol)); } - ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { - self.on_listen_upgrade_error(listen_upgrade_error) + ConnectionEvent::DialUpgradeError(ev) => { + self.pending_streams.submit_response(Err(ev.error)); } + ConnectionEvent::ListenUpgradeError(ev) => void::unreachable(ev.error), _ => {} } }