Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Commit

Permalink
Implement StreamMuxer trait for Connection (#2)
Browse files Browse the repository at this point in the history
* start impl StreamMuxer trait

* implement poll_outbound fn

* implement poll fns

* impl close and flush_all fns

* wrap fields in arc and clone it in open_outbound

Fixes:

```
error[E0759]: `self` has an anonymous lifetime `'_` but it needs to
satisfy a `'static` lifetime requirement
```

* implement poll_event fn

* add some comments

* close `incoming_data_channels_rx` in close

to prevent adding new data channels

* add comment to destroy_outbound

* drop locks early

* reduce the number of locks

* bring back noise and start smoke tests

* reserve 1 for the initial substream

* smoke tests: add webrtc fingerprint to dialing addr

* use tokio executor in swarm

* do not assume fingerprint to be last

actually it's supposed to be 3rd (ip, port, fingerprint, peer_id)
but let's not assume anything

* increase read buf size to 8192

and reduce one in smoke tests

* verify peer_id and close initial stream (used for noise)

* allow changing read buf cap in Connection

* use Default::default when constructing RTCDataChannelInit

* remove unnecessary field from RTCDataChannelInit

* remove next_outbound_channel_id

It turned out we actually don't need it. webrtc-rs lib will increase
channel ID for us (on both ends on the connection).

Also, extract `register_data_channel_open_handler` fn to reduce the
amount of duplicated code.

* fix some clippy warnings

* exchange TLS fingerprint certificates inside noise

Fixes
#2 (comment)

* remove udp_mux files and update PollDataChannel

based on new ice, sctp and data crates

* include old code
melekes authored Jun 17, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent f043c3d commit 0d85a4a
Showing 11 changed files with 1,442 additions and 1,059 deletions.
183 changes: 172 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 19 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -9,25 +9,36 @@ edition = "2021"
publish = false

[dependencies]
async-trait = "0.1.52"
bytes = "1"
env_logger = "0.9.0"
fnv = "1.0"
futures = "0.3.17"
futures-lite = "1.12.0"
futures-timer = "3.0"
hex = "0.4"
if-watch = "0.2.2"
libp2p-core = { version = "0.32.0", default-features = false, git = "https://github.com/melekes/rust-libp2p", branch = "anton/x-webrtc" }
libp2p-noise = { version = "0.35.0", git = "https://github.com/melekes/rust-libp2p", branch = "anton/x-webrtc" }
log = "0.4.14"
serde = { version = "1.0", features = ["derive"] }
stun = "0.4.2"
thiserror = "1"
tinytemplate = "1.2"
tokio-crate = { package = "tokio", version = "1.17.0", default-features = false, features = ["net"]}
tokio-crate = { package = "tokio", version = "1.18.2", features = ["net"]}
webrtc = { version = "0.4.0", git = "https://github.com/melekes/webrtc", branch = "anton/168-allow-persistent-certificates" }
webrtc-ice = "0.6.6"
webrtc-data = "0.3.3"
webrtc-sctp = "0.4.3"
if-watch = "0.2.2"
futures-timer = "3.0"
stun = "0.4.2"
webrtc-ice = "0.7.0"
webrtc-sctp = "0.5.0"
webrtc-util = { version = "0.5.3", default-features = false, features = ["conn", "vnet", "sync"] }
async-trait = "0.1.52"

[dev-dependencies]
env_logger = "0.9.0"
anyhow = "1.0.41"
rand = "0.8.4"
rand_core = "0.5.1"
rcgen = "0.8.14"
libp2p-swarm = { version = "0.35.0", git = "https://github.com/melekes/rust-libp2p", branch = "anton/x-webrtc" }
libp2p-request-response = { version = "0.17.0", git = "https://github.com/melekes/rust-libp2p", branch = "anton/x-webrtc" }

[patch.crates-io]
webrtc-data = { version = "0.3.3", git = "https://github.com/melekes/webrtc-data.git", branch = "anton/async-read-write-for-data-channel" }
315 changes: 286 additions & 29 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -20,59 +20,316 @@

mod poll_data_channel;

use futures::prelude::*;
use fnv::FnvHashMap;
use futures::channel::mpsc;
use futures::channel::oneshot::{self, Sender};
use futures::lock::Mutex as FutMutex;
use futures::{future::BoxFuture, prelude::*, ready};
use futures_lite::stream::StreamExt;
use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent};
use log::{debug, error, trace};
use webrtc::data_channel::RTCDataChannel;
use webrtc::peer_connection::RTCPeerConnection;
use webrtc_data::data_channel::DataChannel;
use webrtc_data::data_channel::DataChannel as DetachedDataChannel;

use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::{Arc, Mutex as StdMutex};
use std::task::{Context, Poll};

use poll_data_channel::PollDataChannel;
pub(crate) use poll_data_channel::PollDataChannel;

/// A WebRTC connection over a single data channel. See lib documentation for
/// the reasoning as to why a single data channel is being used.
pub struct Connection<'a> {
/// A WebRTC connection, wrapping [`RTCPeerConnection`] and implementing [`StreamMuxer`] trait.
pub struct Connection {
connection_inner: Arc<FutMutex<ConnectionInner>>, // uses futures mutex because used in async code (see open_outbound)
data_channels_inner: StdMutex<DataChannelsInner>,
}

struct ConnectionInner {
/// `RTCPeerConnection` to the remote peer.
pub inner: RTCPeerConnection,
/// A data channel.
pub data_channel: PollDataChannel<'a>,
rtc_conn: RTCPeerConnection,
}

struct DataChannelsInner {
/// A map of data channels.
map: FnvHashMap<u16, PollDataChannel>,
/// Channel onto which incoming data channels are put.
incoming_data_channels_rx: mpsc::Receiver<Arc<DetachedDataChannel>>,
/// Temporary read buffer's capacity (equal for all data channels).
/// See [`PollDataChannel`] `read_buf_cap`.
read_buf_cap: Option<usize>,
}

impl Connection<'_> {
pub fn new(peer_conn: RTCPeerConnection, data_channel: Arc<DataChannel>) -> Self {
impl Connection {
/// Creates a new connection.
pub async fn new(rtc_conn: RTCPeerConnection) -> Self {
let (data_channel_tx, data_channel_rx) = mpsc::channel(10);

Connection::register_incoming_data_channels_handler(&rtc_conn, data_channel_tx).await;

Self {
inner: peer_conn,
data_channel: PollDataChannel::new(data_channel),
connection_inner: Arc::new(FutMutex::new(ConnectionInner { rtc_conn })),
data_channels_inner: StdMutex::new(DataChannelsInner {
map: FnvHashMap::default(),
incoming_data_channels_rx: data_channel_rx,
read_buf_cap: None,
}),
}
}

/// Set the capacity of a data channel's temporary read buffer (equal for all data channels; default: 8192).
pub fn set_data_channels_read_buf_capacity(&mut self, cap: usize) {
let mut data_channels_inner = self.data_channels_inner.lock().unwrap();
data_channels_inner.read_buf_cap = Some(cap);
}

/// Registers a handler for incoming data channels.
async fn register_incoming_data_channels_handler(
rtc_conn: &RTCPeerConnection,
tx: mpsc::Sender<Arc<DetachedDataChannel>>,
) {
rtc_conn
.on_data_channel(Box::new(move |data_channel: Arc<RTCDataChannel>| {
debug!(
"Incoming data channel '{}'-'{}'",
data_channel.label(),
data_channel.id()
);

let data_channel = data_channel.clone();
let mut tx = tx.clone();

Box::pin(async move {
data_channel
.on_open({
let data_channel = data_channel.clone();
Box::new(move || {
debug!(
"Data channel '{}'-'{}' open",
data_channel.label(),
data_channel.id()
);

Box::pin(async move {
let data_channel = data_channel.clone();
match data_channel.detach().await {
Ok(detached) => {
if let Err(e) = tx.try_send(detached) {
// This can happen if the client is not reading
// events (using `poll_event`) fast enough, which
// generally shouldn't be the case.
error!("Can't send data channel: {}", e);
}
},
Err(e) => {
error!("Can't detach data channel: {}", e);
},
};
})
})
})
.await;
})
}))
.await;
}
}

impl AsyncRead for Connection<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
impl<'a> StreamMuxer for Connection {
type Substream = PollDataChannel;
type OutboundSubstream = BoxFuture<'static, Result<Arc<DetachedDataChannel>, Self::Error>>;
type Error = io::Error;

fn poll_event(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<StreamMuxerEvent<Self::Substream>, Self::Error>> {
let mut data_channels_inner = self.data_channels_inner.lock().unwrap();
match ready!(data_channels_inner.incoming_data_channels_rx.poll_next(cx)) {
Some(detached) => {
trace!("Incoming substream {}", detached.stream_identifier());

let ch = PollDataChannel::new(detached);
// if let Some(cap) = data_channels_inner.read_buf_cap {
// ch.set_read_buf_capacity(cap);
// }

data_channels_inner
.map
.insert(ch.stream_identifier(), ch.clone());

Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(ch)))
},
None => Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"incoming_data_channels_rx is closed (no messages left)",
))),
}
}

fn open_outbound(&self) -> Self::OutboundSubstream {
let connection_inner = self.connection_inner.clone();

Box::pin(async move {
let connection_inner = connection_inner.lock().await;

// Create a datachannel with label 'data'
let data_channel = connection_inner
.rtc_conn
.create_data_channel("data", None)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("webrtc error: {}", e)))
.await?;

trace!("Opening outbound substream {}", data_channel.id());

// No need to hold the lock during the DTLS handshake.
drop(connection_inner);

let (tx, rx) = oneshot::channel::<Arc<DetachedDataChannel>>();

// Wait until the data channel is opened and detach it.
register_data_channel_open_handler(data_channel, tx).await;

// Wait until data channel is opened and ready to use
match rx.await {
Ok(detached) => Ok(detached),
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e.to_string())),
}
})
}

fn poll_outbound(
&self,
cx: &mut Context<'_>,
s: &mut Self::OutboundSubstream,
) -> Poll<Result<Self::Substream, Self::Error>> {
match ready!(s.as_mut().poll(cx)) {
Ok(detached) => {
let mut data_channels_inner = self.data_channels_inner.lock().unwrap();

let ch = PollDataChannel::new(detached);
// if let Some(cap) = data_channels_inner.read_buf_cap {
// ch.set_read_buf_capacity(cap);
// }

data_channels_inner
.map
.insert(ch.stream_identifier(), ch.clone());

Poll::Ready(Ok(ch))
},
Err(e) => Poll::Ready(Err(e)),
}
}

/// NOTE: `_s` might be waiting at one of the await points, and dropping the future will
/// abruptly interrupt the execution.
fn destroy_outbound(&self, _s: Self::OutboundSubstream) {}

fn read_substream(
&self,
cx: &mut Context<'_>,
s: &mut Self::Substream,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.data_channel).poll_read(cx, buf)
) -> Poll<Result<usize, Self::Error>> {
Pin::new(s).poll_read(cx, buf)
}
}

impl AsyncWrite for Connection<'_> {
fn poll_write(
mut self: Pin<&mut Self>,
fn write_substream(
&self,
cx: &mut Context<'_>,
s: &mut Self::Substream,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.data_channel).poll_write(cx, buf)
) -> Poll<Result<usize, Self::Error>> {
Pin::new(s).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.data_channel).poll_flush(cx)
fn flush_substream(
&self,
cx: &mut Context<'_>,
s: &mut Self::Substream,
) -> Poll<Result<(), Self::Error>> {
Pin::new(s).poll_flush(cx)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.data_channel).poll_close(cx)
fn shutdown_substream(
&self,
cx: &mut Context<'_>,
s: &mut Self::Substream,
) -> Poll<Result<(), Self::Error>> {
trace!("Closing substream {}", s.stream_identifier());
Pin::new(s).poll_close(cx)
}

fn destroy_substream(&self, s: Self::Substream) {
let mut data_channels_inner = self.data_channels_inner.lock().unwrap();
data_channels_inner.map.remove(&s.stream_identifier());
}

fn close(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
debug!("Closing connection");

// First, flush all the buffered data.
match ready!(self.flush_all(cx)) {
Ok(_) => {
// Second, shutdown all the substreams.
let mut data_channels_inner = self.data_channels_inner.lock().unwrap();
for (_, ch) in &mut data_channels_inner.map {
match ready!(self.shutdown_substream(cx, ch)) {
Ok(_) => continue,
Err(e) => return Poll::Ready(Err(e)),
}
}

// Third, close `incoming_data_channels_rx`
data_channels_inner.incoming_data_channels_rx.close();

Poll::Ready(Ok(()))
},
Err(e) => Poll::Ready(Err(e)),
}
}

fn flush_all(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut data_channels_inner = self.data_channels_inner.lock().unwrap();
for (_, ch) in &mut data_channels_inner.map {
match ready!(self.flush_substream(cx, ch)) {
Ok(_) => continue,
Err(e) => return Poll::Ready(Err(e)),
}
}
Poll::Ready(Ok(()))
}
}

pub(crate) async fn register_data_channel_open_handler(
data_channel: Arc<RTCDataChannel>,
data_channel_tx: Sender<Arc<DetachedDataChannel>>,
) {
data_channel
.on_open({
let data_channel = data_channel.clone();
Box::new(move || {
debug!(
"Data channel '{}'-'{}' open",
data_channel.label(),
data_channel.id()
);

Box::pin(async move {
let data_channel = data_channel.clone();
match data_channel.detach().await {
Ok(detached) => {
if let Err(e) = data_channel_tx.send(detached) {
error!("Can't send data channel: {:?}", e);
}
},
Err(e) => {
error!("Can't detach data channel: {}", e);
},
};
})
})
})
.await;
}
556 changes: 344 additions & 212 deletions src/connection/poll_data_channel.rs

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use libp2p_core::PeerId;
use thiserror::Error;

/// Error in WebRTC.
@@ -29,6 +30,18 @@ pub enum Error {
WebRTC(#[from] webrtc::Error),
#[error("io error: {0}")]
IoError(#[from] std::io::Error),
#[error("noise error: {0}")]
Noise(#[from] libp2p_noise::NoiseError),

// Authentication errors.
#[error("invalid fingerprint (expected {expected:?}, got {got:?})")]
InvalidFingerprint { expected: String, got: String },
#[error("invalid peer ID (expected {expected:?}, got {got:?})")]
InvalidPeerID {
expected: Option<PeerId>,
got: PeerId,
},

#[error("internal error: {0} (see debug logs)")]
InternalError(String),
}
228 changes: 146 additions & 82 deletions src/transport.rs

Large diffs are not rendered by default.

151 changes: 74 additions & 77 deletions src/udp_mux.rs
Original file line number Diff line number Diff line change
@@ -25,22 +25,17 @@ use std::{
collections::{HashMap, HashSet},
io::ErrorKind,
net::SocketAddr,
sync::Arc,
sync::{Arc, Weak},
};

use futures::channel::mpsc;
use libp2p_core::multiaddr::{Multiaddr, Protocol};
use webrtc_ice::udp_mux::UDPMux;
use webrtc_ice::udp_mux::{UDPMux, UDPMuxConn, UDPMuxConnParams, UDPMuxWriter};
use webrtc_util::{sync::RwLock, Conn, Error};

use tokio_crate as tokio;
use tokio_crate::sync::{watch, Mutex};

mod socket_addr_ext;

mod udp_mux_conn;
use udp_mux_conn::{UDPMuxConn, UDPMuxConnParams};

use async_trait::async_trait;

use stun::{
@@ -127,56 +122,19 @@ impl UDPMuxNewAddr {
self.closed_watch_tx.lock().await.is_none()
}

async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error> {
self.params
.conn
.send_to(buf, *target)
.await
.map_err(Into::into)
}

/// Create a muxed connection for a given ufrag.
async fn create_muxed_conn(self: &Arc<Self>, ufrag: &str) -> Result<UDPMuxConn, Error> {
let local_addr = self.params.conn.local_addr().await?;

let params = UDPMuxConnParams {
local_addr,
key: ufrag.into(),
udp_mux: Arc::clone(self),
udp_mux: Arc::downgrade(self) as Weak<dyn UDPMuxWriter + Sync + Send>,
};

Ok(UDPMuxConn::new(params))
}

async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) {
if self.is_closed().await {
return;
}

let key = conn.key();
{
let mut addresses = self.address_map.write();

addresses
.entry(addr)
.and_modify(|e| {
if e.key() != key {
e.remove_address(&addr);
*e = conn.clone()
}
})
.or_insert_with(|| conn.clone());
}

// remove addr from new_addrs once conn is established
{
let mut new_addrs = self.new_addrs.write();
new_addrs.remove(&addr);
}

log::debug!("Registered {} for {}", addr, key);
}

async fn conn_from_stun_message(&self, buffer: &[u8], addr: &SocketAddr) -> Option<UDPMuxConn> {
match ufrag_from_stun_message(buffer, true) {
Ok(ufrag) => {
@@ -234,12 +192,12 @@ impl UDPMuxNewAddr {
let a = Multiaddr::empty()
.with(addr.ip().into())
.with(Protocol::Udp(addr.port()))
.with(Protocol::XWebRTC(hex_to_cow(&ufrag.replace(":", ""))));
.with(Protocol::XWebRTC(hex_to_cow(&ufrag.replace(':', ""))));
if let Err(err) = new_addr_tx.try_send(a) {
log::error!("Failed to send new address {}: {}", &addr, err);
} else {
let mut new_addrs = loop_self.new_addrs.write();
new_addrs.insert(addr.clone());
new_addrs.insert(addr);
};
}
Err(e) => {
@@ -291,7 +249,7 @@ impl UDPMux for UDPMuxNewAddr {
};

// NOTE: We don't wait for these closure to complete
for (_, conn) in old_conns.into_iter() {
for (_, conn) in old_conns {
conn.close();
}

@@ -364,6 +322,46 @@ impl UDPMux for UDPMuxNewAddr {
}
}

#[async_trait]
impl UDPMuxWriter for UDPMuxNewAddr {
async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) {
if self.is_closed().await {
return;
}

let key = conn.key();
{
let mut addresses = self.address_map.write();

addresses
.entry(addr)
.and_modify(|e| {
if e.key() != key {
e.remove_address(&addr);
*e = conn.clone();
}
})
.or_insert_with(|| conn.clone());
}

// remove addr from new_addrs once conn is established
{
let mut new_addrs = self.new_addrs.write();
new_addrs.remove(&addr);
}

log::debug!("Registered {} for {}", addr, key);
}

async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error> {
self.params
.conn
.send_to(buf, *target)
.await
.map_err(Into::into)
}
}

fn hex_to_cow<'a>(s: &str) -> Cow<'a, [u8; 32]> {
let mut buf = [0; 32];
hex::decode_to_slice(s, &mut buf).unwrap();
@@ -379,37 +377,36 @@ fn ufrag_from_stun_message(buffer: &[u8], local_ufrag: bool) -> Result<String, E
(m.unmarshal_binary(buffer), m)
};

match result {
Err(err) => Err(Error::Other(format!(
if let Err(err) = result {
Err(Error::Other(format!(
"failed to handle decode ICE: {}",
err
))),
Ok(_) => {
let (attr, found) = message.attributes.get(ATTR_USERNAME);
if !found {
return Err(Error::Other("no username attribute in STUN message".into()));
}
)))
} else {
let (attr, found) = message.attributes.get(ATTR_USERNAME);
if !found {
return Err(Error::Other("no username attribute in STUN message".into()));
}

match String::from_utf8(attr.value) {
// Per the RFC this shouldn't happen
// https://datatracker.ietf.org/doc/html/rfc5389#section-15.3
Err(err) => Err(Error::Other(format!(
"failed to decode USERNAME from STUN message as UTF-8: {}",
err
))),
Ok(s) => {
// s is a combination of the local_ufrag and the remote ufrag separated by `:`.
let res = if local_ufrag {
s.split(":").next()
} else {
s.split(":").last()
};
match res {
Some(s) => Ok(s.to_owned()),
None => Err(Error::Other("can't get ufrag from username".into())),
}
},
}
},
match String::from_utf8(attr.value) {
// Per the RFC this shouldn't happen
// https://datatracker.ietf.org/doc/html/rfc5389#section-15.3
Err(err) => Err(Error::Other(format!(
"failed to decode USERNAME from STUN message as UTF-8: {}",
err
))),
Ok(s) => {
// s is a combination of the local_ufrag and the remote ufrag separated by `:`.
let res = if local_ufrag {
s.split(':').next()
} else {
s.split(':').last()
};
match res {
Some(s) => Ok(s.to_owned()),
None => Err(Error::Other("can't get ufrag from username".into())),
}
},
}
}
}
267 changes: 0 additions & 267 deletions src/udp_mux/socket_addr_ext.rs

This file was deleted.

308 changes: 0 additions & 308 deletions src/udp_mux/udp_mux_conn.rs

This file was deleted.

148 changes: 83 additions & 65 deletions src/upgrade.rs
Original file line number Diff line number Diff line change
@@ -18,12 +18,17 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use futures::channel::oneshot;
use futures::select;
use futures::FutureExt;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use futures::{channel::oneshot, future, select, FutureExt, TryFutureExt};
use futures_timer::Delay;
use libp2p_core::multiaddr::{Multiaddr, Protocol};
use log::{debug, error, trace};
use libp2p_core::identity;
use libp2p_core::{
multiaddr::{Multiaddr, Protocol},
PeerId,
};
use libp2p_core::{InboundUpgrade, UpgradeInfo};
use libp2p_noise::{Keypair, NoiseConfig, NoiseError, RemoteIdentity, X25519Spec};
use log::{debug, trace};
use webrtc::api::APIBuilder;
use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
use webrtc::dtls_transport::dtls_role::DTLSRole;
@@ -36,6 +41,7 @@ use std::sync::Arc;
use std::time::Duration;

use crate::connection::Connection;
use crate::connection::PollDataChannel;
use crate::error::Error;
use crate::sdp;
use crate::transport;
@@ -44,14 +50,15 @@ pub async fn webrtc(
udp_mux: Arc<dyn UDPMux + Send + Sync>,
config: RTCConfiguration,
addr: Multiaddr,
) -> Result<Connection<'static>, Error> {
id_keys: identity::Keypair,
) -> Result<(PeerId, Connection), Error> {
trace!("upgrading {}", addr);

let socket_addr = transport::multiaddr_to_socketaddr(&addr)
.ok_or_else(|| Error::InvalidMultiaddr(addr.clone()))?;
let fingerprint = transport::fingerprint_of_first_certificate(&config);
let our_fingerprint = transport::fingerprint_of_first_certificate(&config);

let mut se = transport::build_setting_engine(udp_mux, &socket_addr, &fingerprint);
let mut se = transport::build_setting_engine(udp_mux, &socket_addr, &our_fingerprint);
{
// Act as a lite ICE (ICE which does not send additional candidates).
se.set_lite(true);
@@ -65,63 +72,17 @@ pub async fn webrtc(
let api = APIBuilder::new().with_setting_engine(se).build();
let peer_connection = api.new_peer_connection(config).await?;

// Create a datachannel with label 'data'.
let data_channel = peer_connection
.create_data_channel(
"data",
Some(RTCDataChannelInit {
negotiated: Some(true),
id: Some(1),
ordered: None,
max_retransmits: None,
max_packet_life_time: None,
protocol: None,
}),
)
.await?;

let (data_channel_rx, mut data_channel_tx) = oneshot::channel::<Arc<DetachedDataChannel>>();

// Wait until the data channel is opened and detach it.
data_channel
.on_open({
let data_channel = data_channel.clone();
Box::new(move || {
debug!(
"Data channel '{}'-'{}' open.",
data_channel.label(),
data_channel.id()
);

Box::pin(async move {
let data_channel = data_channel.clone();
match data_channel.detach().await {
Ok(detached) => {
if let Err(_) = data_channel_rx.send(detached) {
error!("data_channel_tx dropped");
}
},
Err(e) => {
error!("Can't detach data channel: {}", e);
},
};
})
})
})
.await;

// Set the remote description to the predefined SDP.
let fingerprint = match addr.iter().last() {
Some(Protocol::XWebRTC(f)) => f,
_ => {
debug!("{} is not a WebRTC multiaddr", addr);
return Err(Error::InvalidMultiaddr(addr));
},
let remote_fingerprint = if let Some(Protocol::XWebRTC(f)) = addr.iter().last() {
transport::fingerprint_to_string(&f)
} else {
debug!("{} is not a WebRTC multiaddr", addr);
return Err(Error::InvalidMultiaddr(addr));
};
let client_session_description = transport::render_description(
sdp::CLIENT_SESSION_DESCRIPTION,
socket_addr,
&transport::fingerprint_to_string(&fingerprint),
&remote_fingerprint,
);
debug!("OFFER: {:?}", client_session_description);
let sdp = RTCSessionDescription::offer(client_session_description).unwrap();
@@ -133,14 +94,71 @@ pub async fn webrtc(
debug!("ANSWER: {:?}", answer.sdp);
peer_connection.set_local_description(answer).await?;

// Create a datachannel with label 'data'.
let data_channel = peer_connection
.create_data_channel(
"data",
Some(RTCDataChannelInit {
negotiated: Some(true),
id: Some(1),
..RTCDataChannelInit::default()
}),
)
.await?;

let (tx, mut rx) = oneshot::channel::<Arc<DetachedDataChannel>>();

// Wait until the data channel is opened and detach it.
// Wait until the data channel is opened and detach it.
crate::connection::register_data_channel_open_handler(data_channel, tx).await;

// Wait until data channel is opened and ready to use
select! {
res = data_channel_tx => match res {
Ok(dc) => Ok(Connection::new(peer_connection, dc)),
Err(e) => Err(Error::InternalError(e.to_string())),
let detached = select! {
res = rx => match res {
Ok(detached) => detached,
Err(e) => return Err(Error::InternalError(e.to_string())),
},
_ = Delay::new(Duration::from_secs(10)).fuse() => Err(Error::InternalError(
_ = Delay::new(Duration::from_secs(10)).fuse() => return Err(Error::InternalError(
"data channel opening took longer than 10 seconds (see logs)".into(),
))
};

trace!("noise handshake with {}", addr);
let dh_keys = Keypair::<X25519Spec>::new()
.into_authentic(&id_keys)
.unwrap();
let noise = NoiseConfig::xx(dh_keys);
let info = noise.protocol_info().next().unwrap();
let (peer_id, mut noise_io) = noise
.upgrade_inbound(PollDataChannel::new(detached.clone()), info)
.and_then(|(remote, io)| match remote {
RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)),
_ => future::err(NoiseError::AuthenticationFailed),
})
.await
.map_err(Error::Noise)?;

// Exchange TLS certificate fingerprints to prevent MiM attacks.
trace!("exchanging TLS certificate fingerprints with {}", addr);
let n = noise_io.write(&our_fingerprint.into_bytes()).await?;
noise_io.flush().await?;
let mut buf = vec![0; n]; // ASSERT: fingerprint's format is the same.
noise_io.read_exact(buf.as_mut_slice()).await?;
let fingerprint_from_noise =
String::from_utf8(buf).map_err(|_| Error::Noise(NoiseError::AuthenticationFailed))?;
if fingerprint_from_noise != remote_fingerprint {
return Err(Error::InvalidFingerprint {
expected: remote_fingerprint,
got: fingerprint_from_noise,
});
}

// Close the initial data channel after noise handshake is done.
// https://github.com/webrtc-rs/sctp/pull/14
// detached
// .close()
// .await
// .map_err(|e| Error::WebRTC(e.into()))?;

Ok((peer_id, Connection::new(peer_connection).await))
}
305 changes: 305 additions & 0 deletions tests/smoke.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::future::FutureExt;
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use futures::stream::StreamExt;
use libp2p_core::identity;
use libp2p_core::multiaddr::Protocol;
use libp2p_core::upgrade;
use libp2p_request_response::{
ProtocolName, ProtocolSupport, RequestResponse, RequestResponseCodec, RequestResponseConfig,
RequestResponseEvent, RequestResponseMessage,
};
use libp2p_swarm::{Swarm, SwarmBuilder, SwarmEvent};
use libp2p_webrtc_direct::transport::WebRTCDirectTransport;
use log::trace;
use rand::RngCore;
use rcgen::KeyPair;
use tokio_crate as tokio;
use webrtc::peer_connection::certificate::RTCCertificate;

use std::borrow::Cow;
use std::{io, iter};

fn generate_certificate() -> RTCCertificate {
let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256).expect("key pair");
RTCCertificate::from_key_pair(kp).expect("certificate")
}

fn generate_tls_keypair() -> identity::Keypair {
identity::Keypair::generate_ed25519()
}

async fn create_swarm() -> Result<(Swarm<RequestResponse<PingCodec>>, String)> {
let cert = generate_certificate();
let keypair = generate_tls_keypair();
let peer_id = keypair.public().to_peer_id();
let transport = WebRTCDirectTransport::new(cert, keypair, "127.0.0.1:0").await?;
let fingerprint = transport.cert_fingerprint();
let protocols = iter::once((PingProtocol(), ProtocolSupport::Full));
let cfg = RequestResponseConfig::default();
let behaviour = RequestResponse::new(PingCodec(), protocols, cfg);
trace!("{}", peer_id);
Ok((
SwarmBuilder::new(transport.boxed(), behaviour, peer_id)
.executor(Box::new(|fut| {
tokio::spawn(fut);
}))
.build(),
fingerprint,
))
}

#[tokio::test]
async fn smoke() -> Result<()> {
let _ = env_logger::builder().is_test(true).try_init();

let mut rng = rand::thread_rng();

let (mut a, a_fingerprint) = create_swarm().await?;
let (mut b, _b_fingerprint) = create_swarm().await?;

Swarm::listen_on(&mut a, "/ip4/127.0.0.1/udp/0".parse()?)?;

let addr = match a.next().await {
Some(SwarmEvent::NewListenAddr { address, .. }) => address,
e => panic!("{:?}", e),
};
let addr = addr.with(Protocol::XWebRTC(hex_to_cow(
&a_fingerprint.replace(":", ""),
)));

let mut data = vec![0; 4096];
rng.fill_bytes(&mut data);

b.behaviour_mut()
.add_address(&Swarm::local_peer_id(&a), addr);
b.behaviour_mut()
.send_request(&Swarm::local_peer_id(&a), Ping(data.clone()));

match b.next().await {
Some(SwarmEvent::Dialing(_)) => {},
e => panic!("{:?}", e),
}

match a.next().await {
Some(SwarmEvent::IncomingConnection { .. }) => {},
e => panic!("{:?}", e),
};

match b.next().await {
Some(SwarmEvent::ConnectionEstablished { .. }) => {},
e => panic!("{:?}", e),
};

match a.next().await {
Some(SwarmEvent::ConnectionEstablished { .. }) => {},
e => panic!("{:?}", e),
};

assert!(b.next().now_or_never().is_none());

match a.next().await {
Some(SwarmEvent::Behaviour(RequestResponseEvent::Message {
message:
RequestResponseMessage::Request {
request: Ping(ping),
channel,
..
},
..
})) => {
a.behaviour_mut()
.send_response(channel, Pong(ping))
.unwrap();
},
e => panic!("{:?}", e),
}

match a.next().await {
Some(SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { .. })) => {},
e => panic!("{:?}", e),
}

match b.next().await {
Some(SwarmEvent::Behaviour(RequestResponseEvent::Message {
message:
RequestResponseMessage::Response {
response: Pong(pong),
..
},
..
})) => assert_eq!(data, pong),
e => panic!("{:?}", e),
}

a.behaviour_mut().send_request(
&Swarm::local_peer_id(&b),
Ping(b"another substream".to_vec()),
);

assert!(a.next().now_or_never().is_none());

match b.next().await {
Some(SwarmEvent::Behaviour(RequestResponseEvent::Message {
message:
RequestResponseMessage::Request {
request: Ping(data),
channel,
..
},
..
})) => {
b.behaviour_mut()
.send_response(channel, Pong(data))
.unwrap();
},
e => panic!("{:?}", e),
}

match b.next().await {
Some(SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { .. })) => {},
e => panic!("{:?}", e),
}

match a.next().await {
Some(SwarmEvent::Behaviour(RequestResponseEvent::Message {
message:
RequestResponseMessage::Response {
response: Pong(data),
..
},
..
})) => assert_eq!(data, b"another substream".to_vec()),
e => panic!("{:?}", e),
}

Ok(())
}

#[derive(Debug, Clone)]
struct PingProtocol();

#[derive(Clone)]
struct PingCodec();

#[derive(Debug, Clone, PartialEq, Eq)]
struct Ping(Vec<u8>);

#[derive(Debug, Clone, PartialEq, Eq)]
struct Pong(Vec<u8>);

impl ProtocolName for PingProtocol {
fn protocol_name(&self) -> &[u8] {
"/ping/1".as_bytes()
}
}

#[async_trait]
impl RequestResponseCodec for PingCodec {
type Protocol = PingProtocol;
type Request = Ping;
type Response = Pong;

async fn read_request<T>(&mut self, _: &PingProtocol, io: &mut T) -> io::Result<Self::Request>
where
T: AsyncRead + Unpin + Send,
{
upgrade::read_length_prefixed(io, 4096)
.map(|res| match res {
Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)),
Ok(vec) if vec.is_empty() => Err(io::ErrorKind::UnexpectedEof.into()),
Ok(vec) => Ok(Ping(vec)),
})
.await
}

async fn read_response<T>(&mut self, _: &PingProtocol, io: &mut T) -> io::Result<Self::Response>
where
T: AsyncRead + Unpin + Send,
{
upgrade::read_length_prefixed(io, 4096)
.map(|res| match res {
Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)),
Ok(vec) if vec.is_empty() => Err(io::ErrorKind::UnexpectedEof.into()),
Ok(vec) => Ok(Pong(vec)),
})
.await
}

async fn write_request<T>(
&mut self,
_: &PingProtocol,
io: &mut T,
Ping(data): Ping,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
upgrade::write_length_prefixed(io, data).await?;
io.close().await?;
Ok(())
}

async fn write_response<T>(
&mut self,
_: &PingProtocol,
io: &mut T,
Pong(data): Pong,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
upgrade::write_length_prefixed(io, data).await?;
io.close().await?;
Ok(())
}
}

#[tokio::test]
async fn dial_failure() -> Result<()> {
let _ = env_logger::builder().is_test(true).try_init();

let (mut a, a_fingerprint) = create_swarm().await?;
let (mut b, _b_fingerprint) = create_swarm().await?;

Swarm::listen_on(&mut a, "/ip4/127.0.0.1/udp/0".parse()?)?;

let addr = match a.next().await {
Some(SwarmEvent::NewListenAddr { address, .. }) => address,
e => panic!("{:?}", e),
};
let addr = addr.with(Protocol::XWebRTC(hex_to_cow(
&a_fingerprint.replace(":", ""),
)));

let a_peer_id = &Swarm::local_peer_id(&a).clone();
drop(a); // stop a swarm so b can never reach it

b.behaviour_mut().add_address(a_peer_id, addr);
b.behaviour_mut()
.send_request(a_peer_id, Ping(b"hello world".to_vec()));

match b.next().await {
Some(SwarmEvent::Dialing(_)) => {},
e => panic!("{:?}", e),
}

match b.next().await {
Some(SwarmEvent::OutgoingConnectionError { .. }) => {},
e => panic!("{:?}", e),
};

match b.next().await {
Some(SwarmEvent::Behaviour(RequestResponseEvent::OutboundFailure { .. })) => {},
e => panic!("{:?}", e),
};

Ok(())
}

fn hex_to_cow<'a>(s: &str) -> Cow<'a, [u8; 32]> {
let mut buf = [0; 32];
hex::decode_to_slice(s, &mut buf).unwrap();
Cow::Owned(buf)
}

0 comments on commit 0d85a4a

Please sign in to comment.