Skip to content

Commit

Permalink
Wait for routes update in WireguardMonitor::start
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Jan 28, 2025
1 parent b3bb654 commit 00e7c49
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 24 deletions.
9 changes: 6 additions & 3 deletions talpid-routing/src/unix/android.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ use std::sync::Mutex;
use crate::imp::RouteManagerCommand;
use futures::{
channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
stream::StreamExt,
stream::StreamExt, Stream,
};
use ipnetwork::IpNetwork;
use jnix::{
jni::{objects::JObject, JNIEnv},
FromJava, JnixEnv,
};
use talpid_types::android::{AndroidContext, NetworkState};
use talpid_types::android::NetworkState;

/// Stub error type for routing errors on Android.
/// Errors that occur while setting up VpnService tunnel.
Expand Down Expand Up @@ -86,7 +87,7 @@ impl RouteManagerImpl {
match command {
RouteManagerCommand::NewChangeListener(tx) => {
// register a listener for new route updates
let _ = tx.send(self.listen());
self.listeners.push(tx);
}
RouteManagerCommand::Shutdown(tx) => {
tx.send(()).map_err(|()| Error::Send)?; // TODO: Surely we can do better than this
Expand All @@ -101,6 +102,8 @@ impl RouteManagerImpl {
Ok(())
}

// pub fn wait_for_routes(&mut self, routes: Vec<IpNetwork>) -> impl Stream<Item = bool> { }

fn notify_change_listeners(&mut self, message: RoutesUpdate) {
self.listeners
.retain(|listener| listener.unbounded_send(message.clone()).is_ok());
Expand Down
45 changes: 30 additions & 15 deletions talpid-routing/src/unix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use futures::channel::{
mpsc::{self, UnboundedSender},
oneshot,
};
use ipnetwork::IpNetwork;
use std::{collections::HashSet, sync::Arc};

#[cfg(any(target_os = "linux", target_os = "macos"))]
Expand Down Expand Up @@ -97,7 +98,7 @@ pub(crate) enum RouteManagerCommand {
#[cfg(target_os = "android")]
#[derive(Debug)]
pub(crate) enum RouteManagerCommand {
NewChangeListener(oneshot::Sender<mpsc::UnboundedReceiver<imp::RoutesUpdate>>),
NewChangeListener(mpsc::UnboundedSender<imp::RoutesUpdate>),
AddRoutes(
HashSet<RequiredRoute>,
oneshot::Sender<Result<(), PlatformError>>,
Expand Down Expand Up @@ -212,16 +213,30 @@ impl RouteManagerHandle {
.map_err(|_| Error::RouteManagerDown)
}

// TODO: Do we even want this?
// #[cfg(target_os = "android")]
// pub async fn wait_for_routes(&self, routes: Vec<Ipnetwork>) -> Result<imp::RouteResult, Error> {
// let (result_tx, result_rx) = oneshot::channel();
// let msg = RouteManagerCommand::WaitForRoutes(result_tx, routes);
// self.tx
// .unbounded_send(msg)
// .map_err(|_| Error::RouteManagerDown);
// result_rx.await.map_err(|_| Error::ManagerChannelDown)
// }
#[cfg(target_os = "android")]
#[allow(missing_docs)]
pub fn wait_for_routes(&self, routes: Vec<IpNetwork>) -> impl futures::Stream<Item = bool> {
use futures::StreamExt;

let (stream_tx, stream_rx) = mpsc::unbounded();
self.tx
.unbounded_send(RouteManagerCommand::NewChangeListener(stream_tx))
.map_err(|_| Error::RouteManagerDown).unwrap(); //?;

stream_rx.map(move |change| {
use std::collections::HashSet;

// Wait for NetworkState updates to check if it includes all necessary routes
let xs: HashSet<IpNetwork> = HashSet::from_iter(routes.iter().copied());
match change {
imp::RoutesUpdate::NewNetworkState(network_state) => network_state.routes.map(|new_routes| {
//new_routes.contains(routes)
let ys = HashSet::from_iter(new_routes.iter().map(|route_info| IpNetwork::new(route_info.destination.address, route_info.destination.prefix_length as u8 ).unwrap()));
xs.is_subset(&ys)
}).unwrap_or(false)
}
})
}

/// Listen for non-tunnel default route changes.
#[cfg(target_os = "macos")]
Expand Down Expand Up @@ -318,13 +333,13 @@ impl RouteManagerHandle {

// TODO: We might not want this
/// Listen for route changes.
// #[cfg(target_os = "android")]
// pub async fn change_listener(&self) -> Result<impl Stream<Item = imp::RoutesUpdate>, Error> {
// let (response_tx, response_rx) = oneshot::channel();
//#[cfg(target_os = "android")]
//pub async fn change_listener(&self) -> Result<impl futures::Stream<Item = imp::RoutesUpdate>, Error> {
// let (response_tx, response_rx) = mpsc::unbounded();
// self.tx
// .unbounded_send(RouteManagerCommand::NewChangeListener(response_tx))
// .map_err(|_| Error::RouteManagerDown)?;
// response_rx.await.map_err(|_| Error::ManagerChannelDown)
// response_rx.map_err(|_| Error::ManagerChannelDown)
// }

/// Listen for route changes.
Expand Down
6 changes: 3 additions & 3 deletions talpid-types/src/android/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ pub struct AndroidContext {
}

/// A Java-compatible variant of [IpNetwork]
#[derive(Clone, Debug, Eq, PartialEq, IntoJava, FromJava)]
#[derive(Clone, Debug, Eq, PartialEq, Hash, IntoJava, FromJava)]
#[jnix(package = "net.mullvad.talpid.model")]
pub struct InetNetwork {
pub address: IpAddr,
pub prefix_length: i16,
}

#[derive(Clone, Debug, Eq, PartialEq, IntoJava, FromJava)]
#[derive(Clone, Debug, Eq, PartialEq, Hash, IntoJava, FromJava)]
#[jnix(package = "net.mullvad.talpid.model")]
pub struct RouteInfo {
pub destination: InetNetwork,
pub gateway: Option<InetAddress>,
pub interface: Option<String>,
}

#[derive(Clone, Debug, Eq, PartialEq, IntoJava, FromJava)]
#[derive(Clone, Debug, Eq, PartialEq, Hash, IntoJava, FromJava)]
#[jnix(package = "net.mullvad.talpid.model")]
pub struct NetworkState {
pub network_handle: i64,
Expand Down
11 changes: 8 additions & 3 deletions talpid-wireguard/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,11 @@ impl WireguardMonitor {
log_path: Option<&Path>,
args: TunnelArgs<'_>,
) -> Result<WireguardMonitor> {
use futures::StreamExt;

let desired_mtu = get_desired_mtu(params);
let mut config =
Config::from_parameters(params, desired_mtu).map_err(Error::WireguardConfigError)?;

let (close_obfs_sender, close_obfs_listener) = sync_mpsc::channel();
// Start obfuscation server and patch the WireGuard config to point the endpoint to it.
let obfuscator = args
Expand Down Expand Up @@ -431,6 +432,9 @@ impl WireguardMonitor {
)
.map_err(Error::ConnectivityMonitorError)?;

let route_to_wait_for = args.tun_provider.lock().unwrap().config_mut().routes.clone();
let route_updates = args.route_manager.wait_for_routes(route_to_wait_for.clone());

let tunnel = args.runtime.block_on(Self::open_wireguard_go_tunnel(
&config,
log_path,
Expand Down Expand Up @@ -473,8 +477,9 @@ impl WireguardMonitor {
.on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic))
.await;

// TODO: We might not want this
// let _ = route_change_listener.next().await?;
// Wait for routes to come up
// TODO: Time out (eventually) and return proper error
route_updates.any(|routes_are_correct| async move { routes_are_correct }).await;

if should_negotiate_ephemeral_peer {
let ephemeral_obfs_sender = close_obfs_sender.clone();
Expand Down

0 comments on commit 00e7c49

Please sign in to comment.