Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(server): do not block connection handler task during client auth #35

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/auth/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{net::IpAddr, sync::Arc, time::Duration};

use crate::constants::{AUTH_FAILED_MESSAGE, AUTH_MESSAGE_BUFFER_SIZE, AUTH_TIMEOUT_MESSAGE};
use crate::server::address_pool::AddressPool;
use anyhow::{anyhow, Context, Result};
use bytes::BytesMut;
use ipnet::IpNet;
Expand All @@ -21,28 +22,28 @@ pub enum AuthServerMessage {
/// Represents an authentication server handling initial authentication and session management.
pub struct AuthServer<'a> {
user_database: &'a UserDatabase,
client_address: IpNet,
address_pool: &'a AddressPool,
connection: Arc<Connection>,
auth_timeout: Duration,
}

impl<'a> AuthServer<'a> {
pub fn new(
user_database: &'a UserDatabase,
address_pool: &'a AddressPool,
connection: Arc<Connection>,
client_address: IpNet,
auth_timeout: Duration,
) -> Self {
Self {
user_database,
client_address,
address_pool,
connection,
auth_timeout,
}
}

/// Handles authentication for a client.
pub async fn handle_authentication(&self) -> Result<String> {
pub async fn handle_authentication(&self) -> Result<(String, IpNet)> {
let (send_stream, mut recv_stream) =
match timeout(self.auth_timeout, self.connection.accept_bi()).await {
Ok(Ok(streams)) => streams,
Expand Down Expand Up @@ -77,7 +78,7 @@ impl<'a> AuthServer<'a> {
mut send_stream: SendStream,
username: String,
password: String,
) -> Result<String> {
) -> Result<(String, IpNet)> {
let auth_result = self.user_database.authenticate(&username, password).await;

if auth_result.is_err() {
Expand All @@ -87,14 +88,17 @@ impl<'a> AuthServer<'a> {
.expect_err("Handle failure always returns an error"));
}

let response = AuthServerMessage::Authenticated(
self.client_address.addr(),
self.client_address.netmask(),
);
let client_address = self
.address_pool
.next_available_address()
.ok_or_else(|| anyhow!("Could not find an available address for client"))?;

let response =
AuthServerMessage::Authenticated(client_address.addr(), client_address.netmask());

Self::send_message(&mut send_stream, response).await?;

Ok(username)
Ok((username, client_address))
}

/// Handles a failure during authentication.
Expand Down
79 changes: 44 additions & 35 deletions src/server/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use bytes::Bytes;
use ipnet::IpNet;

use crate::server::address_pool::AddressPool;
use quinn::Connection;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -14,8 +15,8 @@
#[derive(Clone)]
pub struct QuincyConnection {
connection: Arc<Connection>,
pub username: Option<String>,
pub client_address: IpNet,
username: Option<String>,
client_address: Option<IpNet>,
ingress_queue: UnboundedSender<Bytes>,
}

Expand All @@ -25,55 +26,51 @@
/// ### Arguments
/// - `connection` - the underlying QUIC connection
/// - `tun_queue` - the queue to send data to the TUN interface
/// - `user_database` - the user database
/// - `auth_timeout` - the authentication timeout
/// - `client_address` - the assigned client address
pub fn new(
connection: Arc<Connection>,
client_address: IpNet,
tun_queue: UnboundedSender<Bytes>,
) -> Self {
pub fn new(connection: Connection, tun_queue: UnboundedSender<Bytes>) -> Self {
Self {
connection,
connection: Arc::new(connection),
username: None,
client_address,
client_address: None,
ingress_queue: tun_queue,
}
}

/// Attempts to authenticate the client.
pub async fn authenticate(
&mut self,
mut self,
user_database: &UserDatabase,
address_pool: &AddressPool,
connection_timeout: Duration,
) -> Result<()> {
) -> Result<Self> {
let auth_server = AuthServer::new(
user_database,
address_pool,
self.connection.clone(),
self.client_address,
connection_timeout,
);

let username = auth_server.handle_authentication().await?;
let (username, client_address) = auth_server.handle_authentication().await?;

info!(
"Connection established: user = {}, client address = {}, remote address = {}",
username,
self.client_address.addr(),
client_address.addr(),

Check warning on line 57 in src/server/connection.rs

View check run for this annotation

Codecov / codecov/patch

src/server/connection.rs#L57

Added line #L57 was not covered by tests
self.connection.remote_address().ip(),
);

self.username = Some(username);
self.client_address = Some(client_address);

Ok(())
Ok(self)
}

/// Starts the tasks for this instance of Quincy connection.
pub async fn run(self, egress_queue: UnboundedReceiver<Bytes>) -> (Self, Error) {
if self.username.is_none() {
let client_address = self.client_address.addr();
let client_address = self.connection.remote_address();

Check warning on line 70 in src/server/connection.rs

View check run for this annotation

Codecov / codecov/patch

src/server/connection.rs#L70

Added line #L70 was not covered by tests
return (
self,
anyhow!("Client '{client_address}' is not authenticated"),
anyhow!("Client '{}' is not authenticated", client_address.ip()),

Check warning on line 73 in src/server/connection.rs

View check run for this annotation

Codecov / codecov/patch

src/server/connection.rs#L73

Added line #L73 was not covered by tests
);
}

Expand All @@ -87,12 +84,16 @@
outgoing_data_err = outgoing_data_task => outgoing_data_err,
incoming_data_err = incoming_data_task => incoming_data_err,
}
.expect("Joining tasks never fails")
.expect_err("Connection tasks always return an error");
.expect("joining tasks never fails")
.expect_err("connection tasks always return an error");

Check warning on line 88 in src/server/connection.rs

View check run for this annotation

Codecov / codecov/patch

src/server/connection.rs#L87-L88

Added lines #L87 - L88 were not covered by tests

(self, err)
}

/// Processes outgoing data and sends it to the QUIC connection.
///
/// ### Arguments
/// - `egress_queue` - the queue to receive data from the TUN interface
async fn process_outgoing_data(
self: Arc<Self>,
mut egress_queue: UnboundedReceiver<Bytes>,
Expand All @@ -101,14 +102,12 @@
let data = egress_queue
.recv()
.await
.ok_or_else(|| anyhow!("Egress queue has been closed"))?;
.ok_or(anyhow!("Egress queue has been closed"))?;

let max_datagram_size = self.connection.max_datagram_size().ok_or_else(|| {
anyhow!(
"Client {} failed to provide maximum datagram size",
self.connection.remote_address().ip()
)
})?;
let max_datagram_size = self.connection.max_datagram_size().ok_or(anyhow!(
"Client {} failed to provide maximum datagram size",
self.client_address()?.addr()
))?;

Check warning on line 110 in src/server/connection.rs

View check run for this annotation

Codecov / codecov/patch

src/server/connection.rs#L110

Added line #L110 was not covered by tests

debug!("Maximum QUIC datagram size: {max_datagram_size}");

Expand All @@ -124,29 +123,39 @@
debug!(
"Sending {} bytes to {:?}",
data.len(),
self.client_address.addr()
self.client_address()?.addr()

Check warning on line 126 in src/server/connection.rs

View check run for this annotation

Codecov / codecov/patch

src/server/connection.rs#L126

Added line #L126 was not covered by tests
);

self.connection.send_datagram(data)?;
}
}

/// Processes incoming data and sends it to the TUN interface queue.
///
/// ### Arguments
/// - `connection` - a reference to the underlying QUIC connection
/// - `tun_queue` - a sender of an unbounded queue used by the tunnel worker to receive data
async fn process_incoming_data(self: Arc<Self>) -> Result<()> {
loop {
let data = self.connection.read_datagram().await?;

debug!(
"Received {} bytes from {:?}",
data.len(),
self.client_address.addr()
self.client_address()?.addr()

Check warning on line 141 in src/server/connection.rs

View check run for this annotation

Codecov / codecov/patch

src/server/connection.rs#L141

Added line #L141 was not covered by tests
);

self.ingress_queue.send(data)?;
}
}

/// Returns the username associated with this connection.
pub fn username(&self) -> Result<&str> {
self.username
.as_deref()
.ok_or(anyhow!("Connection is unauthenticated"))
}

Check warning on line 153 in src/server/connection.rs

View check run for this annotation

Codecov / codecov/patch

src/server/connection.rs#L149-L153

Added lines #L149 - L153 were not covered by tests

/// Returns the client address associated with this connection.
pub fn client_address(&self) -> Result<&IpNet> {
self.client_address
.as_ref()
.ok_or(anyhow!("Connection is unauthenticated"))
}
}
45 changes: 27 additions & 18 deletions src/server/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::config::{ConnectionConfig, TunnelConfig};
use crate::server::address_pool::AddressPool;
use crate::server::connection::QuincyConnection;
use crate::utils::socket::bind_socket;
use anyhow::{anyhow, Result};
use anyhow::Result;
use bytes::Bytes;
use dashmap::DashMap;
use etherparse::{IpHeader, PacketHeaders};
Expand Down Expand Up @@ -117,46 +117,55 @@ impl QuincyTunnel {
endpoint: Endpoint,
) -> Result<()> {
info!(
"Listening for incoming connections: {}",
"Starting connection handler: {}",
endpoint.local_addr().expect("Endpoint has a local address")
);

let mut authentication_tasks = FuturesUnordered::new();
let mut connection_tasks = FuturesUnordered::new();

loop {
tokio::select! {
// New connections
Some(handshake) = endpoint.accept() => {
debug!(
"Received incoming connection from '{}'",
handshake.remote_address().ip()
);

let client_tun_ip = self.address_pool
.next_available_address()
.ok_or_else(|| anyhow!("Could not find an available address for client"))?;
let quic_connection = handshake.await?;

let quic_connection = Arc::new(handshake.await?);
let (connection_sender, connection_receiver) = mpsc::unbounded_channel();

let mut connection = QuincyConnection::new(
quic_connection.clone(),
client_tun_ip,
let connection = QuincyConnection::new(
quic_connection,
ingress_queue.clone(),
);

if let Err(e) = connection.authenticate(&self.user_database, self.connection_config.timeout).await {
warn!("Failed to authenticate client {client_tun_ip}: {e}");
self.address_pool.release_address(&client_tun_ip.addr());
continue;
}
authentication_tasks.push(
connection.authenticate(&self.user_database, &self.address_pool, self.connection_config.timeout)
);
}

// Authentication tasks
Some(connection) = authentication_tasks.next() => {
let connection = match connection {
Ok(connection) => connection,
Err(e) => {
warn!("Failed to authenticate client: {e}");
continue;
}
};

let client_address = connection.client_address()?.addr();
let (connection_sender, connection_receiver) = mpsc::unbounded_channel();

connection_tasks.push(tokio::spawn(connection.run(connection_receiver)));
self.connection_queues.insert(client_tun_ip.addr(), connection_sender);
self.connection_queues.insert(client_address, connection_sender);
}

// Connection tasks
Some(connection) = connection_tasks.next() => {
let (connection, err) = connection?;
let client_address = &connection.client_address.addr();
let client_address = &connection.client_address()?.addr();

self.connection_queues.remove(client_address);
self.address_pool.release_address(client_address);
Expand Down
Loading