diff --git a/Cargo.toml b/Cargo.toml index 0031fa3fd60..cc99fc2af96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ default = [ "websocket", "yamux", ] + autonat = ["dep:libp2p-autonat"] dcutr = ["dep:libp2p-dcutr", "libp2p-metrics?/dcutr"] deflate = ["dep:libp2p-deflate"] diff --git a/core/src/connection.rs b/core/src/connection.rs index 3a8d54d04a1..91008408fe2 100644 --- a/core/src/connection.rs +++ b/core/src/connection.rs @@ -43,25 +43,6 @@ impl std::ops::Add for ConnectionId { } } -/// The ID of a single listener. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct ListenerId(u64); - -impl ListenerId { - /// Creates a `ListenerId` from a non-negative integer. - pub fn new(id: u64) -> Self { - Self(id) - } -} - -impl std::ops::Add for ListenerId { - type Output = Self; - - fn add(self, other: u64) -> Self { - Self(self.0 + other) - } -} - /// The endpoint roles associated with a peer-to-peer communication channel. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum Endpoint { diff --git a/core/src/either.rs b/core/src/either.rs index a9e51a47e79..bde66c98900 100644 --- a/core/src/either.rs +++ b/core/src/either.rs @@ -20,7 +20,7 @@ use crate::{ muxing::{StreamMuxer, StreamMuxerEvent}, - transport::{ListenerEvent, Transport, TransportError}, + transport::{ListenerId, Transport, TransportError, TransportEvent}, Multiaddr, ProtocolName, }; use futures::{ @@ -353,48 +353,6 @@ pub enum EitherOutbound { B(B::OutboundSubstream), } -/// Implements `Stream` and dispatches all method calls to either `First` or `Second`. -#[pin_project(project = EitherListenStreamProj)] -#[derive(Debug, Copy, Clone)] -#[must_use = "futures do nothing unless polled"] -pub enum EitherListenStream { - First(#[pin] A), - Second(#[pin] B), -} - -impl Stream - for EitherListenStream -where - AStream: TryStream, Error = AError>, - BStream: TryStream, Error = BError>, -{ - type Item = Result< - ListenerEvent, EitherError>, - EitherError, - >; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - EitherListenStreamProj::First(a) => match TryStream::try_poll_next(a, cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le - .map(EitherFuture::First) - .map_err(EitherError::A)))), - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), - }, - EitherListenStreamProj::Second(a) => match TryStream::try_poll_next(a, cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le - .map(EitherFuture::Second) - .map_err(EitherError::B)))), - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::B(err)))), - }, - } - } -} - /// Implements `Future` and dispatches all method calls to either `First` or `Second`. #[pin_project(project = EitherFutureProj)] #[derive(Debug, Copy, Clone)] @@ -464,11 +422,12 @@ impl ProtocolName for EitherName { } } } - +#[pin_project(project = EitherTransportProj)] #[derive(Debug, Copy, Clone)] +#[must_use = "futures do nothing unless polled"] pub enum EitherTransport { - Left(A), - Right(B), + Left(#[pin] A), + Right(#[pin] B), } impl Transport for EitherTransport @@ -478,29 +437,54 @@ where { type Output = EitherOutput; type Error = EitherError; - type Listener = EitherListenStream; type ListenerUpgrade = EitherFuture; type Dial = EitherFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - use TransportError::*; - match self { - EitherTransport::Left(a) => match a.listen_on(addr) { - Ok(listener) => Ok(EitherListenStream::First(listener)), - Err(MultiaddrNotSupported(addr)) => Err(MultiaddrNotSupported(addr)), - Err(Other(err)) => Err(Other(EitherError::A(err))), + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.project() { + EitherTransportProj::Left(a) => match a.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(event) => Poll::Ready( + event + .map_upgrade(EitherFuture::First) + .map_err(EitherError::A), + ), }, - EitherTransport::Right(b) => match b.listen_on(addr) { - Ok(listener) => Ok(EitherListenStream::Second(listener)), - Err(MultiaddrNotSupported(addr)) => Err(MultiaddrNotSupported(addr)), - Err(Other(err)) => Err(Other(EitherError::B(err))), + EitherTransportProj::Right(b) => match b.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(event) => Poll::Ready( + event + .map_upgrade(EitherFuture::Second) + .map_err(EitherError::B), + ), }, } } + fn remove_listener(&mut self, id: ListenerId) -> bool { + match self { + EitherTransport::Left(t) => t.remove_listener(id), + EitherTransport::Right(t) => t.remove_listener(id), + } + } + + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + use TransportError::*; + match self { + EitherTransport::Left(a) => a.listen_on(addr).map_err(|e| match e { + MultiaddrNotSupported(addr) => MultiaddrNotSupported(addr), + Other(err) => Other(EitherError::A(err)), + }), + EitherTransport::Right(b) => b.listen_on(addr).map_err(|e| match e { + MultiaddrNotSupported(addr) => MultiaddrNotSupported(addr), + Other(err) => Other(EitherError::B(err)), + }), + } + } + fn dial(&mut self, addr: Multiaddr) -> Result> { use TransportError::*; match self { diff --git a/core/src/transport.rs b/core/src/transport.rs index a625e3b552c..d94056b33cc 100644 --- a/core/src/transport.rs +++ b/core/src/transport.rs @@ -25,10 +25,14 @@ //! any desired protocols. The rest of the module defines combinators for //! modifying a transport through composition with other transports or protocol upgrades. -use crate::connection::ConnectedPoint; use futures::prelude::*; use multiaddr::Multiaddr; -use std::{error::Error, fmt}; +use std::{ + error::Error, + fmt, + pin::Pin, + task::{Context, Poll}, +}; pub mod and_then; pub mod choice; @@ -42,6 +46,8 @@ pub mod upgrade; mod boxed; mod optional; +use crate::ConnectedPoint; + pub use self::boxed::Boxed; pub use self::choice::OrTransport; pub use self::memory::MemoryTransport; @@ -87,21 +93,8 @@ pub trait Transport { /// An error that occurred during connection setup. type Error: Error; - /// A stream of [`Output`](Transport::Output)s for inbound connections. - /// - /// An item should be produced whenever a connection is received at the lowest level of the - /// transport stack. The item must be a [`ListenerUpgrade`](Transport::ListenerUpgrade) future - /// that resolves to an [`Output`](Transport::Output) value once all protocol upgrades - /// have been applied. - /// - /// If this stream produces an error, it is considered fatal and the listener is killed. It - /// is possible to report non-fatal errors by producing a [`ListenerEvent::Error`]. - type Listener: Stream< - Item = Result, Self::Error>, - >; - /// A pending [`Output`](Transport::Output) for an inbound connection, - /// obtained from the [`Listener`](Transport::Listener) stream. + /// obtained from the [`Transport`] stream. /// /// After a connection has been accepted by the transport, it may need to go through /// asynchronous post-processing (i.e. protocol upgrade negotiations). Such @@ -115,15 +108,22 @@ pub trait Transport { /// obtained from [dialing](Transport::dial). type Dial: Future>; + // TODO: fix docs /// Listens on the given [`Multiaddr`], producing a stream of pending, inbound connections - /// and addresses this transport is listening on (cf. [`ListenerEvent`]). + /// and addresses this transport is listening on (cf. [`TransportEvent`]). /// /// Returning an error from the stream is considered fatal. The listener can also report - /// non-fatal errors by producing a [`ListenerEvent::Error`]. - fn listen_on(&mut self, addr: Multiaddr) -> Result> + /// non-fatal errors by producing a [`TransportEvent::Error`]. + fn listen_on(&mut self, addr: Multiaddr) -> Result> where Self: Sized; + /// Remove a listener. + /// + /// Return `true` if there was a listener with this Id, `false` + /// otherwise. + fn remove_listener(&mut self, id: ListenerId) -> bool; + /// Dials the given [`Multiaddr`], returning a future for a pending outbound connection. /// /// If [`TransportError::MultiaddrNotSupported`] is returned, it may be desirable to @@ -144,6 +144,14 @@ pub trait Transport { where Self: Sized; + // TODO: Add docs + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> + where + Self: Sized; + /// Performs a transport-specific mapping of an address `observed` by /// a remote onto a local `listen` address to yield an address for /// the local node that may be reachable for other peers. @@ -152,9 +160,8 @@ pub trait Transport { /// Boxes the transport, including custom transport errors. fn boxed(self) -> boxed::Boxed where - Self: Transport + Sized + Send + 'static, + Self: Transport + Sized + Send + Unpin + 'static, Self::Dial: Send + 'static, - Self::Listener: Send + 'static, Self::ListenerUpgrade: Send + 'static, Self::Error: Send + Sync, { @@ -221,78 +228,160 @@ pub trait Transport { } } -/// Event produced by [`Transport::Listener`]s. +/// The ID of a single listener. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct ListenerId(u64); + +impl ListenerId { + /// Creates a new `ListenerId`. + pub fn new() -> Self { + ListenerId(rand::random()) + } +} + +impl Default for ListenerId { + fn default() -> Self { + Self::new() + } +} + +/// Event produced by [`Transport`]s. /// -/// Transports are expected to produce `Upgrade` events only for +/// Transports are expected to produce `Incoming` events only for /// listen addresses which have previously been announced via /// a `NewAddress` event and which have not been invalidated by /// an `AddressExpired` event yet. -#[derive(Clone, Debug, PartialEq)] -pub enum ListenerEvent { - /// The transport is listening on a new additional [`Multiaddr`]. - NewAddress(Multiaddr), - /// An upgrade, consisting of the upgrade future, the listener address and the remote address. - Upgrade { - /// The upgrade. +pub enum TransportEvent { + /// A new address is being listened on. + NewAddress { + /// The listener that is listening on the new address. + listener_id: ListenerId, + /// The new address that is being listened on. + listen_addr: Multiaddr, + }, + /// An address is no longer being listened on. + AddressExpired { + /// The listener that is no longer listening on the address. + listener_id: ListenerId, + /// The new address that is being listened on. + listen_addr: Multiaddr, + }, + /// A connection is incoming on one of the listeners. + Incoming { + /// The listener that produced the upgrade. + listener_id: ListenerId, + /// The produced upgrade. upgrade: TUpgr, - /// The local address which produced this upgrade. + /// Local connection address. local_addr: Multiaddr, - /// The remote address which produced this upgrade. - remote_addr: Multiaddr, + /// Address used to send back data to the incoming client. + send_back_addr: Multiaddr, + }, + /// A listener closed. + ListenerClosed { + /// The ID of the listener that closed. + listener_id: ListenerId, + /// Reason for the closure. Contains `Ok(())` if the stream produced `None`, or `Err` + /// if the stream produced an error. + reason: Result<(), TErr>, }, - /// A [`Multiaddr`] is no longer used for listening. - AddressExpired(Multiaddr), - /// A non-fatal error has happened on the listener. + /// A listener errored. /// - /// This event should be generated in order to notify the user that something wrong has - /// happened. The listener, however, continues to run. - Error(TErr), + /// The listener will continue to be polled for new events and the event + /// is for informational purposes only. + Error { + /// The ID of the listener that errored. + listener_id: ListenerId, + /// The error value. + error: TErr, + }, } -impl ListenerEvent { - /// In case this [`ListenerEvent`] is an upgrade, apply the given function - /// to the upgrade and multiaddress and produce another listener event - /// based the the function's result. - pub fn map(self, f: impl FnOnce(TUpgr) -> U) -> ListenerEvent { +impl TransportEvent { + pub fn map_upgrade(self, map: impl FnOnce(TUpgr) -> U) -> TransportEvent { match self { - ListenerEvent::Upgrade { + TransportEvent::Incoming { + listener_id, upgrade, local_addr, - remote_addr, - } => ListenerEvent::Upgrade { - upgrade: f(upgrade), + send_back_addr, + } => TransportEvent::Incoming { + listener_id, + upgrade: map(upgrade), local_addr, - remote_addr, + send_back_addr, + }, + TransportEvent::NewAddress { + listen_addr, + listener_id, + } => TransportEvent::NewAddress { + listen_addr, + listener_id, + }, + TransportEvent::AddressExpired { + listen_addr, + listener_id, + } => TransportEvent::AddressExpired { + listen_addr, + listener_id, + }, + TransportEvent::Error { listener_id, error } => { + TransportEvent::Error { listener_id, error } + } + TransportEvent::ListenerClosed { + listener_id, + reason, + } => TransportEvent::ListenerClosed { + listener_id, + reason, }, - ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), - ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), - ListenerEvent::Error(e) => ListenerEvent::Error(e), } } - /// In case this [`ListenerEvent`] is an [`Error`](ListenerEvent::Error), - /// apply the given function to the error and produce another listener event based on the - /// function's result. - pub fn map_err(self, f: impl FnOnce(TErr) -> U) -> ListenerEvent { + pub fn map_err(self, map_err: impl FnOnce(TErr) -> E) -> TransportEvent { match self { - ListenerEvent::Upgrade { + TransportEvent::Incoming { + listener_id, upgrade, local_addr, - remote_addr, - } => ListenerEvent::Upgrade { + send_back_addr, + } => TransportEvent::Incoming { + listener_id, upgrade, local_addr, - remote_addr, + send_back_addr, + }, + TransportEvent::NewAddress { + listen_addr, + listener_id, + } => TransportEvent::NewAddress { + listen_addr, + listener_id, + }, + TransportEvent::AddressExpired { + listen_addr, + listener_id, + } => TransportEvent::AddressExpired { + listen_addr, + listener_id, + }, + TransportEvent::Error { listener_id, error } => TransportEvent::Error { + listener_id, + error: map_err(error), + }, + TransportEvent::ListenerClosed { + listener_id, + reason, + } => TransportEvent::ListenerClosed { + listener_id, + reason: reason.map_err(map_err), }, - ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), - ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), - ListenerEvent::Error(e) => ListenerEvent::Error(f(e)), } } /// Returns `true` if this is an `Upgrade` listener event. pub fn is_upgrade(&self) -> bool { - matches!(self, ListenerEvent::Upgrade { .. }) + matches!(self, TransportEvent::Incoming { .. }) } /// Try to turn this listener event into upgrade parts. @@ -300,13 +389,13 @@ impl ListenerEvent { /// Returns `None` if the event is not actually an upgrade, /// otherwise the upgrade and the remote address. pub fn into_upgrade(self) -> Option<(TUpgr, Multiaddr)> { - if let ListenerEvent::Upgrade { + if let TransportEvent::Incoming { upgrade, - remote_addr, + send_back_addr, .. } = self { - Some((upgrade, remote_addr)) + Some((upgrade, send_back_addr)) } else { None } @@ -314,7 +403,7 @@ impl ListenerEvent { /// Returns `true` if this is a `NewAddress` listener event. pub fn is_new_address(&self) -> bool { - matches!(self, ListenerEvent::NewAddress(_)) + matches!(self, TransportEvent::NewAddress { .. }) } /// Try to turn this listener event into the `NewAddress` part. @@ -322,8 +411,8 @@ impl ListenerEvent { /// Returns `None` if the event is not actually a `NewAddress`, /// otherwise the address. pub fn into_new_address(self) -> Option { - if let ListenerEvent::NewAddress(a) = self { - Some(a) + if let TransportEvent::NewAddress { listen_addr, .. } = self { + Some(listen_addr) } else { None } @@ -331,7 +420,7 @@ impl ListenerEvent { /// Returns `true` if this is an `AddressExpired` listener event. pub fn is_address_expired(&self) -> bool { - matches!(self, ListenerEvent::AddressExpired(_)) + matches!(self, TransportEvent::AddressExpired { .. }) } /// Try to turn this listener event into the `AddressExpired` part. @@ -339,8 +428,8 @@ impl ListenerEvent { /// Returns `None` if the event is not actually a `AddressExpired`, /// otherwise the address. pub fn into_address_expired(self) -> Option { - if let ListenerEvent::AddressExpired(a) = self { - Some(a) + if let TransportEvent::AddressExpired { listen_addr, .. } = self { + Some(listen_addr) } else { None } @@ -348,7 +437,7 @@ impl ListenerEvent { /// Returns `true` if this is an `Error` listener event. pub fn is_error(&self) -> bool { - matches!(self, ListenerEvent::Error(_)) + matches!(self, TransportEvent::Error { .. }) } /// Try to turn this listener event into the `Error` part. @@ -356,14 +445,59 @@ impl ListenerEvent { /// Returns `None` if the event is not actually a `Error`, /// otherwise the error. pub fn into_error(self) -> Option { - if let ListenerEvent::Error(err) = self { - Some(err) + if let TransportEvent::Error { error, .. } = self { + Some(error) } else { None } } } +impl fmt::Debug for TransportEvent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + TransportEvent::NewAddress { + listener_id, + listen_addr, + } => f + .debug_struct("TransportEvent::NewAddress") + .field("listener_id", listener_id) + .field("listen_addr", listen_addr) + .finish(), + TransportEvent::AddressExpired { + listener_id, + listen_addr, + } => f + .debug_struct("TransportEvent::AddressExpired") + .field("listener_id", listener_id) + .field("listen_addr", listen_addr) + .finish(), + TransportEvent::Incoming { + listener_id, + local_addr, + .. + } => f + .debug_struct("TransportEvent::Incoming") + .field("listener_id", listener_id) + .field("local_addr", local_addr) + .finish(), + TransportEvent::ListenerClosed { + listener_id, + reason, + } => f + .debug_struct("TransportEvent::Closed") + .field("listener_id", listener_id) + .field("reason", reason) + .finish(), + TransportEvent::Error { listener_id, error } => f + .debug_struct("TransportEvent::Error") + .field("listener_id", listener_id) + .field("error", error) + .finish(), + } + } +} + /// An error during [dialing][Transport::dial] or [listening][Transport::listen_on] /// on a [`Transport`]. #[derive(Debug, Clone)] diff --git a/core/src/transport/and_then.rs b/core/src/transport/and_then.rs index f73a0caf8e6..b4f9c92186f 100644 --- a/core/src/transport/and_then.rs +++ b/core/src/transport/and_then.rs @@ -21,15 +21,17 @@ use crate::{ connection::{ConnectedPoint, Endpoint}, either::EitherError, - transport::{ListenerEvent, Transport, TransportError}, + transport::{ListenerId, Transport, TransportError, TransportEvent}, }; use futures::{future::Either, prelude::*}; use multiaddr::Multiaddr; use std::{error, marker::PhantomPinned, pin::Pin, task::Context, task::Poll}; /// See the `Transport::and_then` method. +#[pin_project::pin_project] #[derive(Debug, Clone)] pub struct AndThen { + #[pin] transport: T, fun: C, } @@ -49,27 +51,17 @@ where { type Output = O; type Error = EitherError; - type Listener = AndThenStream; type ListenerUpgrade = AndThenFuture; type Dial = AndThenFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let listener = self - .transport + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.transport .listen_on(addr) - .map_err(|err| err.map(EitherError::A))?; - // Try to negotiate the protocol. - // Note that failing to negotiate a protocol will never produce a future with an error. - // Instead the `stream` will produce `Ok(Err(...))`. - // `stream` can only produce an `Err` if `listening_stream` produces an `Err`. - let stream = AndThenStream { - stream: listener, - fun: self.fun.clone(), - }; - Ok(stream) + .map_err(|err| err.map(EitherError::A)) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -116,68 +108,40 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.address_translation(server, observed) } -} -/// Custom `Stream` to avoid boxing. -/// -/// Applies a function to every stream item. -#[pin_project::pin_project] -#[derive(Debug, Clone)] -pub struct AndThenStream { - #[pin] - stream: TListener, - fun: TMap, -} - -impl Stream - for AndThenStream -where - TListener: TryStream, Error = TTransErr>, - TListUpgr: TryFuture, - TMap: FnOnce(TTransOut, ConnectedPoint) -> TMapOut + Clone, - TMapOut: TryFuture, -{ - type Item = Result< - ListenerEvent< - AndThenFuture, - EitherError, - >, - EitherError, - >; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let this = self.project(); - match TryStream::try_poll_next(this.stream, cx) { - Poll::Ready(Some(Ok(event))) => { - let event = match event { - ListenerEvent::Upgrade { - upgrade, - local_addr, - remote_addr, - } => { - let point = ConnectedPoint::Listener { - local_addr: local_addr.clone(), - send_back_addr: remote_addr.clone(), - }; - ListenerEvent::Upgrade { - upgrade: AndThenFuture { - inner: Either::Left(Box::pin(upgrade)), - args: Some((this.fun.clone(), point)), - _marker: PhantomPinned, - }, - local_addr, - remote_addr, - } - } - ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), - ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), - ListenerEvent::Error(e) => ListenerEvent::Error(EitherError::A(e)), + match this.transport.poll(cx) { + Poll::Ready(TransportEvent::Incoming { + listener_id, + upgrade, + local_addr, + send_back_addr, + }) => { + let point = ConnectedPoint::Listener { + local_addr: local_addr.clone(), + send_back_addr: send_back_addr.clone(), }; - - Poll::Ready(Some(Ok(event))) + Poll::Ready(TransportEvent::Incoming { + listener_id, + upgrade: AndThenFuture { + inner: Either::Left(Box::pin(upgrade)), + args: Some((this.fun.clone(), point)), + _marker: PhantomPinned, + }, + local_addr, + send_back_addr, + }) + } + Poll::Ready(other) => { + let mapped = other + .map_upgrade(|_upgrade| unreachable!("case already matched")) + .map_err(EitherError::A); + Poll::Ready(mapped) } - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), - Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } diff --git a/core/src/transport/boxed.rs b/core/src/transport/boxed.rs index 8a804aa40be..b2560c4a662 100644 --- a/core/src/transport/boxed.rs +++ b/core/src/transport/boxed.rs @@ -18,18 +18,22 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{ListenerEvent, Transport, TransportError}; -use futures::prelude::*; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; +use futures::{prelude::*, stream::FusedStream}; use multiaddr::Multiaddr; -use std::{error::Error, fmt, io, pin::Pin}; +use std::{ + error::Error, + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; /// Creates a new [`Boxed`] transport from the given transport. pub fn boxed(transport: T) -> Boxed where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + Sync, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, { Boxed { @@ -41,19 +45,22 @@ where /// and `ListenerUpgrade` futures are `Box`ed and only the `Output` /// and `Error` types are captured in type variables. pub struct Boxed { - inner: Box + Send>, + inner: Box + Send + Unpin>, } type Dial = Pin> + Send>>; -type Listener = - Pin, io::Error>>> + Send>>; type ListenerUpgrade = Pin> + Send>>; trait Abstract { - fn listen_on(&mut self, addr: Multiaddr) -> Result, TransportError>; + fn listen_on(&mut self, addr: Multiaddr) -> Result>; + fn remove_listener(&mut self, id: ListenerId) -> bool; fn dial(&mut self, addr: Multiaddr) -> Result, TransportError>; fn dial_as_listener(&mut self, addr: Multiaddr) -> Result, TransportError>; fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option; + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, io::Error>>; } impl Abstract for T @@ -61,22 +68,14 @@ where T: Transport + 'static, T::Error: Send + Sync, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, { - fn listen_on(&mut self, addr: Multiaddr) -> Result, TransportError> { - let listener = Transport::listen_on(self, addr).map_err(|e| e.map(box_err))?; - let fut = listener - .map_ok(|event| { - event - .map(|upgrade| { - let up = upgrade.map_err(box_err); - Box::pin(up) as ListenerUpgrade - }) - .map_err(box_err) - }) - .map_err(box_err); - Ok(Box::pin(fut)) + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + Transport::listen_on(self, addr).map_err(|e| e.map(box_err)) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + Transport::remove_listener(self, id) } fn dial(&mut self, addr: Multiaddr) -> Result, TransportError> { @@ -96,6 +95,20 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { Transport::address_translation(self, server, observed) } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, io::Error>> { + self.poll(cx).map(|event| { + event + .map_upgrade(|upgrade| { + let up = upgrade.map_err(box_err); + Box::pin(up) as ListenerUpgrade + }) + .map_err(box_err) + }) + } } impl fmt::Debug for Boxed { @@ -107,17 +120,17 @@ impl fmt::Debug for Boxed { impl Transport for Boxed { type Output = O; type Error = io::Error; - type Listener = Listener; type ListenerUpgrade = ListenerUpgrade; type Dial = Dial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { self.inner.listen_on(addr) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.remove_listener(id) + } + fn dial(&mut self, addr: Multiaddr) -> Result> { self.inner.dial(addr) } @@ -132,6 +145,27 @@ impl Transport for Boxed { fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.inner.address_translation(server, observed) } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.inner.as_mut()).poll(cx) + } +} + +impl Stream for Boxed { + type Item = TransportEvent, io::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Transport::poll(self, cx).map(Some) + } +} + +impl FusedStream for Boxed { + fn is_terminated(&self) -> bool { + false + } } fn box_err(e: E) -> io::Error { diff --git a/core/src/transport/choice.rs b/core/src/transport/choice.rs index f1d21cfa30c..17528c1d4a8 100644 --- a/core/src/transport/choice.rs +++ b/core/src/transport/choice.rs @@ -18,13 +18,15 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::either::{EitherError, EitherFuture, EitherListenStream, EitherOutput}; -use crate::transport::{Transport, TransportError}; +use crate::either::{EitherError, EitherFuture, EitherOutput}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use multiaddr::Multiaddr; +use std::{pin::Pin, task::Context, task::Poll}; /// Struct returned by `or_transport()`. #[derive(Debug, Copy, Clone)] -pub struct OrTransport(A, B); +#[pin_project::pin_project] +pub struct OrTransport(#[pin] A, #[pin] B); impl OrTransport { pub fn new(a: A, b: B) -> OrTransport { @@ -39,33 +41,27 @@ where { type Output = EitherOutput; type Error = EitherError; - type Listener = EitherListenStream; type ListenerUpgrade = EitherFuture; type Dial = EitherFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let addr = match self.0.listen_on(addr) { - Ok(listener) => return Ok(EitherListenStream::First(listener)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => { - return Err(TransportError::Other(EitherError::A(err))) - } + res => return res.map_err(|err| err.map(EitherError::A)), }; let addr = match self.1.listen_on(addr) { - Ok(listener) => return Ok(EitherListenStream::Second(listener)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => { - return Err(TransportError::Other(EitherError::B(err))) - } + res => return res.map_err(|err| err.map(EitherError::B)), }; Err(TransportError::MultiaddrNotSupported(addr)) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.0.remove_listener(id) || self.1.remove_listener(id) + } + fn dial(&mut self, addr: Multiaddr) -> Result> { let addr = match self.0.dial(addr) { Ok(connec) => return Ok(EitherFuture::First(connec)), @@ -116,4 +112,24 @@ where self.1.address_translation(server, observed) } } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.project(); + match this.0.poll(cx) { + Poll::Ready(ev) => { + return Poll::Ready(ev.map_upgrade(EitherFuture::First).map_err(EitherError::A)) + } + Poll::Pending => {} + } + match this.1.poll(cx) { + Poll::Ready(ev) => { + return Poll::Ready(ev.map_upgrade(EitherFuture::Second).map_err(EitherError::B)) + } + Poll::Pending => {} + } + Poll::Pending + } } diff --git a/core/src/transport/dummy.rs b/core/src/transport/dummy.rs index 5862348b0d4..a7d1cab9089 100644 --- a/core/src/transport/dummy.rs +++ b/core/src/transport/dummy.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{ListenerEvent, Transport, TransportError}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use crate::Multiaddr; use futures::{prelude::*, task::Context, task::Poll}; use std::{fmt, io, marker::PhantomData, pin::Pin}; @@ -56,19 +56,17 @@ impl Clone for DummyTransport { impl Transport for DummyTransport { type Output = TOut; type Error = io::Error; - type Listener = futures::stream::Pending< - Result, Self::Error>, - >; type ListenerUpgrade = futures::future::Pending>; type Dial = futures::future::Pending>; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { Err(TransportError::MultiaddrNotSupported(addr)) } + fn remove_listener(&mut self, _id: ListenerId) -> bool { + false + } + fn dial(&mut self, addr: Multiaddr) -> Result> { Err(TransportError::MultiaddrNotSupported(addr)) } @@ -83,6 +81,13 @@ impl Transport for DummyTransport { fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { None } + + fn poll( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + Poll::Pending + } } /// Implementation of `AsyncRead` and `AsyncWrite`. Not meant to be instanciated. diff --git a/core/src/transport/map.rs b/core/src/transport/map.rs index 703e1ea430b..50f7b826d36 100644 --- a/core/src/transport/map.rs +++ b/core/src/transport/map.rs @@ -20,15 +20,19 @@ use crate::{ connection::{ConnectedPoint, Endpoint}, - transport::{ListenerEvent, Transport, TransportError}, + transport::{Transport, TransportError, TransportEvent}, }; use futures::prelude::*; use multiaddr::Multiaddr; use std::{pin::Pin, task::Context, task::Poll}; +use super::ListenerId; + /// See `Transport::map`. #[derive(Debug, Copy, Clone)] +#[pin_project::pin_project] pub struct Map { + #[pin] transport: T, fun: F, } @@ -54,19 +58,15 @@ where { type Output = D; type Error = T::Error; - type Listener = MapStream; type ListenerUpgrade = MapFuture; type Dial = MapFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let stream = self.transport.listen_on(addr)?; - Ok(MapStream { - stream, - fun: self.fun.clone(), - }) + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.transport.listen_on(addr) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -99,58 +99,37 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.address_translation(server, observed) } -} - -/// Custom `Stream` implementation to avoid boxing. -/// -/// Maps a function over every stream item. -#[pin_project::pin_project] -#[derive(Clone, Debug)] -pub struct MapStream { - #[pin] - stream: T, - fun: F, -} -impl Stream for MapStream -where - T: TryStream, Error = E>, - X: TryFuture, - F: FnOnce(A, ConnectedPoint) -> B + Clone, -{ - type Item = Result, E>, E>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let this = self.project(); - match TryStream::try_poll_next(this.stream, cx) { - Poll::Ready(Some(Ok(event))) => { - let event = match event { - ListenerEvent::Upgrade { - upgrade, - local_addr, - remote_addr, - } => { - let point = ConnectedPoint::Listener { - local_addr: local_addr.clone(), - send_back_addr: remote_addr.clone(), - }; - ListenerEvent::Upgrade { - upgrade: MapFuture { - inner: upgrade, - args: Some((this.fun.clone(), point)), - }, - local_addr, - remote_addr, - } - } - ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), - ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), - ListenerEvent::Error(e) => ListenerEvent::Error(e), + match this.transport.poll(cx) { + Poll::Ready(TransportEvent::Incoming { + listener_id, + upgrade, + local_addr, + send_back_addr, + }) => { + let point = ConnectedPoint::Listener { + local_addr: local_addr.clone(), + send_back_addr: send_back_addr.clone(), }; - Poll::Ready(Some(Ok(event))) + Poll::Ready(TransportEvent::Incoming { + listener_id, + upgrade: MapFuture { + inner: upgrade, + args: Some((this.fun.clone(), point)), + }, + local_addr, + send_back_addr, + }) + } + Poll::Ready(other) => { + let mapped = other.map_upgrade(|_upgrade| unreachable!("case already matched")); + Poll::Ready(mapped) } - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), - Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } diff --git a/core/src/transport/map_err.rs b/core/src/transport/map_err.rs index 6cc2c5c3662..99f2912447f 100644 --- a/core/src/transport/map_err.rs +++ b/core/src/transport/map_err.rs @@ -18,14 +18,16 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{ListenerEvent, Transport, TransportError}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use futures::prelude::*; use multiaddr::Multiaddr; use std::{error, pin::Pin, task::Context, task::Poll}; /// See `Transport::map_err`. #[derive(Debug, Copy, Clone)] +#[pin_project::pin_project] pub struct MapErr { + #[pin] transport: T, map: F, } @@ -45,19 +47,16 @@ where { type Output = T::Output; type Error = TErr; - type Listener = MapErrListener; type ListenerUpgrade = MapErrListenerUpgrade; type Dial = MapErrDial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let map = self.map.clone(); - match self.transport.listen_on(addr) { - Ok(stream) => Ok(MapErrListener { inner: stream, map }), - Err(err) => Err(err.map(map)), - } + self.transport.listen_on(addr).map_err(|err| err.map(map)) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -88,41 +87,20 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.address_translation(server, observed) } -} -/// Listening stream for `MapErr`. -#[pin_project::pin_project] -pub struct MapErrListener { - #[pin] - inner: T::Listener, - map: F, -} - -impl Stream for MapErrListener -where - T: Transport, - F: FnOnce(T::Error) -> TErr + Clone, - TErr: error::Error, -{ - type Item = Result, TErr>, TErr>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let this = self.project(); - match TryStream::try_poll_next(this.inner, cx) { - Poll::Ready(Some(Ok(event))) => { - let map = &*this.map; - let event = event - .map(move |value| MapErrListenerUpgrade { - inner: value, - map: Some(map.clone()), - }) - .map_err(|err| (map.clone())(err)); - Poll::Ready(Some(Ok(event))) - } - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err((this.map.clone())(err)))), - } + let map = &*this.map; + this.transport.poll(cx).map(|ev| { + ev.map_upgrade(move |value| MapErrListenerUpgrade { + inner: value, + map: Some(map.clone()), + }) + .map_err(map.clone()) + }) } } diff --git a/core/src/transport/memory.rs b/core/src/transport/memory.rs index 40bc5d3da15..933ecb1b0df 100644 --- a/core/src/transport/memory.rs +++ b/core/src/transport/memory.rs @@ -18,10 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{ - transport::{ListenerEvent, TransportError}, - Transport, -}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use fnv::FnvHashMap; use futures::{ channel::mpsc, @@ -34,7 +31,12 @@ use lazy_static::lazy_static; use multiaddr::{Multiaddr, Protocol}; use parking_lot::Mutex; use rw_stream_sink::RwStreamSink; -use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin}; +use std::{ + collections::{hash_map::Entry, VecDeque}, + error, fmt, io, + num::NonZeroU64, + pin::Pin, +}; lazy_static! { static ref HUB: Hub = Hub(Mutex::new(FnvHashMap::default())); @@ -91,8 +93,16 @@ impl Hub { } /// Transport that supports `/memory/N` multiaddresses. -#[derive(Debug, Copy, Clone, Default)] -pub struct MemoryTransport; +#[derive(Default)] +pub struct MemoryTransport { + listeners: VecDeque>>, +} + +impl MemoryTransport { + pub fn new() -> Self { + Self::default() + } +} /// Connection to a `MemoryTransport` currently being opened. pub struct DialFuture { @@ -168,14 +178,10 @@ impl Future for DialFuture { impl Transport for MemoryTransport { type Output = Channel>; type Error = MemoryTransportError; - type Listener = Listener; type ListenerUpgrade = Ready>; type Dial = DialFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let port = if let Ok(port) = parse_memory_addr(&addr) { port } else { @@ -187,14 +193,29 @@ impl Transport for MemoryTransport { None => return Err(TransportError::Other(MemoryTransportError::Unreachable)), }; + let id = ListenerId::new(); let listener = Listener { + id, port, addr: Protocol::Memory(port.get()).into(), receiver: rx, tell_listen_addr: true, }; + self.listeners.push_back(Box::pin(listener)); - Ok(listener) + Ok(id) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) { + let listener = self.listeners.get_mut(index).unwrap(); + let val_in = HUB.unregister_port(&listener.port); + debug_assert!(val_in.is_some()); + listener.receiver.close(); + true + } else { + false + } } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -221,6 +242,56 @@ impl Transport for MemoryTransport { fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { None } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> + where + Self: Sized, + { + let mut remaining = self.listeners.len(); + while let Some(mut listener) = self.listeners.pop_back() { + if listener.tell_listen_addr { + listener.tell_listen_addr = false; + let listen_addr = listener.addr.clone(); + let listener_id = listener.id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::NewAddress { + listen_addr, + listener_id, + }); + } + + let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) { + Poll::Pending => None, + Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming { + listener_id: listener.id, + upgrade: future::ready(Ok(channel)), + local_addr: listener.addr.clone(), + send_back_addr: Protocol::Memory(dial_port.get()).into(), + }), + Poll::Ready(None) => { + // Listener was closed. + return Poll::Ready(TransportEvent::ListenerClosed { + listener_id: listener.id, + reason: Ok(()), + }); + } + }; + + self.listeners.push_front(listener); + if let Some(event) = event { + return Poll::Ready(event); + } else { + remaining -= 1; + if remaining == 0 { + break; + } + } + } + Poll::Pending + } } /// Error that can be produced from the `MemoryTransport`. @@ -245,51 +316,17 @@ impl error::Error for MemoryTransportError {} /// Listener for memory connections. pub struct Listener { + id: ListenerId, /// Port we're listening on. port: NonZeroU64, /// The address we are listening on. addr: Multiaddr, /// Receives incoming connections. receiver: ChannelReceiver, - /// Generate `ListenerEvent::NewAddress` to inform about our listen address. + /// Generate `TransportEvent::NewAddress` to inform about our listen address. tell_listen_addr: bool, } -impl Stream for Listener { - type Item = Result< - ListenerEvent>, MemoryTransportError>>, MemoryTransportError>, - MemoryTransportError, - >; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.tell_listen_addr { - self.tell_listen_addr = false; - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))); - } - - let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => panic!("Alive listeners always have a sender."), - Poll::Ready(Some(v)) => v, - }; - - let event = ListenerEvent::Upgrade { - upgrade: future::ready(Ok(channel)), - local_addr: self.addr.clone(), - remote_addr: Protocol::Memory(dial_port.get()).into(), - }; - - Poll::Ready(Some(Ok(event))) - } -} - -impl Drop for Listener { - fn drop(&mut self) { - let val_in = HUB.unregister_port(&self.port); - debug_assert!(val_in.is_some()); - } -} - /// If the address is `/memory/n`, returns the value of `n`. fn parse_memory_addr(a: &Multiaddr) -> Result { let mut protocols = a.iter(); @@ -418,28 +455,34 @@ mod tests { #[test] fn listening_twice() { let mut transport = MemoryTransport::default(); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_ok()); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_ok()); - let _listener = transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .unwrap(); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_err()); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_err()); - drop(_listener); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_ok()); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_ok()); + + let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap(); + let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap(); + + let listener_id_1 = transport.listen_on(addr_1.clone()).unwrap(); + assert!( + transport.remove_listener(listener_id_1), + "Listener doesn't exist." + ); + + let listener_id_2 = transport.listen_on(addr_1.clone()).unwrap(); + let listener_id_3 = transport.listen_on(addr_2.clone()).unwrap(); + + assert!(transport.listen_on(addr_1.clone()).is_err()); + assert!(transport.listen_on(addr_2.clone()).is_err()); + + assert!( + transport.remove_listener(listener_id_2), + "Listener doesn't exist." + ); + assert!(transport.listen_on(addr_1).is_ok()); + assert!(transport.listen_on(addr_2.clone()).is_err()); + + assert!( + transport.remove_listener(listener_id_3), + "Listener doesn't exist." + ); + assert!(transport.listen_on(addr_2).is_ok()); } #[test] @@ -456,6 +499,35 @@ mod tests { .is_ok()); } + #[test] + fn stop_listening() { + let rand_port = rand::random::().saturating_add(1); + let addr: Multiaddr = format!("/memory/{}", rand_port).parse().unwrap(); + + let mut transport = MemoryTransport::default().boxed(); + futures::executor::block_on(async { + let listener_id = transport.listen_on(addr.clone()).unwrap(); + let reported_addr = transport + .select_next_some() + .await + .into_new_address() + .expect("new address"); + assert_eq!(addr, reported_addr); + assert!(transport.remove_listener(listener_id)); + match transport.select_next_some().await { + TransportEvent::ListenerClosed { + listener_id: id, + reason, + } => { + assert_eq!(id, listener_id); + assert!(reason.is_ok()) + } + other => panic!("Unexpected transport event: {:?}", other), + } + assert!(!transport.remove_listener(listener_id)); + }) + } + #[test] fn communicating_between_dialer_and_listener() { let msg = [1, 2, 3]; @@ -466,16 +538,16 @@ mod tests { let t1_addr: Multiaddr = format!("/memory/{}", rand_port).parse().unwrap(); let cloned_t1_addr = t1_addr.clone(); - let mut t1 = MemoryTransport::default(); + let mut t1 = MemoryTransport::default().boxed(); let listener = async move { - let listener = t1.listen_on(t1_addr.clone()).unwrap(); - - let upgrade = listener - .filter_map(|ev| futures::future::ready(ListenerEvent::into_upgrade(ev.unwrap()))) - .next() - .await - .unwrap(); + t1.listen_on(t1_addr.clone()).unwrap(); + let upgrade = loop { + let event = t1.select_next_some().await; + if let Some(upgrade) = event.into_upgrade() { + break upgrade; + } + }; let mut socket = upgrade.0.await.unwrap(); @@ -504,14 +576,16 @@ mod tests { Protocol::Memory(rand::random::().saturating_add(1)).into(); let listener_addr_cloned = listener_addr.clone(); - let mut listener_transport = MemoryTransport::default(); + let mut listener_transport = MemoryTransport::default().boxed(); let listener = async move { - let mut listener = listener_transport.listen_on(listener_addr.clone()).unwrap(); - while let Some(ev) = listener.next().await { - if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { + listener_transport.listen_on(listener_addr.clone()).unwrap(); + loop { + if let TransportEvent::Incoming { send_back_addr, .. } = + listener_transport.select_next_some().await + { assert!( - remote_addr != listener_addr, + send_back_addr != listener_addr, "Expect dialer address not to equal listener address." ); return; @@ -539,14 +613,16 @@ mod tests { Protocol::Memory(rand::random::().saturating_add(1)).into(); let listener_addr_cloned = listener_addr.clone(); - let mut listener_transport = MemoryTransport::default(); + let mut listener_transport = MemoryTransport::default().boxed(); let listener = async move { - let mut listener = listener_transport.listen_on(listener_addr.clone()).unwrap(); - while let Some(ev) = listener.next().await { - if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { + listener_transport.listen_on(listener_addr.clone()).unwrap(); + loop { + if let TransportEvent::Incoming { send_back_addr, .. } = + listener_transport.select_next_some().await + { let dialer_port = - NonZeroU64::new(parse_memory_addr(&remote_addr).unwrap()).unwrap(); + NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap(); assert!( HUB.get(&dialer_port).is_some(), diff --git a/core/src/transport/optional.rs b/core/src/transport/optional.rs index cb10c35e133..2d93077659c 100644 --- a/core/src/transport/optional.rs +++ b/core/src/transport/optional.rs @@ -18,8 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{Transport, TransportError}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use multiaddr::Multiaddr; +use std::{pin::Pin, task::Context, task::Poll}; /// Transport that is possibly disabled. /// @@ -28,7 +29,8 @@ use multiaddr::Multiaddr; /// enabled (read: contains `Some`), then dialing and listening will be handled by the inner /// transport. #[derive(Debug, Copy, Clone)] -pub struct OptionalTransport(Option); +#[pin_project::pin_project] +pub struct OptionalTransport(#[pin] Option); impl OptionalTransport { /// Builds an `OptionalTransport` with the given transport in an enabled @@ -55,14 +57,10 @@ where { type Output = T::Output; type Error = T::Error; - type Listener = T::Listener; type ListenerUpgrade = T::ListenerUpgrade; type Dial = T::Dial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { if let Some(inner) = self.0.as_mut() { inner.listen_on(addr) } else { @@ -70,6 +68,14 @@ where } } + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(inner) = self.0.as_mut() { + inner.remove_listener(id) + } else { + false + } + } + fn dial(&mut self, addr: Multiaddr) -> Result> { if let Some(inner) = self.0.as_mut() { inner.dial(addr) @@ -96,4 +102,15 @@ where None } } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if let Some(inner) = self.project().0.as_pin_mut() { + inner.poll(cx) + } else { + Poll::Pending + } + } } diff --git a/core/src/transport/timeout.rs b/core/src/transport/timeout.rs index bb413cf8909..5c3867b3c01 100644 --- a/core/src/transport/timeout.rs +++ b/core/src/transport/timeout.rs @@ -25,7 +25,7 @@ // TODO: add example use crate::{ - transport::{ListenerEvent, TransportError}, + transport::{ListenerId, TransportError, TransportEvent}, Multiaddr, Transport, }; use futures::prelude::*; @@ -38,7 +38,9 @@ use std::{error, fmt, io, pin::Pin, task::Context, task::Poll, time::Duration}; /// **Note**: `listen_on` is never subject to a timeout, only the setup of each /// individual accepted connection. #[derive(Debug, Copy, Clone)] +#[pin_project::pin_project] pub struct TransportTimeout { + #[pin] inner: InnerTrans, outgoing_timeout: Duration, incoming_timeout: Duration, @@ -80,25 +82,17 @@ where { type Output = InnerTrans::Output; type Error = TransportTimeoutError; - type Listener = TimeoutListener; type ListenerUpgrade = Timeout; type Dial = Timeout; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let listener = self - .inner + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.inner .listen_on(addr) - .map_err(|err| err.map(TransportTimeoutError::Other))?; - - let listener = TimeoutListener { - inner: listener, - timeout: self.incoming_timeout, - }; + .map_err(|err| err.map(TransportTimeoutError::Other)) + } - Ok(listener) + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -129,45 +123,21 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.inner.address_translation(server, observed) } -} -// TODO: can be removed and replaced with an `impl Stream` once impl Trait is fully stable -// in Rust (https://github.com/rust-lang/rust/issues/34511) -#[pin_project::pin_project] -pub struct TimeoutListener { - #[pin] - inner: InnerStream, - timeout: Duration, -} - -impl Stream for TimeoutListener -where - InnerStream: TryStream, Error = E>, -{ - type Item = - Result, TransportTimeoutError>, TransportTimeoutError>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let this = self.project(); - - let poll_out = match TryStream::try_poll_next(this.inner, cx) { - Poll::Ready(Some(Err(err))) => { - return Poll::Ready(Some(Err(TransportTimeoutError::Other(err)))) - } - Poll::Ready(Some(Ok(v))) => v, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - }; - - let timeout = *this.timeout; - let event = poll_out - .map(move |inner_fut| Timeout { - inner: inner_fut, - timer: Delay::new(timeout), - }) - .map_err(TransportTimeoutError::Other); - - Poll::Ready(Some(Ok(event))) + let timeout = *this.incoming_timeout; + this.inner.poll(cx).map(|event| { + event + .map_upgrade(move |inner_fut| Timeout { + inner: inner_fut, + timer: Delay::new(timeout), + }) + .map_err(TransportTimeoutError::Other) + }) } } diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index 5508429754f..e2c1a9bc7e6 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -26,8 +26,8 @@ use crate::{ connection::ConnectedPoint, muxing::{StreamMuxer, StreamMuxerBox}, transport::{ - and_then::AndThen, boxed::boxed, timeout::TransportTimeout, ListenerEvent, Transport, - TransportError, + and_then::AndThen, boxed::boxed, timeout::TransportTimeout, ListenerId, Transport, + TransportError, TransportEvent, }, upgrade::{ self, apply_inbound, apply_outbound, InboundUpgrade, InboundUpgradeApply, OutboundUpgrade, @@ -287,16 +287,16 @@ where /// A authenticated and multiplexed transport, obtained from /// [`Authenticated::multiplex`]. #[derive(Clone)] -pub struct Multiplexed(T); +#[pin_project::pin_project] +pub struct Multiplexed(#[pin] T); impl Multiplexed { /// Boxes the authenticated, multiplexed transport, including /// the [`StreamMuxer`] and custom transport errors. pub fn boxed(self) -> super::Boxed<(PeerId, StreamMuxerBox)> where - T: Transport + Sized + Send + 'static, + T: Transport + Sized + Send + Unpin + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Error: Send + Sync, M: StreamMuxer + Send + Sync + 'static, @@ -331,7 +331,6 @@ where { type Output = T::Output; type Error = T::Error; - type Listener = T::Listener; type ListenerUpgrade = T::ListenerUpgrade; type Dial = T::Dial; @@ -339,6 +338,10 @@ where self.0.dial(addr) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.0.remove_listener(id) + } + fn dial_as_listener( &mut self, addr: Multiaddr, @@ -346,16 +349,20 @@ where self.0.dial_as_listener(addr) } - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { self.0.listen_on(addr) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.0.address_translation(server, observed) } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().0.poll(cx) + } } /// An inbound or outbound upgrade. @@ -365,7 +372,9 @@ type EitherUpgrade = future::Either, OutboundUpg /// /// See [`Transport::upgrade`] #[derive(Debug, Copy, Clone)] +#[pin_project::pin_project] pub struct Upgrade { + #[pin] inner: T, upgrade: U, } @@ -387,7 +396,6 @@ where { type Output = (PeerId, D); type Error = TransportUpgradeError; - type Listener = ListenerStream; type ListenerUpgrade = ListenerUpgradeFuture; type Dial = DialUpgradeFuture; @@ -402,6 +410,10 @@ where }) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.remove_listener(id) + } + fn dial_as_listener( &mut self, addr: Multiaddr, @@ -416,23 +428,31 @@ where }) } - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let stream = self - .inner + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.inner .listen_on(addr) - .map_err(|err| err.map(TransportUpgradeError::Transport))?; - Ok(ListenerStream { - stream: Box::pin(stream), - upgrade: self.upgrade.clone(), - }) + .map_err(|err| err.map(TransportUpgradeError::Transport)) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.inner.address_translation(server, observed) } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.project(); + let upgrade = this.upgrade.clone(); + this.inner.poll(cx).map(|event| { + event + .map_upgrade(move |future| ListenerUpgradeFuture { + future: Box::pin(future), + upgrade: future::Either::Left(Some(upgrade)), + }) + .map_err(TransportUpgradeError::Transport) + }) + } } /// Errors produced by a transport upgrade. @@ -532,43 +552,6 @@ where { } -/// The [`Transport::Listener`] stream of an [`Upgrade`]d transport. -pub struct ListenerStream { - stream: Pin>, - upgrade: U, -} - -impl Stream for ListenerStream -where - S: TryStream, Error = E>, - F: TryFuture, - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade, Output = D> + Clone, -{ - type Item = Result< - ListenerEvent, TransportUpgradeError>, - TransportUpgradeError, - >; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(TryStream::try_poll_next(self.stream.as_mut(), cx)) { - Some(Ok(event)) => { - let event = event - .map(move |future| ListenerUpgradeFuture { - future: Box::pin(future), - upgrade: future::Either::Left(Some(self.upgrade.clone())), - }) - .map_err(TransportUpgradeError::Transport); - Poll::Ready(Some(Ok(event))) - } - Some(Err(err)) => Poll::Ready(Some(Err(TransportUpgradeError::Transport(err)))), - None => Poll::Ready(None), - } - } -} - -impl Unpin for ListenerStream {} - /// The [`Transport::ListenerUpgrade`] future of an [`Upgrade`]d transport. pub struct ListenerUpgradeFuture where diff --git a/core/tests/transport_upgrade.rs b/core/tests/transport_upgrade.rs index 9fd1e8eaabb..f52cf2cb3d6 100644 --- a/core/tests/transport_upgrade.rs +++ b/core/tests/transport_upgrade.rs @@ -95,7 +95,8 @@ fn upgrade_pipeline() { // Gracefully close the connection to allow protocol // negotiation to complete. util::CloseMuxer::new(mplex).map_ok(move |mplex| (peer, mplex)) - }); + }) + .boxed(); let dialer_keys = identity::Keypair::generate_ed25519(); let dialer_id = dialer_keys.public().to_peer_id(); @@ -113,17 +114,18 @@ fn upgrade_pipeline() { // Gracefully close the connection to allow protocol // negotiation to complete. util::CloseMuxer::new(mplex).map_ok(move |mplex| (peer, mplex)) - }); + }) + .boxed(); let listen_addr1 = Multiaddr::from(Protocol::Memory(random::())); let listen_addr2 = listen_addr1.clone(); - let mut listener = listener_transport.listen_on(listen_addr1).unwrap(); + listener_transport.listen_on(listen_addr1).unwrap(); let server = async move { loop { - let (upgrade, _remote_addr) = - match listener.next().await.unwrap().unwrap().into_upgrade() { + let (upgrade, _send_back_addr) = + match listener_transport.select_next_some().await.into_upgrade() { Some(u) => u, None => continue, }; diff --git a/examples/chat-tokio.rs b/examples/chat-tokio.rs index 2400c8a98a5..66c25205246 100644 --- a/examples/chat-tokio.rs +++ b/examples/chat-tokio.rs @@ -45,13 +45,14 @@ use libp2p::{ mplex, noise, swarm::{dial_opts::DialOpts, NetworkBehaviourEventProcess, SwarmBuilder, SwarmEvent}, - // `TokioTcpConfig` is available through the `tcp-tokio` feature. - tcp::TokioTcpConfig, + // `TokioTcpTransport` is available through the `tcp-tokio` feature. + tcp::TokioTcpTransport, Multiaddr, NetworkBehaviour, PeerId, Transport, }; +use libp2p_tcp::GenTcpConfig; use std::error::Error; use tokio::io::{self, AsyncBufReadExt}; @@ -72,8 +73,7 @@ async fn main() -> Result<(), Box> { // Create a tokio-based TCP transport use noise for authenticated // encryption and Mplex for multiplexing of substreams on a TCP stream. - let transport = TokioTcpConfig::new() - .nodelay(true) + let transport = TokioTcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(mplex::MplexConfig::new()) diff --git a/examples/ipfs-private.rs b/examples/ipfs-private.rs index fdeed494141..113bdf988f2 100644 --- a/examples/ipfs-private.rs +++ b/examples/ipfs-private.rs @@ -44,10 +44,11 @@ use libp2p::{ noise, ping, pnet::{PnetConfig, PreSharedKey}, swarm::{NetworkBehaviourEventProcess, SwarmEvent}, - tcp::TcpConfig, + tcp::TcpTransport, yamux::YamuxConfig, Multiaddr, NetworkBehaviour, PeerId, Swarm, Transport, }; +use libp2p_tcp::GenTcpConfig; use std::{env, error::Error, fs, path::Path, str::FromStr, time::Duration}; /// Builds the transport that serves as a common ground for all connections. @@ -61,7 +62,7 @@ pub fn build_transport( let noise_config = noise::NoiseConfig::xx(noise_keys).into_authenticated(); let yamux_config = YamuxConfig::default(); - let base_transport = TcpConfig::new().nodelay(true); + let base_transport = TcpTransport::new(GenTcpConfig::default().nodelay(true)); let maybe_encrypted = match psk { Some(psk) => EitherTransport::Left( base_transport.and_then(move |socket, _| PnetConfig::new(psk).handshake(socket)), diff --git a/muxers/mplex/benches/split_send_size.rs b/muxers/mplex/benches/split_send_size.rs index ff52177529d..9c65ccc91cf 100644 --- a/muxers/mplex/benches/split_send_size.rs +++ b/muxers/mplex/benches/split_send_size.rs @@ -23,15 +23,16 @@ use async_std::task; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; -use futures::channel::oneshot; use futures::future::poll_fn; use futures::prelude::*; +use futures::{channel::oneshot, future::join}; use libp2p_core::{ identity, multiaddr::multiaddr, muxing, transport, upgrade, Multiaddr, PeerId, StreamMuxer, Transport, }; use libp2p_mplex as mplex; use libp2p_plaintext::PlainText2Config; +use libp2p_tcp::GenTcpConfig; use std::time::Duration; type BenchTransport = transport::Boxed<(PeerId, muxing::StreamMuxerBox)>; @@ -57,11 +58,13 @@ fn prepare(c: &mut Criterion) { let tcp_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1)), Tcp(0u16)]; for &size in BENCH_SIZES.iter() { tcp.throughput(Throughput::Bytes(payload.len() as u64)); - let mut trans = tcp_transport(size); + let mut receiver_trans = tcp_transport(size); + let mut sender_trans = tcp_transport(size); tcp.bench_function(format!("{}", size), |b| { b.iter(|| { run( - black_box(&mut trans), + black_box(&mut receiver_trans), + black_box(&mut sender_trans), black_box(&payload), black_box(&tcp_addr), ) @@ -74,11 +77,13 @@ fn prepare(c: &mut Criterion) { let mem_addr = multiaddr![Memory(0u64)]; for &size in BENCH_SIZES.iter() { mem.throughput(Throughput::Bytes(payload.len() as u64)); - let mut trans = mem_transport(size); + let mut receiver_trans = mem_transport(size); + let mut sender_trans = mem_transport(size); mem.bench_function(format!("{}", size), |b| { b.iter(|| { run( - black_box(&mut trans), + black_box(&mut receiver_trans), + black_box(&mut sender_trans), black_box(&payload), black_box(&mem_addr), ) @@ -89,20 +94,24 @@ fn prepare(c: &mut Criterion) { } /// Transfers the given payload between two nodes using the given transport. -fn run(transport: &mut BenchTransport, payload: &Vec, listen_addr: &Multiaddr) { - let mut listener = transport.listen_on(listen_addr.clone()).unwrap(); +fn run( + receiver_trans: &mut BenchTransport, + sender_trans: &mut BenchTransport, + payload: &Vec, + listen_addr: &Multiaddr, +) { + receiver_trans.listen_on(listen_addr.clone()).unwrap(); let (addr_sender, addr_receiver) = oneshot::channel(); let mut addr_sender = Some(addr_sender); let payload_len = payload.len(); - // Spawn the receiver. - let receiver = task::spawn(async move { + let receiver = async move { loop { - match listener.next().await.unwrap().unwrap() { - transport::ListenerEvent::NewAddress(a) => { - addr_sender.take().unwrap().send(a).unwrap(); + match receiver_trans.next().await.unwrap() { + transport::TransportEvent::NewAddress { listen_addr, .. } => { + addr_sender.take().unwrap().send(listen_addr).unwrap(); } - transport::ListenerEvent::Upgrade { upgrade, .. } => { + transport::TransportEvent::Incoming { upgrade, .. } => { let (_peer, conn) = upgrade.await.unwrap(); let mut s = poll_fn(|cx| conn.poll_event(cx)) .await @@ -127,12 +136,12 @@ fn run(transport: &mut BenchTransport, payload: &Vec, listen_addr: &Multiadd _ => panic!("Unexpected listener event"), } } - }); + }; // Spawn and block on the sender, i.e. until all data is sent. - task::block_on(async move { + let sender = async move { let addr = addr_receiver.await.unwrap(); - let (_peer, conn) = transport.dial(addr).unwrap().await.unwrap(); + let (_peer, conn) = sender_trans.dial(addr).unwrap().await.unwrap(); let mut handle = conn.open_outbound(); let mut stream = poll_fn(|cx| conn.poll_outbound(cx, &mut handle)) .await @@ -150,10 +159,10 @@ fn run(transport: &mut BenchTransport, payload: &Vec, listen_addr: &Multiadd return; } } - }); + }; // Wait for all data to be received. - task::block_on(receiver); + task::block_on(join(sender, receiver)); } fn tcp_transport(split_send_size: usize) -> BenchTransport { @@ -163,8 +172,7 @@ fn tcp_transport(split_send_size: usize) -> BenchTransport { let mut mplex = mplex::MplexConfig::default(); mplex.set_split_send_size(split_send_size); - libp2p_tcp::TcpConfig::new() - .nodelay(true) + libp2p_tcp::TcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(PlainText2Config { local_public_key }) .multiplex(mplex) diff --git a/muxers/mplex/tests/async_write.rs b/muxers/mplex/tests/async_write.rs index 96b608a68ad..59ef10feef4 100644 --- a/muxers/mplex/tests/async_write.rs +++ b/muxers/mplex/tests/async_write.rs @@ -20,7 +20,7 @@ use futures::{channel::oneshot, prelude::*}; use libp2p_core::{muxing, upgrade, Transport}; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::TcpTransport; use std::sync::Arc; #[test] @@ -32,28 +32,27 @@ fn async_write() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); - let client = listener + let client = transport .next() .await - .unwrap() - .unwrap() + .expect("some event") .into_upgrade() .unwrap() .0 @@ -71,7 +70,7 @@ fn async_write() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() + let mut transport = TcpTransport::default() .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let client = Arc::new(transport.dial(rx.await.unwrap()).unwrap().await.unwrap()); diff --git a/muxers/mplex/tests/two_peers.rs b/muxers/mplex/tests/two_peers.rs index 77e1a09997b..72051f0632c 100644 --- a/muxers/mplex/tests/two_peers.rs +++ b/muxers/mplex/tests/two_peers.rs @@ -20,7 +20,7 @@ use futures::{channel::oneshot, prelude::*}; use libp2p_core::{muxing, upgrade, Transport}; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::TcpTransport; use std::sync::Arc; #[test] @@ -32,28 +32,27 @@ fn client_to_server_outbound() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); - let client = listener + let client = transport .next() .await - .unwrap() - .unwrap() + .expect("some event") .into_upgrade() .unwrap() .0 @@ -71,8 +70,9 @@ fn client_to_server_outbound() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); let client = Arc::new(transport.dial(rx.await.unwrap()).unwrap().await.unwrap()); let mut inbound = loop { @@ -100,29 +100,28 @@ fn client_to_server_inbound() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); let client = Arc::new( - listener + transport .next() .await - .unwrap() - .unwrap() + .expect("some event") .into_upgrade() .unwrap() .0 @@ -147,8 +146,9 @@ fn client_to_server_inbound() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)) @@ -168,28 +168,27 @@ fn protocol_not_match() { let _bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); - let client = listener + let client = transport .next() .await - .unwrap() - .unwrap() + .expect("some event") .into_upgrade() .unwrap() .0 @@ -209,8 +208,9 @@ fn protocol_not_match() { // Make sure they do not connect when protocols do not match let mut mplex = libp2p_mplex::MplexConfig::new(); mplex.set_protocol_name(b"/mplextest/1.0.0"); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); match transport.dial(rx.await.unwrap()).unwrap().await { Ok(_) => { assert!(false, "Dialing should fail here as protocols do not match") diff --git a/protocols/autonat/src/behaviour.rs b/protocols/autonat/src/behaviour.rs index 86c65bb3416..f98bf5af532 100644 --- a/protocols/autonat/src/behaviour.rs +++ b/protocols/autonat/src/behaviour.rs @@ -29,9 +29,8 @@ pub use as_server::{InboundProbeError, InboundProbeEvent}; use futures_timer::Delay; use instant::Instant; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - multiaddr::Protocol, - ConnectedPoint, Endpoint, Multiaddr, PeerId, + connection::ConnectionId, multiaddr::Protocol, transport::ListenerId, ConnectedPoint, Endpoint, + Multiaddr, PeerId, }; use libp2p_request_response::{ handler::RequestResponseHandlerEvent, ProtocolSupport, RequestId, RequestResponse, diff --git a/protocols/dcutr/examples/client.rs b/protocols/dcutr/examples/client.rs index 94bed7c2d7f..dd73b7d3ac3 100644 --- a/protocols/dcutr/examples/client.rs +++ b/protocols/dcutr/examples/client.rs @@ -32,7 +32,7 @@ use libp2p::noise; use libp2p::ping::{Ping, PingConfig, PingEvent}; use libp2p::relay::v2::client::{self, Client}; use libp2p::swarm::{SwarmBuilder, SwarmEvent}; -use libp2p::tcp::TcpConfig; +use libp2p::tcp::{GenTcpConfig, TcpTransport}; use libp2p::Transport; use libp2p::{identity, NetworkBehaviour, PeerId}; use log::info; @@ -95,7 +95,10 @@ fn main() -> Result<(), Box> { let transport = OrTransport::new( relay_transport, - block_on(DnsConfig::system(TcpConfig::new().port_reuse(true))).unwrap(), + block_on(DnsConfig::system(TcpTransport::new( + GenTcpConfig::default().port_reuse(true), + ))) + .unwrap(), ) .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 5f5e468365f..a6665b36cd7 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -22,9 +22,8 @@ use crate::handler::{IdentifyHandler, IdentifyHandlerEvent, IdentifyPush}; use crate::protocol::{IdentifyInfo, ReplySubstream, UpgradeError}; use futures::prelude::*; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - multiaddr::Protocol, - ConnectedPoint, Multiaddr, PeerId, PublicKey, + connection::ConnectionId, multiaddr::Protocol, transport::ListenerId, ConnectedPoint, + Multiaddr, PeerId, PublicKey, }; use libp2p_swarm::{ dial_opts::{self, DialOpts}, @@ -517,7 +516,7 @@ mod tests { use libp2p_mplex::MplexConfig; use libp2p_noise as noise; use libp2p_swarm::{Swarm, SwarmEvent}; - use libp2p_tcp::TcpConfig; + use libp2p_tcp::{GenTcpConfig, TcpTransport}; fn transport() -> ( identity::PublicKey, @@ -528,8 +527,7 @@ mod tests { .into_authentic(&id_keys) .unwrap(); let pubkey = id_keys.public(); - let transport = TcpConfig::new() - .nodelay(true) + let transport = TcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(MplexConfig::new()) diff --git a/protocols/identify/src/protocol.rs b/protocols/identify/src/protocol.rs index 0fd50460e2d..3b86baded51 100644 --- a/protocols/identify/src/protocol.rs +++ b/protocols/identify/src/protocol.rs @@ -292,7 +292,7 @@ mod tests { upgrade::{self, apply_inbound, apply_outbound}, Transport, }; - use libp2p_tcp::TcpConfig; + use libp2p_tcp::TcpTransport; #[test] fn correct_transfer() { @@ -304,26 +304,24 @@ mod tests { let (tx, rx) = oneshot::channel(); let bg_task = async_std::task::spawn(async move { - let mut transport = TcpConfig::new(); + let mut transport = TcpTransport::default().boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); - let socket = listener + let socket = transport .next() .await - .unwrap() - .unwrap() + .expect("some event") .into_upgrade() .unwrap() .0 @@ -349,7 +347,7 @@ mod tests { }); async_std::task::block_on(async move { - let mut transport = TcpConfig::new(); + let mut transport = TcpTransport::default(); let socket = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); let info = apply_outbound(socket, IdentifyProtocol, upgrade::Version::V1) diff --git a/protocols/kad/src/behaviour.rs b/protocols/kad/src/behaviour.rs index d5b322e096a..59d63b36e9d 100644 --- a/protocols/kad/src/behaviour.rs +++ b/protocols/kad/src/behaviour.rs @@ -40,8 +40,7 @@ use crate::K_VALUE; use fnv::{FnvHashMap, FnvHashSet}; use instant::Instant; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - ConnectedPoint, Multiaddr, PeerId, + connection::ConnectionId, transport::ListenerId, ConnectedPoint, Multiaddr, PeerId, }; use libp2p_swarm::{ dial_opts::{self, DialOpts}, diff --git a/protocols/kad/src/protocol.rs b/protocols/kad/src/protocol.rs index 648f7fc9e07..656917b54f6 100644 --- a/protocols/kad/src/protocol.rs +++ b/protocols/kad/src/protocol.rs @@ -603,7 +603,7 @@ where mod tests { /*// TODO: restore - use self::libp2p_tcp::TcpConfig; + use self::libp2p_tcp::TcpTransport; use self::tokio::runtime::current_thread::Runtime; use futures::{Future, Sink, Stream}; use libp2p_core::{PeerId, PublicKey, Transport}; @@ -658,10 +658,10 @@ mod tests { let (tx, rx) = mpsc::channel(); let bg_thread = thread::spawn(move || { - let transport = TcpConfig::new().with_upgrade(KademliaProtocolConfig); + let transport = TcpTransport::default().with_upgrade(KademliaProtocolConfig); let (listener, addr) = transport - .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .listen_on( "/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); tx.send(addr).unwrap(); @@ -678,7 +678,7 @@ mod tests { let _ = rt.block_on(future).unwrap(); }); - let transport = TcpConfig::new().with_upgrade(KademliaProtocolConfig); + let transport = TcpTransport::default().with_upgrade(KademliaProtocolConfig); let future = transport .dial(rx.recv().unwrap()) diff --git a/protocols/mdns/src/behaviour.rs b/protocols/mdns/src/behaviour.rs index c844f742af9..244b2b784dd 100644 --- a/protocols/mdns/src/behaviour.rs +++ b/protocols/mdns/src/behaviour.rs @@ -25,7 +25,7 @@ use crate::MdnsConfig; use async_io::Timer; use futures::prelude::*; use if_watch::{IfEvent, IfWatcher}; -use libp2p_core::connection::ListenerId; +use libp2p_core::transport::ListenerId; use libp2p_core::{Multiaddr, PeerId}; use libp2p_swarm::{ handler::DummyConnectionHandler, ConnectionHandler, NetworkBehaviour, NetworkBehaviourAction, diff --git a/protocols/ping/src/protocol.rs b/protocols/ping/src/protocol.rs index ae60f67a858..390823ec412 100644 --- a/protocols/ping/src/protocol.rs +++ b/protocols/ping/src/protocol.rs @@ -115,9 +115,10 @@ where #[cfg(test)] mod tests { use super::*; + use futures::StreamExt; use libp2p_core::{ multiaddr::multiaddr, - transport::{memory::MemoryTransport, ListenerEvent, Transport}, + transport::{memory::MemoryTransport, Transport}, }; use rand::{thread_rng, Rng}; use std::time::Duration; @@ -125,24 +126,28 @@ mod tests { #[test] fn ping_pong() { let mem_addr = multiaddr![Memory(thread_rng().gen::())]; - let mut listener = MemoryTransport.listen_on(mem_addr).unwrap(); + let mut transport = MemoryTransport::new().boxed(); + transport.listen_on(mem_addr).unwrap(); - let listener_addr = - if let Some(Some(Ok(ListenerEvent::NewAddress(a)))) = listener.next().now_or_never() { - a - } else { - panic!("MemoryTransport not listening on an address!"); - }; + let listener_addr = transport + .select_next_some() + .now_or_never() + .and_then(|ev| ev.into_new_address()) + .expect("MemoryTransport not listening on an address!"); async_std::task::spawn(async move { - let listener_event = listener.next().await.unwrap(); - let (listener_upgrade, _) = listener_event.unwrap().into_upgrade().unwrap(); + let transport_event = transport.next().await.unwrap(); + let (listener_upgrade, _) = transport_event.into_upgrade().unwrap(); let conn = listener_upgrade.await.unwrap(); recv_ping(conn).await.unwrap(); }); async_std::task::block_on(async move { - let c = MemoryTransport.dial(listener_addr).unwrap().await.unwrap(); + let c = MemoryTransport::new() + .dial(listener_addr) + .unwrap() + .await + .unwrap(); let (_, rtt) = send_ping(c).await.unwrap(); assert!(rtt > Duration::from_secs(0)); }); diff --git a/protocols/ping/tests/ping.rs b/protocols/ping/tests/ping.rs index dbde7db608d..ac45949ced7 100644 --- a/protocols/ping/tests/ping.rs +++ b/protocols/ping/tests/ping.rs @@ -31,7 +31,7 @@ use libp2p_mplex as mplex; use libp2p_noise as noise; use libp2p_ping as ping; use libp2p_swarm::{DummyBehaviour, KeepAlive, Swarm, SwarmEvent}; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::{GenTcpConfig, TcpTransport}; use libp2p_yamux as yamux; use quickcheck::*; use rand::prelude::*; @@ -248,8 +248,7 @@ fn mk_transport(muxer: MuxerChoice) -> (PeerId, transport::Boxed<(PeerId, Stream .unwrap(); ( peer_id, - TcpConfig::new() - .nodelay(true) + TcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(match muxer { diff --git a/protocols/relay/examples/relay_v2.rs b/protocols/relay/examples/relay_v2.rs index 8a4ee914fce..25d0bb7fc94 100644 --- a/protocols/relay/examples/relay_v2.rs +++ b/protocols/relay/examples/relay_v2.rs @@ -28,7 +28,7 @@ use libp2p::multiaddr::Protocol; use libp2p::ping::{Ping, PingConfig, PingEvent}; use libp2p::relay::v2::relay::{self, Relay}; use libp2p::swarm::{Swarm, SwarmEvent}; -use libp2p::tcp::TcpConfig; +use libp2p::tcp::TcpTransport; use libp2p::Transport; use libp2p::{identity, NetworkBehaviour, PeerId}; use libp2p::{noise, Multiaddr}; @@ -46,7 +46,7 @@ fn main() -> Result<(), Box> { let local_peer_id = PeerId::from(local_key.public()); println!("Local peer id: {:?}", local_peer_id); - let tcp_transport = TcpConfig::new(); + let tcp_transport = TcpTransport::default(); let noise_keys = noise::Keypair::::new() .into_authentic(&local_key) diff --git a/protocols/relay/src/v2/client/transport.rs b/protocols/relay/src/v2/client/transport.rs index 5414353786c..d5a71e54b3e 100644 --- a/protocols/relay/src/v2/client/transport.rs +++ b/protocols/relay/src/v2/client/transport.rs @@ -23,12 +23,13 @@ use crate::v2::client::RelayedConnection; use crate::v2::RequestId; use futures::channel::mpsc; use futures::channel::oneshot; -use futures::future::{ready, BoxFuture, Future, FutureExt, Ready}; +use futures::future::{ready, BoxFuture, FutureExt, Ready}; use futures::ready; use futures::sink::SinkExt; +use futures::stream::SelectAll; use futures::stream::{Stream, StreamExt}; use libp2p_core::multiaddr::{Multiaddr, Protocol}; -use libp2p_core::transport::{ListenerEvent, TransportError}; +use libp2p_core::transport::{ListenerId, TransportError, TransportEvent}; use libp2p_core::{PeerId, Transport}; use std::collections::VecDeque; use std::pin::Pin; @@ -85,9 +86,10 @@ use thiserror::Error; /// .with(Protocol::P2pCircuit); // Signal to listen via remote relay node. /// transport.listen_on(relay_addr).unwrap(); /// ``` -#[derive(Clone)] pub struct ClientTransport { to_behaviour: mpsc::Sender, + pending_to_behaviour: VecDeque, + listeners: SelectAll, } impl ClientTransport { @@ -112,22 +114,22 @@ impl ClientTransport { /// ``` pub(crate) fn new() -> (Self, mpsc::Receiver) { let (to_behaviour, from_transport) = mpsc::channel(0); - - (ClientTransport { to_behaviour }, from_transport) + let transport = ClientTransport { + to_behaviour, + pending_to_behaviour: VecDeque::new(), + listeners: SelectAll::new(), + }; + (transport, from_transport) } } impl Transport for ClientTransport { type Output = RelayedConnection; type Error = RelayError; - type Listener = RelayListener; type ListenerUpgrade = Ready>; type Dial = RelayedDial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let (relay_peer_id, relay_addr) = match parse_relayed_multiaddr(addr)? { RelayedMultiaddr { relay_peer_id: None, @@ -147,25 +149,31 @@ impl Transport for ClientTransport { }; let (to_listener, from_behaviour) = mpsc::channel(0); - let mut to_behaviour = self.to_behaviour.clone(); - let msg_to_behaviour = Some( - async move { - to_behaviour - .send(TransportToBehaviourMsg::ListenReq { - relay_peer_id, - relay_addr, - to_listener, - }) - .await - } - .boxed(), - ); - - Ok(RelayListener { - queued_new_addresses: Default::default(), + self.pending_to_behaviour + .push_back(TransportToBehaviourMsg::ListenReq { + relay_peer_id, + relay_addr, + to_listener, + }); + + let listener_id = ListenerId::new(); + let listener = RelayListener { + listener_id, + queued_events: Default::default(), from_behaviour, - msg_to_behaviour, - }) + is_closed: false, + }; + self.listeners.push(listener); + Ok(listener_id) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) { + listener.close(Ok(())); + true + } else { + false + } } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -217,6 +225,35 @@ impl Transport for ClientTransport { fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { None } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> + where + Self: Sized, + { + loop { + if !self.pending_to_behaviour.is_empty() { + match self.to_behaviour.poll_ready(cx) { + Poll::Ready(Ok(())) => { + let msg = self + .pending_to_behaviour + .pop_front() + .expect("Called !is_empty()."); + let _ = self.to_behaviour.start_send(msg); + continue; + } + Poll::Ready(Err(_)) => unreachable!("Receiver is never dropped."), + Poll::Pending => {} + } + } + match self.listeners.poll_next_unpin(cx) { + Poll::Ready(Some(event)) => return Poll::Ready(event), + _ => return Poll::Pending, + } + } + } } #[derive(Default)] @@ -282,64 +319,77 @@ fn parse_relayed_multiaddr( } pub struct RelayListener { - queued_new_addresses: VecDeque, + listener_id: ListenerId, + queued_events: VecDeque<::Item>, from_behaviour: mpsc::Receiver, - msg_to_behaviour: Option>>, + is_closed: bool, } -impl Unpin for RelayListener {} +impl RelayListener { + fn close(&mut self, reason: Result<(), RelayError>) { + self.queued_events + .push_back(TransportEvent::ListenerClosed { + listener_id: self.listener_id, + reason, + }); + self.is_closed = true; + } +} impl Stream for RelayListener { - type Item = - Result>, RelayError>, RelayError>; + type Item = TransportEvent<::ListenerUpgrade, RelayError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - if let Some(msg) = &mut self.msg_to_behaviour { - match Future::poll(msg.as_mut(), cx) { - Poll::Ready(Ok(())) => self.msg_to_behaviour = None, - Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))), - Poll::Pending => {} - } + if let Some(event) = self.queued_events.pop_front() { + return Poll::Ready(Some(event)); } - if let Some(addr) = self.queued_new_addresses.pop_front() { - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(addr)))); + if self.is_closed { + return Poll::Ready(None); } let msg = match ready!(self.from_behaviour.poll_next_unpin(cx)) { Some(msg) => msg, None => { // Sender of `from_behaviour` has been dropped, signaling listener to close. - return Poll::Ready(None); + self.close(Ok(())); + continue; } }; - let result = match msg { + match msg { ToListenerMsg::Reservation(Ok(Reservation { addrs })) => { debug_assert!( - self.queued_new_addresses.is_empty(), + self.queued_events.is_empty(), "Assert empty due to previous `pop_front` attempt." ); // Returned as [`ListenerEvent::NewAddress`] in next iteration of loop. - self.queued_new_addresses = addrs.into(); - - continue; + self.queued_events = addrs + .into_iter() + .map(|listen_addr| TransportEvent::NewAddress { + listener_id: self.listener_id, + listen_addr, + }) + .collect(); } ToListenerMsg::IncomingRelayedConnection { stream, src_peer_id, relay_addr, relay_peer_id: _, - } => Ok(ListenerEvent::Upgrade { - upgrade: ready(Ok(stream)), - local_addr: relay_addr.with(Protocol::P2pCircuit), - remote_addr: Protocol::P2p(src_peer_id.into()).into(), - }), - ToListenerMsg::Reservation(Err(())) => Err(RelayError::Reservation), + } => { + let listener_id = self.listener_id; + + self.queued_events.push_back(TransportEvent::Incoming { + upgrade: ready(Ok(stream)), + listener_id, + local_addr: relay_addr.with(Protocol::P2pCircuit), + send_back_addr: Protocol::P2p(src_peer_id.into()).into(), + }) + } + ToListenerMsg::Reservation(Err(())) => self.close(Err(RelayError::Reservation)), }; - - return Poll::Ready(Some(result)); } } } diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index 6cd6a732d4e..8cbc06e7444 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -32,7 +32,7 @@ use libp2p_core::{ use libp2p_noise::{Keypair, NoiseConfig, X25519Spec}; use libp2p_request_response::*; use libp2p_swarm::{Swarm, SwarmEvent}; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::{GenTcpConfig, TcpTransport}; use rand::{self, Rng}; use std::{io, iter}; @@ -300,8 +300,7 @@ fn mk_transport() -> (PeerId, transport::Boxed<(PeerId, StreamMuxerBox)>) { .unwrap(); ( peer_id, - TcpConfig::new() - .nodelay(true) + TcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(libp2p_yamux::YamuxConfig::default()) diff --git a/src/bandwidth.rs b/src/bandwidth.rs index 2e22d73b163..a58eec95ddb 100644 --- a/src/bandwidth.rs +++ b/src/bandwidth.rs @@ -20,7 +20,7 @@ use crate::{ core::{ - transport::{ListenerEvent, TransportError}, + transport::{TransportError, TransportEvent}, Transport, }, Multiaddr, @@ -31,6 +31,7 @@ use futures::{ prelude::*, ready, }; +use libp2p_core::transport::ListenerId; use std::{ convert::TryFrom as _, io, @@ -45,7 +46,9 @@ use std::{ /// Wraps around a `Transport` and counts the number of bytes that go through all the opened /// connections. #[derive(Clone)] +#[pin_project::pin_project] pub struct BandwidthLogging { + #[pin] inner: TInner, sinks: Arc, } @@ -73,18 +76,32 @@ where { type Output = BandwidthConnecLogging; type Error = TInner::Error; - type Listener = BandwidthListener; type ListenerUpgrade = BandwidthFuture; type Dial = BandwidthFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let sinks = self.sinks.clone(); - self.inner - .listen_on(addr) - .map(move |inner| BandwidthListener { inner, sinks }) + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.project(); + match this.inner.poll(cx) { + Poll::Ready(event) => { + let event = event.map_upgrade({ + let sinks = this.sinks.clone(); + |inner| BandwidthFuture { inner, sinks } + }); + Poll::Ready(event) + } + Poll::Pending => Poll::Pending, + } + } + + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.inner.listen_on(addr) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -109,39 +126,6 @@ where } } -/// Wraps around a `Stream` that produces connections. Wraps each connection around a bandwidth -/// counter. -#[pin_project::pin_project] -pub struct BandwidthListener { - #[pin] - inner: TInner, - sinks: Arc, -} - -impl Stream for BandwidthListener -where - TInner: TryStream, Error = TErr>, -{ - type Item = Result, TErr>, TErr>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - let event = if let Some(event) = ready!(this.inner.try_poll_next(cx)?) { - event - } else { - return Poll::Ready(None); - }; - - let event = event.map({ - let sinks = this.sinks.clone(); - |inner| BandwidthFuture { inner, sinks } - }); - - Poll::Ready(Some(Ok(event))) - } -} - /// Wraps around a `Future` that produces a connection. Wraps the connection around a bandwidth /// counter. #[pin_project::pin_project] diff --git a/src/lib.rs b/src/lib.rs index 18492d4f1e2..8437fd65a82 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -206,9 +206,15 @@ pub async fn development_transport( keypair: identity::Keypair, ) -> std::io::Result> { let transport = { - let dns_tcp = dns::DnsConfig::system(tcp::TcpConfig::new().nodelay(true)).await?; + let dns_tcp = dns::DnsConfig::system(tcp::TcpTransport::new( + tcp::GenTcpConfig::new().nodelay(true), + )) + .await?; let ws_dns_tcp = websocket::WsConfig::new( - dns::DnsConfig::system(tcp::TcpConfig::new().nodelay(true)).await?, + dns::DnsConfig::system(tcp::TcpTransport::new( + tcp::GenTcpConfig::new().nodelay(true), + )) + .await?, ); dns_tcp.or_transport(ws_dns_tcp) }; @@ -264,9 +270,11 @@ pub fn tokio_development_transport( keypair: identity::Keypair, ) -> std::io::Result> { let transport = { - let dns_tcp = dns::TokioDnsConfig::system(tcp::TokioTcpConfig::new().nodelay(true))?; + let dns_tcp = dns::TokioDnsConfig::system(tcp::TokioTcpTransport::new( + tcp::GenTcpConfig::new().nodelay(true), + ))?; let ws_dns_tcp = websocket::WsConfig::new(dns::TokioDnsConfig::system( - tcp::TokioTcpConfig::new().nodelay(true), + tcp::TokioTcpTransport::new(tcp::GenTcpConfig::new().nodelay(true)), )?); dns_tcp.or_transport(ws_dns_tcp) }; diff --git a/swarm-derive/src/lib.rs b/swarm-derive/src/lib.rs index e81bd8ae4d1..1216add96c0 100644 --- a/swarm-derive/src/lib.rs +++ b/swarm-derive/src/lib.rs @@ -57,7 +57,7 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { let connection_id = quote! {::libp2p::core::connection::ConnectionId}; let dial_errors = quote! {Option<&Vec<::libp2p::core::Multiaddr>>}; let connected_point = quote! {::libp2p::core::ConnectedPoint}; - let listener_id = quote! {::libp2p::core::connection::ListenerId}; + let listener_id = quote! {::libp2p::core::transport::ListenerId}; let dial_error = quote! {::libp2p::swarm::DialError}; let poll_parameters = quote! {::libp2p::swarm::PollParameters}; diff --git a/swarm/src/behaviour.rs b/swarm/src/behaviour.rs index d09427fbc6b..3dd6ddf9588 100644 --- a/swarm/src/behaviour.rs +++ b/swarm/src/behaviour.rs @@ -25,8 +25,7 @@ use crate::dial_opts::DialOpts; use crate::handler::{ConnectionHandler, IntoConnectionHandler}; use crate::{AddressRecord, AddressScore, DialError}; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - ConnectedPoint, Multiaddr, PeerId, + connection::ConnectionId, transport::ListenerId, ConnectedPoint, Multiaddr, PeerId, }; use std::{task::Context, task::Poll}; diff --git a/swarm/src/behaviour/either.rs b/swarm/src/behaviour/either.rs index 479534f6a8f..54e60e77b3a 100644 --- a/swarm/src/behaviour/either.rs +++ b/swarm/src/behaviour/either.rs @@ -25,8 +25,7 @@ use crate::{ }; use either::Either; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - ConnectedPoint, Multiaddr, PeerId, + connection::ConnectionId, transport::ListenerId, ConnectedPoint, Multiaddr, PeerId, }; use std::{task::Context, task::Poll}; diff --git a/swarm/src/behaviour/toggle.rs b/swarm/src/behaviour/toggle.rs index 25183b932c7..50ea6487770 100644 --- a/swarm/src/behaviour/toggle.rs +++ b/swarm/src/behaviour/toggle.rs @@ -29,8 +29,9 @@ use crate::{ }; use either::Either; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, + connection::ConnectionId, either::{EitherError, EitherOutput}, + transport::ListenerId, upgrade::{DeniedUpgrade, EitherUpgrade}, ConnectedPoint, Multiaddr, PeerId, }; diff --git a/swarm/src/connection.rs b/swarm/src/connection.rs index e49c2ae52f6..87805c74d80 100644 --- a/swarm/src/connection.rs +++ b/swarm/src/connection.rs @@ -20,7 +20,6 @@ mod error; mod handler_wrapper; -mod listeners; mod substream; pub(crate) mod pool; @@ -29,7 +28,6 @@ pub use error::{ ConnectionError, PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError, }; -pub use listeners::{ListenersEvent, ListenersStream}; pub use pool::{ConnectionCounters, ConnectionLimits}; pub use pool::{EstablishedConnection, PendingConnection}; pub use substream::{Close, Substream, SubstreamEndpoint}; diff --git a/swarm/src/connection/listeners.rs b/swarm/src/connection/listeners.rs deleted file mode 100644 index 484a36dc15d..00000000000 --- a/swarm/src/connection/listeners.rs +++ /dev/null @@ -1,554 +0,0 @@ -// Copyright 2018 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Manage listening on multiple multiaddresses at once. - -use crate::{ - transport::{ListenerEvent, TransportError}, - Multiaddr, Transport, -}; -use futures::{prelude::*, task::Context, task::Poll}; -use libp2p_core::connection::ListenerId; -use log::debug; -use smallvec::SmallVec; -use std::{collections::VecDeque, fmt, mem, pin::Pin}; - -/// Implementation of `futures::Stream` that allows listening on multiaddresses. -/// -/// To start using a [`ListenersStream`], create one with [`ListenersStream::new`] by passing an -/// implementation of [`Transport`]. This [`Transport`] will be used to start listening, therefore -/// you want to pass a [`Transport`] that supports the protocols you wish you listen on. -/// -/// Then, call [`ListenersStream::listen_on`] for all addresses you want to start listening on. -/// -/// The [`ListenersStream`] never ends and never produces errors. If a listener errors or closes, an -/// event is generated on the stream and the listener is then dropped, but the [`ListenersStream`] -/// itself continues. -pub struct ListenersStream -where - TTrans: Transport, -{ - /// Transport used to spawn listeners. - transport: TTrans, - /// All the active listeners. - /// The `Listener` struct contains a stream that we want to be pinned. Since the `VecDeque` - /// can be resized, the only way is to use a `Pin>`. - listeners: VecDeque>>>, - /// The next listener ID to assign. - next_id: ListenerId, - /// Pending listeners events to return from [`ListenersStream::poll`]. - pending_events: VecDeque>, -} - -/// A single active listener. -#[pin_project::pin_project] -#[derive(Debug)] -struct Listener -where - TTrans: Transport, -{ - /// The ID of this listener. - id: ListenerId, - /// The object that actually listens. - #[pin] - listener: TTrans::Listener, - /// Addresses it is listening on. - addresses: SmallVec<[Multiaddr; 4]>, -} - -/// Event that can happen on the `ListenersStream`. -pub enum ListenersEvent -where - TTrans: Transport, -{ - /// A new address is being listened on. - NewAddress { - /// The listener that is listening on the new address. - listener_id: ListenerId, - /// The new address that is being listened on. - listen_addr: Multiaddr, - }, - /// An address is no longer being listened on. - AddressExpired { - /// The listener that is no longer listening on the address. - listener_id: ListenerId, - /// The new address that is being listened on. - listen_addr: Multiaddr, - }, - /// A connection is incoming on one of the listeners. - Incoming { - /// The listener that produced the upgrade. - listener_id: ListenerId, - /// The produced upgrade. - upgrade: TTrans::ListenerUpgrade, - /// Local connection address. - local_addr: Multiaddr, - /// Address used to send back data to the incoming client. - send_back_addr: Multiaddr, - }, - /// A listener closed. - Closed { - /// The ID of the listener that closed. - listener_id: ListenerId, - /// The addresses that the listener was listening on. - addresses: Vec, - /// Reason for the closure. Contains `Ok(())` if the stream produced `None`, or `Err` - /// if the stream produced an error. - reason: Result<(), TTrans::Error>, - }, - /// A listener errored. - /// - /// The listener will continue to be polled for new events and the event - /// is for informational purposes only. - Error { - /// The ID of the listener that errored. - listener_id: ListenerId, - /// The error value. - error: TTrans::Error, - }, -} - -impl ListenersStream -where - TTrans: Transport, -{ - /// Starts a new stream of listeners. - pub fn new(transport: TTrans) -> Self { - ListenersStream { - transport, - listeners: VecDeque::new(), - next_id: ListenerId::new(1), - pending_events: VecDeque::new(), - } - } - - /// Start listening on a multiaddress. - /// - /// Returns an error if the transport doesn't support the given multiaddress. - pub fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let listener = self.transport.listen_on(addr)?; - self.listeners.push_back(Box::pin(Listener { - id: self.next_id, - listener, - addresses: SmallVec::new(), - })); - let id = self.next_id; - self.next_id = self.next_id + 1; - Ok(id) - } - - /// Remove the listener matching the given `ListenerId`. - /// - /// Returns `true` if there was a listener with this ID, `false` - /// otherwise. - pub fn remove_listener(&mut self, id: ListenerId) -> bool { - if let Some(i) = self.listeners.iter().position(|l| l.id == id) { - let mut listener = self - .listeners - .remove(i) - .expect("Index can not be out of bounds."); - let listener_project = listener.as_mut().project(); - let addresses = mem::take(listener_project.addresses).into_vec(); - self.pending_events.push_back(ListenersEvent::Closed { - listener_id: *listener_project.id, - addresses, - reason: Ok(()), - }); - true - } else { - false - } - } - - /// Returns a reference to the transport passed when building this object. - pub fn transport(&self) -> &TTrans { - &self.transport - } - - /// Returns a mutable reference to the transport passed when building this object. - pub fn transport_mut(&mut self) -> &mut TTrans { - &mut self.transport - } - - /// Returns an iterator that produces the list of addresses we're listening on. - pub fn listen_addrs(&self) -> impl Iterator { - self.listeners.iter().flat_map(|l| l.addresses.iter()) - } - - /// Provides an API similar to `Stream`, except that it cannot end. - pub fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Return pending events from closed listeners. - if let Some(event) = self.pending_events.pop_front() { - return Poll::Ready(event); - } - // We remove each element from `listeners` one by one and add them back. - let mut remaining = self.listeners.len(); - while let Some(mut listener) = self.listeners.pop_back() { - let mut listener_project = listener.as_mut().project(); - match TryStream::try_poll_next(listener_project.listener.as_mut(), cx) { - Poll::Pending => { - self.listeners.push_front(listener); - remaining -= 1; - if remaining == 0 { - break; - } - } - Poll::Ready(Some(Ok(ListenerEvent::Upgrade { - upgrade, - local_addr, - remote_addr, - }))) => { - let id = *listener_project.id; - self.listeners.push_front(listener); - return Poll::Ready(ListenersEvent::Incoming { - listener_id: id, - upgrade, - local_addr, - send_back_addr: remote_addr, - }); - } - Poll::Ready(Some(Ok(ListenerEvent::NewAddress(a)))) => { - if listener_project.addresses.contains(&a) { - debug!("Transport has reported address {} multiple times", a) - } else { - listener_project.addresses.push(a.clone()); - } - let id = *listener_project.id; - self.listeners.push_front(listener); - return Poll::Ready(ListenersEvent::NewAddress { - listener_id: id, - listen_addr: a, - }); - } - Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(a)))) => { - listener_project.addresses.retain(|x| x != &a); - let id = *listener_project.id; - self.listeners.push_front(listener); - return Poll::Ready(ListenersEvent::AddressExpired { - listener_id: id, - listen_addr: a, - }); - } - Poll::Ready(Some(Ok(ListenerEvent::Error(error)))) => { - let id = *listener_project.id; - self.listeners.push_front(listener); - return Poll::Ready(ListenersEvent::Error { - listener_id: id, - error, - }); - } - Poll::Ready(None) => { - let addresses = mem::take(listener_project.addresses).into_vec(); - return Poll::Ready(ListenersEvent::Closed { - listener_id: *listener_project.id, - addresses, - reason: Ok(()), - }); - } - Poll::Ready(Some(Err(err))) => { - let addresses = mem::take(listener_project.addresses).into_vec(); - return Poll::Ready(ListenersEvent::Closed { - listener_id: *listener_project.id, - addresses, - reason: Err(err), - }); - } - } - } - - // We register the current task to be woken up if a new listener is added. - Poll::Pending - } -} - -impl Stream for ListenersStream -where - TTrans: Transport, -{ - type Item = ListenersEvent; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ListenersStream::poll(self, cx).map(Option::Some) - } -} - -impl Unpin for ListenersStream where TTrans: Transport {} - -impl fmt::Debug for ListenersStream -where - TTrans: Transport + fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - f.debug_struct("ListenersStream") - .field("transport", &self.transport) - .field("listen_addrs", &self.listen_addrs().collect::>()) - .finish() - } -} - -impl fmt::Debug for ListenersEvent -where - TTrans: Transport, - TTrans::Error: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self { - ListenersEvent::NewAddress { - listener_id, - listen_addr, - } => f - .debug_struct("ListenersEvent::NewAddress") - .field("listener_id", listener_id) - .field("listen_addr", listen_addr) - .finish(), - ListenersEvent::AddressExpired { - listener_id, - listen_addr, - } => f - .debug_struct("ListenersEvent::AddressExpired") - .field("listener_id", listener_id) - .field("listen_addr", listen_addr) - .finish(), - ListenersEvent::Incoming { - listener_id, - local_addr, - .. - } => f - .debug_struct("ListenersEvent::Incoming") - .field("listener_id", listener_id) - .field("local_addr", local_addr) - .finish(), - ListenersEvent::Closed { - listener_id, - addresses, - reason, - } => f - .debug_struct("ListenersEvent::Closed") - .field("listener_id", listener_id) - .field("addresses", addresses) - .field("reason", reason) - .finish(), - ListenersEvent::Error { listener_id, error } => f - .debug_struct("ListenersEvent::Error") - .field("listener_id", listener_id) - .field("error", error) - .finish(), - } - } -} - -#[cfg(test)] -mod tests { - use futures::{future::BoxFuture, stream::BoxStream}; - - use super::*; - use crate::transport; - - #[test] - fn incoming_event() { - async_std::task::block_on(async move { - let mut mem_transport = transport::MemoryTransport::default(); - - let mut listeners = ListenersStream::new(mem_transport); - listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - - let address = { - let event = listeners.next().await.unwrap(); - if let ListenersEvent::NewAddress { listen_addr, .. } = event { - listen_addr - } else { - panic!("Was expecting the listen address to be reported") - } - }; - - let address2 = address.clone(); - async_std::task::spawn(async move { - mem_transport.dial(address2).unwrap().await.unwrap(); - }); - - match listeners.next().await.unwrap() { - ListenersEvent::Incoming { - local_addr, - send_back_addr, - .. - } => { - assert_eq!(local_addr, address); - assert!(send_back_addr != address); - } - _ => panic!(), - } - }); - } - - #[test] - fn listener_event_error_isnt_fatal() { - // Tests that a listener continues to be polled even after producing - // a `ListenerEvent::Error`. - - #[derive(Clone)] - struct DummyTrans; - impl transport::Transport for DummyTrans { - type Output = (); - type Error = std::io::Error; - type Listener = BoxStream< - 'static, - Result, std::io::Error>, - >; - type ListenerUpgrade = BoxFuture<'static, Result>; - type Dial = BoxFuture<'static, Result>; - - fn listen_on( - &mut self, - _: Multiaddr, - ) -> Result> { - Ok(Box::pin(stream::unfold((), |()| async move { - Some(( - Ok(ListenerEvent::Error(std::io::Error::from( - std::io::ErrorKind::Other, - ))), - (), - )) - }))) - } - - fn dial( - &mut self, - _: Multiaddr, - ) -> Result> { - panic!() - } - - fn dial_as_listener( - &mut self, - _: Multiaddr, - ) -> Result> { - panic!() - } - - fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { - None - } - } - - async_std::task::block_on(async move { - let transport = DummyTrans; - let mut listeners = ListenersStream::new(transport); - listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - - for _ in 0..10 { - match listeners.next().await.unwrap() { - ListenersEvent::Error { .. } => {} - _ => panic!(), - } - } - }); - } - - #[test] - fn listener_error_is_fatal() { - // Tests that a listener stops after producing an error on the stream itself. - - #[derive(Clone)] - struct DummyTrans; - impl transport::Transport for DummyTrans { - type Output = (); - type Error = std::io::Error; - type Listener = BoxStream< - 'static, - Result, std::io::Error>, - >; - type ListenerUpgrade = BoxFuture<'static, Result>; - type Dial = BoxFuture<'static, Result>; - - fn listen_on( - &mut self, - _: Multiaddr, - ) -> Result> { - Ok(Box::pin(stream::unfold((), |()| async move { - Some((Err(std::io::Error::from(std::io::ErrorKind::Other)), ())) - }))) - } - - fn dial( - &mut self, - _: Multiaddr, - ) -> Result> { - panic!() - } - - fn dial_as_listener( - &mut self, - _: Multiaddr, - ) -> Result> { - panic!() - } - - fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { - None - } - } - - async_std::task::block_on(async move { - let transport = DummyTrans; - let mut listeners = ListenersStream::new(transport); - listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - - match listeners.next().await.unwrap() { - ListenersEvent::Closed { .. } => {} - _ => panic!(), - } - }); - } - - #[test] - fn listener_closed() { - async_std::task::block_on(async move { - let mem_transport = transport::MemoryTransport::default(); - - let mut listeners = ListenersStream::new(mem_transport); - let id = listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - - let event = listeners.next().await.unwrap(); - let addr; - if let ListenersEvent::NewAddress { listen_addr, .. } = event { - addr = listen_addr - } else { - panic!("Was expecting the listen address to be reported") - } - - assert!(listeners.remove_listener(id)); - - match listeners.next().await.unwrap() { - ListenersEvent::Closed { - listener_id, - addresses, - reason: Ok(()), - } => { - assert_eq!(listener_id, id); - assert!(addresses.contains(&addr)); - } - other => panic!("Unexpected listeners event: {:?}", other), - } - }); - } -} diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index 90b158dbf1e..f728634ae09 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -79,23 +79,23 @@ pub use handler::{ pub use registry::{AddAddressResult, AddressRecord, AddressScore}; use connection::pool::{Pool, PoolConfig, PoolEvent}; -use connection::{EstablishedConnection, IncomingInfo, ListenersEvent, ListenersStream, Substream}; +use connection::{EstablishedConnection, IncomingInfo, Substream}; use dial_opts::{DialOpts, PeerCondition}; use either::Either; use futures::{executor::ThreadPoolBuilder, prelude::*, stream::FusedStream}; use libp2p_core::connection::{ConnectionId, PendingPoint}; use libp2p_core::{ - connection::{ConnectedPoint, ListenerId}, + connection::ConnectedPoint, multiaddr::Protocol, multihash::Multihash, muxing::StreamMuxerBox, - transport::{self, TransportError}, + transport::{self, ListenerId, TransportError, TransportEvent}, upgrade::ProtocolName, Endpoint, Executor, Multiaddr, Negotiated, PeerId, Transport, }; use registry::{AddressIntoIter, Addresses}; use smallvec::SmallVec; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::iter; use std::num::{NonZeroU32, NonZeroU8, NonZeroUsize}; use std::{ @@ -258,7 +258,7 @@ where TBehaviour: NetworkBehaviour, { /// Listeners for incoming connections. - listeners: ListenersStream>, + transport: transport::Boxed<(PeerId, StreamMuxerBox)>, /// The nodes currently active. pool: Pool, transport::Boxed<(PeerId, StreamMuxerBox)>>, @@ -274,7 +274,7 @@ where supported_protocols: SmallVec<[Vec; 16]>, /// List of multiaddresses we're listening on. - listened_addrs: SmallVec<[Multiaddr; 8]>, + listened_addrs: HashMap>, /// List of multiaddresses we're listening on, after account for external IP addresses and /// similar mechanisms. @@ -326,7 +326,7 @@ where /// Listeners report their new listening addresses as [`SwarmEvent::NewListenAddr`]. /// Depending on the underlying transport, one listener may have multiple listening addresses. pub fn listen_on(&mut self, addr: Multiaddr) -> Result> { - let id = self.listeners.listen_on(addr)?; + let id = self.transport.listen_on(addr)?; self.behaviour.inject_new_listener(id); Ok(id) } @@ -335,8 +335,8 @@ where /// /// Returns `true` if there was a listener with this ID, `false` /// otherwise. - pub fn remove_listener(&mut self, id: ListenerId) -> bool { - self.listeners.remove_listener(id) + pub fn remove_listener(&mut self, listener_id: ListenerId) -> bool { + self.transport.remove_listener(listener_id) } /// Dial a known or unknown peer. @@ -445,8 +445,9 @@ where }; let mut unique_addresses = HashSet::new(); - addresses.retain(|a| { - !self.listened_addrs.contains(a) && unique_addresses.insert(a.clone()) + addresses.retain(|addr| { + !self.listened_addrs.values().flatten().any(|a| a == addr) + && unique_addresses.insert(addr.clone()) }); if addresses.is_empty() { @@ -506,11 +507,8 @@ where .map(|a| match p2p_addr(peer_id, a) { Ok(address) => { let dial = match role_override { - Endpoint::Dialer => self.listeners.transport_mut().dial(address.clone()), - Endpoint::Listener => self - .listeners - .transport_mut() - .dial_as_listener(address.clone()), + Endpoint::Dialer => self.transport.dial(address.clone()), + Endpoint::Listener => self.transport.dial_as_listener(address.clone()), }; match dial { Ok(fut) => fut @@ -545,7 +543,7 @@ where /// Returns an iterator that produces the list of addresses we're listening on. pub fn listeners(&self) -> impl Iterator { - self.listeners.listen_addrs() + self.listened_addrs.values().flatten() } /// Returns the peer ID of the swarm passed as parameter. @@ -829,12 +827,15 @@ where None } - fn handle_listeners_event( + fn handle_transport_event( &mut self, - event: ListenersEvent>, + event: TransportEvent< + as Transport>::ListenerUpgrade, + io::Error, + >, ) -> Option>> { match event { - ListenersEvent::Incoming { + TransportEvent::Incoming { listener_id: _, upgrade, local_addr, @@ -862,13 +863,14 @@ where } }; } - ListenersEvent::NewAddress { + TransportEvent::NewAddress { listener_id, listen_addr, } => { log::debug!("Listener {:?}; New address: {:?}", listener_id, listen_addr); - if !self.listened_addrs.contains(&listen_addr) { - self.listened_addrs.push(listen_addr.clone()) + let addrs = self.listened_addrs.entry(listener_id).or_default(); + if !addrs.contains(&listen_addr) { + addrs.push(listen_addr.clone()) } self.behaviour .inject_new_listen_addr(listener_id, &listen_addr); @@ -877,7 +879,7 @@ where address: listen_addr, }); } - ListenersEvent::AddressExpired { + TransportEvent::AddressExpired { listener_id, listen_addr, } => { @@ -886,7 +888,9 @@ where listener_id, listen_addr ); - self.listened_addrs.retain(|a| a != &listen_addr); + if let Some(addrs) = self.listened_addrs.get_mut(&listener_id) { + addrs.retain(|a| a != &listen_addr); + } self.behaviour .inject_expired_listen_addr(listener_id, &listen_addr); return Some(SwarmEvent::ExpiredListenAddr { @@ -894,13 +898,13 @@ where address: listen_addr, }); } - ListenersEvent::Closed { + TransportEvent::ListenerClosed { listener_id, - addresses, reason, } => { log::debug!("Listener {:?}; Closed by {:?}.", listener_id, reason); - for addr in addresses.iter() { + let addrs = self.listened_addrs.remove(&listener_id).unwrap_or_default(); + for addr in addrs.iter() { self.behaviour.inject_expired_listen_addr(listener_id, addr); } self.behaviour.inject_listener_closed( @@ -912,11 +916,11 @@ where ); return Some(SwarmEvent::ListenerClosed { listener_id, - addresses, + addresses: addrs.to_vec(), reason, }); } - ListenersEvent::Error { listener_id, error } => { + TransportEvent::Error { listener_id, error } => { self.behaviour.inject_listener_error(listener_id, &error); return Some(SwarmEvent::ListenerError { listener_id, error }); } @@ -973,11 +977,11 @@ where // // The translation is transport-specific. See [`Transport::address_translation`]. let translated_addresses = { - let transport = self.listeners.transport(); let mut addrs: Vec<_> = self - .listeners - .listen_addrs() - .filter_map(move |server| transport.address_translation(server, &address)) + .listened_addrs + .values() + .flatten() + .filter_map(|server| self.transport.address_translation(server, &address)) .collect(); // remove duplicates @@ -1059,7 +1063,7 @@ where let mut parameters = SwarmPollParameters { local_peer_id: &this.local_peer_id, supported_protocols: &this.supported_protocols, - listened_addrs: &this.listened_addrs, + listened_addrs: this.listened_addrs.values().flatten().collect(), external_addrs: &this.external_addrs, }; this.behaviour.poll(cx, &mut parameters) @@ -1092,10 +1096,10 @@ where }; // Poll the listener(s) for new connections. - match ListenersStream::poll(Pin::new(&mut this.listeners), cx) { + match Pin::new(&mut this.transport).poll(cx) { Poll::Pending => {} - Poll::Ready(listeners_event) => { - if let Some(swarm_event) = this.handle_listeners_event(listeners_event) { + Poll::Ready(transport_event) => { + if let Some(swarm_event) = this.handle_transport_event(transport_event) { return Poll::Ready(swarm_event); } @@ -1230,13 +1234,13 @@ where pub struct SwarmPollParameters<'a> { local_peer_id: &'a PeerId, supported_protocols: &'a [Vec], - listened_addrs: &'a [Multiaddr], + listened_addrs: Vec<&'a Multiaddr>, external_addrs: &'a Addresses, } impl<'a> PollParameters for SwarmPollParameters<'a> { type SupportedProtocolsIter = std::iter::Cloned>>; - type ListenedAddressesIter = std::iter::Cloned>; + type ListenedAddressesIter = std::iter::Cloned>; type ExternalAddressesIter = AddressIntoIter; fn supported_protocols(&self) -> Self::SupportedProtocolsIter { @@ -1244,7 +1248,7 @@ impl<'a> PollParameters for SwarmPollParameters<'a> { } fn listened_addresses(&self) -> Self::ListenedAddressesIter { - self.listened_addrs.iter().cloned() + self.listened_addrs.clone().into_iter().cloned() } fn external_addresses(&self) -> Self::ExternalAddressesIter { @@ -1400,11 +1404,11 @@ where Swarm { local_peer_id: self.local_peer_id, - listeners: ListenersStream::new(self.transport), + transport: self.transport, pool: Pool::new(self.local_peer_id, pool_config, self.connection_limits), behaviour: self.behaviour, supported_protocols, - listened_addrs: SmallVec::new(), + listened_addrs: HashMap::new(), external_addrs: Addresses::default(), banned_peers: HashSet::new(), banned_peer_connections: HashSet::new(), @@ -1617,7 +1621,7 @@ mod tests { use libp2p::plaintext; use libp2p::yamux; use libp2p_core::multiaddr::multiaddr; - use libp2p_core::transport::ListenerEvent; + use libp2p_core::transport::TransportEvent; use libp2p_core::Endpoint; use quickcheck::{quickcheck, Arbitrary, Gen, QuickCheck}; use rand::prelude::SliceRandom; @@ -2066,20 +2070,19 @@ mod tests { // `+ 2` to ensure a subset of addresses is dialed by network_2. let num_listen_addrs = concurrency_factor.0.get() + 2; let mut listen_addresses = Vec::new(); - let mut listeners = Vec::new(); + let mut transports = Vec::new(); for _ in 0..num_listen_addrs { - let mut listener = transport::MemoryTransport {} - .listen_on("/memory/0".parse().unwrap()) - .unwrap(); + let mut transport = transport::MemoryTransport::default().boxed(); + transport.listen_on("/memory/0".parse().unwrap()).unwrap(); - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(address) => { - listen_addresses.push(address); + match transport.select_next_some().await { + TransportEvent::NewAddress { listen_addr, .. } => { + listen_addresses.push(listen_addr); } _ => panic!("Expected `NewListenAddr` event."), } - listeners.push(listener); + transports.push(transport); } // Have swarm dial each listener and wait for each listener to receive the incoming @@ -2091,10 +2094,12 @@ mod tests { .build(), ) .unwrap(); - for mut listener in listeners.into_iter() { + for mut transport in transports.into_iter() { loop { - match futures::future::select(listener.next(), swarm.next()).await { - Either::Left((Some(Ok(ListenerEvent::Upgrade { .. })), _)) => { + match futures::future::select(transport.select_next_some(), swarm.next()) + .await + { + Either::Left((TransportEvent::Incoming { .. }, _)) => { break; } Either::Left(_) => { diff --git a/swarm/src/test.rs b/swarm/src/test.rs index e201432ec39..166e9185a47 100644 --- a/swarm/src/test.rs +++ b/swarm/src/test.rs @@ -23,9 +23,7 @@ use crate::{ PollParameters, }; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - multiaddr::Multiaddr, - ConnectedPoint, PeerId, + connection::ConnectionId, multiaddr::Multiaddr, transport::ListenerId, ConnectedPoint, PeerId, }; use std::collections::HashMap; use std::task::{Context, Poll}; diff --git a/transports/deflate/tests/test.rs b/transports/deflate/tests/test.rs index 743d464c1cf..e718fc0075d 100644 --- a/transports/deflate/tests/test.rs +++ b/transports/deflate/tests/test.rs @@ -21,7 +21,7 @@ use futures::{future, prelude::*}; use libp2p_core::{transport::Transport, upgrade}; use libp2p_deflate::DeflateConfig; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::TcpTransport; use quickcheck::{QuickCheck, RngCore, TestResult}; #[test] @@ -44,37 +44,38 @@ fn lot_of_data() { } async fn run(message1: Vec) { - let mut transport = TcpConfig::new().and_then(|conn, endpoint| { - upgrade::apply( - conn, - DeflateConfig::default(), - endpoint, - upgrade::Version::V1, - ) - }); - - let mut listener = transport + let new_transport = || { + TcpTransport::default() + .and_then(|conn, endpoint| { + upgrade::apply( + conn, + DeflateConfig::default(), + endpoint, + upgrade::Version::V1, + ) + }) + .boxed() + }; + let mut listener_trans = new_transport(); + listener_trans .listen_on("/ip4/0.0.0.0/tcp/0".parse().expect("multiaddr")) .expect("listener"); - let listen_addr = listener - .by_ref() + let listen_addr = listener_trans .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("new address"); let message2 = message1.clone(); let listener_task = async_std::task::spawn(async move { - let mut conn = listener - .filter(|e| future::ready(e.as_ref().map(|e| e.is_upgrade()).unwrap_or(false))) + let mut conn = listener_trans + .filter(|e| future::ready(e.is_upgrade())) .next() .await .expect("some event") - .expect("no error") .into_upgrade() .expect("upgrade") .0 @@ -89,7 +90,8 @@ async fn run(message1: Vec) { conn.close().await.expect("close") }); - let mut conn = transport + let mut dialer_trans = new_transport(); + let mut conn = dialer_trans .dial(listen_addr) .expect("dialer") .await diff --git a/transports/dns/src/lib.rs b/transports/dns/src/lib.rs index 45806a7b772..0ee89f78373 100644 --- a/transports/dns/src/lib.rs +++ b/transports/dns/src/lib.rs @@ -60,15 +60,23 @@ use futures::{future::BoxFuture, prelude::*}; use libp2p_core::{ connection::Endpoint, multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, TransportError}, + transport::{ListenerId, TransportError, TransportEvent}, Transport, }; use parking_lot::Mutex; use smallvec::SmallVec; #[cfg(any(feature = "async-std", feature = "tokio"))] use std::io; -use std::sync::Arc; -use std::{convert::TryFrom, error, fmt, iter, net::IpAddr, str}; +use std::{ + convert::TryFrom, + error, fmt, iter, + net::IpAddr, + ops::DerefMut, + pin::Pin, + str, + sync::Arc, + task::{Context, Poll}, +}; #[cfg(any(feature = "async-std", feature = "tokio"))] use trust_dns_resolver::system_conf; use trust_dns_resolver::{proto::xfer::dns_handle::DnsHandle, AsyncResolver, ConnectionProvider}; @@ -174,7 +182,7 @@ where impl Transport for GenDnsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send, T::Dial: Send, C: DnsHandle, @@ -182,38 +190,21 @@ where { type Output = T::Output; type Error = DnsErr; - type Listener = stream::MapErr< - stream::MapOk< - T::Listener, - fn( - ListenerEvent, - ) -> ListenerEvent, - >, - fn(T::Error) -> Self::Error, - >; type ListenerUpgrade = future::MapErr Self::Error>; type Dial = future::Either< future::MapErr Self::Error>, BoxFuture<'static, Result>, >; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let listener = self - .inner + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.inner .lock() .listen_on(addr) - .map_err(|err| err.map(DnsErr::Transport))?; - let listener = listener - .map_ok::<_, fn(_) -> _>(|event| { - event - .map(|upgr| upgr.map_err::<_, fn(_) -> _>(DnsErr::Transport)) - .map_err(DnsErr::Transport) - }) - .map_err::<_, fn(_) -> _>(DnsErr::Transport); - Ok(listener) + .map_err(|e| e.map(DnsErr::Transport)) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.lock().remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -230,11 +221,23 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.inner.lock().address_translation(server, observed) } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut inner = self.inner.lock(); + Transport::poll(Pin::new(inner.deref_mut()), cx).map(|event| { + event + .map_upgrade(|upgr| upgr.map_err::<_, fn(_) -> _>(DnsErr::Transport)) + .map_err(DnsErr::Transport) + }) + } } impl GenDnsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send, T::Dial: Send, C: DnsHandle, @@ -571,11 +574,10 @@ fn invalid_data(e: impl Into>) -> io::E #[cfg(test)] mod tests { use super::*; - use futures::{future::BoxFuture, stream::BoxStream}; + use futures::future::BoxFuture; use libp2p_core::{ multiaddr::{Multiaddr, Protocol}, - transport::ListenerEvent, - transport::TransportError, + transport::{TransportError, TransportEvent}, PeerId, Transport, }; @@ -589,20 +591,20 @@ mod tests { impl Transport for CustomTransport { type Output = (); type Error = std::io::Error; - type Listener = BoxStream< - 'static, - Result, Self::Error>, - >; type ListenerUpgrade = BoxFuture<'static, Result>; type Dial = BoxFuture<'static, Result>; fn listen_on( &mut self, _: Multiaddr, - ) -> Result> { + ) -> Result> { unreachable!() } + fn remove_listener(&mut self, _: ListenerId) -> bool { + false + } + fn dial(&mut self, addr: Multiaddr) -> Result> { // Check that all DNS components have been resolved, i.e. replaced. assert!(!addr.iter().any(|p| match p { @@ -625,13 +627,20 @@ mod tests { fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { None } + + fn poll( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + unreachable!() + } } async fn run(mut transport: GenDnsConfig) where C: DnsHandle, P: ConnectionProvider, - T: Transport + Clone + Send + 'static, + T: Transport + Clone + Send + Unpin + 'static, T::Error: Send, T::Dial: Send, { diff --git a/transports/noise/src/lib.rs b/transports/noise/src/lib.rs index f4cc85dea4a..ee609fd028d 100644 --- a/transports/noise/src/lib.rs +++ b/transports/noise/src/lib.rs @@ -40,14 +40,14 @@ //! //! ``` //! use libp2p_core::{identity, Transport, upgrade}; -//! use libp2p_tcp::TcpConfig; +//! use libp2p_tcp::TcpTransport; //! use libp2p_noise::{Keypair, X25519Spec, NoiseConfig}; //! //! # fn main() { //! let id_keys = identity::Keypair::generate_ed25519(); //! let dh_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); //! let noise = NoiseConfig::xx(dh_keys).into_authenticated(); -//! let builder = TcpConfig::new().upgrade(upgrade::Version::V1).authenticate(noise); +//! let builder = TcpTransport::default().upgrade(upgrade::Version::V1).authenticate(noise); //! // let transport = builder.multiplex(...); //! # } //! ``` diff --git a/transports/noise/tests/smoke.rs b/transports/noise/tests/smoke.rs index 5c745c9463c..0945105d211 100644 --- a/transports/noise/tests/smoke.rs +++ b/transports/noise/tests/smoke.rs @@ -24,12 +24,12 @@ use futures::{ prelude::*, }; use libp2p_core::identity; -use libp2p_core::transport::{ListenerEvent, Transport}; +use libp2p_core::transport::{self, Transport}; use libp2p_core::upgrade::{self, apply_inbound, apply_outbound, Negotiated}; use libp2p_noise::{ Keypair, NoiseConfig, NoiseError, NoiseOutput, RemoteIdentity, X25519Spec, X25519, }; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::TcpTransport; use log::info; use quickcheck::QuickCheck; use std::{convert::TryInto, io, net::TcpStream}; @@ -41,7 +41,7 @@ fn core_upgrade_compat() { let id_keys = identity::Keypair::generate_ed25519(); let dh_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); let noise = NoiseConfig::xx(dh_keys).into_authenticated(); - let _ = TcpConfig::new() + let _ = TcpTransport::default() .upgrade(upgrade::Version::V1) .authenticate(noise); } @@ -60,7 +60,7 @@ fn xx_spec() { let server_dh = Keypair::::new() .into_authentic(&server_id) .unwrap(); - let server_transport = TcpConfig::new() + let server_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -69,12 +69,13 @@ fn xx_spec() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &client_id_public)); + .and_then(move |out, _| expect_identity(out, &client_id_public)) + .boxed(); let client_dh = Keypair::::new() .into_authentic(&client_id) .unwrap(); - let client_transport = TcpConfig::new() + let client_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -83,7 +84,8 @@ fn xx_spec() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &server_id_public)); + .and_then(move |out, _| expect_identity(out, &server_id_public)) + .boxed(); run(server_transport, client_transport, messages); true @@ -105,7 +107,7 @@ fn xx() { let client_id_public = client_id.public(); let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); - let server_transport = TcpConfig::new() + let server_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -114,10 +116,11 @@ fn xx() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &client_id_public)); + .and_then(move |out, _| expect_identity(out, &client_id_public)) + .boxed(); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); - let client_transport = TcpConfig::new() + let client_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -126,7 +129,8 @@ fn xx() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &server_id_public)); + .and_then(move |out, _| expect_identity(out, &server_id_public)) + .boxed(); run(server_transport, client_transport, messages); true @@ -148,7 +152,7 @@ fn ix() { let client_id_public = client_id.public(); let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); - let server_transport = TcpConfig::new() + let server_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -157,10 +161,11 @@ fn ix() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &client_id_public)); + .and_then(move |out, _| expect_identity(out, &client_id_public)) + .boxed(); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); - let client_transport = TcpConfig::new() + let client_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -169,7 +174,8 @@ fn ix() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &server_id_public)); + .and_then(move |out, _| expect_identity(out, &server_id_public)) + .boxed(); run(server_transport, client_transport, messages); true @@ -192,7 +198,7 @@ fn ik_xx() { let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); let server_dh_public = server_dh.public().clone(); - let server_transport = TcpConfig::new() + let server_transport = TcpTransport::default() .and_then(move |output, endpoint| { if endpoint.is_listener() { Either::Left(apply_inbound(output, NoiseConfig::ik_listener(server_dh))) @@ -204,11 +210,12 @@ fn ik_xx() { )) } }) - .and_then(move |out, _| expect_identity(out, &client_id_public)); + .and_then(move |out, _| expect_identity(out, &client_id_public)) + .boxed(); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); let server_id_public2 = server_id_public.clone(); - let client_transport = TcpConfig::new() + let client_transport = TcpTransport::default() .and_then(move |output, endpoint| { if endpoint.is_dialer() { Either::Left(apply_outbound( @@ -220,7 +227,8 @@ fn ik_xx() { Either::Right(apply_inbound(output, NoiseConfig::xx(client_dh))) } }) - .and_then(move |out, _| expect_identity(out, &server_id_public2)); + .and_then(move |out, _| expect_identity(out, &server_id_public2)) + .boxed(); run(server_transport, client_transport, messages); true @@ -232,34 +240,28 @@ fn ik_xx() { type Output = (RemoteIdentity, NoiseOutput>>); -fn run(mut server_transport: T, mut client_transport: U, messages: I) -where - T: Transport>, - T::Dial: Send + 'static, - T::Listener: Send + Unpin + 'static, - T::ListenerUpgrade: Send + 'static, - U: Transport>, - U::Dial: Send + 'static, - U::Listener: Send + 'static, - U::ListenerUpgrade: Send + 'static, +fn run( + mut server: transport::Boxed>, + mut client: transport::Boxed>, + messages: I, +) where I: IntoIterator + Clone, { futures::executor::block_on(async { - let mut server: T::Listener = server_transport + server .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); let server_address = server - .try_next() + .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); let outbound_msgs = messages.clone(); let client_fut = async { - let mut client_session = client_transport + let mut client_session = client .dial(server_address.clone()) .unwrap() .await @@ -276,13 +278,12 @@ where let server_fut = async { let mut server_session = server - .try_next() + .next() .await .expect("some event") - .map(ListenerEvent::into_upgrade) - .expect("no error") - .map(|client| client.0) + .into_upgrade() .expect("listener upgrade") + .0 .await .map(|(_, session)| session) .expect("no error"); diff --git a/transports/plaintext/tests/smoke.rs b/transports/plaintext/tests/smoke.rs index ec20e8ff20e..0f14f4f2ae6 100644 --- a/transports/plaintext/tests/smoke.rs +++ b/transports/plaintext/tests/smoke.rs @@ -18,14 +18,11 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::io::{AsyncReadExt, AsyncWriteExt}; -use futures::stream::TryStreamExt; -use libp2p_core::{ - identity, - multiaddr::Multiaddr, - transport::{ListenerEvent, Transport}, - upgrade, +use futures::{ + io::{AsyncReadExt, AsyncWriteExt}, + StreamExt, }; +use libp2p_core::{identity, multiaddr::Multiaddr, transport::Transport, upgrade}; use libp2p_plaintext::PlainText2Config; use log::debug; use quickcheck::QuickCheck; @@ -45,8 +42,8 @@ fn variable_msg_length() { let client_id_public = client_id.public(); futures::executor::block_on(async { - let mut server_transport = - libp2p_core::transport::MemoryTransport {}.and_then(move |output, endpoint| { + let mut server = libp2p_core::transport::MemoryTransport::new() + .and_then(move |output, endpoint| { upgrade::apply( output, PlainText2Config { @@ -55,10 +52,11 @@ fn variable_msg_length() { endpoint, libp2p_core::upgrade::Version::V1, ) - }); + }) + .boxed(); - let mut client_transport = - libp2p_core::transport::MemoryTransport {}.and_then(move |output, endpoint| { + let mut client = libp2p_core::transport::MemoryTransport::new() + .and_then(move |output, endpoint| { upgrade::apply( output, PlainText2Config { @@ -67,31 +65,28 @@ fn variable_msg_length() { endpoint, libp2p_core::upgrade::Version::V1, ) - }); + }) + .boxed(); let server_address: Multiaddr = format!("/memory/{}", std::cmp::Ord::max(1, rand::random::())) .parse() .unwrap(); - let mut server = server_transport.listen_on(server_address.clone()).unwrap(); + server.listen_on(server_address.clone()).unwrap(); // Ignore server listen address event. let _ = server - .try_next() + .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); let client_fut = async { debug!("dialing {:?}", server_address); - let (received_server_id, mut client_channel) = client_transport - .dial(server_address) - .unwrap() - .await - .unwrap(); + let (received_server_id, mut client_channel) = + client.dial(server_address).unwrap().await.unwrap(); assert_eq!(received_server_id, server_id.public().to_peer_id()); debug!("Client: writing message."); @@ -105,13 +100,12 @@ fn variable_msg_length() { let server_fut = async { let mut server_channel = server - .try_next() + .next() .await .expect("some event") - .map(ListenerEvent::into_upgrade) + .into_upgrade() .expect("no error") - .map(|client| client.0) - .expect("listener upgrade xyz") + .0 .await .map(|(_, session)| session) .expect("no error"); diff --git a/transports/quic/Cargo.toml b/transports/quic/Cargo.toml index b2d5e1226e3..c2529f3856e 100644 --- a/transports/quic/Cargo.toml +++ b/transports/quic/Cargo.toml @@ -15,6 +15,7 @@ if-watch = "1.0.0" libp2p-core = { version = "0.33.0", path = "../../core" } parking_lot = "0.12.0" quinn-proto = { version = "0.8.2", default-features = false, features = ["tls-rustls"] } +rand = "0.8.5" rcgen = "0.9.2" ring = "0.16.20" rustls = { version = "0.20.2", default-features = false, features = ["dangerous_configuration"] } diff --git a/transports/quic/src/connection.rs b/transports/quic/src/connection.rs index f91cd9af127..0c9217b190a 100644 --- a/transports/quic/src/connection.rs +++ b/transports/quic/src/connection.rs @@ -44,7 +44,7 @@ use std::{ /// /// Contains everything needed to process a connection with a remote. /// Tied to a specific [`crate::Endpoint`]. -pub(crate) struct Connection { +pub struct Connection { /// Endpoint this connection belongs to. endpoint: Arc, /// Future whose job is to send a message to the endpoint. Only one at a time. @@ -54,7 +54,7 @@ pub(crate) struct Connection { from_endpoint: mpsc::Receiver, /// The QUIC state machine for this specific connection. - pub(crate) connection: quinn_proto::Connection, + pub connection: quinn_proto::Connection, /// Identifier for this connection according to the endpoint. Used when sending messages to /// the endpoint. connection_id: quinn_proto::ConnectionHandle, @@ -100,7 +100,7 @@ impl Connection { /// This function assumes that the [`quinn_proto::Connection`] is completely fresh and none of /// its methods has ever been called. Failure to comply might lead to logic errors and panics. // TODO: maybe abstract `to_endpoint` more and make it generic? dunno - pub(crate) fn from_quinn_connection( + pub fn from_quinn_connection( endpoint: Arc, connection: quinn_proto::Connection, connection_id: quinn_proto::ConnectionHandle, @@ -124,9 +124,9 @@ impl Connection { /// The local address which was used when the peer established the connection. /// /// Works for server connections only. - pub(crate) fn local_addr(&self) -> SocketAddr { + pub fn local_addr(&self) -> SocketAddr { debug_assert_eq!(self.connection.side(), quinn_proto::Side::Server); - let endpoint_addr = self.endpoint.local_addr; + let endpoint_addr = self.endpoint.socket_addr(); self.connection .local_ip() .map(|ip| SocketAddr::new(ip, endpoint_addr.port())) @@ -134,25 +134,25 @@ impl Connection { // In a normal case scenario this should not happen, because // we get want to get a local addr for a server connection only. tracing::error!("trying to get quinn::local_ip for a client"); - endpoint_addr + *endpoint_addr }) } /// Returns the address of the node we're connected to. // TODO: can change /!\ - pub(crate) fn remote_addr(&self) -> SocketAddr { + pub fn remote_addr(&self) -> SocketAddr { self.connection.remote_address() } /// Returns `true` if this connection is still pending. Returns `false` if we are connected to /// the remote or if the connection is closed. - pub(crate) fn is_handshaking(&self) -> bool { + pub fn is_handshaking(&self) -> bool { self.is_handshaking } /// Returns the address of the node we're connected to. /// Panics if the connection is still handshaking. - pub(crate) fn remote_peer_id(&self) -> PeerId { + pub fn remote_peer_id(&self) -> PeerId { debug_assert!(!self.is_handshaking()); let session = self.connection.crypto_session(); let identity = session @@ -171,7 +171,7 @@ impl Connection { /// Start closing the connection. A [`ConnectionEvent::ConnectionLost`] event will be /// produced in the future. - pub(crate) fn close(&mut self) { + pub fn close(&mut self) { // TODO: what if the user calls this multiple times? // We send a dummy `0` error code with no message, as the API of StreamMuxer doesn't // support this. @@ -187,7 +187,7 @@ impl Connection { /// /// If `None` is returned, then a [`ConnectionEvent::StreamAvailable`] event will later be /// produced when a substream is available. - pub(crate) fn pop_incoming_substream(&mut self) -> Option { + pub fn pop_incoming_substream(&mut self) -> Option { self.connection.streams().accept(quinn_proto::Dir::Bi) } @@ -198,7 +198,7 @@ impl Connection { /// /// If `None` is returned, then a [`ConnectionEvent::StreamOpened`] event will later be /// produced when a substream is available. - pub(crate) fn pop_outgoing_substream(&mut self) -> Option { + pub fn pop_outgoing_substream(&mut self) -> Option { self.connection.streams().open(quinn_proto::Dir::Bi) } @@ -210,7 +210,7 @@ impl Connection { /// On success, a [`quinn_proto::StreamEvent::Finished`] event will later be produced when the /// substream has been effectively closed. A [`ConnectionEvent::StreamStopped`] event can also /// be emitted. - pub(crate) fn shutdown_substream( + pub fn shutdown_substream( &mut self, id: quinn_proto::StreamId, ) -> Result<(), quinn_proto::FinishError> { @@ -220,7 +220,7 @@ impl Connection { } /// Polls the connection for an event that happend on it. - pub(crate) fn poll_event(&mut self, cx: &mut Context<'_>) -> Poll { + pub fn poll_event(&mut self, cx: &mut Context<'_>) -> Poll { // Nothing more can be done if the connection is closed. // Return `Pending` without registering the waker, essentially freezing the task forever. if self.closed.is_some() { @@ -399,7 +399,7 @@ impl Drop for Connection { /// Event generated by the [`Connection`]. #[derive(Debug)] -pub(crate) enum ConnectionEvent { +pub enum ConnectionEvent { /// Now connected to the remote and certificates are available. Connected, diff --git a/transports/quic/src/endpoint.rs b/transports/quic/src/endpoint.rs index 6d0d3ceaa94..59804a1bd96 100644 --- a/transports/quic/src/endpoint.rs +++ b/transports/quic/src/endpoint.rs @@ -28,24 +28,20 @@ //! the rest of the code only happens through channels. See the documentation of the //! [`background_task`] for a thorough description. -use crate::{connection::Connection, tls}; - -use std::net::{SocketAddr, UdpSocket}; +use crate::{connection::Connection, tls, transport}; use futures::{ channel::{mpsc, oneshot}, lock::Mutex, prelude::*, - stream::Stream, }; -use libp2p_core::multiaddr::Multiaddr; use quinn_proto::{ClientConfig as QuinnClientConfig, ServerConfig as QuinnServerConfig}; use std::{ collections::{HashMap, VecDeque}, - fmt, io, - pin::Pin, + fmt, + net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}, sync::{Arc, Weak}, - task::{Context, Poll}, + task::{Poll, Waker}, time::{Duration, Instant}, }; @@ -58,16 +54,11 @@ pub struct Config { server_config: Arc, /// The endpoint configuration to pass to `quinn_proto`. endpoint_config: Arc, - /// The [`Multiaddr`] to use to spawn the UDP socket. - multiaddr: Multiaddr, } impl Config { /// Creates a new configuration object with default values. - pub fn new( - keypair: &libp2p_core::identity::Keypair, - multiaddr: Multiaddr, - ) -> Result { + pub fn new(keypair: &libp2p_core::identity::Keypair) -> Result { let mut transport = quinn_proto::TransportConfig::default(); transport.max_concurrent_uni_streams(0u32.into()); // Can only panic if value is out of range. transport.datagram_receive_buffer_size(None); @@ -86,7 +77,6 @@ impl Config { client_config, server_config: Arc::new(server_config), endpoint_config: Default::default(), - multiaddr, }) } } @@ -100,60 +90,55 @@ pub struct Endpoint { /// See [`Endpoint::new_connections`] (just below) for a commentary about the mutex. to_endpoint: Mutex>, - /// Channel where new connections are being sent. - /// This is protected by a futures-friendly `Mutex`, meaning that receiving a connection is - /// done in two steps: locking this mutex, and grabbing the next element on the `Receiver`. - /// The only consequence of this `Mutex` is that multiple simultaneous calls to - /// [`Endpoint::poll_incoming`] are serialized. - new_connections: Mutex>, - /// Copy of [`Endpoint::to_endpoint`], except not behind a `Mutex`. Used if we want to be /// guaranteed a slot in the messages buffer. to_endpoint2: mpsc::Sender, - /// Socketaddr of the local UDP socket passed in the configuration at initialization after it - /// has potentially been modified to handle port number `0`. - pub(crate) local_addr: SocketAddr, + socket_addr: SocketAddr, } impl Endpoint { - /// Builds a new `Endpoint`. - pub fn new(config: Config) -> Result, io::Error> { - let local_socket_addr = match crate::transport::multiaddr_to_socketaddr(&config.multiaddr) { - Some(a) => a, - None => panic!(), // TODO: Err(TransportError::MultiaddrNotSupported(multiaddr)), - }; + /// Builds a new `Endpoint` that is listening on the [`SocketAddr`]. + pub fn new_bidirectional( + config: Config, + socket_addr: SocketAddr, + ) -> Result<(Arc, mpsc::Receiver), transport::Error> { + let (new_connections_tx, new_connections_rx) = mpsc::channel(1); + let endpoint = Self::new(config, socket_addr, Some(new_connections_tx))?; + Ok((endpoint, new_connections_rx)) + } - // NOT blocking, as per man:bind(2), as we pass an IP address. - let socket = std::net::UdpSocket::bind(&local_socket_addr)?; - // TODO: - /*let port_is_zero = local_socket_addr.port() == 0; - let local_socket_addr = socket.local_addr()?; - if port_is_zero { - assert_ne!(local_socket_addr.port(), 0); - assert_eq!(multiaddr.pop(), Some(Protocol::Quic)); - assert_eq!(multiaddr.pop(), Some(Protocol::Udp(0))); - multiaddr.push(Protocol::Udp(local_socket_addr.port())); - multiaddr.push(Protocol::Quic); - }*/ + /// Builds a new `Endpoint` that only supports outbound connections. + pub fn new_dialer(config: Config) -> Result, transport::Error> { + let socket_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0); + Self::new(config, socket_addr.into(), None) + } + fn new( + config: Config, + socket_addr: SocketAddr, + new_connections: Option>, + ) -> Result, transport::Error> { + // NOT blocking, as per man:bind(2), as we pass an IP address. + let socket = std::net::UdpSocket::bind(&socket_addr)?; let (to_endpoint_tx, to_endpoint_rx) = mpsc::channel(32); let to_endpoint2 = to_endpoint_tx.clone(); - let (new_connections_tx, new_connections_rx) = mpsc::channel(1); let endpoint = Arc::new(Endpoint { to_endpoint: Mutex::new(to_endpoint_tx), to_endpoint2, - new_connections: Mutex::new(new_connections_rx), - local_addr: socket.local_addr()?, + socket_addr: socket.local_addr()?, }); + let server_config = new_connections.map(|c| (c, config.server_config.clone())); + // TODO: just for testing, do proper task spawning async_global_executor::spawn(background_task( - config, + config.endpoint_config, + config.client_config, + server_config, Arc::downgrade(&endpoint), async_io::Async::::new(socket)?, - new_connections_tx, to_endpoint_rx.fuse(), )) .detach(); @@ -161,14 +146,15 @@ impl Endpoint { Ok(endpoint) } + pub fn socket_addr(&self) -> &SocketAddr { + &self.socket_addr + } + /// Asks the endpoint to start dialing the given address. /// /// Note that this method only *starts* the dialing. `Ok` is returned as soon as possible, even /// when the remote might end up being unreachable. - pub(crate) async fn dial( - &self, - addr: SocketAddr, - ) -> Result { + pub async fn dial(&self, addr: SocketAddr) -> Result { // The two `expect`s below can panic if the background task has stopped. The background // task can stop only if the `Endpoint` is destroyed or if the task itself panics. In other // words, we panic here iff a panic has already happened somewhere else, which is a @@ -183,19 +169,12 @@ impl Endpoint { rx.await.expect("background task has crashed") } - /// Tries to pop a new incoming connection from the queue. - pub(crate) fn poll_incoming(&self, cx: &mut Context) -> Poll> { - let mut connections_lock = self.new_connections.lock(); - let mut guard = futures::ready!(Pin::new(&mut connections_lock).poll(cx)); - Pin::new(&mut *guard).poll_next(cx) - } - /// Asks the endpoint to send a UDP packet. /// /// Note that this method only queues the packet and returns as soon as the packet is in queue. /// There is no guarantee that the packet will actually be sent, but considering that this is /// a UDP packet, you cannot rely on the packet being delivered anyway. - pub(crate) async fn send_udp_packet(&self, destination: SocketAddr, data: impl Into>) { + pub async fn send_udp_packet(&self, destination: SocketAddr, data: impl Into>) { let _ = self .to_endpoint .lock() @@ -213,7 +192,7 @@ impl Endpoint { /// /// If `event.is_drained()` is true, the event indicates that the connection no longer exists. /// This must therefore be the last event sent using this [`quinn_proto::ConnectionHandle`]. - pub(crate) async fn report_quinn_event( + pub async fn report_quinn_event( &self, connection_id: quinn_proto::ConnectionHandle, event: quinn_proto::EndpointEvent, @@ -234,7 +213,7 @@ impl Endpoint { /// /// This method bypasses back-pressure mechanisms and is meant to be called only from /// destructors, where waiting is not advisable. - pub(crate) fn report_quinn_event_non_block( + pub fn report_quinn_event_non_block( &self, connection_id: quinn_proto::ConnectionHandle, event: quinn_proto::EndpointEvent, @@ -251,7 +230,6 @@ impl Endpoint { assert!(result.is_ok()); } } - /// Message sent to the endpoint background task. #[derive(Debug)] enum ToEndpoint { @@ -363,17 +341,20 @@ enum ToEndpoint { /// for as long as any QUIC connection is open. /// async fn background_task( - config: Config, + endpoint_config: Arc, + client_config: quinn_proto::ClientConfig, + server_config: Option<(mpsc::Sender, Arc)>, endpoint_weak: Weak, udp_socket: async_io::Async, - mut new_connections: mpsc::Sender, mut receiver: stream::Fuse>, ) { + let (mut new_connections, server_config) = match server_config { + Some((a, b)) => (Some(a), Some(b)), + None => (None, None), + }; + // The actual QUIC state machine. - let mut endpoint = quinn_proto::Endpoint::new( - config.endpoint_config.clone(), - Some(config.server_config.clone()), - ); + let mut endpoint = quinn_proto::Endpoint::new(endpoint_config.clone(), server_config); // List of all active connections, with a sender to notify them of events. let mut alive_connections = HashMap::>::new(); @@ -393,6 +374,8 @@ async fn background_task( // code below. let mut next_packet_out: Option<(SocketAddr, Vec)> = None; + let mut new_connection_waker: Option = None; + // Main loop of the task. loop { // Start by flushing `next_packet_out`. @@ -437,7 +420,7 @@ async fn background_task( // name. While we don't use domain names, the underlying rustls library // is based upon the assumption that we do. let (connection_id, connection) = - match endpoint.connect(config.client_config.clone(), addr, "l") { + match endpoint.connect(client_config.clone(), addr, "l") { Ok(c) => c, Err(err) => { let _ = result.send(Err(err)); @@ -502,8 +485,17 @@ async fn background_task( readiness = { let active = !queued_new_connections.is_empty(); let new_connections = &mut new_connections; + let new_connection_waker = &mut new_connection_waker; future::poll_fn(move |cx| { - if active { new_connections.poll_ready(cx) } else { Poll::Pending } + match new_connections.as_mut() { + Some(ref mut c) if active => { + c.poll_ready(cx) + } + _ => { + let _ = new_connection_waker.insert(cx.waker().clone()); + Poll::Pending + } + } }) .fuse() } => { @@ -515,6 +507,7 @@ async fn background_task( let elem = queued_new_connections.pop_front() .expect("if queue is empty, the future above is always Pending; qed"); + let new_connections = new_connections.as_mut().expect("in case of None, the future above is always Pending; qed"); new_connections.start_send(elem) .expect("future is waken up only if poll_ready returned Ready; qed"); //endpoint.accept(); @@ -565,6 +558,9 @@ async fn background_task( // to the `new_connections` channel. We call `endpoint.accept()` only once // the element has successfully been sent on `new_connections`. queued_new_connections.push_back(connection); + if let Some(waker) = new_connection_waker.take() { + waker.wake(); + } }, } } diff --git a/transports/quic/src/lib.rs b/transports/quic/src/lib.rs index de2487a75ca..69a68adb895 100644 --- a/transports/quic/src/lib.rs +++ b/transports/quic/src/lib.rs @@ -25,13 +25,14 @@ //! Example: //! //! ``` -//! use libp2p_quic::{Config, Endpoint}; -//! use libp2p_core::Multiaddr; +//! use libp2p_quic::{Config, QuicTransport}; +//! use libp2p_core::{Multiaddr, Transport}; //! //! let keypair = libp2p_core::identity::Keypair::generate_ed25519(); +//! let quic_config = Config::new(&keypair).expect("could not make config"); +//! let mut quic_transport = QuicTransport::new(quic_config); //! let addr = "/ip4/127.0.0.1/udp/12345/quic".parse().expect("bad address?"); -//! let quic_config = Config::new(&keypair, addr).expect("could not make config"); -//! let quic_endpoint = Endpoint::new(quic_config).expect("I/O error"); +//! quic_transport.listen_on(addr).expect("listen error."); //! ``` //! //! The `Endpoint` struct implements the `Transport` trait of the `core` library. See the @@ -61,7 +62,7 @@ mod upgrade; pub mod transport; -pub use endpoint::{Config, Endpoint}; +pub use endpoint::Config; pub use error::Error; pub use muxer::{QuicMuxer, Substream}; pub use transport::QuicTransport; diff --git a/transports/quic/src/transport.rs b/transports/quic/src/transport.rs index 6c06c38a6ff..71528be8499 100644 --- a/transports/quic/src/transport.rs +++ b/transports/quic/src/transport.rs @@ -22,20 +22,26 @@ //! //! Combines all the objects in the other modules to implement the trait. +use crate::connection::Connection; +use crate::Config; use crate::{endpoint::Endpoint, in_addr::InAddr, muxer::QuicMuxer, upgrade::Upgrade}; -use futures::prelude::*; use futures::stream::StreamExt; +use futures::{channel::mpsc, prelude::*, stream::SelectAll}; use if_watch::IfEvent; use libp2p_core::{ multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, TransportError}, + transport::{ListenerId, TransportError, TransportEvent}, PeerId, Transport, }; -use std::task::{Context, Poll}; -use std::{net::SocketAddr, pin::Pin, sync::Arc}; +use std::{ + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; // We reexport the errors that are exposed in the API. // All of these types use one another. @@ -45,25 +51,21 @@ pub use quinn_proto::{ TransportError as QuinnTransportError, TransportErrorCode, }; -/// Wraps around an `Arc` and implements the [`Transport`] trait. -/// -/// > **Note**: This type is necessary because Rust unfortunately forbids implementing the -/// > `Transport` trait directly on `Arc`. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct QuicTransport { - endpoint: Arc, - /// The IP addresses of network interfaces on which the listening socket - /// is accepting connections. - /// - /// If the listen socket listens on all interfaces, these may change over - /// time as interfaces become available or unavailable. - in_addr: InAddr, + config: Config, + listeners: SelectAll, + /// Endpoint to use if no listener exists. + dialer: Option>, } impl QuicTransport { - pub fn new(endpoint: Arc) -> Self { - let in_addr = InAddr::new(endpoint.local_addr.ip()); - Self { endpoint, in_addr } + pub fn new(config: Config) -> Self { + Self { + listeners: SelectAll::new(), + config, + dialer: None, + } } } @@ -76,48 +78,76 @@ pub enum Error { /// Error after the remote has been reached. #[error("{0}")] Established(Libp2pQuicConnectionError), - /// Error while working with IfWatcher. + #[error("{0}")] - IfWatcher(std::io::Error), + Io(#[from] std::io::Error), + + #[error("Background task crashed.")] + TaskCrashed, } impl Transport for QuicTransport { type Output = (PeerId, QuicMuxer); type Error = Error; - // type Listener = Pin< - // Box, Self::Error>> + Send>, - // >; - type Listener = Self; type ListenerUpgrade = Upgrade; type Dial = Pin> + Send>>; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - multiaddr_to_socketaddr(&addr) - .ok_or_else(|| TransportError::MultiaddrNotSupported(addr))?; - Ok(self.clone()) + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + let socket_addr = + multiaddr_to_socketaddr(&addr).ok_or(TransportError::MultiaddrNotSupported(addr))?; + let listener_id = ListenerId::new(); + let listener = Listener::new(listener_id, socket_addr, self.config.clone()) + .map_err(TransportError::Other)?; + self.listeners.push(listener); + // Drop reference to dialer endpoint so that the endpoint is dropped once the last + // connection that uses it is closed. + // New outbound connections will use a bidirectional (listener) endpoint. + let _ = self.dialer.take(); + Ok(listener_id) } - fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) { + listener.close(Ok(())); + true + } else { + false + } + } + + fn address_translation(&self, _server: &Multiaddr, observed: &Multiaddr) -> Option { Some(observed.clone()) } fn dial(&mut self, addr: Multiaddr) -> Result> { - let socket_addr = if let Some(socket_addr) = multiaddr_to_socketaddr(&addr) { - if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() { - tracing::error!("multiaddr not supported"); - return Err(TransportError::MultiaddrNotSupported(addr)); - } - socket_addr - } else { + let socket_addr = multiaddr_to_socketaddr(&addr) + .ok_or_else(|| TransportError::MultiaddrNotSupported(addr.clone()))?; + if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() { tracing::error!("multiaddr not supported"); return Err(TransportError::MultiaddrNotSupported(addr)); + } + let endpoint = if self.listeners.is_empty() { + match self.dialer.clone() { + Some(endpoint) => endpoint, + None => { + let endpoint = + Endpoint::new_dialer(self.config.clone()).map_err(TransportError::Other)?; + let _ = self.dialer.insert(endpoint.clone()); + endpoint + } + } + } else { + // Pick a random listener to use for dialing. + // TODO: Prefer listeners with same IP version. + let n = rand::random::() % self.listeners.len(); + let listener = self + .listeners + .iter_mut() + .nth(n) + .expect("Can not be out of bound."); + listener.endpoint.clone() }; - let endpoint = self.endpoint.clone(); - Ok(async move { let connection = endpoint.dial(socket_addr).await.map_err(Error::Reach)?; let final_connec = Upgrade::from_connection(connection).await?; @@ -136,37 +166,109 @@ impl Transport for QuicTransport { // https://github.com/libp2p/specs/blob/master/relay/DCUtR.md#the-protocol self.dial(addr) } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.listeners.poll_next_unpin(cx) { + Poll::Ready(Some(ev)) => Poll::Ready(ev), + _ => Poll::Pending, + } + } } -impl Stream for QuicTransport { - type Item = Result, Error>; +#[derive(Debug)] +struct Listener { + endpoint: Arc, + + listener_id: ListenerId, - fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let me = Pin::into_inner(self); - let endpoint = me.endpoint.as_ref(); + /// Channel where new connections are being sent. + new_connections_rx: mpsc::Receiver, - // Poll for a next IfEvent - match me.in_addr.poll_next_unpin(cx) { + /// The IP addresses of network interfaces on which the listening socket + /// is accepting connections. + /// + /// If the listen socket listens on all interfaces, these may change over + /// time as interfaces become available or unavailable. + in_addr: InAddr, + + /// Set to `Some` if this [`Listener`] should close. + /// Optionally contains a [`TransportEvent::ListenerClosed`] that should be + /// reported before the listener's stream is terminated. + report_closed: Option::Item>>, +} + +impl Listener { + fn new( + listener_id: ListenerId, + socket_addr: SocketAddr, + config: Config, + ) -> Result { + let in_addr = InAddr::new(socket_addr.ip()); + let (endpoint, new_connections_rx) = Endpoint::new_bidirectional(config, socket_addr)?; + Ok(Listener { + endpoint, + listener_id, + new_connections_rx, + in_addr, + report_closed: None, + }) + } + + /// Report the listener as closed in a [`TransportEvent::ListenerClosed`] and + /// terminate the stream. + fn close(&mut self, reason: Result<(), Error>) { + match self.report_closed { + Some(_) => tracing::debug!("Listener was already closed."), + None => { + // Report the listener event as closed. + let _ = self + .report_closed + .insert(Some(TransportEvent::ListenerClosed { + listener_id: self.listener_id, + reason, + })); + } + } + } + + /// Poll for a next If Event. + fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Option<::Item> { + match self.in_addr.poll_next_unpin(cx) { Poll::Ready(mut item) => { if let Some(item) = item.take() { // Consume all events for up/down interface changes. match item { Ok(IfEvent::Up(inet)) => { let ip = inet.addr(); - if endpoint.local_addr.is_ipv4() == ip.is_ipv4() { - let socket_addr = SocketAddr::new(ip, endpoint.local_addr.port()); + if self.endpoint.socket_addr().is_ipv4() == ip.is_ipv4() { + let socket_addr = + SocketAddr::new(ip, self.endpoint.socket_addr().port()); let ma = socketaddr_to_multiaddr(&socket_addr); tracing::debug!("New listen address: {}", ma); - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(ma)))); + Some(TransportEvent::NewAddress { + listener_id: self.listener_id, + listen_addr: ma, + }) + } else { + self.poll_if_addr(cx) } } Ok(IfEvent::Down(inet)) => { let ip = inet.addr(); - if endpoint.local_addr.is_ipv4() == ip.is_ipv4() { - let socket_addr = SocketAddr::new(ip, endpoint.local_addr.port()); + if self.endpoint.socket_addr().is_ipv4() == ip.is_ipv4() { + let socket_addr = + SocketAddr::new(ip, self.endpoint.socket_addr().port()); let ma = socketaddr_to_multiaddr(&socket_addr); tracing::debug!("Expired listen address: {}", ma); - return Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(ma)))); + Some(TransportEvent::AddressExpired { + listener_id: self.listener_id, + listen_addr: ma, + }) + } else { + self.poll_if_addr(cx) } } Err(err) => { @@ -174,37 +276,56 @@ impl Stream for QuicTransport { "Failure polling interfaces: {:?}.", err }; - return Poll::Ready(Some(Ok(ListenerEvent::Error(Error::IfWatcher( - err, - ))))); + Some(TransportEvent::Error { + listener_id: self.listener_id, + error: err.into(), + }) } } + } else { + self.poll_if_addr(cx) } } - Poll::Pending => { - // continue polling endpoint - } + Poll::Pending => None, } + } +} - let connection = match endpoint.poll_incoming(cx) { - Poll::Ready(Some(connection)) => connection, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, +impl Stream for Listener { + type Item = TransportEvent<::ListenerUpgrade, Error>; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(closed) = self.report_closed.as_mut() { + // Listener was closed. + // Report the transport event if there is one. On the next iteration, return + // `Poll::Ready(None)` to terminate the stream. + return Poll::Ready(closed.take()); + } + if let Some(event) = self.poll_if_addr(cx) { + return Poll::Ready(Some(event)); + } + let connection = match futures::ready!(self.new_connections_rx.poll_next_unpin(cx)) { + Some(c) => c, + None => { + self.close(Err(Error::TaskCrashed)); + return self.poll_next(cx); + } }; + let local_addr = socketaddr_to_multiaddr(&connection.local_addr()); - let remote_addr = socketaddr_to_multiaddr(&connection.remote_addr()); - let event = ListenerEvent::Upgrade { + let send_back_addr = socketaddr_to_multiaddr(&connection.remote_addr()); + let event = TransportEvent::Incoming { upgrade: Upgrade::from_connection(connection), local_addr, - remote_addr, + send_back_addr, + listener_id: self.listener_id, }; - Poll::Ready(Some(Ok(event))) + Poll::Ready(Some(event)) } } /// Tries to turn a QUIC multiaddress into a UDP [`SocketAddr`]. Returns None if the format /// of the multiaddr is wrong. -pub(crate) fn multiaddr_to_socketaddr(addr: &Multiaddr) -> Option { +pub fn multiaddr_to_socketaddr(addr: &Multiaddr) -> Option { let mut iter = addr.iter(); let proto1 = iter.next()?; let proto2 = iter.next()?; diff --git a/transports/quic/tests/smoke.rs b/transports/quic/tests/smoke.rs index 92ec9fccbe7..dbd66815406 100644 --- a/transports/quic/tests/smoke.rs +++ b/transports/quic/tests/smoke.rs @@ -2,22 +2,22 @@ use anyhow::Result; use async_trait::async_trait; use futures::future::FutureExt; use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use futures::select; use futures::stream::StreamExt; -use libp2p::core::upgrade; +use futures::task::Spawn; +use libp2p::core::multiaddr::Protocol; +use libp2p::core::muxing::StreamMuxerBox; +use libp2p::core::{upgrade, ConnectedPoint, Transport}; use libp2p::request_response::{ ProtocolName, ProtocolSupport, RequestResponse, RequestResponseCodec, RequestResponseConfig, RequestResponseEvent, RequestResponseMessage, }; -use libp2p::swarm::{Swarm, SwarmBuilder, SwarmEvent}; -use libp2p::{Multiaddr, Transport}; -use libp2p_core::muxing::StreamMuxerBox; -use libp2p_quic::{Config as QuicConfig, Endpoint as QuicEndpoint, QuicTransport}; +use libp2p::swarm::dial_opts::{DialOpts, PeerCondition}; +use libp2p::swarm::{DialError, Swarm, SwarmEvent}; +use libp2p_quic::{Config as QuicConfig, QuicTransport}; use rand::RngCore; -use std::{io, iter}; - -use futures::task::Spawn; use std::num::NonZeroU8; -use std::time::Duration; +use std::{io, iter}; fn generate_tls_keypair() -> libp2p::identity::Keypair { libp2p::identity::Keypair::generate_ed25519() @@ -27,10 +27,8 @@ fn generate_tls_keypair() -> libp2p::identity::Keypair { async fn create_swarm(keylog: bool) -> Result>> { let keypair = generate_tls_keypair(); let peer_id = keypair.public().to_peer_id(); - let addr: Multiaddr = "/ip4/127.0.0.1/udp/0/quic".parse()?; - let config = QuicConfig::new(&keypair, addr).unwrap(); - let endpoint = QuicEndpoint::new(config).unwrap(); - let transport = QuicTransport::new(endpoint); + let config = QuicConfig::new(&keypair).unwrap(); + let transport = QuicTransport::new(config); // TODO: // transport @@ -442,3 +440,105 @@ fn concurrent_connections_and_streams() { // QuickCheck::new().quickcheck(prop as fn(_, _) -> _); } + +#[async_std::test] +async fn endpoint_reuse() -> Result<()> { + setup_global_subscriber(); + + let mut swarm_a = create_swarm(false).await?; + let mut swarm_b = create_swarm(false).await?; + let b_peer_id = *swarm_b.local_peer_id(); + + swarm_a.listen_on("/ip4/127.0.0.1/udp/0/quic".parse()?)?; + let a_addr = match swarm_a.next().await { + Some(SwarmEvent::NewListenAddr { address, .. }) => address, + e => panic!("{:?}", e), + }; + + swarm_b.dial(a_addr.clone()).unwrap(); + let b_send_back_addr = loop { + select! { + ev = swarm_a.select_next_some() => match ev { + SwarmEvent::ConnectionEstablished { endpoint, .. } => { + break endpoint.get_remote_address().clone() + } + SwarmEvent::IncomingConnection { local_addr, ..} => { + assert!(swarm_a.listeners().any(|a| a == &local_addr)); + } + e => panic!("{:?}", e), + }, + ev = swarm_b.select_next_some() => match ev { + SwarmEvent::ConnectionEstablished { .. } => {}, + e => panic!("{:?}", e), + } + } + }; + + let dial_opts = DialOpts::peer_id(b_peer_id) + .addresses(vec![b_send_back_addr.clone()]) + .extend_addresses_through_behaviour() + .condition(PeerCondition::Always) + .build(); + swarm_a.dial(dial_opts).unwrap(); + + // Expect the dial to fail since b is not listening on an address. + loop { + select! { + ev = swarm_a.select_next_some() => match ev { + SwarmEvent::ConnectionEstablished { ..} => panic!("Unexpected dial success."), + SwarmEvent::OutgoingConnectionError {error, .. } => { + assert!(matches!(error, DialError::Transport(_))); + break + } + _ => {} + }, + _ = swarm_b.select_next_some() => {}, + } + } + swarm_b.listen_on("/ip4/127.0.0.1/udp/0/quic".parse()?)?; + let b_addr = match swarm_b.next().await { + Some(SwarmEvent::NewListenAddr { address, .. }) => address, + e => panic!("{:?}", e), + }; + + let dial_opts = DialOpts::peer_id(b_peer_id) + .addresses(vec![b_addr.clone(), b_send_back_addr]) + .condition(PeerCondition::Always) + .build(); + swarm_a.dial(dial_opts).unwrap(); + let expected_b_addr = b_addr.with(Protocol::P2p(b_peer_id.into())); + + let mut a_reported = false; + let mut b_reported = false; + while !a_reported || !b_reported { + select! { + ev = swarm_a.select_next_some() => match ev{ + SwarmEvent::ConnectionEstablished { endpoint, ..} => { + assert!(endpoint.is_dialer()); + assert_eq!(endpoint.get_remote_address(), &expected_b_addr); + a_reported = true; + } + SwarmEvent::OutgoingConnectionError {error, .. } => { + panic!("Unexpected error {:}", error) + } + _ => {} + }, + ev = swarm_b.select_next_some() => match ev{ + SwarmEvent::ConnectionEstablished { endpoint, ..} => { + match endpoint { + ConnectedPoint::Dialer{..} => panic!("Unexpected outbound connection"), + ConnectedPoint::Listener {send_back_addr, local_addr} => { + // Expect that the local listening endpoint was used for dialing. + assert!(swarm_b.listeners().any(|a| a == &local_addr)); + assert_eq!(send_back_addr, a_addr); + b_reported = true; + } + } + } + _ => {} + }, + } + } + + Ok(()) +} diff --git a/transports/tcp/Cargo.toml b/transports/tcp/Cargo.toml index 67e28b35ff7..c919d633fb5 100644 --- a/transports/tcp/Cargo.toml +++ b/transports/tcp/Cargo.toml @@ -20,6 +20,8 @@ ipnet = "2.0.0" libc = "0.2.80" libp2p-core = { version = "0.33.0", path = "../../core", default-features = false } log = "0.4.11" +pin-project = "1.0.0" +smallvec = "1.6.1" socket2 = { version = "0.4.0", features = ["all"] } tokio-crate = { package = "tokio", version = "1.0.1", default-features = false, features = ["net"], optional = true } diff --git a/transports/tcp/src/lib.rs b/transports/tcp/src/lib.rs index a2650f7d216..57255b9e044 100644 --- a/transports/tcp/src/lib.rs +++ b/transports/tcp/src/lib.rs @@ -22,7 +22,7 @@ //! //! # Usage //! -//! This crate provides a `TcpConfig` and `TokioTcpConfig`, depending on +//! This crate provides a `TcpTransport` and `TokioTcpTransport`, depending on //! the enabled features, which implement the `Transport` trait for use as a //! transport with `libp2p-core` or `libp2p-swarm`. @@ -31,16 +31,16 @@ mod provider; #[cfg(feature = "async-io")] pub use provider::async_io; -/// The type of a [`GenTcpConfig`] using the `async-io` implementation. +/// The type of a [`GenTcpTransport`] using the `async-io` implementation. #[cfg(feature = "async-io")] -pub type TcpConfig = GenTcpConfig; +pub type TcpTransport = GenTcpTransport; #[cfg(feature = "tokio")] pub use provider::tokio; -/// The type of a [`GenTcpConfig`] using the `tokio` implementation. +/// The type of a [`GenTcpTransport`] using the `tokio` implementation. #[cfg(feature = "tokio")] -pub type TokioTcpConfig = GenTcpConfig; +pub type TokioTcpTransport = GenTcpTransport; use futures::{ future::{self, BoxFuture, Ready}, @@ -51,11 +51,11 @@ use futures_timer::Delay; use libp2p_core::{ address_translation, multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, Transport, TransportError}, + transport::{ListenerId, Transport, TransportError, TransportEvent}, }; use socket2::{Domain, Socket, Type}; use std::{ - collections::HashSet, + collections::{HashSet, VecDeque}, io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener}, pin::Pin, @@ -67,18 +67,16 @@ use std::{ use provider::{IfEvent, Provider}; /// The configuration for a TCP/IP transport capability for libp2p. -#[derive(Debug)] -pub struct GenTcpConfig { - /// The type of the I/O provider. - _impl: std::marker::PhantomData, +#[derive(Clone, Debug)] +pub struct GenTcpConfig { /// TTL to set for opened sockets, or `None` to keep default. ttl: Option, /// `TCP_NODELAY` to set for opened sockets, or `None` to keep default. nodelay: Option, /// Size of the listen backlog for listen sockets. backlog: u32, - /// The configuration of port reuse when dialing. - port_reuse: PortReuse, + /// Whether port reuse should be enabled. + enable_port_reuse: bool, } type Port = u16; @@ -159,10 +157,7 @@ impl PortReuse { } } -impl GenTcpConfig -where - T: Provider + Send, -{ +impl GenTcpConfig { /// Creates a new configuration for a TCP/IP transport: /// /// * Nagle's algorithm, i.e. `TCP_NODELAY`, is _enabled_. @@ -178,8 +173,7 @@ where ttl: None, nodelay: None, backlog: 1024, - port_reuse: PortReuse::Disabled, - _impl: std::marker::PhantomData, + enable_port_reuse: false, } } @@ -238,29 +232,29 @@ where /// > a single outgoing connection to a particular address and port /// > of a peer per local listening socket address. /// - /// `GenTcpConfig` keeps track of the listen socket addresses as they - /// are reported by polling [`TcpListenStream`]s obtained from - /// [`GenTcpConfig::listen_on()`]. It is possible to listen on multiple + /// [`GenTcpTransport`] keeps track of the listen socket addresses as they + /// are reported by polling it. It is possible to listen on multiple /// addresses, enabling port reuse for each, knowing exactly which listen - /// address is reused when dialing with a specific `GenTcpConfig`, as in the + /// address is reused when dialing with a specific `GenTcpTransport`, as in the /// following example: /// /// ```no_run - /// # use libp2p_core::transport::ListenerEvent; + /// # use futures::StreamExt; + /// # use libp2p_core::transport::{ListenerId, TransportEvent}; /// # use libp2p_core::{Multiaddr, Transport}; - /// # use futures::stream::StreamExt; + /// # use std::pin::Pin; /// #[cfg(feature = "async-io")] /// #[async_std::main] /// async fn main() -> std::io::Result<()> { - /// use libp2p_tcp::TcpConfig; + /// use libp2p_tcp::{GenTcpConfig, TcpTransport}; /// /// let listen_addr1: Multiaddr = "/ip4/127.0.0.1/tcp/9001".parse().unwrap(); /// let listen_addr2: Multiaddr = "/ip4/127.0.0.1/tcp/9002".parse().unwrap(); /// - /// let mut tcp1 = TcpConfig::new().port_reuse(true); - /// let mut listener1 = tcp1.listen_on(listen_addr1.clone()).expect("listener"); - /// match listener1.next().await.expect("event")? { - /// ListenerEvent::NewAddress(listen_addr) => { + /// let mut tcp1 = TcpTransport::new(GenTcpConfig::new().port_reuse(true)).boxed(); + /// tcp1.listen_on( listen_addr1.clone()).expect("listener"); + /// match tcp1.select_next_some().await { + /// TransportEvent::NewAddress { listen_addr, .. } => { /// println!("Listening on {:?}", listen_addr); /// let mut stream = tcp1.dial(listen_addr2.clone()).unwrap().await?; /// // `stream` has `listen_addr1` as its local socket address. @@ -268,10 +262,10 @@ where /// _ => {} /// } /// - /// let mut tcp2 = TcpConfig::new().port_reuse(true); - /// let mut listener2 = tcp2.listen_on(listen_addr2).expect("listener"); - /// match listener2.next().await.expect("event")? { - /// ListenerEvent::NewAddress(listen_addr) => { + /// let mut tcp2 = TcpTransport::new(GenTcpConfig::new().port_reuse(true)).boxed(); + /// tcp2.listen_on( listen_addr2).expect("listener"); + /// match tcp2.select_next_some().await { + /// TransportEvent::NewAddress { listen_addr, .. } => { /// println!("Listening on {:?}", listen_addr); /// let mut socket = tcp2.dial(listen_addr1).unwrap().await?; /// // `stream` has `listen_addr2` as its local socket address. @@ -287,7 +281,7 @@ where /// case, one is chosen whose IP protocol version and loopback status is the /// same as that of the remote address. Consequently, for maximum control of /// the local listening addresses and ports that are used for outgoing - /// connections, a new `GenTcpConfig` should be created for each listening + /// connections, a new `GenTcpTransport` should be created for each listening /// socket, avoiding the use of wildcard addresses which bind a socket to /// all network interfaces. /// @@ -295,15 +289,50 @@ where /// option `SO_REUSEPORT` is set, if available, to permit /// reuse of listening ports for multiple sockets. pub fn port_reuse(mut self, port_reuse: bool) -> Self { - self.port_reuse = if port_reuse { + self.enable_port_reuse = port_reuse; + self + } +} + +impl Default for GenTcpConfig { + fn default() -> Self { + Self::new() + } +} + +pub struct GenTcpTransport +where + T: Provider + Send, +{ + config: GenTcpConfig, + + /// The configuration of port reuse when dialing. + port_reuse: PortReuse, + /// All the active listeners. + /// The `TcpListenStream` struct contains a stream that we want to be pinned. Since the `VecDeque` + /// can be resized, the only way is to use a `Pin>`. + listeners: VecDeque>>>, + /// Pending listeners events to return from [`GenTcpTransport::poll`]. + pending_events: VecDeque::ListenerUpgrade, io::Error>>, +} + +impl GenTcpTransport +where + T: Provider + Send, +{ + pub fn new(config: GenTcpConfig) -> Self { + let port_reuse = if config.enable_port_reuse { PortReuse::Enabled { listen_addrs: Arc::new(RwLock::new(HashSet::new())), } } else { PortReuse::Disabled }; - - self + GenTcpTransport { + config, + port_reuse, + ..Default::default() + } } fn create_socket(&self, socket_addr: &SocketAddr) -> io::Result { @@ -316,10 +345,10 @@ where if socket_addr.is_ipv6() { socket.set_only_v6(true)?; } - if let Some(ttl) = self.ttl { + if let Some(ttl) = self.config.ttl { socket.set_ttl(ttl)?; } - if let Some(nodelay) = self.nodelay { + if let Some(nodelay) = self.config.nodelay { socket.set_nodelay(nodelay)?; } socket.set_reuse_address(true)?; @@ -330,22 +359,42 @@ where Ok(socket) } - fn do_listen(&mut self, socket_addr: SocketAddr) -> io::Result> { + fn do_listen( + &mut self, + id: ListenerId, + socket_addr: SocketAddr, + ) -> io::Result> { let socket = self.create_socket(&socket_addr)?; socket.bind(&socket_addr.into())?; - socket.listen(self.backlog as _)?; + socket.listen(self.config.backlog as _)?; socket.set_nonblocking(true)?; - TcpListenStream::::new(socket.into(), self.port_reuse.clone()) + TcpListenStream::::new(id, socket.into(), self.port_reuse.clone()) } } -impl Default for GenTcpConfig { +impl Default for GenTcpTransport +where + T: Provider + Send, +{ fn default() -> Self { - Self::new() + let config = GenTcpConfig::default(); + let port_reuse = if config.enable_port_reuse { + PortReuse::Enabled { + listen_addrs: Arc::new(RwLock::new(HashSet::new())), + } + } else { + PortReuse::Disabled + }; + GenTcpTransport { + port_reuse, + config, + listeners: VecDeque::new(), + pending_events: VecDeque::new(), + } } } -impl Transport for GenTcpConfig +impl Transport for GenTcpTransport where T: Provider + Send + 'static, T::Listener: Unpin, @@ -355,20 +404,35 @@ where type Output = T::Stream; type Error = io::Error; type Dial = Pin> + Send>>; - type Listener = TcpListenStream; type ListenerUpgrade = Ready>; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let socket_addr = if let Ok(sa) = multiaddr_to_socketaddr(addr.clone()) { sa } else { return Err(TransportError::MultiaddrNotSupported(addr)); }; + let id = ListenerId::new(); log::debug!("listening on {}", socket_addr); - self.do_listen(socket_addr).map_err(TransportError::Other) + let listener = self + .do_listen(id, socket_addr) + .map_err(TransportError::Other)?; + self.listeners.push_back(Box::pin(listener)); + Ok(id) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(index) = self.listeners.iter().position(|l| l.listener_id != id) { + self.listeners.remove(index); + self.pending_events + .push_back(TransportEvent::ListenerClosed { + listener_id: id, + reason: Ok(()), + }); + true + } else { + false + } } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -441,9 +505,106 @@ where PortReuse::Enabled { .. } => Some(observed.clone()), } } + + // TODO: docs + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + // Return pending events from closed listeners. + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(event); + } + // We remove each element from `listeners` one by one and add them back. + let mut remaining = self.listeners.len(); + while let Some(mut listener) = self.listeners.pop_back() { + match TryStream::try_poll_next(listener.as_mut(), cx) { + Poll::Pending => { + self.listeners.push_front(listener); + remaining -= 1; + if remaining == 0 { + break; + } + } + Poll::Ready(Some(Ok(TcpTransportEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + }))) => { + let id = listener.listener_id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::Incoming { + listener_id: id, + upgrade, + local_addr, + send_back_addr: remote_addr, + }); + } + Poll::Ready(Some(Ok(TcpTransportEvent::NewAddress(a)))) => { + let id = listener.listener_id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::NewAddress { + listener_id: id, + listen_addr: a, + }); + } + Poll::Ready(Some(Ok(TcpTransportEvent::AddressExpired(a)))) => { + let id = listener.listener_id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::AddressExpired { + listener_id: id, + listen_addr: a, + }); + } + Poll::Ready(Some(Ok(TcpTransportEvent::Error(error)))) => { + let id = listener.listener_id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::Error { + listener_id: id, + error, + }); + } + Poll::Ready(None) => { + return Poll::Ready(TransportEvent::ListenerClosed { + listener_id: listener.listener_id, + reason: Ok(()), + }); + } + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(TransportEvent::ListenerClosed { + listener_id: listener.listener_id, + reason: Err(err), + }); + } + } + } + + // We register the current task to be woken up if a new listener is added. + Poll::Pending + } } -type TcpListenerEvent = ListenerEvent>, io::Error>; +#[derive(Debug)] +pub enum TcpTransportEvent { + /// The transport is listening on a new additional [`Multiaddr`]. + NewAddress(Multiaddr), + /// An upgrade, consisting of the upgrade future, the listener address and the remote address. + Upgrade { + /// The upgrade. + upgrade: Ready>, + /// The local address which produced this upgrade. + local_addr: Multiaddr, + /// The remote address which produced this upgrade. + remote_addr: Multiaddr, + }, + /// A [`Multiaddr`] is no longer used for listening. + AddressExpired(Multiaddr), + /// A non-fatal error has happened on the listener. + /// + /// This event should be generated in order to notify the user that something wrong has + /// happened. The listener, however, continues to run. + Error(io::Error), +} enum IfWatch { Pending(BoxFuture<'static, io::Result>), @@ -469,6 +630,8 @@ pub struct TcpListenStream where T: Provider, { + /// The ID of this listener. + listener_id: ListenerId, /// The socket address that the listening socket is bound to, /// which may be a "wildcard address" like `INADDR_ANY` or `IN6ADDR_ANY` /// when listening on all interfaces for IPv4 respectively IPv6 connections. @@ -500,8 +663,12 @@ where T: Provider, { /// Constructs a `TcpListenStream` for incoming connections around - /// the given `TcpListener`. - fn new(listener: TcpListener, port_reuse: PortReuse) -> io::Result { + /// the given listener. + fn new( + listener_id: ListenerId, + listener: TcpListener, + port_reuse: PortReuse, + ) -> io::Result { let listen_addr = listener.local_addr()?; let in_addr = if match &listen_addr { @@ -526,6 +693,7 @@ where Ok(TcpListenStream { port_reuse, listener, + listener_id, listen_addr, in_addr, pause: None, @@ -569,7 +737,7 @@ where T::Stream: Unpin, T::IfWatcher: Unpin, { - type Item = Result, io::Error>; + type Item = Result, io::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let me = Pin::into_inner(self); @@ -590,7 +758,7 @@ where }; *if_watch = IfWatch::Pending(T::if_watcher()); me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(ListenerEvent::Error(err)))); + return Poll::Ready(Some(Ok(TcpTransportEvent::Error(err)))); } }, // Consume all events for up/down interface changes. @@ -604,9 +772,9 @@ where let ma = ip_to_multiaddr(ip, me.listen_addr.port()); log::debug!("New listen address: {}", ma); me.port_reuse.register(ip, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress( - ma, - )))); + return Poll::Ready(Some(Ok( + TcpTransportEvent::NewAddress(ma), + ))); } } Ok(IfEvent::Down(inet)) => { @@ -617,7 +785,7 @@ where log::debug!("Expired listen address: {}", ma); me.port_reuse.unregister(ip, me.listen_addr.port()); return Poll::Ready(Some(Ok( - ListenerEvent::AddressExpired(ma), + TcpTransportEvent::AddressExpired(ma), ))); } } @@ -627,7 +795,7 @@ where err }; me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(ListenerEvent::Error(err)))); + return Poll::Ready(Some(Ok(TcpTransportEvent::Error(err)))); } } } @@ -638,7 +806,7 @@ where InAddr::One { addr, out } => { if let Some(multiaddr) = out.take() { me.port_reuse.register(*addr, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(multiaddr)))); + return Poll::Ready(Some(Ok(TcpTransportEvent::NewAddress(multiaddr)))); } } } @@ -661,7 +829,7 @@ where // These errors are non-fatal for the listener stream. log::error!("error accepting incoming connection: {}", e); me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(ListenerEvent::Error(e)))); + return Poll::Ready(Some(Ok(TcpTransportEvent::Error(e)))); } }; @@ -671,7 +839,7 @@ where log::debug!("Incoming connection from {} at {}", remote_addr, local_addr); - return Poll::Ready(Some(Ok(ListenerEvent::Upgrade { + return Poll::Ready(Some(Ok(TcpTransportEvent::Upgrade { upgrade: future::ok(incoming.stream), local_addr, remote_addr, @@ -718,7 +886,10 @@ fn ip_to_multiaddr(ip: IpAddr, port: u16) -> Multiaddr { #[cfg(test)] mod tests { use super::*; - use futures::channel::{mpsc, oneshot}; + use futures::{ + channel::{mpsc, oneshot}, + future::poll_fn, + }; #[test] fn multiaddr_to_tcp_conversion() { @@ -774,14 +945,14 @@ mod tests { env_logger::try_init().ok(); async fn listener(addr: Multiaddr, mut ready_tx: mpsc::Sender) { - let mut tcp = GenTcpConfig::::new(); - let mut listener = tcp.listen_on(addr).unwrap(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()).boxed(); + tcp.listen_on(addr).unwrap(); loop { - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(listen_addr) => { + match tcp.select_next_some().await { + TransportEvent::NewAddress { listen_addr, .. } => { ready_tx.send(listen_addr).await.unwrap(); } - ListenerEvent::Upgrade { upgrade, .. } => { + TransportEvent::Incoming { upgrade, .. } => { let mut upgrade = upgrade.await.unwrap(); let mut buf = [0u8; 3]; upgrade.read_exact(&mut buf).await.unwrap(); @@ -796,7 +967,7 @@ mod tests { async fn dialer(mut ready_rx: mpsc::Receiver) { let addr = ready_rx.next().await.unwrap(); - let mut tcp = GenTcpConfig::::new(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()); // Obtain a future socket through dialing let mut socket = tcp.dial(addr.clone()).unwrap().await.unwrap(); @@ -843,13 +1014,13 @@ mod tests { env_logger::try_init().ok(); async fn listener(addr: Multiaddr, mut ready_tx: mpsc::Sender) { - let mut tcp = GenTcpConfig::::new(); - let mut listener = tcp.listen_on(addr).unwrap(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()).boxed(); + tcp.listen_on(addr).unwrap(); loop { - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(a) => { - let mut iter = a.iter(); + match tcp.select_next_some().await { + TransportEvent::NewAddress { listen_addr, .. } => { + let mut iter = listen_addr.iter(); match iter.next().expect("ip address") { Protocol::Ip4(ip) => assert!(!ip.is_unspecified()), Protocol::Ip6(ip) => assert!(!ip.is_unspecified()), @@ -858,11 +1029,11 @@ mod tests { if let Protocol::Tcp(port) = iter.next().expect("port") { assert_ne!(0, port) } else { - panic!("No TCP port in address: {}", a) + panic!("No TCP port in address: {}", listen_addr) } - ready_tx.send(a).await.ok(); + ready_tx.send(listen_addr).await.ok(); } - ListenerEvent::Upgrade { .. } => { + TransportEvent::Incoming { .. } => { return; } _ => {} @@ -872,7 +1043,7 @@ mod tests { async fn dialer(mut ready_rx: mpsc::Receiver) { let dest_addr = ready_rx.next().await.unwrap(); - let mut tcp = GenTcpConfig::::new(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()); tcp.dial(dest_addr).unwrap().await.unwrap(); } @@ -916,22 +1087,22 @@ mod tests { mut ready_tx: mpsc::Sender, port_reuse_rx: oneshot::Receiver>, ) { - let mut tcp = GenTcpConfig::::new(); - let mut listener = tcp.listen_on(addr).unwrap(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()).boxed(); + tcp.listen_on(addr).unwrap(); loop { - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(listen_addr) => { + match tcp.select_next_some().await { + TransportEvent::NewAddress { listen_addr, .. } => { ready_tx.send(listen_addr).await.ok(); } - ListenerEvent::Upgrade { + TransportEvent::Incoming { upgrade, - local_addr: _, - mut remote_addr, + mut send_back_addr, + .. } => { // Receive the dialer tcp port reuse let remote_port_reuse = port_reuse_rx.await.unwrap(); // And check it is the same as the remote port used for upgrade - assert_eq!(remote_addr.pop().unwrap(), remote_port_reuse); + assert_eq!(send_back_addr.pop().unwrap(), remote_port_reuse); let mut upgrade = upgrade.await.unwrap(); let mut buf = [0u8; 3]; @@ -951,11 +1122,12 @@ mod tests { port_reuse_tx: oneshot::Sender>, ) { let dest_addr = ready_rx.next().await.unwrap(); - let mut tcp = GenTcpConfig::::new().port_reuse(true); - let mut listener = tcp.listen_on(addr).unwrap(); - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(_) => { + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new().port_reuse(true)); + tcp.listen_on(addr).unwrap(); + match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await { + TransportEvent::NewAddress { .. } => { // Check that tcp and listener share the same port reuse SocketAddr + let listener = tcp.listeners.front().unwrap(); let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip()); let port_reuse_listener = listener .port_reuse @@ -1018,11 +1190,13 @@ mod tests { env_logger::try_init().ok(); async fn listen_twice(addr: Multiaddr) { - let mut tcp = GenTcpConfig::::new().port_reuse(true); - let mut listener1 = tcp.listen_on(addr).unwrap(); - match listener1.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(addr1) => { - // Check that tcp and listener share the same port reuse SocketAddr + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new().port_reuse(true)); + tcp.listen_on(addr).unwrap(); + match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await { + TransportEvent::NewAddress { + listen_addr: addr1, .. + } => { + let listener1 = tcp.listeners.front().unwrap(); let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip()); let port_reuse_listener1 = listener1 @@ -1032,9 +1206,11 @@ mod tests { assert_eq!(port_reuse_tcp, port_reuse_listener1); // Listen on the same address a second time. - let mut listener2 = tcp.listen_on(addr1.clone()).unwrap(); - match listener2.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(addr2) => { + tcp.listen_on(addr1.clone()).unwrap(); + match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await { + TransportEvent::NewAddress { + listen_addr: addr2, .. + } => { assert_eq!(addr1, addr2); return; } @@ -1071,13 +1247,10 @@ mod tests { env_logger::try_init().ok(); async fn listen(addr: Multiaddr) -> Multiaddr { - GenTcpConfig::::new() - .listen_on(addr) - .unwrap() - .next() + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()).boxed(); + tcp.listen_on(addr).unwrap(); + tcp.select_next_some() .await - .expect("some event") - .expect("no error") .into_new_address() .expect("listen address") } @@ -1111,13 +1284,13 @@ mod tests { fn test(addr: Multiaddr) { #[cfg(feature = "async-io")] { - let mut tcp = TcpConfig::new(); + let mut tcp = TcpTransport::new(GenTcpConfig::new()); assert!(tcp.listen_on(addr.clone()).is_err()); } #[cfg(feature = "tokio")] { - let mut tcp = TokioTcpConfig::new(); + let mut tcp = TokioTcpTransport::new(GenTcpConfig::new()); assert!(tcp.listen_on(addr.clone()).is_err()); } } diff --git a/transports/uds/src/lib.rs b/transports/uds/src/lib.rs index 54d7a6f7ffa..e85d625ed1e 100644 --- a/transports/uds/src/lib.rs +++ b/transports/uds/src/lib.rs @@ -43,113 +43,194 @@ use futures::{ future::{BoxFuture, Ready}, prelude::*, }; +use libp2p_core::transport::ListenerId; use libp2p_core::{ multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, TransportError}, + transport::{TransportError, TransportEvent}, Transport, }; use log::debug; +use std::collections::VecDeque; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{io, path::PathBuf}; +pub type Listener = BoxStream< + 'static, + Result< + TransportEvent<::ListenerUpgrade, ::Error>, + Result<(), ::Error>, + >, +>; + macro_rules! codegen { ($feature_name:expr, $uds_config:ident, $build_listener:expr, $unix_stream:ty, $($mut_or_not:tt)*) => { + /// Represents the configuration for a Unix domain sockets transport capability for libp2p. + #[cfg_attr(docsrs, doc(cfg(feature = $feature_name)))] + pub struct $uds_config { + listeners: VecDeque<(ListenerId, Listener)>, + } -/// Represents the configuration for a Unix domain sockets transport capability for libp2p. -#[cfg_attr(docsrs, doc(cfg(feature = $feature_name)))] -#[derive(Debug, Clone)] -pub struct $uds_config { -} + impl $uds_config { + /// Creates a new configuration object for Unix domain sockets. + pub fn new() -> $uds_config { + $uds_config { + listeners: VecDeque::new(), + } + } + } -impl $uds_config { - /// Creates a new configuration object for Unix domain sockets. - pub fn new() -> $uds_config { - $uds_config {} - } -} + impl Default for $uds_config { + fn default() -> Self { + Self::new() + } + } -impl Default for $uds_config { - fn default() -> Self { - Self::new() - } -} + impl Transport for $uds_config { + type Output = $unix_stream; + type Error = io::Error; + type ListenerUpgrade = Ready>; + type Dial = BoxFuture<'static, Result>; -impl Transport for $uds_config { - type Output = $unix_stream; - type Error = io::Error; - type Listener = BoxStream<'static, Result, Self::Error>>; - type ListenerUpgrade = Ready>; - type Dial = BoxFuture<'static, Result>; + fn listen_on( + &mut self, + addr: Multiaddr, + ) -> Result> { + if let Ok(path) = multiaddr_to_path(&addr) { + let id = ListenerId::new(); + let listener = $build_listener(path) + .map_err(Err) + .map_ok(move |listener| { + stream::once({ + let addr = addr.clone(); + async move { + debug!("Now listening on {}", addr); + Ok(TransportEvent::NewAddress { + listener_id: id, + listen_addr: addr, + }) + } + }) + .chain(stream::unfold( + listener, + move |listener| { + let addr = addr.clone(); + async move { + let event = match listener.accept().await { + Ok((stream, _)) => { + debug!("incoming connection on {}", addr); + TransportEvent::Incoming { + upgrade: future::ok(stream), + local_addr: addr.clone(), + send_back_addr: addr.clone(), + listener_id: id, + } + } + Err(error) => TransportEvent::Error { + listener_id: id, + error, + }, + }; + Some((Ok(event), listener)) + } + }, + )) + }) + .try_flatten_stream() + .boxed(); + self.listeners.push_back((id, listener)); + Ok(id) + } else { + Err(TransportError::MultiaddrNotSupported(addr)) + } + } - fn listen_on(&mut self, addr: Multiaddr) -> Result> { - if let Ok(path) = multiaddr_to_path(&addr) { - Ok(async move { $build_listener(&path).await } - .map_ok(move |listener| { - stream::once({ - let addr = addr.clone(); - async move { - debug!("Now listening on {}", addr); - Ok(ListenerEvent::NewAddress(addr)) - } - }).chain(stream::unfold(listener, move |$($mut_or_not)* listener| { - let addr = addr.clone(); - async move { - let (stream, _) = match listener.accept().await { - Ok(v) => v, - Err(err) => return Some((Err(err), listener)) - }; - debug!("incoming connection on {}", addr); - let event = ListenerEvent::Upgrade { - upgrade: future::ok(stream), - local_addr: addr.clone(), - remote_addr: addr.clone() - }; - Some((Ok(event), listener)) - } - })) - }) - .try_flatten_stream() - .boxed()) - } else { - Err(TransportError::MultiaddrNotSupported(addr)) - } - } + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(index) = self + .listeners + .iter() + .position(|(listener_id, _)| listener_id == &id) + { + let listener_stream = self.listeners.get_mut(index).unwrap(); + let report_closed_stream = stream::once(async { Err(Ok(())) }).boxed(); + *listener_stream = (id, report_closed_stream); + true + } else { + false + } + } - fn dial(&mut self, addr: Multiaddr) -> Result> { - // TODO: Should we dial at all? - if let Ok(path) = multiaddr_to_path(&addr) { - debug!("Dialing {}", addr); - Ok(async move { <$unix_stream>::connect(&path).await }.boxed()) - } else { - Err(TransportError::MultiaddrNotSupported(addr)) - } - } + fn dial(&mut self, addr: Multiaddr) -> Result> { + // TODO: Should we dial at all? + if let Ok(path) = multiaddr_to_path(&addr) { + debug!("Dialing {}", addr); + Ok(async move { <$unix_stream>::connect(&path).await }.boxed()) + } else { + Err(TransportError::MultiaddrNotSupported(addr)) + } + } - fn dial_as_listener(&mut self, addr: Multiaddr) -> Result> { - self.dial(addr) - } + fn dial_as_listener( + &mut self, + addr: Multiaddr, + ) -> Result> { + self.dial(addr) + } - fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { - None - } -} + fn address_translation( + &self, + _server: &Multiaddr, + _observed: &Multiaddr, + ) -> Option { + None + } -}; + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut remaining = self.listeners.len(); + while let Some((id, mut listener)) = self.listeners.pop_back() { + let event = match Stream::poll_next(Pin::new(&mut listener), cx) { + Poll::Pending => None, + Poll::Ready(None) => panic!("Alive listeners always have a sender."), + Poll::Ready(Some(Ok(event))) => Some(event), + Poll::Ready(Some(Err(reason))) => { + return Poll::Ready(TransportEvent::ListenerClosed { + listener_id: id, + reason, + }) + } + }; + self.listeners.push_front((id, listener)); + if let Some(event) = event { + return Poll::Ready(event); + } else { + remaining -= 1; + if remaining == 0 { + break; + } + } + } + Poll::Pending + } + } + }; } #[cfg(feature = "async-std")] codegen!( "async-std", UdsConfig, - |addr| async move { async_std::os::unix::net::UnixListener::bind(addr).await }, + |addr| async move { async_std::os::unix::net::UnixListener::bind(&addr).await }, async_std::os::unix::net::UnixStream, ); #[cfg(feature = "tokio")] codegen!( "tokio", TokioUdsConfig, - |addr| async move { tokio::net::UnixListener::bind(addr) }, + |addr| async move { tokio::net::UnixListener::bind(&addr) }, tokio::net::UnixStream, - mut ); /// Turns a `Multiaddr` containing a single `Unix` component into a path. @@ -212,24 +293,22 @@ mod tests { let (tx, rx) = oneshot::channel(); async_std::task::spawn(async move { - let mut listener = UdsConfig::new().listen_on(addr).unwrap(); + let mut transport = UdsConfig::new().boxed(); + transport.listen_on(addr).unwrap(); - let listen_addr = listener - .try_next() + let listen_addr = transport + .select_next_some() .await - .unwrap() - .expect("some event") .into_new_address() .expect("listen address"); tx.send(listen_addr).unwrap(); - let (sock, _addr) = listener - .try_filter_map(|e| future::ok(e.into_upgrade())) - .try_next() + let (sock, _addr) = transport + .select_next_some() .await - .unwrap() - .expect("some event"); + .into_upgrade() + .expect("incoming stream"); let mut sock = sock.await.unwrap(); let mut buf = [0u8; 3]; diff --git a/transports/wasm-ext/src/lib.rs b/transports/wasm-ext/src/lib.rs index 64deb877858..a8e715291c1 100644 --- a/transports/wasm-ext/src/lib.rs +++ b/transports/wasm-ext/src/lib.rs @@ -32,10 +32,10 @@ //! module. //! -use futures::{future::Ready, prelude::*}; +use futures::{future::Ready, prelude::*, ready, stream::SelectAll}; use libp2p_core::{ connection::Endpoint, - transport::{ListenerEvent, TransportError}, + transport::{ListenerId, TransportError, TransportEvent}, Multiaddr, Transport, }; use parity_send_wrapper::SendWrapper; @@ -147,6 +147,7 @@ pub mod ffi { /// Implementation of `Transport` whose implementation is handled by some FFI. pub struct ExtTransport { inner: SendWrapper, + listeners: SelectAll, } impl ExtTransport { @@ -154,8 +155,10 @@ impl ExtTransport { pub fn new(transport: ffi::Transport) -> Self { ExtTransport { inner: SendWrapper::new(transport), + listeners: SelectAll::new(), } } + fn do_dial( &mut self, addr: Multiaddr, @@ -187,25 +190,13 @@ impl fmt::Debug for ExtTransport { } } -impl Clone for ExtTransport { - fn clone(&self) -> Self { - ExtTransport { - inner: SendWrapper::new(self.inner.clone().into()), - } - } -} - impl Transport for ExtTransport { type Output = Connection; type Error = JsErr; - type Listener = Listen; type ListenerUpgrade = Ready>; type Dial = Dial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let iter = self.inner.listen_on(&addr.to_string()).map_err(|err| { if is_not_supported_error(&err) { TransportError::MultiaddrNotSupported(addr) @@ -213,12 +204,26 @@ impl Transport for ExtTransport { TransportError::Other(JsErr::from(err)) } })?; - - Ok(Listen { + let listener_id = ListenerId::new(); + let listen = Listen { + listener_id, iterator: SendWrapper::new(iter), next_event: None, pending_events: VecDeque::new(), - }) + is_closed: false, + }; + self.listeners.push(listen); + Ok(listener_id) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + match self.listeners.iter_mut().find(|l| l.listener_id == id) { + Some(listener) => { + listener.close(Ok(())); + true + } + None => false, + } } fn dial(&mut self, addr: Multiaddr) -> Result> @@ -241,6 +246,16 @@ impl Transport for ExtTransport { fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { None } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match ready!(self.listeners.poll_next_unpin(cx)) { + Some(event) => return Poll::Ready(event), + None => Poll::Pending, + } + } } /// Future that dial a remote through an external transport. @@ -271,27 +286,47 @@ impl Future for Dial { /// Stream that listens for incoming connections through an external transport. #[must_use = "futures do nothing unless polled"] pub struct Listen { + listener_id: ListenerId, /// Iterator of `ListenEvent`s. iterator: SendWrapper, /// Promise that will yield the next `ListenEvent`. next_event: Option>, /// List of events that we are waiting to propagate. - pending_events: VecDeque>, JsErr>>, + pending_events: VecDeque<::Item>, + /// If the iterator is done close the listener. + is_closed: bool, +} + +impl Listen { + /// Report the listener as closed as terminate its stream. + fn close(&mut self, reason: Result<(), JsErr>) { + self.pending_events + .push_back(TransportEvent::ListenerClosed { + listener_id: self.listener_id, + reason, + }); + self.is_closed = true; + } } impl fmt::Debug for Listen { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Listen").finish() + f.debug_tuple("Listen").field(&self.listener_id).finish() } } impl Stream for Listen { - type Item = Result>, JsErr>, JsErr>; + type Item = TransportEvent<::ListenerUpgrade, JsErr>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { if let Some(ev) = self.pending_events.pop_front() { - return Poll::Ready(Some(Ok(ev))); + return Poll::Ready(Some(ev)); + } + + if self.is_closed { + // Terminate the stream if the listener closed and all remaining events have been reported. + return Poll::Ready(None); } // Try to fill `self.next_event` if necessary and possible. If we fail, then @@ -309,30 +344,55 @@ impl Stream for Listen { let e = match Future::poll(Pin::new(&mut **next_event), cx) { Poll::Ready(Ok(ev)) => ffi::ListenEvent::from(ev), Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err.into()))), + Poll::Ready(Err(err)) => { + self.close(Err(err.into())); + continue; + } }; self.next_event = None; e } else { - return Poll::Ready(None); + self.close(Ok(())); + continue; }; + let listener_id = self.listener_id; + if let Some(addrs) = event.new_addrs() { for addr in addrs.iter() { - let addr = js_value_to_addr(addr)?; - self.pending_events - .push_back(ListenerEvent::NewAddress(addr)); + match js_value_to_addr(addr) { + Ok(addr) => self.pending_events.push_back(TransportEvent::NewAddress { + listener_id, + listen_addr: addr, + }), + Err(err) => self.pending_events.push_back(TransportEvent::Error { + listener_id, + error: err, + }), + }; } } if let Some(upgrades) = event.new_connections() { for upgrade in upgrades.iter().cloned() { let upgrade: ffi::ConnectionEvent = upgrade.into(); - self.pending_events.push_back(ListenerEvent::Upgrade { - local_addr: upgrade.local_addr().parse()?, - remote_addr: upgrade.observed_addr().parse()?, - upgrade: futures::future::ok(Connection::new(upgrade.connection())), - }); + match upgrade.local_addr().parse().and_then(|local| { + let observed = upgrade.observed_addr().parse()?; + Ok((local, observed)) + }) { + Ok((local_addr, send_back_addr)) => { + self.pending_events.push_back(TransportEvent::Incoming { + listener_id, + local_addr, + send_back_addr, + upgrade: futures::future::ok(Connection::new(upgrade.connection())), + }) + } + Err(err) => self.pending_events.push_back(TransportEvent::Error { + listener_id, + error: err.into(), + }), + } } } @@ -341,8 +401,14 @@ impl Stream for Listen { match js_value_to_addr(addr) { Ok(addr) => self .pending_events - .push_back(ListenerEvent::NewAddress(addr)), - Err(err) => self.pending_events.push_back(ListenerEvent::Error(err)), + .push_back(TransportEvent::AddressExpired { + listener_id, + listen_addr: addr, + }), + Err(err) => self.pending_events.push_back(TransportEvent::Error { + listener_id, + error: err, + }), } } } diff --git a/transports/websocket/Cargo.toml b/transports/websocket/Cargo.toml index e4bbe3297da..1e35f5b9a53 100644 --- a/transports/websocket/Cargo.toml +++ b/transports/websocket/Cargo.toml @@ -17,6 +17,7 @@ futures = "0.3.1" libp2p-core = { version = "0.33.0", path = "../../core", default-features = false } log = "0.4.8" parking_lot = "0.12.0" +pin-project = "1.0.10" quicksink = "0.1" rw-stream-sink = { version = "0.3.0", path = "../../misc/rw-stream-sink" } soketto = { version = "0.7.0", features = ["deflate"] } diff --git a/transports/websocket/src/framed.rs b/transports/websocket/src/framed.rs index c04e7354587..1261d46f1c3 100644 --- a/transports/websocket/src/framed.rs +++ b/transports/websocket/src/framed.rs @@ -26,7 +26,7 @@ use libp2p_core::{ connection::Endpoint, either::EitherOutput, multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, TransportError}, + transport::{ListenerId, TransportError, TransportEvent}, Transport, }; use log::{debug, trace}; @@ -36,7 +36,7 @@ use soketto::{ extension::deflate::Deflate, handshake, }; -use std::sync::Arc; +use std::{collections::HashMap, ops::DerefMut, sync::Arc}; use std::{convert::TryInto, fmt, io, mem, pin::Pin, task::Context, task::Poll}; use url::Url; @@ -53,18 +53,7 @@ pub struct WsConfig { tls_config: tls::Config, max_redirects: u8, use_deflate: bool, -} - -impl Clone for WsConfig { - fn clone(&self) -> Self { - Self { - transport: self.transport.clone(), - max_data_size: self.max_data_size, - tls_config: self.tls_config.clone(), - max_redirects: self.max_redirects, - use_deflate: self.use_deflate, - } - } + listener_protos: HashMap>, } impl WsConfig { @@ -76,6 +65,7 @@ impl WsConfig { tls_config: tls::Config::client(), max_redirects: 0, use_deflate: false, + listener_protos: HashMap::new(), } } @@ -118,149 +108,45 @@ type TlsOrPlain = EitherOutput, server::Tls impl Transport for WsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = Connection; type Error = Error; - type Listener = - BoxStream<'static, Result, Self::Error>>; type ListenerUpgrade = BoxFuture<'static, Result>; type Dial = BoxFuture<'static, Result>; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let mut inner_addr = addr.clone(); - - let (use_tls, proto) = match inner_addr.pop() { + let proto = match inner_addr.pop() { Some(p @ Protocol::Wss(_)) => { if self.tls_config.server.is_some() { - (true, p) + p } else { debug!("/wss address but TLS server support is not configured"); return Err(TransportError::MultiaddrNotSupported(addr)); } } - Some(p @ Protocol::Ws(_)) => (false, p), + Some(p @ Protocol::Ws(_)) => p, _ => { debug!("{} is not a websocket multiaddr", addr); return Err(TransportError::MultiaddrNotSupported(addr)); } }; + match self.transport.lock().listen_on(inner_addr) { + Ok(id) => { + self.listener_protos.insert(id, proto); + Ok(id) + } + Err(e) => Err(e.map(Error::Transport)), + } + } - let tls_config = self.tls_config.clone(); - let max_size = self.max_data_size; - let use_deflate = self.use_deflate; - let transport = self - .transport - .lock() - .listen_on(inner_addr) - .map_err(|e| e.map(Error::Transport))?; - let listen = transport - .map_err(Error::Transport) - .map_ok(move |event| match event { - ListenerEvent::NewAddress(mut a) => { - a = a.with(proto.clone()); - debug!("Listening on {}", a); - ListenerEvent::NewAddress(a) - } - ListenerEvent::AddressExpired(mut a) => { - a = a.with(proto.clone()); - ListenerEvent::AddressExpired(a) - } - ListenerEvent::Error(err) => ListenerEvent::Error(Error::Transport(err)), - ListenerEvent::Upgrade { - upgrade, - mut local_addr, - mut remote_addr, - } => { - local_addr = local_addr.with(proto.clone()); - remote_addr = remote_addr.with(proto.clone()); - let remote1 = remote_addr.clone(); // used for logging - let remote2 = remote_addr.clone(); // used for logging - let tls_config = tls_config.clone(); - - let upgrade = async move { - let stream = upgrade.map_err(Error::Transport).await?; - trace!("incoming connection from {}", remote1); - - let stream = if use_tls { - // begin TLS session - let server = tls_config - .server - .expect("for use_tls we checked server is not none"); - - trace!("awaiting TLS handshake with {}", remote1); - - let stream = server - .accept(stream) - .map_err(move |e| { - debug!("TLS handshake with {} failed: {}", remote1, e); - Error::Tls(tls::Error::from(e)) - }) - .await?; - - let stream: TlsOrPlain<_> = - EitherOutput::First(EitherOutput::Second(stream)); - - stream - } else { - // continue with plain stream - EitherOutput::Second(stream) - }; - - trace!("receiving websocket handshake request from {}", remote2); - - let mut server = handshake::Server::new(stream); - - if use_deflate { - server.add_extension(Box::new(Deflate::new(connection::Mode::Server))); - } - - let ws_key = { - let request = server - .receive_request() - .map_err(|e| Error::Handshake(Box::new(e))) - .await?; - request.key() - }; - - trace!("accepting websocket handshake request from {}", remote2); - - let response = handshake::server::Response::Accept { - key: ws_key, - protocol: None, - }; - - server - .send_response(&response) - .map_err(|e| Error::Handshake(Box::new(e))) - .await?; - - let conn = { - let mut builder = server.into_builder(); - builder.set_max_message_size(max_size); - builder.set_max_frame_size(max_size); - Connection::new(builder) - }; - - Ok(conn) - }; - - ListenerEvent::Upgrade { - upgrade: Box::pin(upgrade) as BoxFuture<'static, _>, - local_addr, - remote_addr, - } - } - }); - Ok(Box::pin(listen)) + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.lock().remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -277,14 +163,99 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.lock().address_translation(server, observed) } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let inner_event = { + let mut transport = self.transport.lock(); + match Transport::poll(Pin::new(transport.deref_mut()), cx) { + Poll::Ready(ev) => ev, + Poll::Pending => return Poll::Pending, + } + }; + let event = match inner_event { + TransportEvent::NewAddress { + listener_id, + mut listen_addr, + } => { + let proto = self + .listener_protos + .get(&listener_id) + .expect("Protocol was inserted in Transport::listen_on."); + listen_addr.push(proto.clone()); + debug!("Listening on {}", listen_addr); + TransportEvent::NewAddress { + listener_id, + listen_addr, + } + } + TransportEvent::AddressExpired { + listener_id, + mut listen_addr, + } => { + let proto = self + .listener_protos + .get(&listener_id) + .expect("Protocol was inserted in Transport::listen_on."); + listen_addr.push(proto.clone()); + TransportEvent::AddressExpired { + listener_id, + listen_addr, + } + } + TransportEvent::Error { listener_id, error } => TransportEvent::Error { + listener_id, + error: Error::Transport(error), + }, + TransportEvent::ListenerClosed { + listener_id, + reason, + } => { + self.listener_protos + .remove(&listener_id) + .expect("Protocol was inserted in Transport::listen_on."); + TransportEvent::ListenerClosed { + listener_id, + reason: reason.map_err(Error::Transport), + } + } + TransportEvent::Incoming { + listener_id, + upgrade, + mut local_addr, + mut send_back_addr, + } => { + let proto = self + .listener_protos + .get(&listener_id) + .expect("Protocol was inserted in Transport::listen_on."); + let use_tls = match proto { + Protocol::Wss(_) => true, + Protocol::Ws(_) => false, + _ => unreachable!("Map contains only ws and wss protocols."), + }; + local_addr.push(proto.clone()); + send_back_addr.push(proto.clone()); + let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls); + TransportEvent::Incoming { + listener_id, + upgrade, + local_addr, + send_back_addr, + } + } + }; + Poll::Ready(event) + } } impl WsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -304,13 +275,25 @@ where // We are looping here in order to follow redirects (if any): let mut remaining_redirects = self.max_redirects; - let mut this = self.clone(); + let transport = self.transport.clone(); + let tls_config = self.tls_config.clone(); + let use_deflate = self.use_deflate; + let max_redirects = self.max_redirects; + let future = async move { loop { - match this.dial_once(addr, role_override).await { + match Self::dial_once( + transport.clone(), + addr, + tls_config.clone(), + use_deflate, + role_override, + ) + .await + { Ok(Either::Left(redirect)) => { if remaining_redirects == 0 { - debug!("Too many redirects (> {})", this.max_redirects); + debug!("Too many redirects (> {})", max_redirects); return Err(Error::TooManyRedirects); } remaining_redirects -= 1; @@ -324,17 +307,20 @@ where Ok(Box::pin(future)) } + /// Attempts to dial the given address and perform a websocket handshake. async fn dial_once( - &mut self, + transport: Arc>, addr: WsAddress, + tls_config: tls::Config, + use_deflate: bool, role_override: Endpoint, ) -> Result>, Error> { trace!("Dialing websocket address: {:?}", addr); let dial = match role_override { - Endpoint::Dialer => self.transport.lock().dial(addr.tcp_addr), - Endpoint::Listener => self.transport.lock().dial_as_listener(addr.tcp_addr), + Endpoint::Dialer => transport.lock().dial(addr.tcp_addr), + Endpoint::Listener => transport.lock().dial_as_listener(addr.tcp_addr), } .map_err(|e| match e { TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a), @@ -350,8 +336,7 @@ where .dns_name .expect("for use_tls we have checked that dns_name is some"); trace!("Starting TLS handshake with {:?}", dns_name); - let stream = self - .tls_config + let stream = tls_config .client .connect(dns_name.clone(), stream) .map_err(|e| { @@ -371,7 +356,7 @@ where let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref()); - if self.use_deflate { + if use_deflate { client.add_extension(Box::new(Deflate::new(connection::Mode::Client))); } @@ -400,6 +385,91 @@ where } } } + + fn map_upgrade( + &self, + upgrade: T::ListenerUpgrade, + remote_addr: Multiaddr, + use_tls: bool, + ) -> ::ListenerUpgrade { + let remote_addr2 = remote_addr.clone(); // used for logging + let tls_config = self.tls_config.clone(); + let max_size = self.max_data_size; + let use_deflate = self.use_deflate; + + async move { + let stream = upgrade.map_err(Error::Transport).await?; + trace!("incoming connection from {}", remote_addr); + + let stream = if use_tls { + // begin TLS session + let server = tls_config + .server + .expect("for use_tls we checked server is not none"); + + trace!("awaiting TLS handshake with {}", remote_addr); + + let stream = server + .accept(stream) + .map_err(move |e| { + debug!("TLS handshake with {} failed: {}", remote_addr, e); + Error::Tls(tls::Error::from(e)) + }) + .await?; + + let stream: TlsOrPlain<_> = EitherOutput::First(EitherOutput::Second(stream)); + + stream + } else { + // continue with plain stream + EitherOutput::Second(stream) + }; + + trace!( + "receiving websocket handshake request from {}", + remote_addr2 + ); + + let mut server = handshake::Server::new(stream); + + if use_deflate { + server.add_extension(Box::new(Deflate::new(connection::Mode::Server))); + } + + let ws_key = { + let request = server + .receive_request() + .map_err(|e| Error::Handshake(Box::new(e))) + .await?; + request.key() + }; + + trace!( + "accepting websocket handshake request from {}", + remote_addr2 + ); + + let response = handshake::server::Response::Accept { + key: ws_key, + protocol: None, + }; + + server + .send_response(&response) + .map_err(|e| Error::Handshake(Box::new(e))) + .await?; + + let conn = { + let mut builder = server.into_builder(); + builder.set_max_message_size(max_size); + builder.set_max_frame_size(max_size); + Connection::new(builder) + }; + + Ok(conn) + } + .boxed() + } } #[derive(Debug)] diff --git a/transports/websocket/src/lib.rs b/transports/websocket/src/lib.rs index 770e559a633..1897b64a2a0 100644 --- a/transports/websocket/src/lib.rs +++ b/transports/websocket/src/lib.rs @@ -26,14 +26,11 @@ pub mod tls; use error::Error; use framed::{Connection, Incoming}; -use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream}; +use futures::{future::BoxFuture, prelude::*, ready}; use libp2p_core::{ connection::ConnectedPoint, multiaddr::Multiaddr, - transport::{ - map::{MapFuture, MapStream}, - ListenerEvent, TransportError, - }, + transport::{map::MapFuture, ListenerId, TransportError, TransportEvent}, Transport, }; use rw_stream_sink::RwStreamSink; @@ -45,20 +42,21 @@ use std::{ /// A Websocket transport. #[derive(Debug)] +#[pin_project::pin_project] pub struct WsConfig where T: Transport, T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static, { + #[pin] transport: libp2p_core::transport::map::Map, WrapperFn>, } impl WsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static, { @@ -114,26 +112,25 @@ where impl Transport for WsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = RwStreamSink>; type Error = Error; - type Listener = MapStream, WrapperFn>; type ListenerUpgrade = MapFuture, WrapperFn>; type Dial = MapFuture, WrapperFn>; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { self.transport.listen_on(addr) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.remove_listener(id) + } + fn dial(&mut self, addr: Multiaddr) -> Result> { self.transport.dial(addr) } @@ -148,11 +145,14 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.address_translation(server, observed) } -} -/// Type alias corresponding to `framed::WsConfig::Listener`. -pub type InnerStream = - BoxStream<'static, Result, Error>, Error>>; + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().transport.poll(cx) + } +} /// Type alias corresponding to `framed::WsConfig::Dial` and `framed::WsConfig::ListenerUpgrade`. pub type InnerFuture = BoxFuture<'static, Result, Error>>; @@ -237,14 +237,15 @@ mod tests { } async fn connect(listen_addr: Multiaddr) { - let ws_config = || WsConfig::new(tcp::TcpConfig::new()); + let new_ws_config = + || WsConfig::new(tcp::TcpTransport::new(tcp::GenTcpConfig::default())).boxed(); - let mut listener = ws_config().listen_on(listen_addr).expect("listener"); + let mut ws_config = new_ws_config(); + ws_config.listen_on(listen_addr).expect("listener"); - let addr = listener - .try_next() + let addr = ws_config + .next() .await - .expect("some event") .expect("no error") .into_new_address() .expect("listen address"); @@ -253,16 +254,15 @@ mod tests { assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1)); let inbound = async move { - let (conn, _addr) = listener - .try_filter_map(|e| future::ready(Ok(e.into_upgrade()))) - .try_next() + let (conn, _addr) = ws_config + .select_next_some() + .map(|ev| ev.into_upgrade()) .await - .unwrap() .unwrap(); conn.await }; - let outbound = ws_config() + let outbound = new_ws_config() .dial(addr.with(Protocol::P2p(PeerId::random().into()))) .unwrap();