Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement eth handshake disconnects #1494

Merged
merged 6 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion crates/net/eth-wire/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@ serde = { version = "1", optional = true }
# reth
reth-codecs = { path = "../../storage/codecs" }
reth-primitives = { path = "../../primitives" }
reth-ecies = { path = "../ecies" }
reth-rlp = { path = "../../rlp", features = ["alloc", "derive", "std", "ethereum-types", "smol_str"] }

# used for Chain and builders
ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false }

tokio = { version = "1.21.2", features = ["full"] }
tokio-util = { version = "0.7.4", features = ["io", "codec"] }
futures = "0.3.24"
tokio-stream = "0.1.11"
pin-project = "1.0"
tracing = "0.1.37"
snap = "1.0.5"
smol_str = "0.1"
metrics = "0.20.1"
async-trait = "0.1"

# arbitrary utils
arbitrary = { version = "1.1.7", features = ["derive"], optional = true }
Expand All @@ -36,7 +39,6 @@ proptest-derive = { version = "0.3", optional = true }

[dev-dependencies]
reth-primitives = { path = "../../primitives", features = ["arbitrary"] }
reth-ecies = { path = "../ecies" }
reth-tracing = { path = "../../tracing" }
ethers-core = { git = "https://github.com/gakonst/ethers-rs", default-features = false }

Expand Down
44 changes: 44 additions & 0 deletions crates/net/eth-wire/src/disconnect.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
//! Disconnect
use bytes::Bytes;
use futures::{Sink, SinkExt};
use reth_codecs::derive_arbitrary;
use reth_ecies::stream::ECIESStream;
use reth_primitives::bytes::{Buf, BufMut};
use reth_rlp::{Decodable, DecodeError, Encodable, Header};
use std::fmt::Display;
use thiserror::Error;
use tokio::io::AsyncWrite;
use tokio_util::codec::{Encoder, Framed};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -143,6 +148,45 @@ impl Decodable for DisconnectReason {
}
}

/// This trait is meant to allow higher level protocols like `eth` to disconnect from a peer, using
/// lower-level disconnect functions (such as those that exist in the `p2p` protocol) if the
/// underlying stream supports it.
#[async_trait::async_trait]
pub trait CanDisconnect<T>: Sink<T> + Unpin + Sized {
/// Disconnects from the underlying stream, using a [`DisconnectReason`] as disconnect
/// information if the stream implements a protocol that can carry the additional disconnect
/// metadata.
async fn disconnect(
&mut self,
Comment on lines +159 to +160
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using mut borrow is fine I guess, mimicks what TcpStream::shutdown does as well.

and we likely need this mut borrow if we want to use it in Sink/Stream impls

reason: DisconnectReason,
) -> Result<(), <Self as Sink<T>>::Error>;
}

// basic impls for things like Framed<TcpStream, etc>
#[async_trait::async_trait]
impl<T, I, U> CanDisconnect<I> for Framed<T, U>
where
T: AsyncWrite + Unpin + Send,
U: Encoder<I> + Send,
{
async fn disconnect(
&mut self,
_reason: DisconnectReason,
) -> Result<(), <Self as Sink<I>>::Error> {
self.close().await
}
}

#[async_trait::async_trait]
impl<S> CanDisconnect<Bytes> for ECIESStream<S>
where
S: AsyncWrite + Unpin + Send,
{
async fn disconnect(&mut self, _reason: DisconnectReason) -> Result<(), std::io::Error> {
self.close().await
}
}

