Skip to content

Commit

Permalink
wip: dynamic tls cert resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Apr 12, 2024
1 parent 60f3cd5 commit bf60fe6
Show file tree
Hide file tree
Showing 20 changed files with 643 additions and 300 deletions.
1 change: 1 addition & 0 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ version_check = "0.9.1"
tokio = { version = "1", features = ["macros", "io-std"] }
figment = { version = "0.10", features = ["test"] }
pretty_assertions = "1"
arc-swap = "1.7"
3 changes: 0 additions & 3 deletions core/lib/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ mod secret_key;
#[cfg(unix)]
pub use crate::shutdown::Sig;

#[cfg(unix)]
pub use crate::listener::unix::UdsConfig;

#[cfg(feature = "secrets")]
pub use secret_key::SecretKey;

Expand Down
10 changes: 7 additions & 3 deletions core/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ impl Error {
self.mark_handled();
match self.kind() {
ErrorKind::Bind(ref a, ref e) => {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
if let Some(e) = e.downcast_ref::<Self>() {
e.pretty_print();
} else {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
}
}

info_!("{}", e);
Expand Down
51 changes: 5 additions & 46 deletions core/lib/src/listener/bindable.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,11 @@
use std::io;
use futures::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite};

use crate::listener::{Listener, Endpoint};

pub trait Bindable: Sized {
type Listener: Listener + 'static;
use crate::listener::{Endpoint, Listener};

pub trait Bind<T>: Listener<Connection: AsyncRead + AsyncWrite> + 'static {
type Error: std::error::Error + Send + 'static;

async fn bind(self) -> Result<Self::Listener, Self::Error>;

/// The endpoint that `self` binds on.
fn bind_endpoint(&self) -> io::Result<Endpoint>;
}

impl<L: Listener + 'static> Bindable for L {
type Listener = L;

type Error = std::convert::Infallible;

async fn bind(self) -> Result<Self::Listener, Self::Error> {
Ok(self)
}

fn bind_endpoint(&self) -> io::Result<Endpoint> {
L::endpoint(self)
}
}

impl<A: Bindable, B: Bindable> Bindable for either::Either<A, B> {
type Listener = tokio_util::either::Either<A::Listener, B::Listener>;

type Error = either::Either<A::Error, B::Error>;

async fn bind(self) -> Result<Self::Listener, Self::Error> {
match self {
either::Either::Left(a) => a.bind()
.map_ok(tokio_util::either::Either::Left)
.map_err(either::Either::Left)
.await,
either::Either::Right(b) => b.bind()
.map_ok(tokio_util::either::Either::Right)
.map_err(either::Either::Right)
.await,
}
}
async fn bind(to: T) -> Result<Self, Self::Error>;

fn bind_endpoint(&self) -> io::Result<Endpoint> {
either::for_both!(self, a => a.bind_endpoint())
}
fn bind_endpoint(to: &T) -> Option<Endpoint>;
}
119 changes: 71 additions & 48 deletions core/lib/src/listener/default.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,87 @@
use either::Either;
use serde::Deserialize;
use tokio_util::either::{Either, Either::{Left, Right}};
use futures::TryFutureExt;

use crate::listener::{Bindable, Endpoint};
use crate::error::{Error, ErrorKind};
use crate::error::ErrorKind;
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Endpoint, tcp::TcpListener};

#[derive(serde::Deserialize)]
pub struct DefaultListener {
#[cfg(unix)] use crate::listener::unix::UnixListener;
#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};

mod private {
use super::{Either, TcpListener};

#[cfg(feature = "tls")] pub type TlsListener<T> = super::TlsListener<T>;
#[cfg(not(feature = "tls"))] pub type TlsListener<T> = T;
#[cfg(unix)] pub type UnixListener = super::UnixListener;
#[cfg(not(unix))] pub type UnixListener = super::TcpListener;

pub type Listener = Either<
Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
Either<TcpListener, UnixListener>,
>;
}

#[derive(Deserialize)]
struct Config {
#[serde(default)]
pub address: Endpoint,
pub port: Option<u16>,
pub reuse: Option<bool>,
address: Endpoint,
#[cfg(feature = "tls")]
pub tls: Option<crate::tls::TlsConfig>,
tls: Option<TlsConfig>,
}

#[cfg(not(unix))] type BaseBindable = Either<std::net::SocketAddr, std::net::SocketAddr>;
#[cfg(unix)] type BaseBindable = Either<std::net::SocketAddr, super::unix::UdsConfig>;
pub type DefaultListener = private::Listener;

#[cfg(not(feature = "tls"))] type TlsBindable<T> = Either<T, T>;
#[cfg(feature = "tls")] type TlsBindable<T> = Either<super::tls::TlsBindable<T>, T>;
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
type Error = crate::Error;

impl DefaultListener {
pub(crate) fn base_bindable(&self) -> Result<BaseBindable, crate::Error> {
match &self.address {
Endpoint::Tcp(mut address) => {
if let Some(port) = self.port {
address.set_port(port);
}
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract()?;

Ok(BaseBindable::Left(address))
},
#[cfg(unix)]
Endpoint::Unix(path) => {
let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, };
Ok(BaseBindable::Right(uds))
},
#[cfg(not(unix))]
e@Endpoint::Unix(_) => {
let msg = "Unix domain sockets unavailable on non-unix platforms.";
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(e.clone()), boxed)))
},
other => {
let msg = format!("unsupported default listener address: {other}");
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(other.clone()), boxed)))
let endpoint = config.address;
match endpoint {
#[cfg(feature = "tls")]
Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(Left(Left(listener)))
}
}
}
Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

