Skip to content

Commit

Permalink
refactor(core): blanket implementation of connection upgrade
Browse files Browse the repository at this point in the history
Introduces blanket implementation of `{In,Out}boundConnectionUpgrade` and uses it in transport upgrade infrastructure.

Resolves: #4307.

Pull-Request: #4316.
  • Loading branch information
PanGan21 authored and thomaseizinger committed Sep 21, 2023
1 parent 52f11b9 commit 5d0c4c7
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 38 deletions.
52 changes: 26 additions & 26 deletions core/src/transport/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::{
TransportError, TransportEvent,
},
upgrade::{
self, apply_inbound, apply_outbound, InboundUpgrade, InboundUpgradeApply, OutboundUpgrade,
OutboundUpgradeApply, UpgradeError,
self, apply_inbound, apply_outbound, InboundConnectionUpgrade, InboundUpgradeApply,
OutboundConnectionUpgrade, OutboundUpgradeApply, UpgradeError,
},
Negotiated,
};
Expand Down Expand Up @@ -101,8 +101,8 @@ where
T: Transport<Output = C>,
C: AsyncRead + AsyncWrite + Unpin,
D: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = (PeerId, D), Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = (PeerId, D), Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = (PeerId, D), Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = (PeerId, D), Error = E> + Clone,
E: Error + 'static,
{
let version = self.version;
Expand All @@ -123,7 +123,7 @@ where
pub struct Authenticate<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>> + OutboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
{
#[pin]
inner: EitherUpgrade<C, U>,
Expand All @@ -132,11 +132,11 @@ where
impl<C, U> Future for Authenticate<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>
+ OutboundUpgrade<
U: InboundConnectionUpgrade<Negotiated<C>>
+ OutboundConnectionUpgrade<
Negotiated<C>,
Output = <U as InboundUpgrade<Negotiated<C>>>::Output,
Error = <U as InboundUpgrade<Negotiated<C>>>::Error,
Output = <U as InboundConnectionUpgrade<Negotiated<C>>>::Output,
Error = <U as InboundConnectionUpgrade<Negotiated<C>>>::Error,
>,
{
type Output = <EitherUpgrade<C, U> as Future>::Output;
Expand All @@ -155,7 +155,7 @@ where
pub struct Multiplex<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>> + OutboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
{
peer_id: Option<PeerId>,
#[pin]
Expand All @@ -165,8 +165,8 @@ where
impl<C, U, M, E> Future for Multiplex<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = M, Error = E>,
U: InboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E>,
{
type Output = Result<(PeerId, M), UpgradeError<E>>;

Expand Down Expand Up @@ -208,8 +208,8 @@ where
T: Transport<Output = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
D: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = D, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = D, Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E> + Clone,
E: Error + 'static,
{
Authenticated(Builder::new(
Expand All @@ -236,8 +236,8 @@ where
T: Transport<Output = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
M: StreamMuxer,
U: InboundUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = M, Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E> + Clone,
E: Error + 'static,
{
let version = self.0.version;
Expand Down Expand Up @@ -269,8 +269,8 @@ where
T: Transport<Output = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
M: StreamMuxer,
U: InboundUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = M, Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E> + Clone,
E: Error + 'static,
F: for<'a> FnOnce(&'a PeerId, &'a ConnectedPoint) -> U + Clone,
{
Expand Down Expand Up @@ -395,8 +395,8 @@ where
T: Transport<Output = (PeerId, C)>,
T::Error: 'static,
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = D, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = D, Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E> + Clone,
E: Error + 'static,
{
type Output = (PeerId, D);
Expand Down Expand Up @@ -502,7 +502,7 @@ where
/// The [`Transport::Dial`] future of an [`Upgrade`]d transport.
pub struct DialUpgradeFuture<F, U, C>
where
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
C: AsyncRead + AsyncWrite + Unpin,
{
future: Pin<Box<F>>,
Expand All @@ -513,7 +513,7 @@ impl<F, U, C, D> Future for DialUpgradeFuture<F, U, C>
where
F: TryFuture<Ok = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>, Output = D>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = D>,
U::Error: Error,
{
type Output = Result<(PeerId, D), TransportUpgradeError<F::Error, U::Error>>;
Expand Down Expand Up @@ -553,7 +553,7 @@ where

impl<F, U, C> Unpin for DialUpgradeFuture<F, U, C>
where
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
C: AsyncRead + AsyncWrite + Unpin,
{
}
Expand All @@ -562,7 +562,7 @@ where
pub struct ListenerUpgradeFuture<F, U, C>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
future: Pin<Box<F>>,
upgrade: future::Either<Option<U>, (PeerId, InboundUpgradeApply<C, U>)>,
Expand All @@ -572,7 +572,7 @@ impl<F, U, C, D> Future for ListenerUpgradeFuture<F, U, C>
where
F: TryFuture<Ok = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = D>,
U: InboundConnectionUpgrade<Negotiated<C>, Output = D>,
U::Error: Error,
{
type Output = Result<(PeerId, D), TransportUpgradeError<F::Error, U::Error>>;
Expand Down Expand Up @@ -613,6 +613,6 @@ where
impl<F, U, C> Unpin for ListenerUpgradeFuture<F, U, C>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
}
60 changes: 60 additions & 0 deletions core/src/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,63 @@ pub trait OutboundUpgrade<C>: UpgradeInfo {
/// The `info` is the identifier of the protocol, as produced by `protocol_info`.
fn upgrade_outbound(self, socket: C, info: Self::Info) -> Self::Future;
}

/// Possible upgrade on an inbound connection
pub trait InboundConnectionUpgrade<T>: UpgradeInfo {
/// Output after the upgrade has been successfully negotiated and the handshake performed.
type Output;
/// Possible error during the handshake.
type Error;
/// Future that performs the handshake with the remote.
type Future: Future<Output = Result<Self::Output, Self::Error>>;

/// After we have determined that the remote supports one of the protocols we support, this
/// method is called to start the handshake.
///
/// The `info` is the identifier of the protocol, as produced by `protocol_info`.
fn upgrade_inbound(self, socket: T, info: Self::Info) -> Self::Future;
}

/// Possible upgrade on an outbound connection
pub trait OutboundConnectionUpgrade<T>: UpgradeInfo {
/// Output after the upgrade has been successfully negotiated and the handshake performed.
type Output;
/// Possible error during the handshake.
type Error;
/// Future that performs the handshake with the remote.
type Future: Future<Output = Result<Self::Output, Self::Error>>;

/// After we have determined that the remote supports one of the protocols we support, this
/// method is called to start the handshake.
///
/// The `info` is the identifier of the protocol, as produced by `protocol_info`.
fn upgrade_outbound(self, socket: T, info: Self::Info) -> Self::Future;
}

// Blanket implementation for InboundConnectionUpgrade based on InboundUpgrade for backwards compatibility
impl<U, T> InboundConnectionUpgrade<T> for U
where
U: InboundUpgrade<T>,
{
type Output = <U as InboundUpgrade<T>>::Output;
type Error = <U as InboundUpgrade<T>>::Error;
type Future = <U as InboundUpgrade<T>>::Future;

fn upgrade_inbound(self, socket: T, info: Self::Info) -> Self::Future {
self.upgrade_inbound(socket, info)
}
}

// Blanket implementation for OutboundConnectionUpgrade based on OutboundUpgrade for backwards compatibility
impl<U, T> OutboundConnectionUpgrade<T> for U
where
U: OutboundUpgrade<T>,
{
type Output = <U as OutboundUpgrade<T>>::Output;
type Error = <U as OutboundUpgrade<T>>::Error;
type Future = <U as OutboundUpgrade<T>>::Future;

fn upgrade_outbound(self, socket: T, info: Self::Info) -> Self::Future {
self.upgrade_outbound(socket, info)
}
}
24 changes: 12 additions & 12 deletions core/src/upgrade/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError};
use crate::upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeError};
use crate::{connection::ConnectedPoint, Negotiated};
use futures::{future::Either, prelude::*};
use log::debug;
Expand All @@ -37,7 +37,7 @@ pub(crate) fn apply<C, U>(
) -> Either<InboundUpgradeApply<C, U>, OutboundUpgradeApply<C, U>>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>> + OutboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
{
match cp {
ConnectedPoint::Dialer { role_override, .. } if role_override.is_dialer() => {
Expand All @@ -51,7 +51,7 @@ where
pub(crate) fn apply_inbound<C, U>(conn: C, up: U) -> InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
InboundUpgradeApply {
inner: InboundUpgradeApplyState::Init {
Expand All @@ -65,7 +65,7 @@ where
pub(crate) fn apply_outbound<C, U>(conn: C, up: U, v: Version) -> OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
OutboundUpgradeApply {
inner: OutboundUpgradeApplyState::Init {
Expand All @@ -79,7 +79,7 @@ where
pub struct InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
inner: InboundUpgradeApplyState<C, U>,
}
Expand All @@ -88,7 +88,7 @@ where
enum InboundUpgradeApplyState<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
Init {
future: ListenerSelectFuture<C, U::Info>,
Expand All @@ -104,14 +104,14 @@ where
impl<C, U> Unpin for InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
}

impl<C, U> Future for InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
type Output = Result<U::Output, UpgradeError<U::Error>>;

Expand Down Expand Up @@ -162,15 +162,15 @@ where
pub struct OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
inner: OutboundUpgradeApplyState<C, U>,
}

enum OutboundUpgradeApplyState<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
Init {
future: DialerSelectFuture<C, <U::InfoIter as IntoIterator>::IntoIter>,
Expand All @@ -186,14 +186,14 @@ where
impl<C, U> Unpin for OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
}

impl<C, U> Future for OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
type Output = Result<U::Output, UpgradeError<U::Error>>;

Expand Down

0 comments on commit 5d0c4c7

Please sign in to comment.