diff --git a/Cargo.lock b/Cargo.lock index 5ce984b..045c75e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -222,6 +222,15 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9cf849ee05b2ee5fba5e36f97ff8ec2533916700fc0758d40d92136a42f3388" +dependencies = [ + "digest 0.10.3", +] + [[package]] name = "block-buffer" version = "0.9.0" @@ -309,6 +318,31 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chacha20" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01b72a433d0cf2aef113ba70f62634c56fddb0f244e6377185c56a7cadbd8f91" +dependencies = [ + "cfg-if", + "cipher 0.3.0", + "cpufeatures", + "zeroize", +] + +[[package]] +name = "chacha20poly1305" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b84ed6d1d5f7aa9bdde921a5090e0ca4d934d250ea3b402a5fab3a994e28a2a" +dependencies = [ + "aead 0.4.3", + "chacha20", + "cipher 0.3.0", + "poly1305", + "zeroize", +] + [[package]] name = "chrono" version = "0.4.19" @@ -452,6 +486,19 @@ dependencies = [ "zeroize", ] +[[package]] +name = "curve25519-dalek" +version = "4.0.0-pre.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12dc3116fe595d7847c701796ac1b189bd86b81f4f593c6f775f9d80fb2e29f4" +dependencies = [ + "byteorder", + "digest 0.10.3", + "rand_core 0.6.3", + "subtle", + "zeroize", +] + [[package]] name = "darling" version = "0.12.4" @@ -599,6 +646,7 @@ checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" dependencies = [ "block-buffer 0.10.2", "crypto-common", + "subtle", ] [[package]] @@ -638,7 +686,7 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c762bae6dcaf24c4c84667b8579785430908723d5c889f469d76a41d59cc7a9d" dependencies = [ - "curve25519-dalek", + "curve25519-dalek 3.2.1", "ed25519", "rand 0.7.3", "serde", @@ -1141,19 +1189,84 @@ dependencies = [ "zeroize", ] +[[package]] +name = "libp2p-noise" +version = "0.35.0" +source = "git+https://github.com/melekes/rust-libp2p?branch=anton/x-webrtc#eecff4070011de91174a4980fac5067687df0e87" +dependencies = [ + "bytes", + "curve25519-dalek 3.2.1", + "futures", + "lazy_static", + "libp2p-core", + "log", + "prost", + "prost-build", + "rand 0.8.5", + "sha2 0.10.2", + "snow", + "static_assertions", + "x25519-dalek", + "zeroize", +] + +[[package]] +name = "libp2p-request-response" +version = "0.17.0" +source = "git+https://github.com/melekes/rust-libp2p?branch=anton/x-webrtc#eecff4070011de91174a4980fac5067687df0e87" +dependencies = [ + "async-trait", + "bytes", + "futures", + "instant", + "libp2p-core", + "libp2p-swarm", + "log", + "rand 0.7.3", + "smallvec", + "unsigned-varint", +] + +[[package]] +name = "libp2p-swarm" +version = "0.35.0" +source = "git+https://github.com/melekes/rust-libp2p?branch=anton/x-webrtc#eecff4070011de91174a4980fac5067687df0e87" +dependencies = [ + "either", + "fnv", + "futures", + "futures-timer", + "instant", + "libp2p-core", + "log", + "pin-project 1.0.10", + "rand 0.7.3", + "smallvec", + "thiserror", + "void", +] + [[package]] name = "libp2p-webrtc-direct" version = "0.1.0" dependencies = [ + "anyhow", "async-trait", "bytes", "env_logger", + "fnv", "futures", + "futures-lite", "futures-timer", "hex", "if-watch", "libp2p-core", + "libp2p-noise", + "libp2p-request-response", + "libp2p-swarm", "log", + "rand 0.8.5", + "rand_core 0.5.1", "rcgen", "serde", "stun", @@ -1587,6 +1700,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "poly1305" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "048aeb476be11a4b6ca432ca569e375810de9294ae78f4774e78ea98a9246ede" +dependencies = [ + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "polyval" version = "0.4.5" @@ -1891,6 +2015,15 @@ dependencies = [ "webrtc-util", ] +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rusticata-macros" version = "3.2.0" @@ -1973,6 +2106,12 @@ dependencies = [ "url", ] +[[package]] +name = "semver" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cb243bdfdb5936c8dc3c45762a19d12ab4550cdc753bc247637d4ec35a040fd" + [[package]] name = "serde" version = "1.0.136" @@ -2072,6 +2211,23 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" +[[package]] +name = "snow" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "774d05a3edae07ce6d68ea6984f3c05e9bba8927e3dd591e3b479e5b03213d0d" +dependencies = [ + "aes-gcm 0.9.4", + "blake2", + "chacha20poly1305", + "curve25519-dalek 4.0.0-pre.2", + "rand_core 0.6.3", + "ring", + "rustc_version", + "sha2 0.10.2", + "subtle", +] + [[package]] name = "socket2" version = "0.4.4" @@ -2233,9 +2389,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.17.0" +version = "1.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2af73ac49756f3f7c01172e34a23e5d0216f6c32333757c2c61feb2bbff5a5ee" +checksum = "4903bf0427cf68dddd5aa6a93220756f8be0c34fcfa9f5e6191e103e15a31395" dependencies = [ "bytes", "libc", @@ -2337,6 +2493,10 @@ name = "unsigned-varint" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d86a8dc7f45e4c1b0d30e43038c38f274e77af056aa5f74b93c2cf9eb3c1c836" +dependencies = [ + "futures-io", + "futures-util", +] [[package]] name = "untrusted" @@ -2487,7 +2647,7 @@ dependencies = [ [[package]] name = "webrtc" version = "0.4.0" -source = "git+https://github.com/melekes/webrtc?branch=anton/168-allow-persistent-certificates#c3eaeb5e4e13b36f0b833c0bdd88bcf35c1f806f" +source = "git+https://github.com/melekes/webrtc?branch=anton/168-allow-persistent-certificates#3cbaaebacfcb808c7fd3d86b101b2d8ed48a764d" dependencies = [ "async-trait", "bytes", @@ -2525,8 +2685,7 @@ dependencies = [ [[package]] name = "webrtc-data" version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c6e77d7f15d517a85b1ac10921a91b541d2358174313719cbd1ed3ee1a439d" +source = "git+https://github.com/melekes/webrtc-data.git?branch=anton/async-read-write-for-data-channel#fc7c5f75561abc8caffa122abaa0bc7395c278be" dependencies = [ "bytes", "derive_builder", @@ -2578,14 +2737,16 @@ dependencies = [ [[package]] name = "webrtc-ice" -version = "0.6.6" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f2a4a1ae47cdce72c648d94895131a80f0ec96d8c78d9187b37d777b9c18cf" +checksum = "d8c20824d5b1f223b71b5db70ae327015756eb83af42f07869409cb7473a168b" dependencies = [ "async-trait", "crc", "log", "rand 0.8.5", + "serde", + "serde_json", "stun", "thiserror", "tokio", @@ -2628,9 +2789,9 @@ dependencies = [ [[package]] name = "webrtc-sctp" -version = "0.4.3" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d48f4c8601507ef48712acdfd5019fc68c4f42d6cf934ec6937b9737339ba030" +checksum = "c97ba5fad11c9109608e774e90250cb2b61f6b6ce52b3b20950fc83c8db9ee87" dependencies = [ "async-trait", "bytes", @@ -2794,7 +2955,7 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2392b6b94a576b4e2bf3c5b2757d63f10ada8020a2e4d08ac849ebcf6ea8e077" dependencies = [ - "curve25519-dalek", + "curve25519-dalek 3.2.1", "rand_core 0.5.1", "zeroize", ] diff --git a/Cargo.toml b/Cargo.toml index 67bf5b5..143d511 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,25 +9,36 @@ edition = "2021" publish = false [dependencies] +async-trait = "0.1.52" bytes = "1" -env_logger = "0.9.0" +fnv = "1.0" futures = "0.3.17" +futures-lite = "1.12.0" +futures-timer = "3.0" hex = "0.4" +if-watch = "0.2.2" libp2p-core = { version = "0.32.0", default-features = false, git = "https://github.com/melekes/rust-libp2p", branch = "anton/x-webrtc" } +libp2p-noise = { version = "0.35.0", git = "https://github.com/melekes/rust-libp2p", branch = "anton/x-webrtc" } log = "0.4.14" serde = { version = "1.0", features = ["derive"] } +stun = "0.4.2" thiserror = "1" tinytemplate = "1.2" -tokio-crate = { package = "tokio", version = "1.17.0", default-features = false, features = ["net"]} +tokio-crate = { package = "tokio", version = "1.18.2", features = ["net"]} webrtc = { version = "0.4.0", git = "https://github.com/melekes/webrtc", branch = "anton/168-allow-persistent-certificates" } -webrtc-ice = "0.6.6" webrtc-data = "0.3.3" -webrtc-sctp = "0.4.3" -if-watch = "0.2.2" -futures-timer = "3.0" -stun = "0.4.2" +webrtc-ice = "0.7.0" +webrtc-sctp = "0.5.0" webrtc-util = { version = "0.5.3", default-features = false, features = ["conn", "vnet", "sync"] } -async-trait = "0.1.52" [dev-dependencies] +env_logger = "0.9.0" +anyhow = "1.0.41" +rand = "0.8.4" +rand_core = "0.5.1" rcgen = "0.8.14" +libp2p-swarm = { version = "0.35.0", git = "https://github.com/melekes/rust-libp2p", branch = "anton/x-webrtc" } +libp2p-request-response = { version = "0.17.0", git = "https://github.com/melekes/rust-libp2p", branch = "anton/x-webrtc" } + +[patch.crates-io] +webrtc-data = { version = "0.3.3", git = "https://github.com/melekes/webrtc-data.git", branch = "anton/async-read-write-for-data-channel" } diff --git a/src/connection.rs b/src/connection.rs index 58d0b12..eae50f8 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -20,59 +20,316 @@ mod poll_data_channel; -use futures::prelude::*; +use fnv::FnvHashMap; +use futures::channel::mpsc; +use futures::channel::oneshot::{self, Sender}; +use futures::lock::Mutex as FutMutex; +use futures::{future::BoxFuture, prelude::*, ready}; +use futures_lite::stream::StreamExt; +use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent}; +use log::{debug, error, trace}; +use webrtc::data_channel::RTCDataChannel; use webrtc::peer_connection::RTCPeerConnection; -use webrtc_data::data_channel::DataChannel; +use webrtc_data::data_channel::DataChannel as DetachedDataChannel; use std::io; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Mutex as StdMutex}; use std::task::{Context, Poll}; -use poll_data_channel::PollDataChannel; +pub(crate) use poll_data_channel::PollDataChannel; -/// A WebRTC connection over a single data channel. See lib documentation for -/// the reasoning as to why a single data channel is being used. -pub struct Connection<'a> { +/// A WebRTC connection, wrapping [`RTCPeerConnection`] and implementing [`StreamMuxer`] trait. +pub struct Connection { + connection_inner: Arc>, // uses futures mutex because used in async code (see open_outbound) + data_channels_inner: StdMutex, +} + +struct ConnectionInner { /// `RTCPeerConnection` to the remote peer. - pub inner: RTCPeerConnection, - /// A data channel. - pub data_channel: PollDataChannel<'a>, + rtc_conn: RTCPeerConnection, +} + +struct DataChannelsInner { + /// A map of data channels. + map: FnvHashMap, + /// Channel onto which incoming data channels are put. + incoming_data_channels_rx: mpsc::Receiver>, + /// Temporary read buffer's capacity (equal for all data channels). + /// See [`PollDataChannel`] `read_buf_cap`. + read_buf_cap: Option, } -impl Connection<'_> { - pub fn new(peer_conn: RTCPeerConnection, data_channel: Arc) -> Self { +impl Connection { + /// Creates a new connection. + pub async fn new(rtc_conn: RTCPeerConnection) -> Self { + let (data_channel_tx, data_channel_rx) = mpsc::channel(10); + + Connection::register_incoming_data_channels_handler(&rtc_conn, data_channel_tx).await; + Self { - inner: peer_conn, - data_channel: PollDataChannel::new(data_channel), + connection_inner: Arc::new(FutMutex::new(ConnectionInner { rtc_conn })), + data_channels_inner: StdMutex::new(DataChannelsInner { + map: FnvHashMap::default(), + incoming_data_channels_rx: data_channel_rx, + read_buf_cap: None, + }), } } + + /// Set the capacity of a data channel's temporary read buffer (equal for all data channels; default: 8192). + pub fn set_data_channels_read_buf_capacity(&mut self, cap: usize) { + let mut data_channels_inner = self.data_channels_inner.lock().unwrap(); + data_channels_inner.read_buf_cap = Some(cap); + } + + /// Registers a handler for incoming data channels. + async fn register_incoming_data_channels_handler( + rtc_conn: &RTCPeerConnection, + tx: mpsc::Sender>, + ) { + rtc_conn + .on_data_channel(Box::new(move |data_channel: Arc| { + debug!( + "Incoming data channel '{}'-'{}'", + data_channel.label(), + data_channel.id() + ); + + let data_channel = data_channel.clone(); + let mut tx = tx.clone(); + + Box::pin(async move { + data_channel + .on_open({ + let data_channel = data_channel.clone(); + Box::new(move || { + debug!( + "Data channel '{}'-'{}' open", + data_channel.label(), + data_channel.id() + ); + + Box::pin(async move { + let data_channel = data_channel.clone(); + match data_channel.detach().await { + Ok(detached) => { + if let Err(e) = tx.try_send(detached) { + // This can happen if the client is not reading + // events (using `poll_event`) fast enough, which + // generally shouldn't be the case. + error!("Can't send data channel: {}", e); + } + }, + Err(e) => { + error!("Can't detach data channel: {}", e); + }, + }; + }) + }) + }) + .await; + }) + })) + .await; + } } -impl AsyncRead for Connection<'_> { - fn poll_read( - mut self: Pin<&mut Self>, +impl<'a> StreamMuxer for Connection { + type Substream = PollDataChannel; + type OutboundSubstream = BoxFuture<'static, Result, Self::Error>>; + type Error = io::Error; + + fn poll_event( + &self, cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + let mut data_channels_inner = self.data_channels_inner.lock().unwrap(); + match ready!(data_channels_inner.incoming_data_channels_rx.poll_next(cx)) { + Some(detached) => { + trace!("Incoming substream {}", detached.stream_identifier()); + + let ch = PollDataChannel::new(detached); + // if let Some(cap) = data_channels_inner.read_buf_cap { + // ch.set_read_buf_capacity(cap); + // } + + data_channels_inner + .map + .insert(ch.stream_identifier(), ch.clone()); + + Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(ch))) + }, + None => Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "incoming_data_channels_rx is closed (no messages left)", + ))), + } + } + + fn open_outbound(&self) -> Self::OutboundSubstream { + let connection_inner = self.connection_inner.clone(); + + Box::pin(async move { + let connection_inner = connection_inner.lock().await; + + // Create a datachannel with label 'data' + let data_channel = connection_inner + .rtc_conn + .create_data_channel("data", None) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("webrtc error: {}", e))) + .await?; + + trace!("Opening outbound substream {}", data_channel.id()); + + // No need to hold the lock during the DTLS handshake. + drop(connection_inner); + + let (tx, rx) = oneshot::channel::>(); + + // Wait until the data channel is opened and detach it. + register_data_channel_open_handler(data_channel, tx).await; + + // Wait until data channel is opened and ready to use + match rx.await { + Ok(detached) => Ok(detached), + Err(e) => Err(io::Error::new(io::ErrorKind::Other, e.to_string())), + } + }) + } + + fn poll_outbound( + &self, + cx: &mut Context<'_>, + s: &mut Self::OutboundSubstream, + ) -> Poll> { + match ready!(s.as_mut().poll(cx)) { + Ok(detached) => { + let mut data_channels_inner = self.data_channels_inner.lock().unwrap(); + + let ch = PollDataChannel::new(detached); + // if let Some(cap) = data_channels_inner.read_buf_cap { + // ch.set_read_buf_capacity(cap); + // } + + data_channels_inner + .map + .insert(ch.stream_identifier(), ch.clone()); + + Poll::Ready(Ok(ch)) + }, + Err(e) => Poll::Ready(Err(e)), + } + } + + /// NOTE: `_s` might be waiting at one of the await points, and dropping the future will + /// abruptly interrupt the execution. + fn destroy_outbound(&self, _s: Self::OutboundSubstream) {} + + fn read_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, buf: &mut [u8], - ) -> Poll> { - Pin::new(&mut self.data_channel).poll_read(cx, buf) + ) -> Poll> { + Pin::new(s).poll_read(cx, buf) } -} -impl AsyncWrite for Connection<'_> { - fn poll_write( - mut self: Pin<&mut Self>, + fn write_substream( + &self, cx: &mut Context<'_>, + s: &mut Self::Substream, buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.data_channel).poll_write(cx, buf) + ) -> Poll> { + Pin::new(s).poll_write(cx, buf) } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.data_channel).poll_flush(cx) + fn flush_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { + Pin::new(s).poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.data_channel).poll_close(cx) + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { + trace!("Closing substream {}", s.stream_identifier()); + Pin::new(s).poll_close(cx) } + + fn destroy_substream(&self, s: Self::Substream) { + let mut data_channels_inner = self.data_channels_inner.lock().unwrap(); + data_channels_inner.map.remove(&s.stream_identifier()); + } + + fn close(&self, cx: &mut Context<'_>) -> Poll> { + debug!("Closing connection"); + + // First, flush all the buffered data. + match ready!(self.flush_all(cx)) { + Ok(_) => { + // Second, shutdown all the substreams. + let mut data_channels_inner = self.data_channels_inner.lock().unwrap(); + for (_, ch) in &mut data_channels_inner.map { + match ready!(self.shutdown_substream(cx, ch)) { + Ok(_) => continue, + Err(e) => return Poll::Ready(Err(e)), + } + } + + // Third, close `incoming_data_channels_rx` + data_channels_inner.incoming_data_channels_rx.close(); + + Poll::Ready(Ok(())) + }, + Err(e) => Poll::Ready(Err(e)), + } + } + + fn flush_all(&self, cx: &mut Context<'_>) -> Poll> { + let mut data_channels_inner = self.data_channels_inner.lock().unwrap(); + for (_, ch) in &mut data_channels_inner.map { + match ready!(self.flush_substream(cx, ch)) { + Ok(_) => continue, + Err(e) => return Poll::Ready(Err(e)), + } + } + Poll::Ready(Ok(())) + } +} + +pub(crate) async fn register_data_channel_open_handler( + data_channel: Arc, + data_channel_tx: Sender>, +) { + data_channel + .on_open({ + let data_channel = data_channel.clone(); + Box::new(move || { + debug!( + "Data channel '{}'-'{}' open", + data_channel.label(), + data_channel.id() + ); + + Box::pin(async move { + let data_channel = data_channel.clone(); + match data_channel.detach().await { + Ok(detached) => { + if let Err(e) = data_channel_tx.send(detached) { + error!("Can't send data channel: {:?}", e); + } + }, + Err(e) => { + error!("Can't detach data channel: {}", e); + }, + }; + }) + }) + }) + .await; } diff --git a/src/connection/poll_data_channel.rs b/src/connection/poll_data_channel.rs index a05ce4f..f9d53ad 100644 --- a/src/connection/poll_data_channel.rs +++ b/src/connection/poll_data_channel.rs @@ -18,266 +18,398 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::Bytes; - use futures::prelude::*; -use futures::ready; use webrtc_data::data_channel::DataChannel; -use webrtc_data::Error; +use webrtc_data::data_channel::PollDataChannel as RTCPollDataChannel; -use std::fmt; use std::io; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -/// Default capacity of the temporary read buffer used by [`PollStream`]. -const DEFAULT_READ_BUF_SIZE: usize = 4096; - -/// State of the read `Future` in [`PollStream`]. -enum ReadFut { - /// Nothing in progress. - Idle, - /// Reading data from the underlying stream. - Reading(Pin, Error>> + Send>>), - /// Finished reading, but there's unread data in the temporary buffer. - RemainingData(Vec), -} +#[derive(Debug)] +pub struct PollDataChannel(RTCPollDataChannel); -impl ReadFut { - /// Gets a mutable reference to the future stored inside `Reading(future)`. - /// - /// # Panics - /// - /// Panics if `ReadFut` variant is not `Reading`. - fn get_reading_mut( - &mut self, - ) -> &mut Pin, Error>> + Send>> { - match self { - ReadFut::Reading(ref mut fut) => fut, - _ => panic!("expected ReadFut to be Reading"), - } +impl PollDataChannel { + /// Constructs a new `PollDataChannel`. + pub fn new(data_channel: Arc) -> Self { + Self(RTCPollDataChannel::new(data_channel)) } -} -/// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and -/// [`AsyncWrite`]. -/// -/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an -/// additional overhead. -pub struct PollDataChannel<'a> { - data_channel: Arc, + /// Get back the inner data_channel. + pub fn into_inner(self) -> RTCPollDataChannel { + self.0 + } - read_fut: ReadFut, - write_fut: Option> + Send + 'a>>>, - shutdown_fut: Option> + Send + 'a>>>, + /// Obtain a clone of the inner data_channel. + pub fn clone_inner(&self) -> RTCPollDataChannel { + self.0.clone() + } - read_buf_cap: usize, -} + /// MessagesSent returns the number of messages sent + pub fn messages_sent(&self) -> usize { + self.0.messages_sent() + } -impl PollDataChannel<'_> { - /// Constructs a new [`PollDataChannel`]. - pub fn new(data_channel: Arc) -> Self { - Self { - data_channel, - read_fut: ReadFut::Idle, - write_fut: None, - shutdown_fut: None, - read_buf_cap: DEFAULT_READ_BUF_SIZE, - } + /// MessagesReceived returns the number of messages received + pub fn messages_received(&self) -> usize { + self.0.messages_received() } - /// Get back the inner data_channel. - pub fn into_inner(self) -> Arc { - self.data_channel + /// BytesSent returns the number of bytes sent + pub fn bytes_sent(&self) -> usize { + self.0.bytes_sent() } - /// Obtain a clone of the inner data_channel. - pub fn clone_inner(&self) -> Arc { - self.data_channel.clone() + /// BytesReceived returns the number of bytes received + pub fn bytes_received(&self) -> usize { + self.0.bytes_received() } - /// Set the capacity of the temporary read buffer (default: 4096). - pub fn set_read_buf_capacity(&mut self, capacity: usize) { - self.read_buf_cap = capacity + /// StreamIdentifier returns the Stream identifier associated to the stream. + pub fn stream_identifier(&self) -> u16 { + self.0.stream_identifier() } + + /// BufferedAmount returns the number of bytes of data currently queued to be + /// sent over this stream. + pub fn buffered_amount(&self) -> usize { + self.0.buffered_amount() + } + + /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing + /// data that is considered "low." Defaults to 0. + pub fn buffered_amount_low_threshold(&self) -> usize { + self.0.buffered_amount_low_threshold() + } + + // TODO + // Set the capacity of the temporary read buffer (default: 8192). + // pub fn set_read_buf_capacity(&mut self, capacity: usize) { + // self.0.set_read_buf_capacity(capacity) + // } } -impl AsyncRead for PollDataChannel<'_> { +impl AsyncRead for PollDataChannel { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let fut = match self.read_fut { - ReadFut::Idle => { - // Read into a temporary buffer because `buf` has an anonymous lifetime, which can - // be shorter than the lifetime of `read_fut`. - let dc = self.data_channel.clone(); - let mut temp_buf = vec![0; self.read_buf_cap]; - self.read_fut = ReadFut::Reading(Box::pin(async move { - dc.read(temp_buf.as_mut_slice()).await.map(|n| { - temp_buf.truncate(n); - temp_buf - }) - })); - self.read_fut.get_reading_mut() - }, - ReadFut::Reading(ref mut fut) => fut, - ReadFut::RemainingData(ref mut data) => { - let remaining = buf.len(); - let len = std::cmp::min(data.len(), remaining); - buf.copy_from_slice(&data[..len]); - if data.len() > remaining { - // ReadFut remains to be RemainingData - data.drain(0..len); - } else { - self.read_fut = ReadFut::Idle; - } - return Poll::Ready(Ok(len)); - }, - }; - - loop { - match ready!(fut.as_mut().poll(cx)) { - // Retry immediately upon empty data or incomplete chunks - // since there's no way to setup a waker. - Err(Error::Sctp(webrtc_sctp::Error::ErrTryAgain)) => {}, - // EOF has been reached => don't touch buf and just return Ok - Err(Error::Sctp(webrtc_sctp::Error::ErrEof)) => { - self.read_fut = ReadFut::Idle; - return Poll::Ready(Ok(0)); - }, - Err(e) => { - self.read_fut = ReadFut::Idle; - return Poll::Ready(Err(webrtc_error_to_io(e))); - }, - Ok(mut temp_buf) => { - let remaining = buf.len(); - let len = std::cmp::min(temp_buf.len(), remaining); - buf.copy_from_slice(&temp_buf[..len]); - if temp_buf.len() > remaining { - temp_buf.drain(0..len); - self.read_fut = ReadFut::RemainingData(temp_buf); - } else { - self.read_fut = ReadFut::Idle; - } - return Poll::Ready(Ok(len)); - }, - } - } + let mut read_buf = tokio_crate::io::ReadBuf::new(buf); + futures::ready!(tokio_crate::io::AsyncRead::poll_read( + Pin::new(&mut self.0), + cx, + &mut read_buf + ))?; + Poll::Ready(Ok(read_buf.filled().len())) } } -impl AsyncWrite for PollDataChannel<'_> { +impl AsyncWrite for PollDataChannel { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let (fut, fut_is_new) = match self.write_fut.as_mut() { - Some(fut) => (fut, false), - None => { - let dc = self.data_channel.clone(); - let bytes = Bytes::copy_from_slice(buf); - ( - self.write_fut - .get_or_insert(Box::pin(async move { dc.write(&bytes).await })), - true, - ) - }, - }; - - match fut.as_mut().poll(cx) { - Poll::Pending => { - // If it's the first time we're polling the future, `Poll::Pending` can't be - // returned because that would mean the `PollDataChannel` is not ready for writing. And - // this is not true since we've just created a future, which is going to write the - // buf to the underlying dc. - // - // It's okay to return `Poll::Ready` if the data is buffered (this is what the - // buffered writer and `File` do). - if fut_is_new { - Poll::Ready(Ok(buf.len())) - } else { - // If it's the subsequent poll, it's okay to return `Poll::Pending` as it - // indicates that the `PollDataChannel` is not ready for writing. Only one future - // can be in progress at the time. - Poll::Pending - } - }, - Poll::Ready(Err(e)) => { - self.write_fut = None; - Poll::Ready(Err(webrtc_error_to_io(e))) - }, - Poll::Ready(Ok(n)) => { - self.write_fut = None; - Poll::Ready(Ok(n)) - }, - } + tokio_crate::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.write_fut.as_mut() { - Some(fut) => match ready!(fut.as_mut().poll(cx)) { - Err(e) => { - self.write_fut = None; - Poll::Ready(Err(webrtc_error_to_io(e))) - }, - Ok(_) => { - self.write_fut = None; - Poll::Ready(Ok(())) - }, - }, - None => Poll::Ready(Ok(())), - } + tokio_crate::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let fut = match self.shutdown_fut.as_mut() { - Some(fut) => fut, - None => { - let data_channel = self.data_channel.clone(); - self.shutdown_fut - .get_or_insert(Box::pin(async move { data_channel.close().await })) - }, - }; - - match ready!(fut.as_mut().poll(cx)) { - Err(e) => Poll::Ready(Err(webrtc_error_to_io(e))), - Ok(_) => Poll::Ready(Ok(())), - } + tokio_crate::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx) } -} -impl<'a> Clone for PollDataChannel<'a> { - fn clone(&self) -> PollDataChannel<'a> { - PollDataChannel::new(self.clone_inner()) + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + tokio_crate::io::AsyncWrite::poll_write_vectored(Pin::new(&mut self.0), cx, bufs) } } -impl fmt::Debug for PollDataChannel<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PollDataChannel") - .field("data_channel", &self.data_channel) - .finish() +impl Clone for PollDataChannel { + fn clone(&self) -> PollDataChannel { + PollDataChannel(self.clone_inner()) } } -impl AsRef for PollDataChannel<'_> { - fn as_ref(&self) -> &DataChannel { - &*self.data_channel - } -} +//use bytes::Bytes; -fn webrtc_error_to_io(error: Error) -> io::Error { - match error { - e @ Error::Sctp(webrtc_sctp::Error::ErrEof) => { - io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string()) - }, - e @ Error::ErrStreamClosed => { - io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()) - }, - e => io::Error::new(io::ErrorKind::Other, e.to_string()), - } -} +//use futures::prelude::*; +//use futures::ready; +//use webrtc_data::data_channel::DataChannel; +//use webrtc_data::Error; + +//use std::fmt; +//use std::io; +//use std::pin::Pin; +//use std::sync::Arc; +//use std::task::{Context, Poll}; + +///// Default capacity of the temporary read buffer used by [`PollStream`]. +//const DEFAULT_READ_BUF_SIZE: usize = 8192; + +///// State of the read `Future` in [`PollStream`]. +//enum ReadFut { +// /// Nothing in progress. +// Idle, +// /// Reading data from the underlying stream. +// Reading(Pin, Error>> + Send>>), +// /// Finished reading, but there's unread data in the temporary buffer. +// RemainingData(Vec), +//} + +//impl ReadFut { +// /// Gets a mutable reference to the future stored inside `Reading(future)`. +// /// +// /// # Panics +// /// +// /// Panics if `ReadFut` variant is not `Reading`. +// fn get_reading_mut( +// &mut self, +// ) -> &mut Pin, Error>> + Send>> { +// match self { +// ReadFut::Reading(ref mut fut) => fut, +// _ => panic!("expected ReadFut to be Reading"), +// } +// } +//} + +///// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and +///// [`AsyncWrite`]. +///// +///// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an +///// additional overhead. +//pub struct PollDataChannel { +// data_channel: Arc, + +// read_fut: ReadFut, +// write_fut: Option> + Send>>>, +// shutdown_fut: Option> + Send>>>, + +// read_buf_cap: usize, +//} + +//impl PollDataChannel { +// /// Constructs a new [`PollDataChannel`]. +// pub fn new(data_channel: Arc) -> Self { +// Self { +// data_channel, +// read_fut: ReadFut::Idle, +// write_fut: None, +// shutdown_fut: None, +// read_buf_cap: DEFAULT_READ_BUF_SIZE, +// } +// } + +// /// Get back the inner data_channel. +// pub fn into_inner(self) -> Arc { +// self.data_channel +// } + +// /// Obtain a clone of the inner data_channel. +// pub fn clone_inner(&self) -> Arc { +// self.data_channel.clone() +// } + +// /// Set the capacity of the temporary read buffer (default: 8192). +// pub fn set_read_buf_capacity(&mut self, cap: usize) { +// self.read_buf_cap = cap +// } + +// /// StreamIdentifier returns the Stream identifier associated to the stream. +// pub fn stream_identifier(&self) -> u16 { +// self.data_channel.stream_identifier() +// } +//} + +//impl AsyncRead for PollDataChannel { +// fn poll_read( +// mut self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// buf: &mut [u8], +// ) -> Poll> { +// log::debug!("read buf BEFORE {:?}", buf); +// if buf.len() == 0 { +// return Poll::Ready(Ok(0)); +// } + +// let fut = match self.read_fut { +// ReadFut::Idle => { +// // Read into a temporary buffer because `buf` has an anonymous lifetime, which can +// // be shorter than the lifetime of `read_fut`. +// let dc = self.data_channel.clone(); +// let mut temp_buf = vec![0; self.read_buf_cap]; +// self.read_fut = ReadFut::Reading(Box::pin(async move { +// dc.read(temp_buf.as_mut_slice()).await.map(|n| { +// temp_buf.truncate(n); +// temp_buf +// }) +// })); +// self.read_fut.get_reading_mut() +// }, +// ReadFut::Reading(ref mut fut) => fut, +// ReadFut::RemainingData(ref mut data) => { +// let remaining = buf.len(); +// let len = std::cmp::min(data.len(), remaining); +// buf.copy_from_slice(&data[..len]); +// if data.len() > remaining { +// // ReadFut remains to be RemainingData +// data.drain(0..len); +// } else { +// self.read_fut = ReadFut::Idle; +// } +// log::debug!("read buf AFTER {:?}", buf); +// return Poll::Ready(Ok(len)); +// }, +// }; + +// loop { +// match ready!(fut.as_mut().poll(cx)) { +// // Retry immediately upon empty data or incomplete chunks +// // since there's no way to setup a waker. +// Err(Error::Sctp(webrtc_sctp::Error::ErrTryAgain)) => {}, +// // EOF has been reached => don't touch buf and just return Ok +// Err(Error::Sctp(webrtc_sctp::Error::ErrEof)) => { +// self.read_fut = ReadFut::Idle; +// return Poll::Ready(Ok(0)); +// }, +// Err(e) => { +// self.read_fut = ReadFut::Idle; +// return Poll::Ready(Err(webrtc_error_to_io(e))); +// }, +// Ok(mut temp_buf) => { +// let remaining = buf.len(); +// let len = std::cmp::min(temp_buf.len(), remaining); +// buf.copy_from_slice(&temp_buf[..len]); +// if temp_buf.len() > remaining { +// temp_buf.drain(0..len); +// self.read_fut = ReadFut::RemainingData(temp_buf); +// } else { +// self.read_fut = ReadFut::Idle; +// } +// log::debug!("read buf AFTER {:?}", buf); +// return Poll::Ready(Ok(len)); +// }, +// } +// } +// } +//} + +//impl AsyncWrite for PollDataChannel { +// fn poll_write( +// mut self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// buf: &[u8], +// ) -> Poll> { +// let (fut, fut_is_new) = match self.write_fut.as_mut() { +// Some(fut) => (fut, false), +// None => { +// let dc = self.data_channel.clone(); +// let bytes = Bytes::copy_from_slice(buf); +// ( +// self.write_fut +// .get_or_insert(Box::pin(async move { dc.write(&bytes).await })), +// true, +// ) +// }, +// }; + +// match fut.as_mut().poll(cx) { +// Poll::Pending => { +// // If it's the first time we're polling the future, `Poll::Pending` can't be +// // returned because that would mean the `PollDataChannel` is not ready for writing. And +// // this is not true since we've just created a future, which is going to write the +// // buf to the underlying dc. +// // +// // It's okay to return `Poll::Ready` if the data is buffered (this is what the +// // buffered writer and `File` do). +// if fut_is_new { +// Poll::Ready(Ok(buf.len())) +// } else { +// // If it's the subsequent poll, it's okay to return `Poll::Pending` as it +// // indicates that the `PollDataChannel` is not ready for writing. Only one future +// // can be in progress at the time. +// Poll::Pending +// } +// }, +// Poll::Ready(Err(e)) => { +// self.write_fut = None; +// Poll::Ready(Err(webrtc_error_to_io(e))) +// }, +// Poll::Ready(Ok(n)) => { +// self.write_fut = None; +// Poll::Ready(Ok(n)) +// }, +// } +// } + +// fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// match self.write_fut.as_mut() { +// Some(fut) => match ready!(fut.as_mut().poll(cx)) { +// Err(e) => { +// self.write_fut = None; +// Poll::Ready(Err(webrtc_error_to_io(e))) +// }, +// Ok(_) => { +// self.write_fut = None; +// Poll::Ready(Ok(())) +// }, +// }, +// None => Poll::Ready(Ok(())), +// } +// } + +// fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// let fut = match self.shutdown_fut.as_mut() { +// Some(fut) => fut, +// None => { +// let data_channel = self.data_channel.clone(); +// self.shutdown_fut +// .get_or_insert(Box::pin(async move { data_channel.close().await })) +// }, +// }; + +// match ready!(fut.as_mut().poll(cx)) { +// Err(e) => Poll::Ready(Err(webrtc_error_to_io(e))), +// Ok(_) => Poll::Ready(Ok(())), +// } +// } +//} + +//impl Clone for PollDataChannel { +// fn clone(&self) -> PollDataChannel { +// PollDataChannel::new(self.clone_inner()) +// } +//} + +//impl fmt::Debug for PollDataChannel { +// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +// f.debug_struct("PollDataChannel") +// .field("data_channel", &self.data_channel) +// .finish() +// } +//} + +//impl AsRef for PollDataChannel { +// fn as_ref(&self) -> &DataChannel { +// &*self.data_channel +// } +//} + +//fn webrtc_error_to_io(error: Error) -> io::Error { +// match error { +// e @ Error::Sctp(webrtc_sctp::Error::ErrEof) => { +// io::Error::new(io::ErrorKind::UnexpectedEof, e.to_string()) +// }, +// e @ Error::ErrStreamClosed => { +// io::Error::new(io::ErrorKind::ConnectionAborted, e.to_string()) +// }, +// e => io::Error::new(io::ErrorKind::Other, e.to_string()), +// } +//} diff --git a/src/error.rs b/src/error.rs index 4b4e680..7a73da6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -18,6 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use libp2p_core::PeerId; use thiserror::Error; /// Error in WebRTC. @@ -29,6 +30,18 @@ pub enum Error { WebRTC(#[from] webrtc::Error), #[error("io error: {0}")] IoError(#[from] std::io::Error), + #[error("noise error: {0}")] + Noise(#[from] libp2p_noise::NoiseError), + + // Authentication errors. + #[error("invalid fingerprint (expected {expected:?}, got {got:?})")] + InvalidFingerprint { expected: String, got: String }, + #[error("invalid peer ID (expected {expected:?}, got {got:?})")] + InvalidPeerID { + expected: Option, + got: PeerId, + }, + #[error("internal error: {0} (see debug logs)")] InternalError(String), } diff --git a/src/transport.rs b/src/transport.rs index 2ccc309..a25ed45 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -18,21 +18,26 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use libp2p_core::{ - multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, TransportError}, - Transport, -}; - +use futures::io::{AsyncReadExt, AsyncWriteExt}; use futures::{ channel::{mpsc, oneshot}, + future, future::BoxFuture, prelude::*, ready, select, TryFutureExt, }; use futures_timer::Delay; use if_watch::{IfEvent, IfWatcher}; -use log::{debug, error, trace}; +use libp2p_core::identity; +use libp2p_core::{ + multiaddr::{Multiaddr, Protocol}, + muxing::StreamMuxerBox, + transport::{Boxed, ListenerEvent, TransportError}, + PeerId, Transport, +}; +use libp2p_core::{OutboundUpgrade, UpgradeInfo}; +use libp2p_noise::{Keypair, NoiseConfig, NoiseError, RemoteIdentity, X25519Spec}; +use log::{debug, trace}; use tinytemplate::TinyTemplate; use tokio_crate::net::{ToSocketAddrs, UdpSocket}; use webrtc::api::setting_engine::SettingEngine; @@ -59,6 +64,7 @@ use crate::error::Error; use crate::sdp; use crate::connection::Connection; +use crate::connection::PollDataChannel; use crate::udp_mux::UDPMuxNewAddr; use crate::udp_mux::UDPMuxParams; use crate::upgrade; @@ -90,6 +96,9 @@ pub struct WebRTCDirectTransport { /// The receiver for new `SocketAddr` connecting to this peer. new_addr_rx: Arc>>, + + /// `Keypair` identifying this peer + id_keys: identity::Keypair, } impl WebRTCDirectTransport { @@ -98,6 +107,7 @@ impl WebRTCDirectTransport { /// Creates a UDP socket bound to `listen_addr`. pub async fn new( certificate: RTCCertificate, + id_keys: identity::Keypair, listen_addr: A, ) -> Result> { // Bind to `listen_addr` and construct a UDP mux. @@ -116,23 +126,32 @@ impl WebRTCDirectTransport { Ok(Self { config: RTCConfiguration { certificates: vec![certificate], - ..Default::default() + ..RTCConfiguration::default() }, udp_mux, udp_mux_addr, new_addr_rx: Arc::new(Mutex::new(new_addr_rx)), + id_keys, }) } /// Returns the SHA-256 fingerprint of the certificate in lowercase hex string as expressed /// utilizing the syntax of 'fingerprint' in . - fn cert_fingerprint(&self) -> String { + pub fn cert_fingerprint(&self) -> String { fingerprint_of_first_certificate(&self.config) } + + /// Creates a boxed libp2p transport. + pub fn boxed(self) -> Boxed<(PeerId, StreamMuxerBox)> { + Transport::map(self, |(peer_id, conn), _| { + (peer_id, StreamMuxerBox::new(conn)) + }) + .boxed() + } } impl Transport for WebRTCDirectTransport { - type Output = Connection<'static>; + type Output = (PeerId, Connection); type Error = Error; type Listener = WebRTCListenStream; type ListenerUpgrade = BoxFuture<'static, Result>; @@ -145,6 +164,7 @@ impl Transport for WebRTCDirectTransport { self.config.clone(), self.udp_mux.clone(), self.new_addr_rx.clone(), + self.id_keys.clone(), )) } @@ -187,6 +207,8 @@ pub struct WebRTCListenStream { udp_mux: Arc, /// The receiver for new `SocketAddr` connecting to this peer. new_addr_rx: Arc>>, + /// `Keypair` identifying this peer + id_keys: identity::Keypair, } impl WebRTCListenStream { @@ -197,6 +219,7 @@ impl WebRTCListenStream { config: RTCConfiguration, udp_mux: Arc, new_addr_rx: Arc>>, + id_keys: identity::Keypair, ) -> Self { // Check whether the listening IP is set or not. let in_addr = if match &listen_addr { @@ -222,13 +245,16 @@ impl WebRTCListenStream { config, udp_mux, new_addr_rx, + id_keys, } } } impl Stream for WebRTCListenStream { - type Item = - Result, Error>>, Error>, Error>; + type Item = Result< + ListenerEvent>, Error>, + Error, + >; fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let me = Pin::into_inner(self); @@ -320,8 +346,12 @@ impl Stream for WebRTCListenStream { Poll::Ready(Some(addr)) => Poll::Ready(Some(Ok(ListenerEvent::Upgrade { local_addr: ip_to_multiaddr(me.listen_addr.ip(), me.listen_addr.port()), remote_addr: addr.clone(), - upgrade: Box::pin(upgrade::webrtc(me.udp_mux.clone(), me.config.clone(), addr)) - as BoxFuture<'static, _>, + upgrade: Box::pin(upgrade::webrtc( + me.udp_mux.clone(), + me.config.clone(), + addr, + me.id_keys.clone(), + )) as BoxFuture<'static, _>, }))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, @@ -331,7 +361,7 @@ impl Stream for WebRTCListenStream { } impl WebRTCDirectTransport { - async fn do_dial(self, addr: Multiaddr) -> Result, Error> { + async fn do_dial(self, addr: Multiaddr) -> Result<(PeerId, Connection), Error> { let socket_addr = multiaddr_to_socketaddr(&addr).ok_or_else(|| Error::InvalidMultiaddr(addr.clone()))?; if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() { @@ -343,8 +373,8 @@ impl WebRTCDirectTransport { let remote = addr.clone(); // used for logging trace!("dialing address: {:?}", remote); - let fingerprint = self.cert_fingerprint(); - let se = build_setting_engine(self.udp_mux.clone(), &socket_addr, &fingerprint); + let our_fingerprint = self.cert_fingerprint(); + let se = build_setting_engine(self.udp_mux.clone(), &socket_addr, &our_fingerprint); let api = APIBuilder::new().with_setting_engine(se).build(); let peer_connection = api @@ -352,51 +382,6 @@ impl WebRTCDirectTransport { .map_err(Error::WebRTC) .await?; - // Create a datachannel with label 'data' - let data_channel = peer_connection - .create_data_channel( - "data", - Some(RTCDataChannelInit { - negotiated: None, - id: Some(1), - ordered: None, - max_retransmits: None, - max_packet_life_time: None, - protocol: None, - }), - ) - .await?; - - let (data_channel_rx, mut data_channel_tx) = oneshot::channel::>(); - - // Wait until the data channel is opened and detach it. - data_channel - .on_open({ - let data_channel = data_channel.clone(); - Box::new(move || { - debug!( - "Data channel '{}'-'{}' open.", - data_channel.label(), - data_channel.id() - ); - - Box::pin(async move { - let data_channel = data_channel.clone(); - match data_channel.detach().await { - Ok(detached) => { - if let Err(_) = data_channel_rx.send(detached) { - error!("data_channel_tx dropped"); - } - }, - Err(e) => { - error!("Can't detach data channel: {}", e); - }, - }; - }) - }) - }) - .await; - let offer = peer_connection .create_offer(None) .map_err(Error::WebRTC) @@ -408,16 +393,14 @@ impl WebRTCDirectTransport { .await?; // Set the remote description to the predefined SDP. - let fingerprint = match addr.iter().last() { - Some(Protocol::XWebRTC(f)) => f, - _ => { - return Err(Error::InvalidMultiaddr(addr)); - }, + let remote_fingerprint = match fingerprint_from_addr(&addr) { + Some(f) => fingerprint_to_string(&f), + None => return Err(Error::InvalidMultiaddr(addr.clone())), }; let server_session_description = render_description( sdp::SERVER_SESSION_DESCRIPTION, socket_addr, - &fingerprint_to_string(&fingerprint), + &remote_fingerprint, ); debug!("ANSWER: {:?}", server_session_description); let sdp = RTCSessionDescription::answer(server_session_description).unwrap(); @@ -428,16 +411,80 @@ impl WebRTCDirectTransport { .map_err(Error::WebRTC) .await?; + // Create a datachannel with label 'data' + let data_channel = peer_connection + .create_data_channel( + "data", + Some(RTCDataChannelInit { + id: Some(1), + ..RTCDataChannelInit::default() + }), + ) + .await?; + + let (tx, mut rx) = oneshot::channel::>(); + + // Wait until the data channel is opened and detach it. + crate::connection::register_data_channel_open_handler(data_channel, tx).await; + // Wait until data channel is opened and ready to use - select! { - res = data_channel_tx => match res { - Ok(dc) => Ok(Connection::new(peer_connection, dc)), - Err(e) => Err(Error::InternalError(e.to_string())), + let detached = select! { + res = rx => match res { + Ok(detached) => detached, + Err(e) => return Err(Error::InternalError(e.to_string())), }, - _ = Delay::new(Duration::from_secs(10)).fuse() => Err(Error::InternalError( + _ = Delay::new(Duration::from_secs(10)).fuse() => return Err(Error::InternalError( "data channel opening took longer than 10 seconds (see logs)".into(), )) + }; + + trace!("noise handshake with {}", remote); + let dh_keys = Keypair::::new() + .into_authentic(&self.id_keys) + .unwrap(); + let noise = NoiseConfig::xx(dh_keys); + let info = noise.protocol_info().next().unwrap(); + let (peer_id, mut noise_io) = noise + .upgrade_outbound(PollDataChannel::new(detached), info) + .and_then(|(remote, io)| match remote { + RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), + _ => future::err(NoiseError::AuthenticationFailed), + }) + .await + .map_err(Error::Noise)?; + + // Exchange TLS certificate fingerprints to prevent MiM attacks. + trace!("exchanging TLS certificate fingerprints with {}", remote); + let n = noise_io.write(&our_fingerprint.into_bytes()).await?; + noise_io.flush().await?; + let mut buf = vec![0; n]; // ASSERT: fingerprint's format is the same. + noise_io.read_exact(buf.as_mut_slice()).await?; + let fingerprint_from_noise = + String::from_utf8(buf).map_err(|_| Error::Noise(NoiseError::AuthenticationFailed))?; + if fingerprint_from_noise != remote_fingerprint { + return Err(Error::InvalidFingerprint { + expected: remote_fingerprint, + got: fingerprint_from_noise, + }); + } + + trace!("verifying peer's identity {}", remote); + let peer_id_from_addr = PeerId::try_from_multiaddr(&addr); + if peer_id_from_addr.is_none() || peer_id_from_addr.unwrap() != peer_id { + return Err(Error::InvalidPeerID { + expected: peer_id_from_addr, + got: peer_id, + }); } + + // Close the initial data channel after noise handshake is done. + // https://github.com/webrtc-rs/sctp/pull/14 + // detached + // .close() + // .await + // .map_err(|e| Error::WebRTC(e.into()))?; + + Ok((peer_id, Connection::new(peer_connection).await)) } } @@ -451,7 +498,7 @@ pub(crate) fn render_description(description: &str, addr: SocketAddr, fingerprin let mut tt = TinyTemplate::new(); tt.add_template("description", description).unwrap(); - let f = fingerprint.to_owned().replace(":", ""); + let f = fingerprint.to_owned().replace(':', ""); let context = sdp::DescriptionContext { ip_version: { if addr.is_ipv4() { @@ -479,7 +526,7 @@ pub(crate) fn multiaddr_to_socketaddr(addr: &Multiaddr) -> Option { let proto2 = iter.next()?; let proto3 = iter.next()?; - while let Some(proto) = iter.next() { + for proto in iter { match proto { Protocol::P2p(_) => {}, // Ignore a `/p2p/...` prefix of possibly outer protocols, if present. _ => return None, @@ -516,7 +563,7 @@ pub(crate) fn fingerprint_of_first_certificate(config: &RTCConfiguration) -> Str .expect("at least one certificate") .get_fingerprints() .expect("fingerprints to succeed"); - fingerprints.first().unwrap().value.to_owned() + fingerprints.first().unwrap().value.clone() } /// Creates a new [`SettingEngine`] and configures it. @@ -528,7 +575,7 @@ pub(crate) fn build_setting_engine( let mut se = SettingEngine::default(); // Set both ICE user and password to fingerprint. // It will be checked by remote side when exchanging ICE messages. - let f = fingerprint.to_owned().replace(":", ""); + let f = fingerprint.to_owned().replace(':', ""); se.set_ice_credentials(f.clone(), f); se.set_udp_network(UDPNetwork::Muxed(udp_mux.clone())); // Allow detaching data channels. @@ -546,6 +593,17 @@ pub(crate) fn build_setting_engine( se } +fn fingerprint_from_addr<'a>(addr: &'a Multiaddr) -> Option> { + let iter = addr.iter(); + for proto in iter { + match proto { + Protocol::XWebRTC(f) => return Some(f), + _ => continue, + } + } + None +} + // Tests ////////////////////////////////////////////////////////////////////////////////////////// #[cfg(test)] @@ -638,7 +696,7 @@ mod tests { async fn dialer_connects_to_listener_ipv4() { let _ = env_logger::builder().is_test(true).try_init(); let a = "127.0.0.1:0".parse().unwrap(); - connect(a).await + connect(a).await; } #[tokio::test] @@ -649,10 +707,12 @@ mod tests { } async fn connect(listen_addr: SocketAddr) { + let id_keys = identity::Keypair::generate_ed25519(); + let t1_peer_id = PeerId::from_public_key(&id_keys.public()); let transport = { let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256).expect("key pair"); let cert = RTCCertificate::from_key_pair(kp).expect("certificate"); - WebRTCDirectTransport::new(cert, listen_addr) + WebRTCDirectTransport::new(cert, id_keys, listen_addr) .await .expect("transport") }; @@ -684,16 +744,20 @@ mod tests { let transport2 = { let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256).expect("key pair"); + let id_keys = identity::Keypair::generate_ed25519(); let cert = RTCCertificate::from_key_pair(kp).expect("certificate"); // okay to reuse `listen_addr` since the port is `0` (any). - WebRTCDirectTransport::new(cert, listen_addr) + WebRTCDirectTransport::new(cert, id_keys, listen_addr) .await .expect("transport") }; // TODO: make code cleaner wrt ":" - let f = &transport.cert_fingerprint().replace(":", ""); + let f = &transport.cert_fingerprint().replace(':', ""); let outbound = transport2 - .dial(addr.with(Protocol::XWebRTC(hex_to_cow(f)))) + .dial( + addr.with(Protocol::XWebRTC(hex_to_cow(f))) + .with(Protocol::P2p(t1_peer_id.into())), + ) .unwrap(); let (a, b) = futures::join!(inbound, outbound); diff --git a/src/udp_mux.rs b/src/udp_mux.rs index 11a3b7d..a993864 100644 --- a/src/udp_mux.rs +++ b/src/udp_mux.rs @@ -25,22 +25,17 @@ use std::{ collections::{HashMap, HashSet}, io::ErrorKind, net::SocketAddr, - sync::Arc, + sync::{Arc, Weak}, }; use futures::channel::mpsc; use libp2p_core::multiaddr::{Multiaddr, Protocol}; -use webrtc_ice::udp_mux::UDPMux; +use webrtc_ice::udp_mux::{UDPMux, UDPMuxConn, UDPMuxConnParams, UDPMuxWriter}; use webrtc_util::{sync::RwLock, Conn, Error}; use tokio_crate as tokio; use tokio_crate::sync::{watch, Mutex}; -mod socket_addr_ext; - -mod udp_mux_conn; -use udp_mux_conn::{UDPMuxConn, UDPMuxConnParams}; - use async_trait::async_trait; use stun::{ @@ -127,14 +122,6 @@ impl UDPMuxNewAddr { self.closed_watch_tx.lock().await.is_none() } - async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result { - self.params - .conn - .send_to(buf, *target) - .await - .map_err(Into::into) - } - /// Create a muxed connection for a given ufrag. async fn create_muxed_conn(self: &Arc, ufrag: &str) -> Result { let local_addr = self.params.conn.local_addr().await?; @@ -142,41 +129,12 @@ impl UDPMuxNewAddr { let params = UDPMuxConnParams { local_addr, key: ufrag.into(), - udp_mux: Arc::clone(self), + udp_mux: Arc::downgrade(self) as Weak, }; Ok(UDPMuxConn::new(params)) } - async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) { - if self.is_closed().await { - return; - } - - let key = conn.key(); - { - let mut addresses = self.address_map.write(); - - addresses - .entry(addr) - .and_modify(|e| { - if e.key() != key { - e.remove_address(&addr); - *e = conn.clone() - } - }) - .or_insert_with(|| conn.clone()); - } - - // remove addr from new_addrs once conn is established - { - let mut new_addrs = self.new_addrs.write(); - new_addrs.remove(&addr); - } - - log::debug!("Registered {} for {}", addr, key); - } - async fn conn_from_stun_message(&self, buffer: &[u8], addr: &SocketAddr) -> Option { match ufrag_from_stun_message(buffer, true) { Ok(ufrag) => { @@ -234,12 +192,12 @@ impl UDPMuxNewAddr { let a = Multiaddr::empty() .with(addr.ip().into()) .with(Protocol::Udp(addr.port())) - .with(Protocol::XWebRTC(hex_to_cow(&ufrag.replace(":", "")))); + .with(Protocol::XWebRTC(hex_to_cow(&ufrag.replace(':', "")))); if let Err(err) = new_addr_tx.try_send(a) { log::error!("Failed to send new address {}: {}", &addr, err); } else { let mut new_addrs = loop_self.new_addrs.write(); - new_addrs.insert(addr.clone()); + new_addrs.insert(addr); }; } Err(e) => { @@ -291,7 +249,7 @@ impl UDPMux for UDPMuxNewAddr { }; // NOTE: We don't wait for these closure to complete - for (_, conn) in old_conns.into_iter() { + for (_, conn) in old_conns { conn.close(); } @@ -364,6 +322,46 @@ impl UDPMux for UDPMuxNewAddr { } } +#[async_trait] +impl UDPMuxWriter for UDPMuxNewAddr { + async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) { + if self.is_closed().await { + return; + } + + let key = conn.key(); + { + let mut addresses = self.address_map.write(); + + addresses + .entry(addr) + .and_modify(|e| { + if e.key() != key { + e.remove_address(&addr); + *e = conn.clone(); + } + }) + .or_insert_with(|| conn.clone()); + } + + // remove addr from new_addrs once conn is established + { + let mut new_addrs = self.new_addrs.write(); + new_addrs.remove(&addr); + } + + log::debug!("Registered {} for {}", addr, key); + } + + async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result { + self.params + .conn + .send_to(buf, *target) + .await + .map_err(Into::into) + } +} + fn hex_to_cow<'a>(s: &str) -> Cow<'a, [u8; 32]> { let mut buf = [0; 32]; hex::decode_to_slice(s, &mut buf).unwrap(); @@ -379,37 +377,36 @@ fn ufrag_from_stun_message(buffer: &[u8], local_ufrag: bool) -> Result Err(Error::Other(format!( + if let Err(err) = result { + Err(Error::Other(format!( "failed to handle decode ICE: {}", err - ))), - Ok(_) => { - let (attr, found) = message.attributes.get(ATTR_USERNAME); - if !found { - return Err(Error::Other("no username attribute in STUN message".into())); - } + ))) + } else { + let (attr, found) = message.attributes.get(ATTR_USERNAME); + if !found { + return Err(Error::Other("no username attribute in STUN message".into())); + } - match String::from_utf8(attr.value) { - // Per the RFC this shouldn't happen - // https://datatracker.ietf.org/doc/html/rfc5389#section-15.3 - Err(err) => Err(Error::Other(format!( - "failed to decode USERNAME from STUN message as UTF-8: {}", - err - ))), - Ok(s) => { - // s is a combination of the local_ufrag and the remote ufrag separated by `:`. - let res = if local_ufrag { - s.split(":").next() - } else { - s.split(":").last() - }; - match res { - Some(s) => Ok(s.to_owned()), - None => Err(Error::Other("can't get ufrag from username".into())), - } - }, - } - }, + match String::from_utf8(attr.value) { + // Per the RFC this shouldn't happen + // https://datatracker.ietf.org/doc/html/rfc5389#section-15.3 + Err(err) => Err(Error::Other(format!( + "failed to decode USERNAME from STUN message as UTF-8: {}", + err + ))), + Ok(s) => { + // s is a combination of the local_ufrag and the remote ufrag separated by `:`. + let res = if local_ufrag { + s.split(':').next() + } else { + s.split(':').last() + }; + match res { + Some(s) => Ok(s.to_owned()), + None => Err(Error::Other("can't get ufrag from username".into())), + } + }, + } } } diff --git a/src/udp_mux/socket_addr_ext.rs b/src/udp_mux/socket_addr_ext.rs deleted file mode 100644 index d1d06ce..0000000 --- a/src/udp_mux/socket_addr_ext.rs +++ /dev/null @@ -1,267 +0,0 @@ -// MIT License -// -// Copyright (c) 2021 WebRTC.rs -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -use std::array::TryFromSliceError; -use std::convert::TryInto; -use std::net::SocketAddr; - -use webrtc_util::Error; - -pub(super) trait SocketAddrExt { - ///Encode a representation of `self` into the buffer and return the length of this encoded - ///version. - /// - /// The buffer needs to be at least 27 bytes in length. - fn encode(&self, buffer: &mut [u8]) -> Result; - - /// Decode a `SocketAddr` from a buffer. The encoding should have previously been done with - /// [`SocketAddrExt::encode`]. - fn decode(buffer: &[u8]) -> Result; -} - -const IPV4_MARKER: u8 = 4; -const IPV4_ADDRESS_SIZE: usize = 7; -const IPV6_MARKER: u8 = 6; -const IPV6_ADDRESS_SIZE: usize = 27; - -pub(super) const MAX_ADDR_SIZE: usize = IPV6_ADDRESS_SIZE; - -impl SocketAddrExt for SocketAddr { - fn encode(&self, buffer: &mut [u8]) -> Result { - use std::net::SocketAddr::*; - - if buffer.len() < MAX_ADDR_SIZE { - return Err(Error::ErrBufferShort); - } - - match self { - V4(addr) => { - let marker = IPV4_MARKER; - let ip: [u8; 4] = addr.ip().octets(); - let port: u16 = addr.port(); - - buffer[0] = marker; - buffer[1..5].copy_from_slice(&ip); - buffer[5..7].copy_from_slice(&port.to_le_bytes()); - - Ok(7) - }, - V6(addr) => { - let marker = IPV6_MARKER; - let ip: [u8; 16] = addr.ip().octets(); - let port: u16 = addr.port(); - let flowinfo = addr.flowinfo(); - let scope_id = addr.scope_id(); - - buffer[0] = marker; - buffer[1..17].copy_from_slice(&ip); - buffer[17..19].copy_from_slice(&port.to_le_bytes()); - buffer[19..23].copy_from_slice(&flowinfo.to_le_bytes()); - buffer[23..27].copy_from_slice(&scope_id.to_le_bytes()); - - Ok(MAX_ADDR_SIZE) - }, - } - } - - fn decode(buffer: &[u8]) -> Result { - use std::net::*; - - match buffer[0] { - IPV4_MARKER => { - if buffer.len() < IPV4_ADDRESS_SIZE { - return Err(Error::ErrBufferShort); - } - - let ip_parts = &buffer[1..5]; - let port = match &buffer[5..7].try_into() { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => u16::from_le_bytes(*input), - }; - - let ip = Ipv4Addr::new(ip_parts[0], ip_parts[1], ip_parts[2], ip_parts[3]); - - Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))) - }, - IPV6_MARKER => { - if buffer.len() < IPV6_ADDRESS_SIZE { - return Err(Error::ErrBufferShort); - } - - // Just to help the type system infer correctly - fn helper(b: &[u8]) -> Result<&[u8; 16], TryFromSliceError> { - b.try_into() - } - - let ip = match helper(&buffer[1..17]) { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => Ipv6Addr::from(*input), - }; - let port = match &buffer[17..19].try_into() { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => u16::from_le_bytes(*input), - }; - - let flowinfo = match &buffer[19..23].try_into() { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => u32::from_le_bytes(*input), - }; - - let scope_id = match &buffer[23..27].try_into() { - Err(_) => return Err(Error::ErrFailedToParseIpaddr), - Ok(input) => u32::from_le_bytes(*input), - }; - - Ok(SocketAddr::V6(SocketAddrV6::new( - ip, port, flowinfo, scope_id, - ))) - }, - _ => Err(Error::ErrFailedToParseIpaddr), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - use std::net::*; - - #[test] - fn test_ipv4() { - let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234)); - - let mut buffer = [0_u8; MAX_ADDR_SIZE]; - let encoded_len = ip.encode(&mut buffer); - - assert_eq!(encoded_len, Ok(7)); - assert_eq!( - &buffer[0..7], - &[IPV4_MARKER, 56, 128, 35, 5, 0x34, 0x12][..] - ); - - let decoded = SocketAddr::decode(&buffer); - - assert_eq!(decoded, Ok(ip)); - } - - #[test] - fn test_ipv6() { - let ip = SocketAddr::V6(SocketAddrV6::new( - Ipv6Addr::from([ - 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123, - ]), - 0x1234, - 0x12345678, - 0x87654321, - )); - - let mut buffer = [0_u8; MAX_ADDR_SIZE]; - let encoded_len = ip.encode(&mut buffer); - - assert_eq!(encoded_len, Ok(27)); - assert_eq!( - &buffer[0..27], - &[ - IPV6_MARKER, // marker - // Start of ipv6 address - 92, - 114, - 235, - 3, - 244, - 64, - 38, - 111, - 20, - 100, - 199, - 241, - 19, - 174, - 220, - 123, - // LE port - 0x34, - 0x12, - // LE flowinfo - 0x78, - 0x56, - 0x34, - 0x12, - // LE scope_id - 0x21, - 0x43, - 0x65, - 0x87, - ][..] - ); - - let decoded = SocketAddr::decode(&buffer); - - assert_eq!(decoded, Ok(ip)); - } - - #[test] - fn test_encode_ipv4_with_short_buffer() { - let mut buffer = vec![0u8; IPV4_ADDRESS_SIZE - 1]; - let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234)); - - let result = ip.encode(&mut buffer); - - assert_eq!(result, Err(Error::ErrBufferShort)); - } - - #[test] - fn test_encode_ipv6_with_short_buffer() { - let mut buffer = vec![0u8; MAX_ADDR_SIZE - 1]; - let ip = SocketAddr::V6(SocketAddrV6::new( - Ipv6Addr::from([ - 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123, - ]), - 0x1234, - 0x12345678, - 0x87654321, - )); - - let result = ip.encode(&mut buffer); - - assert_eq!(result, Err(Error::ErrBufferShort)); - } - - #[test] - fn test_decode_ipv4_with_short_buffer() { - let buffer = vec![IPV4_MARKER, 0]; - - let result = SocketAddr::decode(&buffer); - - assert_eq!(result, Err(Error::ErrBufferShort)); - } - - #[test] - fn test_decode_ipv6_with_short_buffer() { - let buffer = vec![IPV6_MARKER, 0]; - - let result = SocketAddr::decode(&buffer); - - assert_eq!(result, Err(Error::ErrBufferShort)); - } -} diff --git a/src/udp_mux/udp_mux_conn.rs b/src/udp_mux/udp_mux_conn.rs deleted file mode 100644 index 89d070b..0000000 --- a/src/udp_mux/udp_mux_conn.rs +++ /dev/null @@ -1,308 +0,0 @@ -// MIT License -// -// Copyright (c) 2021 WebRTC.rs -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -use std::convert::TryInto; -use std::{collections::HashSet, io, net::SocketAddr, sync::Arc}; - -use async_trait::async_trait; -use tokio_crate as tokio; -use tokio_crate::sync::watch; - -use webrtc_util::{sync::Mutex, Buffer, Conn, Error}; - -use super::socket_addr_ext::{SocketAddrExt, MAX_ADDR_SIZE}; -use super::{normalize_socket_addr, UDPMuxNewAddr, RECEIVE_MTU}; - -#[inline(always)] -/// Create a buffer of appropriate size to fit both a packet with max RECEIVE_MTU and the -/// additional metadata used for muxing. -fn make_buffer() -> Vec { - // The 4 extra bytes are used to encode the length of the data and address respectively. - // See [`write_packet`] for details. - vec![0u8; RECEIVE_MTU + MAX_ADDR_SIZE + 2 + 2] -} - -pub(crate) struct UDPMuxConnParams { - pub(super) local_addr: SocketAddr, - - pub(super) key: String, - - // NOTE: This Arc exists in both directions which is liable to cause a retain cycle. This is - // accounted for in [`UDPMuxNewAddr::close`], which makes sure to drop all Arcs referencing any - // `UDPMuxConn`. - pub(super) udp_mux: Arc, -} - -struct UDPMuxConnInner { - pub(super) params: UDPMuxConnParams, - - /// Close Sender. We'll send a value on this channel when we close - closed_watch_tx: Mutex>>, - - /// Remote addresses we've seen on this connection. - addresses: Mutex>, - - buffer: Buffer, -} - -impl UDPMuxConnInner { - // Sending/Recieving - async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> { - // NOTE: Pion/ice uses Sync.Pool to optimise this. - let mut buffer = make_buffer(); - let mut offset = 0; - - let len = self.buffer.read(&mut buffer, None).await?; - // We always have at least. - // - // * 2 bytes for data len - // * 2 bytes for addr len - // * 7 bytes for an Ipv4 addr - if len < 11 { - return Err(Error::ErrBufferShort); - } - - let data_len: usize = buffer[..2] - .try_into() - .map(u16::from_le_bytes) - .map(From::from) - .unwrap(); - offset += 2; - - let total = 2 + data_len + 2 + 7; - if data_len > buf.len() || total > len { - return Err(Error::ErrBufferShort); - } - - buf[..data_len].copy_from_slice(&buffer[offset..offset + data_len]); - offset += data_len; - - let address_len: usize = buffer[offset..offset + 2] - .try_into() - .map(u16::from_le_bytes) - .map(From::from) - .unwrap(); - offset += 2; - - let addr = SocketAddr::decode(&buffer[offset..offset + address_len])?; - - Ok((data_len, addr)) - } - - async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> ConnResult { - self.params.udp_mux.send_to(buf, target).await - } - - fn is_closed(&self) -> bool { - self.closed_watch_tx.lock().is_none() - } - - fn close(self: &Arc) { - let mut closed_tx = self.closed_watch_tx.lock(); - - if let Some(tx) = closed_tx.take() { - let _ = tx.send(true); - drop(closed_tx); - - let cloned_self = Arc::clone(self); - - { - let mut addresses = self.addresses.lock(); - *addresses = Default::default(); - } - - // NOTE: Alternatively we could wait on the buffer closing here so that - // our caller can wait for things to fully settle down - tokio::spawn(async move { - cloned_self.buffer.close().await; - }); - } - } - - fn local_addr(&self) -> SocketAddr { - self.params.local_addr - } - - // Address related methods - pub(super) fn get_addresses(&self) -> Vec { - let addresses = self.addresses.lock(); - - addresses.iter().cloned().collect() - } - - pub(super) fn add_address(self: &Arc, addr: SocketAddr) { - { - let mut addresses = self.addresses.lock(); - addresses.insert(addr); - } - } - - pub(super) fn remove_address(&self, addr: &SocketAddr) { - { - let mut addresses = self.addresses.lock(); - addresses.remove(addr); - } - } - - pub(super) fn contains_address(&self, addr: &SocketAddr) -> bool { - let addresses = self.addresses.lock(); - - addresses.contains(addr) - } -} - -#[derive(Clone)] -pub(crate) struct UDPMuxConn { - /// Close Receiver. A copy of this can be obtained via [`close_tx`]. - closed_watch_rx: watch::Receiver, - - inner: Arc, -} - -impl UDPMuxConn { - pub(crate) fn new(params: UDPMuxConnParams) -> Self { - let (closed_watch_tx, closed_watch_rx) = watch::channel(false); - - Self { - closed_watch_rx, - inner: Arc::new(UDPMuxConnInner { - params, - closed_watch_tx: Mutex::new(Some(closed_watch_tx)), - addresses: Default::default(), - buffer: Buffer::new(0, 0), - }), - } - } - - pub(crate) fn key(&self) -> &str { - &self.inner.params.key - } - - pub(crate) async fn write_packet(&self, data: &[u8], addr: SocketAddr) -> ConnResult<()> { - // NOTE: Pion/ice uses Sync.Pool to optimise this. - let mut buffer = make_buffer(); - let mut offset = 0; - - if (data.len() + MAX_ADDR_SIZE) > (RECEIVE_MTU + MAX_ADDR_SIZE) { - return Err(Error::ErrBufferShort); - } - - // Format of buffer: | data len(2) | data bytes(dn) | addr len(2) | addr bytes(an) | - // Where the number in parenthesis indicate the number of bytes used - // `dn` and `an` are the length in bytes of data and addr respectively. - - // SAFETY: `data.len()` is at most RECEIVE_MTU(8192) - MAX_ADDR_SIZE(27) - buffer[0..2].copy_from_slice(&(data.len() as u16).to_le_bytes()[..]); - offset += 2; - - buffer[offset..offset + data.len()].copy_from_slice(data); - offset += data.len(); - - let len = addr.encode(&mut buffer[offset + 2..])?; - buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_le_bytes()[..]); - offset += 2 + len; - - self.inner.buffer.write(&buffer[..offset]).await?; - - Ok(()) - } - - pub(crate) fn is_closed(&self) -> bool { - self.inner.is_closed() - } - - /// Get a copy of the close [`tokio::sync::watch::Receiver`] that fires when this - /// connection is closed. - pub(crate) fn close_rx(&self) -> watch::Receiver { - self.closed_watch_rx.clone() - } - - /// Close this connection - pub(crate) fn close(&self) { - self.inner.close(); - } - - pub(super) fn get_addresses(&self) -> Vec { - self.inner.get_addresses() - } - - pub(super) async fn add_address(&self, addr: SocketAddr) { - self.inner.add_address(addr); - self.inner - .params - .udp_mux - .register_conn_for_address(self, addr) - .await; - } - - pub(super) fn remove_address(&self, addr: &SocketAddr) { - self.inner.remove_address(addr) - } - - pub(super) fn contains_address(&self, addr: &SocketAddr) -> bool { - self.inner.contains_address(addr) - } -} - -type ConnResult = Result; - -#[async_trait] -impl Conn for UDPMuxConn { - async fn connect(&self, _addr: SocketAddr) -> ConnResult<()> { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn recv(&self, _buf: &mut [u8]) -> ConnResult { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn recv_from(&self, buf: &mut [u8]) -> ConnResult<(usize, SocketAddr)> { - self.inner.recv_from(buf).await - } - - async fn send(&self, _buf: &[u8]) -> ConnResult { - Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) - } - - async fn send_to(&self, buf: &[u8], target: SocketAddr) -> ConnResult { - let normalized_target = normalize_socket_addr(&target, &self.inner.params.local_addr); - - if !self.contains_address(&normalized_target) { - self.add_address(normalized_target).await; - } - - self.inner.send_to(buf, &normalized_target).await - } - - async fn local_addr(&self) -> ConnResult { - Ok(self.inner.local_addr()) - } - - async fn remote_addr(&self) -> Option { - None - } - async fn close(&self) -> ConnResult<()> { - self.inner.close(); - - Ok(()) - } -} diff --git a/src/upgrade.rs b/src/upgrade.rs index f801ee5..baf1a02 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -18,12 +18,17 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::channel::oneshot; -use futures::select; -use futures::FutureExt; +use futures::io::{AsyncReadExt, AsyncWriteExt}; +use futures::{channel::oneshot, future, select, FutureExt, TryFutureExt}; use futures_timer::Delay; -use libp2p_core::multiaddr::{Multiaddr, Protocol}; -use log::{debug, error, trace}; +use libp2p_core::identity; +use libp2p_core::{ + multiaddr::{Multiaddr, Protocol}, + PeerId, +}; +use libp2p_core::{InboundUpgrade, UpgradeInfo}; +use libp2p_noise::{Keypair, NoiseConfig, NoiseError, RemoteIdentity, X25519Spec}; +use log::{debug, trace}; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_init::RTCDataChannelInit; use webrtc::dtls_transport::dtls_role::DTLSRole; @@ -36,6 +41,7 @@ use std::sync::Arc; use std::time::Duration; use crate::connection::Connection; +use crate::connection::PollDataChannel; use crate::error::Error; use crate::sdp; use crate::transport; @@ -44,14 +50,15 @@ pub async fn webrtc( udp_mux: Arc, config: RTCConfiguration, addr: Multiaddr, -) -> Result, Error> { + id_keys: identity::Keypair, +) -> Result<(PeerId, Connection), Error> { trace!("upgrading {}", addr); let socket_addr = transport::multiaddr_to_socketaddr(&addr) .ok_or_else(|| Error::InvalidMultiaddr(addr.clone()))?; - let fingerprint = transport::fingerprint_of_first_certificate(&config); + let our_fingerprint = transport::fingerprint_of_first_certificate(&config); - let mut se = transport::build_setting_engine(udp_mux, &socket_addr, &fingerprint); + let mut se = transport::build_setting_engine(udp_mux, &socket_addr, &our_fingerprint); { // Act as a lite ICE (ICE which does not send additional candidates). se.set_lite(true); @@ -65,63 +72,17 @@ pub async fn webrtc( let api = APIBuilder::new().with_setting_engine(se).build(); let peer_connection = api.new_peer_connection(config).await?; - // Create a datachannel with label 'data'. - let data_channel = peer_connection - .create_data_channel( - "data", - Some(RTCDataChannelInit { - negotiated: Some(true), - id: Some(1), - ordered: None, - max_retransmits: None, - max_packet_life_time: None, - protocol: None, - }), - ) - .await?; - - let (data_channel_rx, mut data_channel_tx) = oneshot::channel::>(); - - // Wait until the data channel is opened and detach it. - data_channel - .on_open({ - let data_channel = data_channel.clone(); - Box::new(move || { - debug!( - "Data channel '{}'-'{}' open.", - data_channel.label(), - data_channel.id() - ); - - Box::pin(async move { - let data_channel = data_channel.clone(); - match data_channel.detach().await { - Ok(detached) => { - if let Err(_) = data_channel_rx.send(detached) { - error!("data_channel_tx dropped"); - } - }, - Err(e) => { - error!("Can't detach data channel: {}", e); - }, - }; - }) - }) - }) - .await; - // Set the remote description to the predefined SDP. - let fingerprint = match addr.iter().last() { - Some(Protocol::XWebRTC(f)) => f, - _ => { - debug!("{} is not a WebRTC multiaddr", addr); - return Err(Error::InvalidMultiaddr(addr)); - }, + let remote_fingerprint = if let Some(Protocol::XWebRTC(f)) = addr.iter().last() { + transport::fingerprint_to_string(&f) + } else { + debug!("{} is not a WebRTC multiaddr", addr); + return Err(Error::InvalidMultiaddr(addr)); }; let client_session_description = transport::render_description( sdp::CLIENT_SESSION_DESCRIPTION, socket_addr, - &transport::fingerprint_to_string(&fingerprint), + &remote_fingerprint, ); debug!("OFFER: {:?}", client_session_description); let sdp = RTCSessionDescription::offer(client_session_description).unwrap(); @@ -133,14 +94,71 @@ pub async fn webrtc( debug!("ANSWER: {:?}", answer.sdp); peer_connection.set_local_description(answer).await?; + // Create a datachannel with label 'data'. + let data_channel = peer_connection + .create_data_channel( + "data", + Some(RTCDataChannelInit { + negotiated: Some(true), + id: Some(1), + ..RTCDataChannelInit::default() + }), + ) + .await?; + + let (tx, mut rx) = oneshot::channel::>(); + + // Wait until the data channel is opened and detach it. + // Wait until the data channel is opened and detach it. + crate::connection::register_data_channel_open_handler(data_channel, tx).await; + // Wait until data channel is opened and ready to use - select! { - res = data_channel_tx => match res { - Ok(dc) => Ok(Connection::new(peer_connection, dc)), - Err(e) => Err(Error::InternalError(e.to_string())), + let detached = select! { + res = rx => match res { + Ok(detached) => detached, + Err(e) => return Err(Error::InternalError(e.to_string())), }, - _ = Delay::new(Duration::from_secs(10)).fuse() => Err(Error::InternalError( + _ = Delay::new(Duration::from_secs(10)).fuse() => return Err(Error::InternalError( "data channel opening took longer than 10 seconds (see logs)".into(), )) + }; + + trace!("noise handshake with {}", addr); + let dh_keys = Keypair::::new() + .into_authentic(&id_keys) + .unwrap(); + let noise = NoiseConfig::xx(dh_keys); + let info = noise.protocol_info().next().unwrap(); + let (peer_id, mut noise_io) = noise + .upgrade_inbound(PollDataChannel::new(detached.clone()), info) + .and_then(|(remote, io)| match remote { + RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), + _ => future::err(NoiseError::AuthenticationFailed), + }) + .await + .map_err(Error::Noise)?; + + // Exchange TLS certificate fingerprints to prevent MiM attacks. + trace!("exchanging TLS certificate fingerprints with {}", addr); + let n = noise_io.write(&our_fingerprint.into_bytes()).await?; + noise_io.flush().await?; + let mut buf = vec![0; n]; // ASSERT: fingerprint's format is the same. + noise_io.read_exact(buf.as_mut_slice()).await?; + let fingerprint_from_noise = + String::from_utf8(buf).map_err(|_| Error::Noise(NoiseError::AuthenticationFailed))?; + if fingerprint_from_noise != remote_fingerprint { + return Err(Error::InvalidFingerprint { + expected: remote_fingerprint, + got: fingerprint_from_noise, + }); } + + // Close the initial data channel after noise handshake is done. + // https://github.com/webrtc-rs/sctp/pull/14 + // detached + // .close() + // .await + // .map_err(|e| Error::WebRTC(e.into()))?; + + Ok((peer_id, Connection::new(peer_connection).await)) } diff --git a/tests/smoke.rs b/tests/smoke.rs new file mode 100644 index 0000000..00be9e7 --- /dev/null +++ b/tests/smoke.rs @@ -0,0 +1,305 @@ +use anyhow::Result; +use async_trait::async_trait; +use futures::future::FutureExt; +use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use futures::stream::StreamExt; +use libp2p_core::identity; +use libp2p_core::multiaddr::Protocol; +use libp2p_core::upgrade; +use libp2p_request_response::{ + ProtocolName, ProtocolSupport, RequestResponse, RequestResponseCodec, RequestResponseConfig, + RequestResponseEvent, RequestResponseMessage, +}; +use libp2p_swarm::{Swarm, SwarmBuilder, SwarmEvent}; +use libp2p_webrtc_direct::transport::WebRTCDirectTransport; +use log::trace; +use rand::RngCore; +use rcgen::KeyPair; +use tokio_crate as tokio; +use webrtc::peer_connection::certificate::RTCCertificate; + +use std::borrow::Cow; +use std::{io, iter}; + +fn generate_certificate() -> RTCCertificate { + let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256).expect("key pair"); + RTCCertificate::from_key_pair(kp).expect("certificate") +} + +fn generate_tls_keypair() -> identity::Keypair { + identity::Keypair::generate_ed25519() +} + +async fn create_swarm() -> Result<(Swarm>, String)> { + let cert = generate_certificate(); + let keypair = generate_tls_keypair(); + let peer_id = keypair.public().to_peer_id(); + let transport = WebRTCDirectTransport::new(cert, keypair, "127.0.0.1:0").await?; + let fingerprint = transport.cert_fingerprint(); + let protocols = iter::once((PingProtocol(), ProtocolSupport::Full)); + let cfg = RequestResponseConfig::default(); + let behaviour = RequestResponse::new(PingCodec(), protocols, cfg); + trace!("{}", peer_id); + Ok(( + SwarmBuilder::new(transport.boxed(), behaviour, peer_id) + .executor(Box::new(|fut| { + tokio::spawn(fut); + })) + .build(), + fingerprint, + )) +} + +#[tokio::test] +async fn smoke() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let mut rng = rand::thread_rng(); + + let (mut a, a_fingerprint) = create_swarm().await?; + let (mut b, _b_fingerprint) = create_swarm().await?; + + Swarm::listen_on(&mut a, "/ip4/127.0.0.1/udp/0".parse()?)?; + + let addr = match a.next().await { + Some(SwarmEvent::NewListenAddr { address, .. }) => address, + e => panic!("{:?}", e), + }; + let addr = addr.with(Protocol::XWebRTC(hex_to_cow( + &a_fingerprint.replace(":", ""), + ))); + + let mut data = vec![0; 4096]; + rng.fill_bytes(&mut data); + + b.behaviour_mut() + .add_address(&Swarm::local_peer_id(&a), addr); + b.behaviour_mut() + .send_request(&Swarm::local_peer_id(&a), Ping(data.clone())); + + match b.next().await { + Some(SwarmEvent::Dialing(_)) => {}, + e => panic!("{:?}", e), + } + + match a.next().await { + Some(SwarmEvent::IncomingConnection { .. }) => {}, + e => panic!("{:?}", e), + }; + + match b.next().await { + Some(SwarmEvent::ConnectionEstablished { .. }) => {}, + e => panic!("{:?}", e), + }; + + match a.next().await { + Some(SwarmEvent::ConnectionEstablished { .. }) => {}, + e => panic!("{:?}", e), + }; + + assert!(b.next().now_or_never().is_none()); + + match a.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::Message { + message: + RequestResponseMessage::Request { + request: Ping(ping), + channel, + .. + }, + .. + })) => { + a.behaviour_mut() + .send_response(channel, Pong(ping)) + .unwrap(); + }, + e => panic!("{:?}", e), + } + + match a.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { .. })) => {}, + e => panic!("{:?}", e), + } + + match b.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::Message { + message: + RequestResponseMessage::Response { + response: Pong(pong), + .. + }, + .. + })) => assert_eq!(data, pong), + e => panic!("{:?}", e), + } + + a.behaviour_mut().send_request( + &Swarm::local_peer_id(&b), + Ping(b"another substream".to_vec()), + ); + + assert!(a.next().now_or_never().is_none()); + + match b.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::Message { + message: + RequestResponseMessage::Request { + request: Ping(data), + channel, + .. + }, + .. + })) => { + b.behaviour_mut() + .send_response(channel, Pong(data)) + .unwrap(); + }, + e => panic!("{:?}", e), + } + + match b.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { .. })) => {}, + e => panic!("{:?}", e), + } + + match a.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::Message { + message: + RequestResponseMessage::Response { + response: Pong(data), + .. + }, + .. + })) => assert_eq!(data, b"another substream".to_vec()), + e => panic!("{:?}", e), + } + + Ok(()) +} + +#[derive(Debug, Clone)] +struct PingProtocol(); + +#[derive(Clone)] +struct PingCodec(); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct Ping(Vec); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct Pong(Vec); + +impl ProtocolName for PingProtocol { + fn protocol_name(&self) -> &[u8] { + "/ping/1".as_bytes() + } +} + +#[async_trait] +impl RequestResponseCodec for PingCodec { + type Protocol = PingProtocol; + type Request = Ping; + type Response = Pong; + + async fn read_request(&mut self, _: &PingProtocol, io: &mut T) -> io::Result + where + T: AsyncRead + Unpin + Send, + { + upgrade::read_length_prefixed(io, 4096) + .map(|res| match res { + Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), + Ok(vec) if vec.is_empty() => Err(io::ErrorKind::UnexpectedEof.into()), + Ok(vec) => Ok(Ping(vec)), + }) + .await + } + + async fn read_response(&mut self, _: &PingProtocol, io: &mut T) -> io::Result + where + T: AsyncRead + Unpin + Send, + { + upgrade::read_length_prefixed(io, 4096) + .map(|res| match res { + Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), + Ok(vec) if vec.is_empty() => Err(io::ErrorKind::UnexpectedEof.into()), + Ok(vec) => Ok(Pong(vec)), + }) + .await + } + + async fn write_request( + &mut self, + _: &PingProtocol, + io: &mut T, + Ping(data): Ping, + ) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send, + { + upgrade::write_length_prefixed(io, data).await?; + io.close().await?; + Ok(()) + } + + async fn write_response( + &mut self, + _: &PingProtocol, + io: &mut T, + Pong(data): Pong, + ) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send, + { + upgrade::write_length_prefixed(io, data).await?; + io.close().await?; + Ok(()) + } +} + +#[tokio::test] +async fn dial_failure() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let (mut a, a_fingerprint) = create_swarm().await?; + let (mut b, _b_fingerprint) = create_swarm().await?; + + Swarm::listen_on(&mut a, "/ip4/127.0.0.1/udp/0".parse()?)?; + + let addr = match a.next().await { + Some(SwarmEvent::NewListenAddr { address, .. }) => address, + e => panic!("{:?}", e), + }; + let addr = addr.with(Protocol::XWebRTC(hex_to_cow( + &a_fingerprint.replace(":", ""), + ))); + + let a_peer_id = &Swarm::local_peer_id(&a).clone(); + drop(a); // stop a swarm so b can never reach it + + b.behaviour_mut().add_address(a_peer_id, addr); + b.behaviour_mut() + .send_request(a_peer_id, Ping(b"hello world".to_vec())); + + match b.next().await { + Some(SwarmEvent::Dialing(_)) => {}, + e => panic!("{:?}", e), + } + + match b.next().await { + Some(SwarmEvent::OutgoingConnectionError { .. }) => {}, + e => panic!("{:?}", e), + }; + + match b.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::OutboundFailure { .. })) => {}, + e => panic!("{:?}", e), + }; + + Ok(()) +} + +fn hex_to_cow<'a>(s: &str) -> Cow<'a, [u8; 32]> { + let mut buf = [0; 32]; + hex::decode_to_slice(s, &mut buf).unwrap(); + Cow::Owned(buf) +}