pub(crate) fn tls_bindable<T>(&self, inner: T) -> TlsBindable<T> {
#[cfg(feature = "tls")]
if let Some(tls) = self.tls.clone() {
return TlsBindable::Left(super::tls::TlsBindable { inner, tls });
}
Ok(Right(Left(listener)))
}
#[cfg(all(unix, feature = "tls"))]
Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(Left(Right(listener)))
}
#[cfg(unix)]
Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

TlsBindable::Right(inner)
Ok(Right(Right(listener)))
}
_ => {
let msg = format!("unsupported bind endpoint: {endpoint}");
let error = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(ErrorKind::Bind(Some(endpoint), error).into())
}
}
}

pub fn bindable(&self) -> Result<impl Bindable, crate::Error> {
self.base_bindable()
.map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b)))
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Option<Endpoint> {
let endpoint: Option<Endpoint> = rocket.figment().extract_inner("endpoint").ok()?;
endpoint
}
}
2 changes: 1 addition & 1 deletion core/lib/src/listener/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tokio_util::either::Either;

use crate::listener::{Connection, Endpoint};

pub trait Listener: Send + Sync {
pub trait Listener: Sized + Send + Sync {
type Accept: Send;

type Connection: Connection;
Expand Down
3 changes: 0 additions & 3 deletions core/lib/src/listener/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ mod default;
#[cfg(unix)]
#[cfg_attr(nightly, doc(cfg(unix)))]
pub mod unix;
#[cfg(feature = "tls")]
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
pub mod tls;
pub mod tcp;
#[cfg(feature = "http3-preview")]
pub mod quic;
Expand Down
63 changes: 55 additions & 8 deletions core/lib/src/listener/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,68 @@
use std::io;
use std::net::SocketAddr;

use either::Either;
use serde::Deserialize;

#[doc(inline)]
pub use tokio::net::{TcpListener, TcpStream};

use crate::listener::{Listener, Bindable, Connection, Endpoint};
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Connection, Endpoint, Listener};

#[derive(Deserialize)]
pub struct Config {
#[serde(default)]
address: Endpoint,
port: Option<u16>,
}

impl Bind<SocketAddr> for TcpListener {
type Error = std::io::Error;

async fn bind(addr: SocketAddr) -> Result<Self, Self::Error> {
Self::bind(addr).await
}

fn bind_endpoint(addr: &SocketAddr) -> Option<Endpoint> {
Some(Endpoint::Tcp(*addr))
}
}

impl Bind<Config> for TcpListener {
type Error = Either<figment::Error, io::Error>;

async fn bind(config: Config) -> Result<Self, Self::Error> {
let Some(Endpoint::Tcp(addr)) = Self::bind_endpoint(&config) else {
let msg = format!("invalid tcp endpoint: {}", config.address);
let err = figment::Error::from(msg).with_path("address");
return Err(Either::Left(err));
};

Self::bind(addr).await.map_err(Either::Right)
}

fn bind_endpoint(config: &Config) -> Option<Endpoint> {
if let (Some(mut addr), Some(port)) = (config.address.tcp(), config.port) {
addr.set_port(port);
return Some(Endpoint::Tcp(addr));
}

impl Bindable for std::net::SocketAddr {
type Listener = TcpListener;
Some(config.address.clone())
}
}

type Error = io::Error;
impl<'r> Bind<&'r Rocket<Ignite>> for TcpListener {
type Error = Either<figment::Error, io::Error>;

async fn bind(self) -> Result<Self::Listener, Self::Error> {
TcpListener::bind(self).await
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract().map_err(Either::Left)?;
<Self as Bind<_>>::bind(config).await
}

fn bind_endpoint(&self) -> io::Result<Endpoint> {
Ok(Endpoint::Tcp(*self))
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Option<Endpoint> {
let config: Config = rocket.figment().extract().ok()?;
<Self as Bind<_>>::bind_endpoint(&config)
}
}

Expand Down
Loading

0 comments on commit bf60fe6

Please sign in to comment.