#[cfg(test)]
mod tests {
use crate::{p2pstream::P2PMessage, DisconnectReason};
Expand Down
68 changes: 53 additions & 15 deletions crates/net/eth-wire/src/ethstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
errors::{EthHandshakeError, EthStreamError},
message::{EthBroadcastMessage, ProtocolBroadcastMessage},
types::{EthMessage, ProtocolMessage, Status},
EthVersion,
CanDisconnect, DisconnectReason, EthVersion,
};
use futures::{ready, Sink, SinkExt, StreamExt};
use pin_project::pin_project;
Expand Down Expand Up @@ -43,8 +43,8 @@ impl<S> UnauthedEthStream<S> {

impl<S, E> UnauthedEthStream<S>
where
S: Stream<Item = Result<BytesMut, E>> + Sink<Bytes, Error = E> + Unpin,
EthStreamError: From<E>,
S: Stream<Item = Result<BytesMut, E>> + CanDisconnect<Bytes> + Unpin,
EthStreamError: From<E> + From<<S as Sink<Bytes>>::Error>,
{
/// Consumes the [`UnauthedEthStream`] and returns an [`EthStream`] after the `Status`
/// handshake is completed successfully. This also returns the `Status` message sent by the
Expand All @@ -67,13 +67,18 @@ where
self.inner.send(our_status_bytes).await?;

tracing::trace!("waiting for eth status from peer");
let their_msg = self
.inner
.next()
.await
.ok_or(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))??;
let their_msg_res = self.inner.next().await;

let their_msg = match their_msg_res {
Some(msg) => msg,
None => {
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::NoResponse))
}
}?;

if their_msg.len() > MAX_MESSAGE_SIZE {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthStreamError::MessageTooBig(their_msg.len()))
}

Expand All @@ -82,6 +87,7 @@ where
Ok(m) => m,
Err(err) => {
tracing::debug!("decode error in eth handshake: msg={their_msg:x}");
self.inner.disconnect(DisconnectReason::DisconnectRequested).await?;
return Err(err)
}
};
Expand All @@ -95,6 +101,7 @@ where
"validating incoming eth status from peer"
);
if status.genesis != resp.genesis {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedGenesis {
expected: status.genesis,
got: resp.genesis,
Expand All @@ -103,6 +110,7 @@ where
}

if status.version != resp.version {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedProtocolVersion {
expected: status.version,
got: resp.version,
Expand All @@ -111,6 +119,7 @@ where
}

if status.chain != resp.chain {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::MismatchedChain {
expected: status.chain,
got: resp.chain,
Expand All @@ -121,24 +130,33 @@ where
// TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times
// larger, it will still fit within 100 bits
if status.total_difficulty.bit_len() > 100 {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(EthHandshakeError::TotalDifficultyBitLenTooLarge {
maximum: 100,
got: status.total_difficulty.bit_len(),
}
.into())
}

fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)?;
if let Err(err) =
fork_filter.validate(resp.forkid).map_err(EthHandshakeError::InvalidFork)
{
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
return Err(err.into())
}

// now we can create the `EthStream` because the peer has successfully completed
// the handshake
let stream = EthStream::new(version, self.inner);

Ok((stream, resp))
}
_ => Err(EthStreamError::EthHandshakeError(
EthHandshakeError::NonStatusMessageInHandshake,
)),
_ => {
self.inner.disconnect(DisconnectReason::ProtocolBreach).await?;
Err(EthStreamError::EthHandshakeError(
EthHandshakeError::NonStatusMessageInHandshake,
))
}
}
}
}
Expand Down Expand Up @@ -239,10 +257,10 @@ where
}
}

impl<S, E> Sink<EthMessage> for EthStream<S>
impl<S> Sink<EthMessage> for EthStream<S>
where
S: Sink<Bytes, Error = E> + Unpin,
EthStreamError: From<E>,
S: CanDisconnect<Bytes> + Unpin,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
{
type Error = EthStreamError;

Expand All @@ -252,6 +270,15 @@ where

fn start_send(self: Pin<&mut Self>, item: EthMessage) -> Result<(), Self::Error> {
if matches!(item, EthMessage::Status(_)) {
// TODO: to disconnect here we would need to do something similar to P2PStream's
// start_disconnect, which would ideally be a part of the CanDisconnect trait, or at
// least similar.
//
// Other parts of reth do not need traits like CanDisconnect because they work
// exclusively with EthStream<P2PStream<S>>, where the inner P2PStream is accessible,
// allowing for its start_disconnect method to be called.
//
// self.project().inner.start_disconnect(DisconnectReason::ProtocolBreach);
return Err(EthStreamError::EthHandshakeError(EthHandshakeError::StatusNotInHandshake))
}

Expand All @@ -273,6 +300,17 @@ where
}
}

