Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(core): blanket implementation of connection upgrade #4316

Merged
merged 6 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions core/src/transport/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::{
TransportError, TransportEvent,
},
upgrade::{
self, apply_inbound, apply_outbound, InboundUpgrade, InboundUpgradeApply, OutboundUpgrade,
OutboundUpgradeApply, UpgradeError,
self, apply_inbound, apply_outbound, InboundConnectionUpgrade, InboundUpgradeApply,
OutboundConnectionUpgrade, OutboundUpgradeApply, UpgradeError,
},
Negotiated,
};
Expand Down Expand Up @@ -101,8 +101,8 @@ where
T: Transport<Output = C>,
C: AsyncRead + AsyncWrite + Unpin,
D: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = (PeerId, D), Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = (PeerId, D), Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = (PeerId, D), Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = (PeerId, D), Error = E> + Clone,
E: Error + 'static,
{
let version = self.version;
Expand All @@ -123,7 +123,7 @@ where
pub struct Authenticate<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>> + OutboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
{
#[pin]
inner: EitherUpgrade<C, U>,
Expand All @@ -132,11 +132,11 @@ where
impl<C, U> Future for Authenticate<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>
+ OutboundUpgrade<
U: InboundConnectionUpgrade<Negotiated<C>>
+ OutboundConnectionUpgrade<
Negotiated<C>,
Output = <U as InboundUpgrade<Negotiated<C>>>::Output,
Error = <U as InboundUpgrade<Negotiated<C>>>::Error,
Output = <U as InboundConnectionUpgrade<Negotiated<C>>>::Output,
Error = <U as InboundConnectionUpgrade<Negotiated<C>>>::Error,
>,
{
type Output = <EitherUpgrade<C, U> as Future>::Output;
Expand All @@ -155,7 +155,7 @@ where
pub struct Multiplex<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>> + OutboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
{
peer_id: Option<PeerId>,
#[pin]
Expand All @@ -165,8 +165,8 @@ where
impl<C, U, M, E> Future for Multiplex<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = M, Error = E>,
U: InboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E>,
{
type Output = Result<(PeerId, M), UpgradeError<E>>;

Expand Down Expand Up @@ -208,8 +208,8 @@ where
T: Transport<Output = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
D: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = D, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = D, Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E> + Clone,
E: Error + 'static,
{
Authenticated(Builder::new(
Expand All @@ -236,8 +236,8 @@ where
T: Transport<Output = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
M: StreamMuxer,
U: InboundUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = M, Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E> + Clone,
E: Error + 'static,
{
let version = self.0.version;
Expand Down Expand Up @@ -269,8 +269,8 @@ where
T: Transport<Output = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
M: StreamMuxer,
U: InboundUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = M, Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = M, Error = E> + Clone,
E: Error + 'static,
F: for<'a> FnOnce(&'a PeerId, &'a ConnectedPoint) -> U + Clone,
{
Expand Down Expand Up @@ -395,8 +395,8 @@ where
T: Transport<Output = (PeerId, C)>,
T::Error: 'static,
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = D, Error = E>,
U: OutboundUpgrade<Negotiated<C>, Output = D, Error = E> + Clone,
U: InboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = D, Error = E> + Clone,
E: Error + 'static,
{
type Output = (PeerId, D);
Expand Down Expand Up @@ -502,7 +502,7 @@ where
/// The [`Transport::Dial`] future of an [`Upgrade`]d transport.
pub struct DialUpgradeFuture<F, U, C>
where
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
C: AsyncRead + AsyncWrite + Unpin,
{
future: Pin<Box<F>>,
Expand All @@ -513,7 +513,7 @@ impl<F, U, C, D> Future for DialUpgradeFuture<F, U, C>
where
F: TryFuture<Ok = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>, Output = D>,
U: OutboundConnectionUpgrade<Negotiated<C>, Output = D>,
U::Error: Error,
{
type Output = Result<(PeerId, D), TransportUpgradeError<F::Error, U::Error>>;
Expand Down Expand Up @@ -553,7 +553,7 @@ where

impl<F, U, C> Unpin for DialUpgradeFuture<F, U, C>
where
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
C: AsyncRead + AsyncWrite + Unpin,
{
}
Expand All @@ -562,7 +562,7 @@ where
pub struct ListenerUpgradeFuture<F, U, C>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
future: Pin<Box<F>>,
upgrade: future::Either<Option<U>, (PeerId, InboundUpgradeApply<C, U>)>,
Expand All @@ -572,7 +572,7 @@ impl<F, U, C, D> Future for ListenerUpgradeFuture<F, U, C>
where
F: TryFuture<Ok = (PeerId, C)>,
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>, Output = D>,
U: InboundConnectionUpgrade<Negotiated<C>, Output = D>,
U::Error: Error,
{
type Output = Result<(PeerId, D), TransportUpgradeError<F::Error, U::Error>>;
Expand Down Expand Up @@ -613,6 +613,6 @@ where
impl<F, U, C> Unpin for ListenerUpgradeFuture<F, U, C>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
}
60 changes: 60 additions & 0 deletions core/src/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,63 @@ pub trait OutboundUpgrade<C>: UpgradeInfo {
/// The `info` is the identifier of the protocol, as produced by `protocol_info`.
fn upgrade_outbound(self, socket: C, info: Self::Info) -> Self::Future;
}

/// Possible upgrade on an inbound connection
pub trait InboundConnectionUpgrade<T>: UpgradeInfo {
/// Output after the upgrade has been successfully negotiated and the handshake performed.
type Output;
/// Possible error during the handshake.
type Error;
/// Future that performs the handshake with the remote.
type Future: Future<Output = Result<Self::Output, Self::Error>>;

/// After we have determined that the remote supports one of the protocols we support, this
/// method is called to start the handshake.
///
/// The `info` is the identifier of the protocol, as produced by `protocol_info`.
fn upgrade_inbound(self, socket: C, info: Self::Info) -> Self::Future;
PanGan21 marked this conversation as resolved.
Show resolved Hide resolved
}

/// Possible upgrade on an outbound connection
pub trait OutboundConnectionUpgrade<T>: UpgradeInfo {
/// Output after the upgrade has been successfully negotiated and the handshake performed.
type Output;
/// Possible error during the handshake.
type Error;
/// Future that performs the handshake with the remote.
type Future: Future<Output = Result<Self::Output, Self::Error>>;

/// After we have determined that the remote supports one of the protocols we support, this
/// method is called to start the handshake.
///
/// The `info` is the identifier of the protocol, as produced by `protocol_info`.
fn upgrade_outbound(self, socket: C, info: Self::Info) -> Self::Future;
PanGan21 marked this conversation as resolved.
Show resolved Hide resolved
}

// Blanket implementation for InboundConnectionUpgrade based on InboundUpgrade for backwards compatibility
impl<U, T> InboundConnectionUpgrade<T> for U
where
U: InboundUpgrade<T>,
{
type Output = <U as InboundUpgrade<T>>::Output;
type Error = <U as InboundUpgrade<T>>::Error;
type Future = <U as InboundUpgrade<T>>::Future;

fn upgrade_inbound(self, socket: T, info: Self::Info) -> Self::Future {
self.upgrade_inbound(socket, info)
}
}

// Blanket implementation for OutboundConnectionUpgrade based on OutboundUpgrade for backwards compatibility
impl<U, T> OutboundConnectionUpgrade<T> for U
where
U: OutboundUpgrade<T>,
{
type Output = <U as OutboundUpgrade<T>>::Output;
type Error = <U as OutboundUpgrade<T>>::Error;
type Future = <U as OutboundUpgrade<T>>::Future;

fn upgrade_outbound(self, socket: T, info: Self::Info) -> Self::Future {
self.upgrade_outbound(socket, info)
}
}
24 changes: 12 additions & 12 deletions core/src/upgrade/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError};
use crate::upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeError};
use crate::{connection::ConnectedPoint, Negotiated};
use futures::{future::Either, prelude::*};
use log::debug;
Expand All @@ -37,7 +37,7 @@ pub(crate) fn apply<C, U>(
) -> Either<InboundUpgradeApply<C, U>, OutboundUpgradeApply<C, U>>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>> + OutboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
{
match cp {
ConnectedPoint::Dialer { role_override, .. } if role_override.is_dialer() => {
Expand All @@ -51,7 +51,7 @@ where
pub(crate) fn apply_inbound<C, U>(conn: C, up: U) -> InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
InboundUpgradeApply {
inner: InboundUpgradeApplyState::Init {
Expand All @@ -65,7 +65,7 @@ where
pub(crate) fn apply_outbound<C, U>(conn: C, up: U, v: Version) -> OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
OutboundUpgradeApply {
inner: OutboundUpgradeApplyState::Init {
Expand All @@ -79,7 +79,7 @@ where
pub struct InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
inner: InboundUpgradeApplyState<C, U>,
}
Expand All @@ -88,7 +88,7 @@ where
enum InboundUpgradeApplyState<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
Init {
future: ListenerSelectFuture<C, U::Info>,
Expand All @@ -104,14 +104,14 @@ where
impl<C, U> Unpin for InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
}

impl<C, U> Future for InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
U: InboundConnectionUpgrade<Negotiated<C>>,
{
type Output = Result<U::Output, UpgradeError<U::Error>>;

Expand Down Expand Up @@ -162,15 +162,15 @@ where
pub struct OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
inner: OutboundUpgradeApplyState<C, U>,
}

enum OutboundUpgradeApplyState<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
Init {
future: DialerSelectFuture<C, <U::InfoIter as IntoIterator>::IntoIter>,
Expand All @@ -186,14 +186,14 @@ where
impl<C, U> Unpin for OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
}

impl<C, U> Future for OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
U: OutboundConnectionUpgrade<Negotiated<C>>,
{
type Output = Result<U::Output, UpgradeError<U::Error>>;

Expand Down