Skip to content

Commit

Permalink
feat: implement eth handshake disconnects (#1494)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rjected authored Feb 22, 2023
1 parent 0fc9f67 commit c168ef4
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 41 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion crates/net/eth-wire/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@ serde = { version = "1", optional = true }
# reth
reth-codecs = { path = "../../storage/codecs" }
reth-primitives = { path = "../../primitives" }
reth-ecies = { path = "../ecies" }
reth-rlp = { path = "../../rlp", features = ["alloc", "derive", "std", "ethereum-types", "smol_str"] }

# used for Chain and builders
ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false }

tokio = { version = "1.21.2", features = ["full"] }
tokio-util = { version = "0.7.4", features = ["io", "codec"] }
futures = "0.3.24"
tokio-stream = "0.1.11"
pin-project = "1.0"
tracing = "0.1.37"
snap = "1.0.5"
smol_str = "0.1"
metrics = "0.20.1"
async-trait = "0.1"

# arbitrary utils
arbitrary = { version = "1.1.7", features = ["derive"], optional = true }
Expand All @@ -36,7 +39,6 @@ proptest-derive = { version = "0.3", optional = true }

[dev-dependencies]
reth-primitives = { path = "../../primitives", features = ["arbitrary"] }
reth-ecies = { path = "../ecies" }
reth-tracing = { path = "../../tracing" }
ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false }

Expand Down
44 changes: 44 additions & 0 deletions crates/net/eth-wire/src/disconnect.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
//! Disconnect
use bytes::Bytes;
use futures::{Sink, SinkExt};
use reth_codecs::derive_arbitrary;
use reth_ecies::stream::ECIESStream;
use reth_primitives::bytes::{Buf, BufMut};
use reth_rlp::{Decodable, DecodeError, Encodable, Header};
use std::fmt::Display;
use thiserror::Error;
use tokio::io::AsyncWrite;
use tokio_util::codec::{Encoder, Framed};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -143,6 +148,45 @@ impl Decodable for DisconnectReason {
}
}

/// This trait is meant to allow higher level protocols like `eth` to disconnect from a peer, using
/// lower-level disconnect functions (such as those that exist in the `p2p` protocol) if the
/// underlying stream supports it.
#[async_trait::async_trait]
pub trait CanDisconnect<T>: Sink<T> + Unpin + Sized {
/// Disconnects from the underlying stream, using a [`DisconnectReason`] as disconnect
/// information if the stream implements a protocol that can carry the additional disconnect
/// metadata.
async fn disconnect(
&mut self,
reason: DisconnectReason,
) -> Result<(), <Self as Sink<T>>::Error>;
}

// basic impls for things like Framed<TcpStream, etc>
#[async_trait::async_trait]
impl<T, I, U> CanDisconnect<I> for Framed<T, U>
where
T: AsyncWrite + Unpin + Send,
U: Encoder<I> + Send,
{
async fn disconnect(
&mut self,
_reason: DisconnectReason,
) -> Result<(), <Self as Sink<I>>::Error> {
self.close().await
}
}

#[async_trait::async_trait]
impl<S> CanDisconnect<Bytes> for ECIESStream<S>
where
S: AsyncWrite + Unpin + Send,
{
async fn disconnect(&mut self, _reason: DisconnectReason) -> Result<(), std::io::Error> {
self.close().await
}
}

#[cfg(test)]
mod tests {
use crate::{p2pstream::P2PMessage, DisconnectReason};
Expand Down
68 changes: 53 additions & 15 deletions crates/net/eth-wire/src/ethstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
errors::{EthHandshakeError, EthStreamError},
message::{EthBroadcastMessage, ProtocolBroadcastMessage},
types::{EthMessage, ProtocolMessage, Status},
EthVersion,
CanDisconnect, DisconnectReason, EthVersion,
};
use futures::{ready, Sink, SinkExt, StreamExt};
use pin_project::pin_project;
Expand Down Expand Up @@ -43,8 +43,8 @@ impl<S> UnauthedEthStream<S> {

impl<S, E> UnauthedEthStream<S>
where
S: Stream<Item = Result<BytesMut, E>> + Sink<Bytes, Error = E> + Unpin,
EthStreamError: From<E>,
S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Unpin,
EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
{
/// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
/// handshake is completed successfully. This also returns the `Status` message sent by the
Expand All @@ -67,13 +67,18 @@ where
self.inner.send(our_status_bytes).await?;

tracing::trace!("waiting for eth status from peer");
let their_msg = self
.inner
.next()
.await
.ok_or(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))??;
let their_msg_res = self.inner.next().await;

let their_msg = match their_msg_res {
Some(msg) => msg,
None => {
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))
}
}?;

if their_msg.len() > MAX_MESSAGE_SIZE {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthStreamError::MessageTooBig(their_msg.len()))
}

Expand All @@ -82,6 +87,7 @@ where
Ok(m) => m,
Err(err) => {
tracing::debug!("decode error in eth handshake: msg={their_msg:x}");
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(err)
}
};
Expand All @@ -95,6 +101,7 @@ where
"validating incoming eth status from peer"
);
if status.genesis != resp.genesis {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedGenesis {
expected: status.genesis,
got: resp.genesis,
Expand All @@ -103,6 +110,7 @@ where
}

if status.version != resp.version {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedProtocolVersion {
expected: status.version,
got: resp.version,
Expand All @@ -111,6 +119,7 @@ where
}

if status.chain != resp.chain {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedChain {
expected: status.chain,
got: resp.chain,
Expand All @@ -121,24 +130,33 @@ where
// TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
// larger, it will still fit within 100 bits
if status.total_difficulty.bit_len() > 100 {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
maximum: 100,
got: status.total_difficulty.bit_len(),
}
.into())
}

fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)?;
if let Err(err) =
fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)
{
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(err.into())
}

