Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(quic): add support for reusing an existing socket for local dialing #4304

Merged
merged 10 commits into from
Aug 11, 2023
28 changes: 22 additions & 6 deletions transports/quic/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,14 @@ impl<P: Provider> GenTransport<P> {
if l.is_closed {
return false;
}
let listen_addr = l.socket_addr();
SocketFamily::is_same(&listen_addr.ip(), &socket_addr.ip())
&& listen_addr.ip().is_loopback() == socket_addr.ip().is_loopback()
SocketFamily::is_same(&l.socket_addr().ip(), &socket_addr.ip())
})
.filter(|l| {
if socket_addr.ip().is_loopback() {
l.supports_loopback
} else {
true
}
})
.collect();
match listeners.len() {
Expand Down Expand Up @@ -428,6 +433,9 @@ struct Listener<P: Provider> {

/// The stream must be awaken after it has been closed to deliver the last event.
close_listener_waker: Option<Waker>,

/// `true` if a listener supports loopback interface
supports_loopback: bool,
}

impl<P: Provider> Listener<P> {
Expand All @@ -440,13 +448,17 @@ impl<P: Provider> Listener<P> {
) -> Result<Self, Error> {
let if_watcher;
let pending_event;
let mut supports_loopback = false;
let local_addr = socket.local_addr()?;
if local_addr.ip().is_unspecified() {
if_watcher = Some(P::new_if_watcher()?);
pending_event = None;
} else {
if_watcher = None;
let ma = socketaddr_to_multiaddr(&local_addr, version);
if local_addr.ip().is_loopback() {
supports_loopback = true
}
pending_event = Some(TransportEvent::NewAddress {
listener_id,
listen_addr: ma,
Expand All @@ -467,6 +479,7 @@ impl<P: Provider> Listener<P> {
is_closed: false,
pending_event,
close_listener_waker: None,
supports_loopback,
})
}

Expand Down Expand Up @@ -513,7 +526,10 @@ impl<P: Provider> Listener<P> {
if let Some(listen_addr) =
ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
{
log::debug!("New listen address: {}", listen_addr);
log::debug!("New listen address: {listen_addr}");
if inet.addr().is_loopback() {
self.supports_loopback = true;
}
return Poll::Ready(TransportEvent::NewAddress {
listener_id: self.listener_id,
listen_addr,
Expand All @@ -524,7 +540,7 @@ impl<P: Provider> Listener<P> {
if let Some(listen_addr) =
ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
{
log::debug!("Expired listen address: {}", listen_addr);
log::debug!("Expired listen address: {listen_addr}");
mxinden marked this conversation as resolved.
Show resolved Hide resolved
return Poll::Ready(TransportEvent::AddressExpired {
listener_id: self.listener_id,
listen_addr,
Expand Down Expand Up @@ -730,7 +746,7 @@ fn socketaddr_to_multiaddr(socket_addr: &SocketAddr, version: ProtocolVersion) -

#[cfg(test)]
#[cfg(any(feature = "async-std", feature = "tokio"))]
mod test {
mod tests {
use futures::future::poll_fn;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

Expand Down
43 changes: 43 additions & 0 deletions transports/quic/tests/smoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,49 @@ async fn write_after_peer_dropped_stream() {
stream_b.close().await.expect("Close failed.");
}

/// - A listens on 0.0.0.0:0
/// - B listens on 127.0.0.1:0
/// - A dials B
/// - Source port of A at B is the A's listen port
#[cfg(feature = "tokio")]
#[tokio::test]
async fn test_local_listener_reuse() {
mxinden marked this conversation as resolved.
Show resolved Hide resolved
let (_, mut a_transport) = create_default_transport::<quic::tokio::Provider>();
let (_, mut b_transport) = create_default_transport::<quic::tokio::Provider>();

a_transport
.listen_on(
ListenerId::next(),
"/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap(),
)
.unwrap();

// wait until a listener reports a loopback address
let a_addr = 'outer: loop {
arsenron marked this conversation as resolved.
Show resolved Hide resolved
let ev = a_transport.next().await.unwrap();
let listen_addr = ev.into_new_address().unwrap();
for proto in listen_addr.iter() {
if let Protocol::Ip4(ip4) = proto {
if ip4.is_loopback() {
break 'outer listen_addr;
}
}
}
};
// If we do not poll until the end, `NewAddress` events may be `Ready` and `connect` function
// below will panic due to an unexpected event.
poll_fn(|cx| {
let mut pinned = Pin::new(&mut a_transport);
while let Poll::Ready(e) = pinned.as_mut().poll(cx) {}
arsenron marked this conversation as resolved.
Show resolved Hide resolved
Poll::Ready(())
})
.await;

let b_addr = start_listening(&mut b_transport, "/ip4/127.0.0.1/udp/0/quic-v1").await;
let (_, send_back_addr, _) = connect(&mut b_transport, &mut a_transport, b_addr).await.0;
assert_eq!(send_back_addr, a_addr);
arsenron marked this conversation as resolved.
Show resolved Hide resolved
}

async fn smoke<P: Provider>() {
let _ = env_logger::try_init();

Expand Down