Skip to content

Commit

Permalink
fix(server): do not block connection handler task during client auth (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
M0dEx authored Jan 8, 2024
1 parent 13e7f48 commit 42f05c0
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 63 deletions.
24 changes: 14 additions & 10 deletions src/auth/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{net::IpAddr, sync::Arc, time::Duration};

use crate::constants::{AUTH_FAILED_MESSAGE, AUTH_MESSAGE_BUFFER_SIZE, AUTH_TIMEOUT_MESSAGE};
use crate::server::address_pool::AddressPool;
use anyhow::{anyhow, Context, Result};
use bytes::BytesMut;
use ipnet::IpNet;
Expand All @@ -21,28 +22,28 @@ pub enum AuthServerMessage {
/// Represents an authentication server handling initial authentication and session management.
pub struct AuthServer<'a> {
user_database: &'a UserDatabase,
client_address: IpNet,
address_pool: &'a AddressPool,
connection: Arc<Connection>,
auth_timeout: Duration,
}

impl<'a> AuthServer<'a> {
pub fn new(
user_database: &'a UserDatabase,
address_pool: &'a AddressPool,
connection: Arc<Connection>,
client_address: IpNet,
auth_timeout: Duration,
) -> Self {
Self {
user_database,
client_address,
address_pool,
connection,
auth_timeout,
}
}

/// Handles authentication for a client.
pub async fn handle_authentication(&self) -> Result<String> {
pub async fn handle_authentication(&self) -> Result<(String, IpNet)> {
let (send_stream, mut recv_stream) =
match timeout(self.auth_timeout, self.connection.accept_bi()).await {
Ok(Ok(streams)) => streams,
Expand Down Expand Up @@ -77,7 +78,7 @@ impl<'a> AuthServer<'a> {
mut send_stream: SendStream,
username: String,
password: String,
) -> Result<String> {
) -> Result<(String, IpNet)> {
let auth_result = self.user_database.authenticate(&username, password).await;

if auth_result.is_err() {
Expand All @@ -87,14 +88,17 @@ impl<'a> AuthServer<'a> {
.expect_err("Handle failure always returns an error"));
}

let response = AuthServerMessage::Authenticated(
self.client_address.addr(),
self.client_address.netmask(),
);
let client_address = self
.address_pool
.next_available_address()
.ok_or_else(|| anyhow!("Could not find an available address for client"))?;

let response =
AuthServerMessage::Authenticated(client_address.addr(), client_address.netmask());

Self::send_message(&mut send_stream, response).await?;

Ok(username)
Ok((username, client_address))
}

/// Handles a failure during authentication.
Expand Down
79 changes: 44 additions & 35 deletions src/server/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use anyhow::{anyhow, Error, Result};
use bytes::Bytes;
use ipnet::IpNet;

use crate::server::address_pool::AddressPool;
use quinn::Connection;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -14,8 +15,8 @@ use tracing::{debug, info, warn};
#[derive(Clone)]
pub struct QuincyConnection {
connection: Arc<Connection>,
pub username: Option<String>,
pub client_address: IpNet,
username: Option<String>,
client_address: Option<IpNet>,
ingress_queue: UnboundedSender<Bytes>,
}

Expand All @@ -25,55 +26,51 @@ impl QuincyConnection {
/// ### Arguments
/// - `connection` - the underlying QUIC connection
/// - `tun_queue` - the queue to send data to the TUN interface
/// - `user_database` - the user database
/// - `auth_timeout` - the authentication timeout
/// - `client_address` - the assigned client address
pub fn new(
connection: Arc<Connection>,
client_address: IpNet,
tun_queue: UnboundedSender<Bytes>,
) -> Self {
pub fn new(connection: Connection, tun_queue: UnboundedSender<Bytes>) -> Self {
Self {
connection,
connection: Arc::new(connection),
username: None,
client_address,
client_address: None,
ingress_queue: tun_queue,
}
}

/// Attempts to authenticate the client.
pub async fn authenticate(
&mut self,
mut self,
user_database: &UserDatabase,
address_pool: &AddressPool,
connection_timeout: Duration,
) -> Result<()> {
) -> Result<Self> {
let auth_server = AuthServer::new(
user_database,
address_pool,
self.connection.clone(),
self.client_address,
connection_timeout,
);

let username = auth_server.handle_authentication().await?;
let (username, client_address) = auth_server.handle_authentication().await?;

info!(
"Connection established: user = {}, client address = {}, remote address = {}",
username,
self.client_address.addr(),
client_address.addr(),
self.connection.remote_address().ip(),
);

self.username = Some(username);
self.client_address = Some(client_address);

Ok(())
Ok(self)
}

/// Starts the tasks for this instance of Quincy connection.
pub async fn run(self, egress_queue: UnboundedReceiver<Bytes>) -> (Self, Error) {
if self.username.is_none() {
let client_address = self.client_address.addr();
let client_address = self.connection.remote_address();
return (
self,
anyhow!("Client '{client_address}' is not authenticated"),
anyhow!("Client '{}' is not authenticated", client_address.ip()),
);
}

Expand All @@ -87,12 +84,16 @@ impl QuincyConnection {
outgoing_data_err = outgoing_data_task => outgoing_data_err,
incoming_data_err = incoming_data_task => incoming_data_err,
}
.expect("Joining tasks never fails")
.expect_err("Connection tasks always return an error");
.expect("joining tasks never fails")
.expect_err("connection tasks always return an error");

(self, err)
}

/// Processes outgoing data and sends it to the QUIC connection.
///
/// ### Arguments
/// - `egress_queue` - the queue to receive data from the TUN interface
async fn process_outgoing_data(
self: Arc<Self>,
mut egress_queue: UnboundedReceiver<Bytes>,
Expand All @@ -101,14 +102,12 @@ impl QuincyConnection {
let data = egress_queue
.recv()
.await
.ok_or_else(|| anyhow!("Egress queue has been closed"))?;
.ok_or(anyhow!("Egress queue has been closed"))?;

let max_datagram_size = self.connection.max_datagram_size().ok_or_else(|| {
anyhow!(
"Client {} failed to provide maximum datagram size",
self.connection.remote_address().ip()
)
})?;
let max_datagram_size = self.connection.max_datagram_size().ok_or(anyhow!(
"Client {} failed to provide maximum datagram size",
self.client_address()?.addr()
))?;

debug!("Maximum QUIC datagram size: {max_datagram_size}");

Expand All @@ -124,29 +123,39 @@ impl QuincyConnection {
debug!(
"Sending {} bytes to {:?}",
data.len(),
self.client_address.addr()
self.client_address()?.addr()
);

self.connection.send_datagram(data)?;
}
}

/// Processes incoming data and sends it to the TUN interface queue.
///
/// ### Arguments
/// - `connection` - a reference to the underlying QUIC connection
/// - `tun_queue` - a sender of an unbounded queue used by the tunnel worker to receive data
async fn process_incoming_data(self: Arc<Self>) -> Result<()> {
loop {
let data = self.connection.read_datagram().await?;

debug!(
"Received {} bytes from {:?}",
data.len(),
self.client_address.addr()
self.client_address()?.addr()
);

self.ingress_queue.send(data)?;
}
}

/// Returns the username associated with this connection.
pub fn username(&self) -> Result<&str> {
self.username
.as_deref()
.ok_or(anyhow!("Connection is unauthenticated"))
}

/// Returns the client address associated with this connection.
pub fn client_address(&self) -> Result<&IpNet> {
self.client_address
.as_ref()
.ok_or(anyhow!("Connection is unauthenticated"))
}
}
45 changes: 27 additions & 18 deletions src/server/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::config::{ConnectionConfig, TunnelConfig};
use crate::server::address_pool::AddressPool;
use crate::server::connection::QuincyConnection;
use crate::utils::socket::bind_socket;
use anyhow::{anyhow, Result};
use anyhow::Result;
use bytes::Bytes;
use dashmap::DashMap;
use etherparse::{IpHeader, PacketHeaders};
Expand Down Expand Up @@ -117,46 +117,55 @@ impl QuincyTunnel {
endpoint: Endpoint,
) -> Result<()> {
info!(
"Listening for incoming connections: {}",
"Starting connection handler: {}",
endpoint.local_addr().expect("Endpoint has a local address")
);

let mut authentication_tasks = FuturesUnordered::new();
let mut connection_tasks = FuturesUnordered::new();

loop {
tokio::select! {
// New connections
Some(handshake) = endpoint.accept() => {
debug!(
"Received incoming connection from '{}'",
handshake.remote_address().ip()
);

let client_tun_ip = self.address_pool
.next_available_address()
.ok_or_else(|| anyhow!("Could not find an available address for client"))?;
let quic_connection = handshake.await?;

let quic_connection = Arc::new(handshake.await?);
let (connection_sender, connection_receiver) = mpsc::unbounded_channel();

let mut connection = QuincyConnection::new(
quic_connection.clone(),
client_tun_ip,
let connection = QuincyConnection::new(
quic_connection,
ingress_queue.clone(),
);

if let Err(e) = connection.authenticate(&self.user_database, self.connection_config.timeout).await {
warn!("Failed to authenticate client {client_tun_ip}: {e}");
self.address_pool.release_address(&client_tun_ip.addr());
continue;
}
authentication_tasks.push(
connection.authenticate(&self.user_database, &self.address_pool, self.connection_config.timeout)
);
}

// Authentication tasks
Some(connection) = authentication_tasks.next() => {
let connection = match connection {
Ok(connection) => connection,
Err(e) => {
warn!("Failed to authenticate client: {e}");
continue;
}
};

let client_address = connection.client_address()?.addr();
let (connection_sender, connection_receiver) = mpsc::unbounded_channel();

connection_tasks.push(tokio::spawn(connection.run(connection_receiver)));
self.connection_queues.insert(client_tun_ip.addr(), connection_sender);
self.connection_queues.insert(client_address, connection_sender);
}

// Connection tasks
Some(connection) = connection_tasks.next() => {
let (connection, err) = connection?;
let client_address = &connection.client_address.addr();
let client_address = &connection.client_address()?.addr();

self.connection_queues.remove(client_address);
self.address_pool.release_address(client_address);
Expand Down

0 comments on commit 42f05c0

Please sign in to comment.