#[async_trait::async_trait]
impl<S> CanDisconnect<EthMessage> for EthStream<S>
where
S: CanDisconnect<Bytes> + Send,
EthStreamError: From<<S as Sink<Bytes>>::Error>,
{
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), EthStreamError> {
self.inner.disconnect(reason).await.map_err(Into::into)
}
}

#[cfg(test)]
mod tests {
use super::UnauthedEthStream;
Expand Down
2 changes: 1 addition & 1 deletion crates/net/eth-wire/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use tokio_util::codec::{
};

pub use crate::{
disconnect::DisconnectReason,
disconnect::{CanDisconnect, DisconnectReason},
ethstream::{EthStream, UnauthedEthStream, MAX_MESSAGE_SIZE},
hello::HelloMessage,
p2pstream::{P2PMessage, P2PMessageID, P2PStream, ProtocolVersion, UnauthedP2PStream},
Expand Down
55 changes: 33 additions & 22 deletions crates/net/eth-wire/src/p2pstream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(dead_code, unreachable_pub, missing_docs, unused_variables)]
use crate::{
capability::{Capability, SharedCapability},
disconnect::CanDisconnect,
errors::{P2PHandshakeError, P2PStreamError},
pinger::{Pinger, PingerEvent},
DisconnectReason, HelloMessage,
Expand Down Expand Up @@ -72,25 +73,6 @@ impl<S> UnauthedP2PStream<S> {
}
}

impl<S> UnauthedP2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
{
/// Send a disconnect message during the handshake. This is sent without snappy compression.
pub async fn send_disconnect(
&mut self,
reason: DisconnectReason,
) -> Result<(), P2PStreamError> {
let mut buf = BytesMut::new();
P2PMessage::Disconnect(reason).encode(&mut buf);
tracing::trace!(
%reason,
"Sending disconnect message during the handshake",
);
self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io)
}
}

impl<S> UnauthedP2PStream<S>
where
S: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
Expand Down Expand Up @@ -180,6 +162,35 @@ where
}
}

impl<S> UnauthedP2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
{
/// Send a disconnect message during the handshake. This is sent without snappy compression.
pub async fn send_disconnect(
&mut self,
reason: DisconnectReason,
) -> Result<(), P2PStreamError> {
let mut buf = BytesMut::new();
P2PMessage::Disconnect(reason).encode(&mut buf);
tracing::trace!(
%reason,
"Sending disconnect message during the handshake",
);
self.inner.send(buf.freeze()).await.map_err(P2PStreamError::Io)
}
}

#[async_trait::async_trait]
impl<S> CanDisconnect<Bytes> for P2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin + Send + Sync,
{
async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
self.disconnect(reason).await
}
}

/// A P2PStream wraps over any `Stream` that yields bytes and makes it compatible with `p2p`
/// protocol messages.
#[pin_project]
Expand Down Expand Up @@ -284,13 +295,13 @@ impl<S> P2PStream<S> {

impl<S> P2PStream<S>
where
S: Sink<Bytes, Error = io::Error> + Unpin,
S: Sink<Bytes, Error = io::Error> + Unpin + Send,
{
/// Disconnects the connection by sending a disconnect message.
///
/// This future resolves once the disconnect message has been sent and the stream has been
/// closed.
pub async fn disconnect(mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
pub async fn disconnect(&mut self, reason: DisconnectReason) -> Result<(), P2PStreamError> {
self.start_disconnect(reason)?;
self.close().await
}
Expand Down Expand Up @@ -821,7 +832,7 @@ mod tests {

let (server_hello, _) = eth_hello();

let (p2p_stream, _) =
let (mut p2p_stream, _) =
UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();

p2p_stream.disconnect(expected_disconnect).await.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/net/network/src/session/active.rs
Original file line number Diff line number Diff line change
Expand Up @@ -753,9 +753,9 @@ mod tests {
&self,
local_addr: SocketAddr,
f: F,
) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>>
) -> Pin<Box<dyn Future<Output = ()> + Send>>
where
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>>) -> O + Send + Sync + 'static,
F: FnOnce(EthStream<P2PStream<ECIESStream<TcpStream>>>) -> O + Send + 'static,
O: Future<Output = ()> + Send + Sync,
{
let status = self.status;
Expand Down