// now we can create the `EthStream` because the peer has successfully completed
// the handshake
let stream = EthStream::new(version, self.inner);

Ok((stream, resp))
}
_ => Err(EthStreamError::EthHandshakeError(
EthHandshakeError::NonStatusMessageInHandshake,
)),
_ => {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
Err(EthStreamError::EthHandshakeError(
EthHandshakeError::NonStatusMessageInHandshake,
))
}
}
}
}
Expand Down Expand Up @@ -239,10 +257,10 @@ where
}
}

impl<S, E> Sink<EthMessage> for EthStream<S>
impl<S> Sink<EthMessage> for EthStream<S>
where
S: Sink<Bytes, Error = E> + Unpin,
EthStreamError: From<E>,
S: CanDisconnect<Bytes> + Unpin,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
{
type Error = EthStreamError;

Expand All @@ -252,6 +270,15 @@ where

fn start_send(self: Pin<&mut Self>, item: EthMessage) -> Result<(), Self::Error> {
if matches!(item, EthMessage::Status(_)) {
// TODO: to disconnect here we would need to do something similar to P2PStream's
// start_disconnect, which would ideally be a part of the CanDisconnect trait, or at
// least similar.
//
// Other parts of reth do not need traits like CanDisconnect because they work
// exclusively with EthStream<P2PStream<S>>, where the inner P2PStream is accessible,
// allowing for its start_disconnect method to be called.
//
// self.project().inner.start_disconnect(DisconnectReason::ProtocolBreach);
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
}

Expand All @@ -273,6 +300,17 @@ where
}
}

#[async_trait::async_trait]
impl<S> CanDisconnect<EthMessage> for EthStream<S>
where
S: CanDisconnect<Bytes> + Send,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
{
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
self.inner.disconnect(reason).await.map_err(Into::into)
}
}

#[cfg(test)]
mod tests {
use super::UnauthedEthStream;
Expand Down
2 changes: 1 addition & 1 deletion crates/net/eth-wire/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use tokio_util::codec::{
};

pub use crate::{
disconnect::DisconnectReason,
disconnect::{CanDisconnect, DisconnectReason},
ethstream::{EthStream, UnauthedEthStream, MAX_MESSAGE_SIZE},
hello::HelloMessage,
p2pstream::{P2PMessage, P2PMessageID, P2PStream, ProtocolVersion, UnauthedP2PStream},
Expand Down
55 changes: 33 additions & 22 deletions crates/net/eth-wire/src/p2pstream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(dead_code, unreachable_pub, missing_docs, unused_variables)]
use crate::{
capability::{Capability, SharedCapability},
disconnect::CanDisconnect,
errors::{P2PHandshakeError, P2PStreamError},
pinger::{Pinger, PingerEvent},
DisconnectReason, HelloMessage,
Expand Down Expand Up @@ -72,25 +73,6 @@ impl<S> UnauthedP2PStream<S> {
}
}

impl<S> UnauthedP2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
{
/// Send a disconnect message during the handshake. This is sent without snappy compression.
pub async fn send_disconnect(
&mut self,
reason: DisconnectReason,
) -> Result<(), P2PStreamError> {
let mut buf = BytesMut::new();
P2PMessage::Disconnect(reason).encode(&mut buf);
tracing::trace!(
%reason,
"Sending disconnect message during the handshake",
);
self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io)
}
}

impl<S> UnauthedP2PStream<S>
where
S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
Expand Down Expand Up @@ -180,6 +162,35 @@ where
}
}

impl<S> UnauthedP2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
{
/// Send a disconnect message during the handshake. This is sent without snappy compression.
pub async fn send_disconnect(
&mut self,
reason: DisconnectReason,
) -> Result<(), P2PStreamError> {
let mut buf = BytesMut::new();
P2PMessage::Disconnect(reason).encode(&mut buf);
tracing::trace!(
%reason,
"Sending disconnect message during the handshake",
);
self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io)
}
}

#[async_trait::async_trait]
impl<S> CanDisconnect<Bytes> for P2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
{
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
self.disconnect(reason).await
}
}

/// A P2PStream wraps over any `Stream` that yields bytes and makes it compatible with `p2p`
/// protocol messages.
#[pin_project]
Expand Down Expand Up @@ -284,13 +295,13 @@ impl<S> P2PStream<S> {

impl<S> P2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
S: Sink<Bytes, Error = io::Error> + Unpin + Send,
{
/// Disconnects the connection by sending a disconnect message.
///
/// This future resolves once the disconnect message has been sent and the stream has been
/// closed.
pub async fn disconnect(mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
self.start_disconnect(reason)?;
self.close().await
}
Expand Down Expand Up @@ -821,7 +832,7 @@ mod tests {

let (server_hello, _) = eth_hello();

let (p2p_stream, _) =
let (mut p2p_stream, _) =
UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();

p2p_stream.disconnect(expected_disconnect).await.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/net/network/src/session/active.rs
Original file line number Diff line number Diff line change
Expand Up @@ -753,9 +753,9 @@ mod tests {
&self,
local_addr: SocketAddr,
f: F,
) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>>
) -> Pin<Box<dyn Future<Output = ()> + Send>>
where
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>>) -> O + Send + Sync + 'static,
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>>) -> O + Send + 'static,
O: Future<Output = ()> + Send + Sync,
{
let status = self.status;
Expand Down

0 comments on commit c168ef4

Please sign in to comment.