diff --git a/api/tests/disconnect-body.rs b/api/tests/disconnect-body.rs index 4275963018..c1c994241f 100644 --- a/api/tests/disconnect-body.rs +++ b/api/tests/disconnect-body.rs @@ -100,20 +100,10 @@ async fn establish_server(handler: impl Handler) -> (ServerHandle, impl AsyncWri let handle = trillium_testing::config().with_port(0).spawn(handler); let info = handle.info().await; - let port = info.tcp_socket_addr().map_or_else( - || { - info.listener_description() - .split(":") - .nth(1) - .unwrap() - .parse() - .unwrap() - }, - |x| x.port(), - ); + let url = info.state::().unwrap(); let client = ArcedConnector::new(client_config()) - .connect(&format!("http://localhost:{port}").parse().unwrap()) + .connect(url) .await .unwrap(); (handle, client) diff --git a/async-std/src/runtime.rs b/async-std/src/runtime.rs index f79dee549e..b5872ea135 100644 --- a/async-std/src/runtime.rs +++ b/async-std/src/runtime.rs @@ -30,6 +30,14 @@ impl RuntimeTrait for AsyncStdRuntime { fn block_on(&self, fut: Fut) -> Fut::Output { async_std::task::block_on(fut) } + + #[cfg(unix)] + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + signal_hook_async_std::Signals::new(signals).unwrap() + } } impl AsyncStdRuntime { @@ -81,6 +89,6 @@ impl AsyncStdRuntime { impl From for Runtime { fn from(value: AsyncStdRuntime) -> Self { - Runtime::new(value) + Runtime::from_trait_impl(value) } } diff --git a/async-std/src/server/tcp.rs b/async-std/src/server/tcp.rs index 3117a2c0ca..3acb3530d0 100644 --- a/async-std/src/server/tcp.rs +++ b/async-std/src/server/tcp.rs @@ -1,6 +1,6 @@ use crate::{AsyncStdRuntime, AsyncStdTransport}; use async_std::net::{TcpListener, TcpStream}; -use std::{env, io::Result}; +use std::io::Result; use trillium::Info; use trillium_server_common::Server; @@ -22,24 +22,18 @@ impl Server for AsyncStdServer { type Runtime = AsyncStdRuntime; type Transport = AsyncStdTransport; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); - async fn accept(&mut self) -> Result { self.0.accept().await.map(|(t, _)| t.into()) } - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { + fn from_tcp(tcp: std::net::TcpListener) -> Self { Self(tcp.into()) } - fn info(&self) -> Info { - self.0.local_addr().unwrap().into() + fn init(&self, info: &mut Info) { + if let Ok(socket_addr) = self.0.local_addr() { + info.insert_state(socket_addr); + } } fn runtime() -> Self::Runtime { diff --git a/async-std/src/server/unix.rs b/async-std/src/server/unix.rs index e30ef2d1ff..a3d064e0be 100644 --- a/async-std/src/server/unix.rs +++ b/async-std/src/server/unix.rs @@ -2,13 +2,12 @@ use crate::{AsyncStdRuntime, AsyncStdTransport}; use async_std::{ net::{TcpListener, TcpStream}, os::unix::net::{UnixListener, UnixStream}, - stream::StreamExt, }; -use std::{env, io::Result}; +use std::io::Result; use trillium::{log_error, Info}; use trillium_server_common::{ Binding::{self, *}, - Server, Swansong, + Server, }; /// Tcp/Unix Trillium server adapter for Async-Std @@ -41,31 +40,6 @@ impl Server for AsyncStdServer { type Runtime = AsyncStdRuntime; type Transport = Binding, AsyncStdTransport>; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); - - async fn handle_signals(swansong: Swansong) { - use signal_hook::consts::signal::*; - use signal_hook_async_std::Signals; - - let signals = Signals::new([SIGINT, SIGTERM, SIGQUIT]).unwrap(); - let mut signals = signals.fuse(); - while signals.next().await.is_some() { - if swansong.state().is_shutting_down() { - eprintln!("\nSecond interrupt, shutting down harshly"); - std::process::exit(1); - } else { - println!("\nShutting down gracefully.\nControl-C again to force."); - swansong.shut_down(); - } - } - } - async fn accept(&mut self) -> Result { match &self.0 { Tcp(t) => t @@ -80,18 +54,27 @@ impl Server for AsyncStdServer { } } - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { + fn from_tcp(tcp: std::net::TcpListener) -> Self { Self(Tcp(tcp.into())) } - fn listener_from_unix(tcp: std::os::unix::net::UnixListener) -> Self { + fn from_unix(tcp: std::os::unix::net::UnixListener) -> Self { Self(Unix(tcp.into())) } - fn info(&self) -> Info { + fn init(&self, info: &mut Info) { match &self.0 { - Tcp(t) => t.local_addr().unwrap().into(), - Unix(u) => u.local_addr().unwrap().into(), + Tcp(t) => { + if let Ok(socket_addr) = t.local_addr() { + info.insert_state(socket_addr); + } + } + + Unix(u) => { + if let Ok(socket_addr) = u.local_addr() { + info.insert_state(socket_addr); + } + } } } diff --git a/aws-lambda/src/lib.rs b/aws-lambda/src/lib.rs index 1f2276e08d..19448c6c03 100644 --- a/aws-lambda/src/lib.rs +++ b/aws-lambda/src/lib.rs @@ -19,7 +19,7 @@ use lamedh_runtime::{Context, Handler as AwsHandler}; use std::{future::Future, pin::Pin, sync::Arc}; use tokio::runtime; use trillium::{Conn, Handler}; -use trillium_http::{Conn as HttpConn, Synthetic}; +use trillium_http::{Conn as HttpConn, ServerConfig, Synthetic}; mod context; pub use context::LambdaConnExt; @@ -32,14 +32,19 @@ mod response; use response::{AlbMultiHeadersResponse, AlbResponse, LambdaResponse}; #[derive(Debug)] -struct HandlerWrapper(Arc); +struct HandlerWrapper(Arc, Arc); impl AwsHandler for HandlerWrapper { type Error = std::io::Error; type Fut = Pin> + Send + 'static>>; fn call(&mut self, request: LambdaRequest, context: Context) -> Self::Fut { - Box::pin(handler_fn(request, context, Arc::clone(&self.0))) + Box::pin(handler_fn( + request, + context, + Arc::clone(&self.0), + Arc::clone(&self.1), + )) } } @@ -52,17 +57,18 @@ async fn handler_fn( request: LambdaRequest, context: Context, handler: Arc, + server_config: Arc, ) -> std::io::Result { match request { LambdaRequest::Alb(request) => { - let mut conn = request.into_conn().await; + let mut conn = request.into_conn().await.with_server_config(server_config); conn.state_mut().insert(LambdaContext::new(context)); let conn = run_handler(conn, handler).await; Ok(LambdaResponse::Alb(AlbResponse::from_conn(conn).await)) } LambdaRequest::AlbMultiHeaders(request) => { - let mut conn = request.into_conn().await; + let mut conn = request.into_conn().await.with_server_config(server_config); conn.state_mut().insert(LambdaContext::new(context)); let conn = run_handler(conn, handler).await; Ok(LambdaResponse::AlbMultiHeaders( @@ -75,9 +81,9 @@ async fn handler_fn( /// /// This function will poll pending until the server shuts down. pub async fn run_async(mut handler: impl Handler) { - let mut info = "aws lambda".into(); + let mut info = ServerConfig::default().into(); handler.init(&mut info).await; - lamedh_runtime::run(HandlerWrapper(Arc::new(handler))) + lamedh_runtime::run(HandlerWrapper(Arc::new(handler), Arc::new(info.into()))) .await .unwrap() } diff --git a/client/src/client.rs b/client/src/client.rs index 36f657df32..49fb7625a8 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -2,7 +2,7 @@ use crate::{Conn, IntoUrl, Pool, USER_AGENT}; use std::{fmt::Debug, sync::Arc, time::Duration}; use trillium_http::{ transport::BoxedTransport, HeaderName, HeaderValues, Headers, KnownHeaderName, Method, - ReceivedBodyState, Version::Http1_1, + ReceivedBodyState, TypeSet, Version::Http1_1, }; use trillium_server_common::{ url::{Origin, Url}, @@ -76,9 +76,9 @@ impl Client { method!(patch, Patch); /// builds a new client from this `Connector` - pub fn new(config: impl Connector) -> Self { + pub fn new(connector: impl Connector) -> Self { Self { - config: ArcedConnector::new(config), + config: ArcedConnector::new(connector), pool: None, base: None, default_headers: Arc::new(default_request_headers()), @@ -167,6 +167,7 @@ impl Client { timeout: self.timeout, http_version: Http1_1, max_head_length: 8 * 1024, + state: TypeSet::new(), } } diff --git a/client/src/conn.rs b/client/src/conn.rs index 33f2d9405c..7406a79307 100644 --- a/client/src/conn.rs +++ b/client/src/conn.rs @@ -1,27 +1,20 @@ -use crate::{pool::PoolEntry, util::encoding, Pool}; +use crate::{util::encoding, Pool}; use encoding_rs::Encoding; -use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt}; -use memchr::memmem::Finder; -use size::{Base, Size}; -use std::{ - fmt::{self, Debug, Display, Formatter}, - future::{Future, IntoFuture}, - io::{ErrorKind, Write}, - ops::{Deref, DerefMut}, - pin::Pin, - time::Duration, -}; +use std::{net::SocketAddr, time::Duration}; use trillium_http::{ - transport::{BoxedTransport, Transport}, - Body, Error, HeaderName, HeaderValues, Headers, - KnownHeaderName::{Connection, ContentLength, Expect, Host, TransferEncoding}, - Method, ReceivedBody, ReceivedBodyState, Result, Status, Upgrade, Version, + transport::BoxedTransport, Body, Buffer, HeaderName, HeaderValues, Headers, Method, + ReceivedBody, ReceivedBodyState, Status, TypeSet, Version, }; use trillium_server_common::{ url::{Origin, Url}, - ArcedConnector, Connector, + ArcedConnector, Transport, }; +mod implementation; +mod unexpected_status_error; + +pub use unexpected_status_error::UnexpectedStatusError; + /// A wrapper error for [`trillium_http::Error`] or /// [`serde_json::Error`]. Only available when the `json` crate feature is /// enabled. @@ -30,7 +23,7 @@ use trillium_server_common::{ pub enum ClientSerdeError { /// A [`trillium_http::Error`] #[error(transparent)] - HttpError(#[from] Error), + HttpError(#[from] trillium_http::Error), /// A [`serde_json::Error`] #[error(transparent)] @@ -49,37 +42,19 @@ pub struct Conn { pub(crate) status: Option, pub(crate) request_body: Option, pub(crate) pool: Option>, - pub(crate) buffer: trillium_http::Buffer, + pub(crate) buffer: Buffer, pub(crate) response_body_state: ReceivedBodyState, pub(crate) config: ArcedConnector, pub(crate) headers_finalized: bool, pub(crate) timeout: Option, pub(crate) http_version: Version, pub(crate) max_head_length: usize, + pub(crate) state: TypeSet, } /// default http user-agent header pub const USER_AGENT: &str = concat!("trillium-client/", env!("CARGO_PKG_VERSION")); -impl Debug for Conn { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("Conn") - .field("url", &self.url) - .field("method", &self.method) - .field("request_headers", &self.request_headers) - .field("response_headers", &self.response_headers) - .field("status", &self.status) - .field("request_body", &self.request_body) - .field("pool", &self.pool) - .field("buffer", &String::from_utf8_lossy(&self.buffer)) - .field("response_body_state", &self.response_body_state) - .field("config", &self.config) - .field("http_version", &self.http_version) - .field("max_head_length", &self.max_head_length) - .finish() - } -} - impl Conn { /// borrow the request headers pub fn request_headers(&self) -> &Headers { @@ -316,8 +291,8 @@ impl Conn { /// retrieves the url for this conn. /// ``` /// use trillium_client::Client; - /// use trillium_testing::client_config; - /// let client = Client::from(client_config()); + /// let client = Client::from(trillium_testing::client_config()); + /// /// let conn = client.get("http://localhost:9080"); /// /// let url = conn.url(); //<- @@ -379,7 +354,7 @@ impl Conn { /// Attempt to deserialize the response body. Note that this consumes the body content. #[cfg(feature = "json")] - pub async fn response_json(&mut self) -> std::result::Result + pub async fn response_json(&mut self) -> Result where T: serde::de::DeserializeOwned, { @@ -387,19 +362,6 @@ impl Conn { Ok(serde_json::from_str(&body)?) } - pub(crate) fn response_content_length(&self) -> Option { - if self.status == Some(Status::NoContent) - || self.status == Some(Status::NotModified) - || self.method == Method::Head - { - Some(0) - } else { - self.response_headers - .get_str(ContentLength) - .and_then(|c| c.parse().ok()) - } - } - /// returns the status code for this conn. if the conn has not yet /// been sent, this will be None. /// @@ -442,7 +404,7 @@ impl Conn { /// Ok(()) /// }); /// ``` - pub fn success(self) -> std::result::Result { + pub fn success(self) -> Result { match self.status() { Some(status) if status.is_success() => Ok(self), _ => Err(self.into()), @@ -460,7 +422,7 @@ impl Conn { } /// attempts to retrieve the connected peer address - pub fn peer_addr(&self) -> Option { + pub fn peer_addr(&self) -> Option { self.transport .as_ref() .and_then(|t| t.peer_addr().ok().flatten()) @@ -489,507 +451,29 @@ impl Conn { self.http_version } - // --- everything below here is private --- - - fn finalize_headers(&mut self) -> Result<()> { - if self.headers_finalized { - return Ok(()); - } - - let host = self.url.host_str().ok_or(Error::UnexpectedUriFormat)?; - - self.request_headers.try_insert_with(Host, || { - self.url - .port() - .map_or_else(|| host.to_string(), |port| format!("{host}:{port}")) - }); - - if self.pool.is_none() { - self.request_headers.try_insert(Connection, "close"); - } - - match self.body_len() { - Some(0) => {} - Some(len) => { - self.request_headers.insert(Expect, "100-continue"); - self.request_headers.insert(ContentLength, len.to_string()); - } - None => { - self.request_headers.insert(Expect, "100-continue"); - self.request_headers.insert(TransferEncoding, "chunked"); - } - } - - self.headers_finalized = true; - Ok(()) - } - - fn body_len(&self) -> Option { - if let Some(ref body) = self.request_body { - body.len() - } else { - Some(0) - } - } - - async fn find_pool_candidate(&self, head: &[u8]) -> Result> { - let mut byte = [0]; - if let Some(pool) = &self.pool { - for mut candidate in pool.candidates(&self.url.origin()) { - if poll_once(candidate.read(&mut byte)).await.is_none() - && candidate.write_all(head).await.is_ok() - { - return Ok(Some(candidate)); - } - } - } - Ok(None) - } - - async fn connect_and_send_head(&mut self) -> Result<()> { - if self.transport.is_some() { - return Err(Error::Io(std::io::Error::new( - ErrorKind::AlreadyExists, - "conn already connected", - ))); - } - - let head = self.build_head().await?; - - let transport = match self.find_pool_candidate(&head).await? { - Some(transport) => { - log::debug!("reusing connection to {:?}", transport.peer_addr()?); - transport - } - - None => { - let mut transport = self.config.connect(&self.url).await?; - log::debug!("opened new connection to {:?}", transport.peer_addr()?); - transport.write_all(&head).await?; - transport - } - }; - - self.transport = Some(transport); - Ok(()) - } - - async fn build_head(&mut self) -> Result> { - let mut buf = Vec::with_capacity(128); - let url = &self.url; - let method = self.method; - write!(buf, "{method} ")?; - - if method == Method::Connect { - let host = url.host_str().ok_or(Error::UnexpectedUriFormat)?; - - let port = url - .port_or_known_default() - .ok_or(Error::UnexpectedUriFormat)?; - - write!(buf, "{host}:{port}")?; - } else { - write!(buf, "{}", url.path())?; - if let Some(query) = url.query() { - write!(buf, "?{query}")?; - } - } - - write!(buf, " {}\r\n", self.http_version)?; - - for (name, values) in &self.request_headers { - if !name.is_valid() { - return Err(Error::InvalidHeaderName); - } - - for value in values { - if !value.is_valid() { - return Err(Error::InvalidHeaderValue(name.to_owned())); - } - write!(buf, "{name}: ")?; - buf.extend_from_slice(value.as_ref()); - write!(buf, "\r\n")?; - } - } - - write!(buf, "\r\n")?; - log::trace!( - "{}", - std::str::from_utf8(&buf).unwrap().replace("\r\n", "\r\n> ") - ); - - Ok(buf) - } - - fn transport(&mut self) -> &mut BoxedTransport { - self.transport.as_mut().unwrap() - } - - async fn read_head(&mut self) -> Result { - let Self { - buffer, - transport: Some(transport), - .. - } = self - else { - return Err(Error::Closed); - }; - - let mut len = buffer.len(); - let mut search_start = 0; - let finder = Finder::new(b"\r\n\r\n"); - - if len > 0 { - if let Some(index) = finder.find(buffer) { - return Ok(index + 4); - } - search_start = len.saturating_sub(3); - } - - loop { - buffer.expand(); - let bytes = transport.read(&mut buffer[len..]).await?; - len += bytes; - - let search = finder.find(&buffer[search_start..len]); - - if let Some(index) = search { - buffer.truncate(len); - return Ok(search_start + index + 4); - } - - search_start = len.saturating_sub(3); - - if bytes == 0 { - if len == 0 { - return Err(Error::Closed); - } else { - return Err(Error::InvalidHead); - } - } - - if len >= self.max_head_length { - return Err(Error::HeadersTooLong); - } - } - } - - #[cfg(not(feature = "parse"))] - async fn parse_head(&mut self) -> Result<()> { - const MAX_HEADERS: usize = 128; - use crate::HeaderValue; - use std::str::FromStr; - - let head_offset = self.read_head().await?; - let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; - let mut httparse_res = httparse::Response::new(&mut headers); - let parse_result = - httparse_res - .parse(&self.buffer[..head_offset]) - .map_err(|e| match e { - httparse::Error::HeaderName => Error::InvalidHeaderName, - httparse::Error::HeaderValue => Error::InvalidHeaderValue("unknown".into()), - httparse::Error::Status => Error::InvalidStatus, - httparse::Error::TooManyHeaders => Error::HeadersTooLong, - httparse::Error::Version => Error::InvalidVersion, - _ => Error::InvalidHead, - })?; - - match parse_result { - httparse::Status::Complete(n) if n == head_offset => {} - _ => return Err(Error::InvalidHead), - } - - self.status = httparse_res.code.map(|code| code.try_into().unwrap()); - - for header in httparse_res.headers { - let header_name = HeaderName::from_str(header.name)?; - let header_value = HeaderValue::from(header.value.to_owned()); - self.response_headers.append(header_name, header_value); - } - - self.buffer.ignore_front(head_offset); - - self.validate_response_headers()?; - Ok(()) - } - - #[cfg(feature = "parse")] - async fn parse_head(&mut self) -> Result<()> { - use std::str; - - let head_offset = self.read_head().await?; - - let space = memchr::memchr(b' ', &self.buffer[..head_offset]).ok_or(Error::InvalidHead)?; - self.http_version = str::from_utf8(&self.buffer[..space]) - .map_err(|_| Error::InvalidHead)? - .parse() - .map_err(|_| Error::InvalidHead)?; - self.status = Some(str::from_utf8(&self.buffer[space + 1..space + 4])?.parse()?); - let end_of_first_line = 2 + Finder::new("\r\n") - .find(&self.buffer[..head_offset]) - .ok_or(Error::InvalidHead)?; - - self.response_headers - .extend_parse(&self.buffer[end_of_first_line..head_offset]) - .map_err(|_| Error::InvalidHead)?; - - self.buffer.ignore_front(head_offset); - - self.validate_response_headers()?; - Ok(()) - } - - async fn send_body_and_parse_head(&mut self) -> Result<()> { - if self - .request_headers - .eq_ignore_ascii_case(Expect, "100-continue") - { - log::trace!("Expecting 100-continue"); - self.parse_head().await?; - if self.status == Some(Status::Continue) { - self.status = None; - log::trace!("Received 100-continue, sending request body"); - } else { - self.request_body.take(); - log::trace!( - "Received a status code other than 100-continue, not sending request body" - ); - return Ok(()); - } - } - - self.send_body().await?; - loop { - self.parse_head().await?; - if self.status == Some(Status::Continue) { - self.status = None; - } else { - break; - } - } - - Ok(()) - } - - async fn send_body(&mut self) -> Result<()> { - if let Some(mut body) = self.request_body.take() { - io::copy(&mut body, self.transport()).await?; - } - Ok(()) - } - - fn validate_response_headers(&self) -> Result<()> { - let content_length = self.response_headers.has_header(ContentLength); - - let transfer_encoding_chunked = self - .response_headers - .eq_ignore_ascii_case(TransferEncoding, "chunked"); - - if content_length && transfer_encoding_chunked { - Err(Error::UnexpectedHeader(ContentLength.into())) - } else { - Ok(()) - } - } - - fn is_keep_alive(&self) -> bool { - self.response_headers - .eq_ignore_ascii_case(Connection, "keep-alive") - } - - async fn finish_reading_body(&mut self) { - if self.response_body_state != ReceivedBodyState::End { - let body = self.response_body(); - match body.drain().await { - Ok(drain) => log::debug!( - "drained {}", - Size::from_bytes(drain).format().with_base(Base::Base10) - ), - Err(e) => log::warn!("failed to drain body, {:?}", e), - } - } - } - - async fn exec(&mut self) -> Result<()> { - self.finalize_headers()?; - self.connect_and_send_head().await?; - self.send_body_and_parse_head().await?; - Ok(()) - } -} - -impl Drop for Conn { - fn drop(&mut self) { - if !self.is_keep_alive() { - return; - } - - let Some(transport) = self.transport.take() else { - return; - }; - let Ok(Some(peer_addr)) = transport.peer_addr() else { - return; - }; - let Some(pool) = self.pool.take() else { return }; - - let origin = self.url.origin(); - - if self.response_body_state == ReceivedBodyState::End { - log::trace!( - "response body has been read to completion, checking transport back into pool for \ - {}", - &peer_addr - ); - pool.insert(origin, PoolEntry::new(transport, None)); - } else { - let content_length = self.response_content_length(); - let buffer = std::mem::take(&mut self.buffer); - let response_body_state = self.response_body_state; - let encoding = encoding(&self.response_headers); - self.config.runtime().spawn(async move { - let mut response_body = ReceivedBody::new( - content_length, - buffer, - transport, - response_body_state, - None, - encoding, - ); - - match io::copy(&mut response_body, io::sink()).await { - Ok(bytes) => { - let transport = response_body.take_transport().unwrap(); - log::trace!( - "read {} bytes in order to recycle conn for {}", - bytes, - &peer_addr - ); - pool.insert(origin, PoolEntry::new(transport, None)); - } - - Err(ioerror) => log::error!("unable to recycle conn due to {}", ioerror), - }; - }); - } - } -} - -impl From for Body { - fn from(conn: Conn) -> Body { - let received_body: ReceivedBody<'static, _> = conn.into(); - received_body.into() - } -} - -impl From for ReceivedBody<'static, BoxedTransport> { - fn from(mut conn: Conn) -> Self { - let _ = conn.finalize_headers(); - let origin = conn.url.origin(); - - let on_completion = - conn.pool - .take() - .map(|pool| -> Box { - Box::new(move |transport| { - pool.insert(origin.clone(), PoolEntry::new(transport, None)); - }) - }); - - ReceivedBody::new( - conn.response_content_length(), - std::mem::take(&mut conn.buffer), - conn.transport.take().unwrap(), - conn.response_body_state, - on_completion, - conn.response_encoding(), - ) - } -} - -impl From for Upgrade { - fn from(mut conn: Conn) -> Self { - Upgrade::new( - std::mem::take(&mut conn.request_headers), - conn.url.path().to_string(), - conn.method, - conn.transport.take().unwrap(), - std::mem::take(&mut conn.buffer), - ) - } -} - -impl IntoFuture for Conn { - type IntoFuture = Pin + Send + 'static>>; - type Output = Result; - - fn into_future(mut self) -> Self::IntoFuture { - Box::pin(async move { - if let Some(duration) = self.timeout { - self.config - .runtime() - .timeout(duration, self.exec()) - .await - .ok_or(Error::TimedOut("Conn", duration))??; - } else { - self.exec().await?; - } - Ok(self) - }) + /// add state to the client conn and return self + pub fn with_state(mut self, state: T) -> Self { + self.insert_state(state); + self } -} - -impl<'conn> IntoFuture for &'conn mut Conn { - type IntoFuture = Pin + Send + 'conn>>; - type Output = Result<()>; - fn into_future(self) -> Self::IntoFuture { - Box::pin(async move { - self.exec().await?; - Ok(()) - }) + /// add state to the client conn, returning any previously set state of this type + pub fn insert_state(&mut self, state: T) -> Option { + self.state.insert(state) } -} -/// An unexpected http status code was received. Transform this back -/// into the conn with [`From::from`]/[`Into::into`]. -/// -/// Currently only returned by [`Conn::success`] -#[derive(Debug)] -pub struct UnexpectedStatusError(Box); -impl From for UnexpectedStatusError { - fn from(value: Conn) -> Self { - Self(Box::new(value)) + /// borrow state + pub fn state(&self) -> Option<&T> { + self.state.get() } -} -impl From for Conn { - fn from(value: UnexpectedStatusError) -> Self { - *value.0 + /// borrow state mutably + pub fn state_mut(&mut self) -> Option<&mut T> { + self.state.get_mut() } -} - -impl Deref for UnexpectedStatusError { - type Target = Conn; - fn deref(&self) -> &Self::Target { - &self.0 - } -} -impl DerefMut for UnexpectedStatusError { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl std::error::Error for UnexpectedStatusError {} -impl Display for UnexpectedStatusError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self.status() { - Some(status) => f.write_fmt(format_args!( - "expected a success (2xx) status code, but got {status}" - )), - None => f.write_str("expected a status code to be set, but none was"), - } + /// take state + pub fn take_state(&mut self) -> Option { + self.state.take() } } diff --git a/client/src/conn/implementation.rs b/client/src/conn/implementation.rs new file mode 100644 index 0000000000..f2294d83f2 --- /dev/null +++ b/client/src/conn/implementation.rs @@ -0,0 +1,518 @@ +use super::Conn; +use crate::{pool::PoolEntry, util::encoding}; +use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt}; +use memchr::memmem::Finder; +use size::{Base, Size}; +use std::{ + fmt::{self, Debug, Formatter}, + future::{Future, IntoFuture}, + io::{ErrorKind, Write}, + pin::Pin, +}; +use trillium_http::{ + transport::BoxedTransport, + Body, Error, + KnownHeaderName::{Connection, ContentLength, Expect, Host, TransferEncoding}, + Method, ReceivedBody, ReceivedBodyState, Result, Status, TypeSet, Upgrade, +}; +use trillium_server_common::{Connector, Transport}; + +impl Conn { + fn finalize_headers(&mut self) -> Result<()> { + if self.headers_finalized { + return Ok(()); + } + + let host = self.url.host_str().ok_or(Error::UnexpectedUriFormat)?; + + self.request_headers.try_insert_with(Host, || { + self.url + .port() + .map_or_else(|| host.to_string(), |port| format!("{host}:{port}")) + }); + + if self.pool.is_none() { + self.request_headers.try_insert(Connection, "close"); + } + + match self.body_len() { + Some(0) => {} + Some(len) => { + self.request_headers.insert(Expect, "100-continue"); + self.request_headers.insert(ContentLength, len); + } + None => { + self.request_headers.insert(Expect, "100-continue"); + self.request_headers.insert(TransferEncoding, "chunked"); + } + } + + self.headers_finalized = true; + Ok(()) + } + + fn body_len(&self) -> Option { + if let Some(ref body) = self.request_body { + body.len() + } else { + Some(0) + } + } + + async fn find_pool_candidate(&self, head: &[u8]) -> Result> { + let mut byte = [0]; + if let Some(pool) = &self.pool { + for mut candidate in pool.candidates(&self.url.origin()) { + if poll_once(candidate.read(&mut byte)).await.is_none() + && candidate.write_all(head).await.is_ok() + { + return Ok(Some(candidate)); + } + } + } + Ok(None) + } + + async fn connect_and_send_head(&mut self) -> Result<()> { + if self.transport.is_some() { + return Err(Error::Io(std::io::Error::new( + ErrorKind::AlreadyExists, + "conn already connected", + ))); + } + + let head = self.build_head().await?; + + let transport = match self.find_pool_candidate(&head).await? { + Some(transport) => { + log::debug!("reusing connection to {:?}", transport.peer_addr()?); + transport + } + + None => { + let mut transport = self.config.connect(&self.url).await?; + log::debug!("opened new connection to {:?}", transport.peer_addr()?); + transport.write_all(&head).await?; + transport + } + }; + + self.transport = Some(transport); + Ok(()) + } + + async fn build_head(&mut self) -> Result> { + let mut buf = Vec::with_capacity(128); + let url = &self.url; + let method = self.method; + write!(buf, "{method} ")?; + + if method == Method::Connect { + let host = url.host_str().ok_or(Error::UnexpectedUriFormat)?; + + let port = url + .port_or_known_default() + .ok_or(Error::UnexpectedUriFormat)?; + + write!(buf, "{host}:{port}")?; + } else { + write!(buf, "{}", url.path())?; + if let Some(query) = url.query() { + write!(buf, "?{query}")?; + } + } + + write!(buf, " HTTP/1.1\r\n")?; + + for (name, values) in &self.request_headers { + if !name.is_valid() { + return Err(Error::InvalidHeaderName); + } + + for value in values { + if !value.is_valid() { + return Err(Error::InvalidHeaderValue(name.to_owned())); + } + write!(buf, "{name}: ")?; + buf.extend_from_slice(value.as_ref()); + write!(buf, "\r\n")?; + } + } + + write!(buf, "\r\n")?; + log::trace!( + "{}", + std::str::from_utf8(&buf).unwrap().replace("\r\n", "\r\n> ") + ); + + Ok(buf) + } + + fn transport(&mut self) -> &mut BoxedTransport { + self.transport.as_mut().unwrap() + } + + async fn read_head(&mut self) -> Result { + let Self { + buffer, + transport: Some(transport), + .. + } = self + else { + return Err(Error::Closed); + }; + + let mut len = buffer.len(); + let mut search_start = 0; + let finder = Finder::new(b"\r\n\r\n"); + + if len > 0 { + if let Some(index) = finder.find(buffer) { + return Ok(index + 4); + } + search_start = len.saturating_sub(3); + } + + loop { + buffer.expand(); + let bytes = transport.read(&mut buffer[len..]).await?; + len += bytes; + + let search = finder.find(&buffer[search_start..len]); + + if let Some(index) = search { + buffer.truncate(len); + return Ok(search_start + index + 4); + } + + search_start = len.saturating_sub(3); + + if bytes == 0 { + if len == 0 { + return Err(Error::Closed); + } else { + return Err(Error::InvalidHead); + } + } + + if len >= self.max_head_length { + return Err(Error::HeadersTooLong); + } + } + } + + #[cfg(not(feature = "parse"))] + async fn parse_head(&mut self) -> Result<()> { + const MAX_HEADERS: usize = 128; + use crate::{HeaderName, HeaderValue}; + use std::str::FromStr; + + let head_offset = self.read_head().await?; + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut httparse_res = httparse::Response::new(&mut headers); + let parse_result = + httparse_res + .parse(&self.buffer[..head_offset]) + .map_err(|e| match e { + httparse::Error::HeaderName => Error::InvalidHeaderName, + httparse::Error::HeaderValue => Error::InvalidHeaderValue("unknown".into()), + httparse::Error::Status => Error::InvalidStatus, + httparse::Error::TooManyHeaders => Error::HeadersTooLong, + httparse::Error::Version => Error::InvalidVersion, + _ => Error::InvalidHead, + })?; + + match parse_result { + httparse::Status::Complete(n) if n == head_offset => {} + _ => return Err(Error::InvalidHead), + } + + self.status = httparse_res.code.map(|code| code.try_into().unwrap()); + + for header in httparse_res.headers { + let header_name = HeaderName::from_str(header.name)?; + let header_value = HeaderValue::from(header.value.to_owned()); + self.response_headers.append(header_name, header_value); + } + + self.buffer.ignore_front(head_offset); + + self.validate_response_headers()?; + Ok(()) + } + + #[cfg(feature = "parse")] + async fn parse_head(&mut self) -> Result<()> { + use std::str; + + let head_offset = self.read_head().await?; + + let space = memchr::memchr(b' ', &self.buffer[..head_offset]).ok_or(Error::InvalidHead)?; + self.http_version = str::from_utf8(&self.buffer[..space]) + .map_err(|_| Error::InvalidHead)? + .parse() + .map_err(|_| Error::InvalidHead)?; + self.status = Some(str::from_utf8(&self.buffer[space + 1..space + 4])?.parse()?); + let end_of_first_line = 2 + Finder::new("\r\n") + .find(&self.buffer[..head_offset]) + .ok_or(Error::InvalidHead)?; + + self.response_headers + .extend_parse(&self.buffer[end_of_first_line..head_offset]) + .map_err(|_| Error::InvalidHead)?; + + self.buffer.ignore_front(head_offset); + + self.validate_response_headers()?; + Ok(()) + } + + async fn send_body_and_parse_head(&mut self) -> Result<()> { + if self + .request_headers + .eq_ignore_ascii_case(Expect, "100-continue") + { + log::trace!("Expecting 100-continue"); + self.parse_head().await?; + if self.status == Some(Status::Continue) { + self.status = None; + log::trace!("Received 100-continue, sending request body"); + } else { + self.request_body.take(); + log::trace!( + "Received a status code other than 100-continue, not sending request body" + ); + return Ok(()); + } + } + + self.send_body().await?; + loop { + self.parse_head().await?; + if self.status == Some(Status::Continue) { + self.status = None; + } else { + break; + } + } + + Ok(()) + } + + async fn send_body(&mut self) -> Result<()> { + if let Some(mut body) = self.request_body.take() { + io::copy(&mut body, self.transport()).await?; + } + Ok(()) + } + + fn validate_response_headers(&self) -> Result<()> { + let content_length = self.response_headers.has_header(ContentLength); + + let transfer_encoding_chunked = self + .response_headers + .eq_ignore_ascii_case(TransferEncoding, "chunked"); + + if content_length && transfer_encoding_chunked { + Err(Error::UnexpectedHeader(ContentLength.into())) + } else { + Ok(()) + } + } + + pub(super) fn is_keep_alive(&self) -> bool { + self.response_headers + .eq_ignore_ascii_case(Connection, "keep-alive") + } + + pub(super) async fn finish_reading_body(&mut self) { + if self.response_body_state != ReceivedBodyState::End { + let body = self.response_body(); + match body.drain().await { + Ok(drain) => log::debug!( + "drained {}", + Size::from_bytes(drain).format().with_base(Base::Base10) + ), + Err(e) => log::warn!("failed to drain body, {:?}", e), + } + } + } + + async fn exec(&mut self) -> Result<()> { + self.finalize_headers()?; + self.connect_and_send_head().await?; + self.send_body_and_parse_head().await?; + Ok(()) + } + + pub(super) fn response_content_length(&self) -> Option { + if self.status == Some(Status::NoContent) + || self.status == Some(Status::NotModified) + || self.method == Method::Head + { + Some(0) + } else { + self.response_headers + .get_str(ContentLength) + .and_then(|c| c.parse().ok()) + } + } +} + +impl Drop for Conn { + fn drop(&mut self) { + if !self.is_keep_alive() { + return; + } + + let Some(transport) = self.transport.take() else { + return; + }; + let Ok(Some(peer_addr)) = transport.peer_addr() else { + return; + }; + let Some(pool) = self.pool.take() else { return }; + + let origin = self.url.origin(); + + if self.response_body_state == ReceivedBodyState::End { + log::trace!( + "response body has been read to completion, checking transport back into pool for \ + {}", + &peer_addr + ); + pool.insert(origin, PoolEntry::new(transport, None)); + } else { + let content_length = self.response_content_length(); + let buffer = std::mem::take(&mut self.buffer); + let response_body_state = self.response_body_state; + let encoding = encoding(&self.response_headers); + self.config.runtime().spawn(async move { + let mut response_body = ReceivedBody::new( + content_length, + buffer, + transport, + response_body_state, + None, + encoding, + ); + + match io::copy(&mut response_body, io::sink()).await { + Ok(bytes) => { + let transport = response_body.take_transport().unwrap(); + log::trace!( + "read {} bytes in order to recycle conn for {}", + bytes, + &peer_addr + ); + pool.insert(origin, PoolEntry::new(transport, None)); + } + + Err(ioerror) => log::error!("unable to recycle conn due to {}", ioerror), + }; + }); + } + } +} + +impl From for Body { + fn from(conn: Conn) -> Body { + let received_body: ReceivedBody<'static, _> = conn.into(); + received_body.into() + } +} + +impl From for ReceivedBody<'static, BoxedTransport> { + fn from(mut conn: Conn) -> Self { + let _ = conn.finalize_headers(); + let origin = conn.url.origin(); + + let on_completion = + conn.pool + .take() + .map(|pool| -> Box { + Box::new(move |transport| { + pool.insert(origin.clone(), PoolEntry::new(transport, None)); + }) + }); + + ReceivedBody::new( + conn.response_content_length(), + std::mem::take(&mut conn.buffer), + conn.transport.take().unwrap(), + conn.response_body_state, + on_completion, + conn.response_encoding(), + ) + } +} + +impl From for Upgrade { + fn from(mut conn: Conn) -> Self { + Upgrade::new( + std::mem::take(&mut conn.request_headers), + conn.url.path().to_string(), + conn.method, + conn.transport.take().unwrap(), + std::mem::take(&mut conn.buffer), + ) + } +} + +impl IntoFuture for Conn { + type IntoFuture = Pin + Send + 'static>>; + type Output = Result; + + fn into_future(mut self) -> Self::IntoFuture { + Box::pin(async move { (&mut self).await.map(|()| self) }) + } +} + +impl<'conn> IntoFuture for &'conn mut Conn { + type IntoFuture = Pin + Send + 'conn>>; + type Output = Result<()>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(async move { + if let Some(duration) = self.timeout { + self.config + .runtime() + .timeout(duration, self.exec()) + .await + .unwrap_or(Err(Error::TimedOut("Conn", duration)))?; + } else { + self.exec().await?; + } + Ok(()) + }) + } +} + +impl Debug for Conn { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Conn") + .field("url", &self.url) + .field("method", &self.method) + .field("request_headers", &self.request_headers) + .field("response_headers", &self.response_headers) + .field("status", &self.status) + .field("request_body", &self.request_body) + .field("pool", &self.pool) + .field("buffer", &String::from_utf8_lossy(&self.buffer)) + .field("response_body_state", &self.response_body_state) + .field("config", &self.config) + .field("state", &self.state) + .finish() + } +} + +impl AsRef for Conn { + fn as_ref(&self) -> &TypeSet { + &self.state + } +} +impl AsMut for Conn { + fn as_mut(&mut self) -> &mut TypeSet { + &mut self.state + } +} diff --git a/client/src/conn/unexpected_status_error.rs b/client/src/conn/unexpected_status_error.rs new file mode 100644 index 0000000000..07bcc8e57b --- /dev/null +++ b/client/src/conn/unexpected_status_error.rs @@ -0,0 +1,48 @@ +use super::Conn; +use std::{ + error::Error, + fmt::{self, Debug, Display, Formatter}, + ops::{Deref, DerefMut}, +}; +/// An unexpected http status code was received. Transform this back +/// into the conn with [`From::from`]/[`Into::into`]. +/// +/// Currently only returned by [`Conn::success`] +#[derive(Debug)] +pub struct UnexpectedStatusError(Box); +impl From for UnexpectedStatusError { + fn from(value: Conn) -> Self { + Self(Box::new(value)) + } +} + +impl From for Conn { + fn from(value: UnexpectedStatusError) -> Self { + *value.0 + } +} + +impl Deref for UnexpectedStatusError { + type Target = Conn; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for UnexpectedStatusError { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Error for UnexpectedStatusError {} +impl Display for UnexpectedStatusError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.status() { + Some(status) => f.write_fmt(format_args!( + "expected a success (2xx) status code, but got {status}" + )), + None => f.write_str("expected a status code to be set, but none was"), + } + } +} diff --git a/client/tests/timeout.rs b/client/tests/timeout.rs index 6a6880a79c..104a099aad 100644 --- a/client/tests/timeout.rs +++ b/client/tests/timeout.rs @@ -1,10 +1,13 @@ use std::time::Duration; use trillium_client::Client; -use trillium_testing::{client_config, runtime, RuntimeTrait}; +use trillium_testing::{client_config, Runtime}; async fn handler(conn: trillium::Conn) -> trillium::Conn { if conn.path() == "/slow" { - runtime().delay(Duration::from_secs(5)).await; + conn.shared_state::() + .unwrap() + .delay(Duration::from_secs(5)) + .await; } conn.ok("ok") } diff --git a/forwarding/src/lib.rs b/forwarding/src/lib.rs index 216778b790..329e05c570 100644 --- a/forwarding/src/lib.rs +++ b/forwarding/src/lib.rs @@ -52,7 +52,9 @@ where } impl Debug for TrustFn { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("TrustPredicate").field(&"..").finish() + f.debug_tuple("TrustPredicate") + .field(&format_args!("..")) + .finish() } } diff --git a/http/examples/conn-example.rs b/http/examples/conn-example.rs index eb7e541d7d..69c093ffe5 100644 --- a/http/examples/conn-example.rs +++ b/http/examples/conn-example.rs @@ -1,50 +1,43 @@ -fn main() -> trillium_http::Result<()> { - use async_net::{TcpListener, TcpStream}; - use futures_lite::StreamExt; - use swansong::Swansong; - use trillium_http::{Conn, Result}; +fn main() { + use smol::{net::TcpListener, stream::StreamExt}; + use std::sync::Arc; + use trillium_http::ServerConfig; smol::block_on(async { - let swansong = Swansong::new(); - - let server_swansong = swansong.clone(); - let server = smol::spawn(async move { - let listener = TcpListener::bind("localhost:8001").await?; - let mut incoming = server_swansong.interrupt(listener.incoming()); - - while let Some(Ok(stream)) = incoming.next().await { - let swansong = server_swansong.clone(); - smol::spawn(async move { - Conn::map(stream, swansong, |mut conn: Conn| async move { + let server_config = Arc::new(ServerConfig::default()); + let listener = TcpListener::bind("localhost:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + println!("listening on http://{local_addr}"); + + let server = smol::spawn({ + let server_config = server_config.clone(); + async move { + let mut incoming = server_config.swansong().interrupt(listener.incoming()); + + while let Some(Ok(stream)) = incoming.next().await { + smol::spawn(server_config.clone().run(stream, |mut conn| async move { conn.set_response_body("hello world"); conn.set_status(200); conn - }) - .await - }) - .detach() + })) + .detach() + } } - - Result::Ok(()) }); - // this example uses the trillium client - // please note that this api is still especially unstable. - // any other http client would work here too use trillium_client::Client; use trillium_smol::ClientConfig; - let client = Client::new(ClientConfig::default()); - let mut client_conn = client.get("http://localhost:8001").await?; + let client = Client::new(ClientConfig::default()).with_base(local_addr); + let mut client_conn = client.get("/").await.unwrap(); assert_eq!(client_conn.status().unwrap(), 200); assert_eq!( - client_conn.response_body().read_string().await?, + client_conn.response_body().read_string().await.unwrap(), "hello world" ); - swansong.shut_down(); // stop the server after one request - server.await?; // wait for the server to shut down + server.await; - Result::Ok(()) + // server_config.shut_down().await; // stop the server after one request }) } diff --git a/http/examples/http.rs b/http/examples/http.rs index 8ceff5cc81..64ffd39027 100644 --- a/http/examples/http.rs +++ b/http/examples/http.rs @@ -1,6 +1,7 @@ use async_net::{TcpListener, TcpStream}; use futures_lite::prelude::*; -use trillium_http::{Conn, Swansong}; +use std::sync::Arc; +use trillium_http::{Conn, ServerConfig}; async fn handler(mut conn: Conn) -> Conn { conn.set_status(200); @@ -12,18 +13,18 @@ pub fn main() { env_logger::init(); smol::block_on(async move { - let swansong = Swansong::new(); + let server_config = Arc::new(ServerConfig::new()); let port = std::env::var("PORT") .unwrap_or("8080".into()) .parse::() .unwrap(); let listener = TcpListener::bind(("0.0.0.0", port)).await.unwrap(); - let mut incoming = swansong.interrupt(listener.incoming()); + let mut incoming = server_config.swansong().interrupt(listener.incoming()); while let Some(Ok(stream)) = incoming.next().await { - let swansong = swansong.clone(); + let server_config = Arc::clone(&server_config); smol::spawn(async move { - match Conn::map(stream, swansong, handler).await { + match server_config.run(stream, handler).await { Ok(Some(_)) => log::info!("upgrade"), Ok(None) => log::info!("closing connection"), Err(e) => log::error!("{:?}", e), diff --git a/http/examples/tokio-http.rs b/http/examples/tokio-http.rs index 7400416025..32c83112e2 100644 --- a/http/examples/tokio-http.rs +++ b/http/examples/tokio-http.rs @@ -1,6 +1,7 @@ use async_compat::Compat; +use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; -use trillium_http::{Conn, Swansong}; +use trillium_http::{Conn, ServerConfig}; async fn handler(mut conn: Conn>) -> Conn> { let body = conn.request_body().await.read_string().await.unwrap(); @@ -15,14 +16,15 @@ async fn handler(mut conn: Conn>) -> Conn> { #[tokio::main] pub async fn main() { env_logger::init(); - let swansong = Swansong::new(); + let server_config = Arc::new(ServerConfig::new()); + let listener = TcpListener::bind("127.0.0.1:8081").await.unwrap(); loop { match listener.accept().await { Ok((stream, _)) => { - let swansong = swansong.clone(); + let server_config = server_config.clone(); tokio::spawn(async move { - match Conn::map(Compat::new(stream), swansong, handler).await { + match server_config.run(Compat::new(stream), handler).await { Ok(Some(_)) => log::info!("upgrade"), Ok(None) => log::info!("closing connection"), Err(e) => log::error!("{:?}", e), diff --git a/http/examples/unsend.rs b/http/examples/unsend.rs index a753fe4d69..68e503dd46 100644 --- a/http/examples/unsend.rs +++ b/http/examples/unsend.rs @@ -1,7 +1,7 @@ use async_net::{TcpListener, TcpStream}; use futures_lite::prelude::*; -use std::thread; -use trillium_http::{Conn, Swansong}; +use std::{sync::Arc, thread}; +use trillium_http::{Conn, ServerConfig, Swansong}; async fn handler(mut conn: Conn) -> Conn { let rc = std::rc::Rc::new(()); @@ -14,13 +14,15 @@ async fn handler(mut conn: Conn) -> Conn { pub fn main() { env_logger::init(); - let swansong = Swansong::new(); + let server_config = Arc::new(ServerConfig::new()); let (send, receive) = async_channel::unbounded(); let core_ids = core_affinity::get_core_ids().unwrap(); + + let swansong = Swansong::new(); let handles = core_ids .into_iter() .map(|id| { - let swansong = swansong.clone(); + let server_config = server_config.clone(); let receive = receive.clone(); thread::spawn(move || { if !core_affinity::set_for_current(id) { @@ -28,12 +30,11 @@ pub fn main() { } let executor = async_executor::LocalExecutor::new(); - futures_lite::future::block_on(executor.run(async { + async_io::block_on(executor.run(async { while let Ok(transport) = receive.recv().await { - let swansong = swansong.clone(); - + let server_config = server_config.clone(); let future = async move { - match Conn::map(transport, swansong, handler).await { + match server_config.run(transport, handler).await { Ok(_) => {} Err(e) => log::error!("{e}"), } diff --git a/http/src/body.rs b/http/src/body.rs index 2fd42cbcca..98ce660b1a 100644 --- a/http/src/body.rs +++ b/http/src/body.rs @@ -283,7 +283,7 @@ impl Debug for BodyType { .. } => f .debug_struct("BodyType::Streaming") - .field("async_read", &"..") + .field("async_read", &format_args!("..")) .field("len", &len) .field("done", &done) .field("progress", &progress) diff --git a/http/src/conn.rs b/http/src/conn.rs index 40bbda001b..ddc8128b7b 100644 --- a/http/src/conn.rs +++ b/http/src/conn.rs @@ -1,20 +1,17 @@ use crate::{ after_send::{AfterSend, SendStatus}, - copy, - http_config::DEFAULT_CONFIG, liveness::{CancelOnDisconnect, LivenessFut}, received_body::ReceivedBodyState, util::encoding, - Body, BufWriter, Buffer, ConnectionStatus, Error, Headers, HttpConfig, - KnownHeaderName::{Connection, ContentLength, Date, Expect, Host, Server, TransferEncoding}, - Method, ReceivedBody, Result, Status, Swansong, TypeSet, Upgrade, Version, + Body, Buffer, Headers, + KnownHeaderName::{Connection, ContentLength, Date, Host, TransferEncoding}, + Method, ReceivedBody, ServerConfig, Status, Swansong, TypeSet, Version, }; use encoding_rs::Encoding; use futures_lite::{ future, - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + io::{AsyncRead, AsyncWrite}, }; -use memchr::memmem::Finder; use std::{ fmt::{self, Debug, Formatter}, future::Future, @@ -24,6 +21,7 @@ use std::{ sync::Arc, time::{Instant, SystemTime}, }; +mod implementation; /// Default Server header pub const SERVER: &str = concat!("trillium/", env!("CARGO_PKG_VERSION")); @@ -34,6 +32,7 @@ pub const SERVER: &str = concat!("trillium/", env!("CARGO_PKG_VERSION")); /// the request and the response, and holds the transport over which the /// response will be sent. pub struct Conn { + pub(crate) server_config: Arc, pub(crate) request_headers: Headers, pub(crate) response_headers: Headers, pub(crate) path: String, @@ -46,18 +45,15 @@ pub struct Conn { pub(crate) buffer: Buffer, pub(crate) request_body_state: ReceivedBodyState, pub(crate) secure: bool, - pub(crate) swansong: Swansong, pub(crate) after_send: AfterSend, pub(crate) start_time: Instant, pub(crate) peer_ip: Option, - pub(crate) http_config: HttpConfig, - pub(crate) shared_state: Option>, } impl Debug for Conn { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Conn") - .field("http_config", &self.http_config) + .field("server_config", &self.server_config) .field("request_headers", &self.request_headers) .field("response_headers", &self.response_headers) .field("path", &self.path) @@ -65,14 +61,12 @@ impl Debug for Conn { .field("status", &self.status) .field("version", &self.version) .field("state", &self.state) - .field("shared_state", &self.shared_state) .field("response_body", &self.response_body) - .field("transport", &"..") - .field("buffer", &"..") + .field("transport", &format_args!("..")) + .field("buffer", &format_args!("..")) .field("request_body_state", &self.request_body_state) .field("secure", &self.secure) - .field("swansong", &self.swansong) - .field("after_send", &"..") + .field("after_send", &format_args!("..")) .field("start_time", &self.start_time) .field("peer_ip", &self.peer_ip) .finish() @@ -83,143 +77,6 @@ impl Conn where Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, { - /// read any number of new `Conn`s from the transport and call the - /// provided handler function until either the connection is closed or - /// an upgrade is requested. A return value of Ok(None) indicates a - /// closed connection, while a return value of Ok(Some(upgrade)) - /// represents an upgrade. - /// - /// Provides a default [`HttpConfig`] - /// - /// See the documentation for [`Conn`] for a full example. - /// - /// # Errors - /// - /// This will return an error variant if: - /// - /// * there is an io error when reading from the underlying transport - /// * headers are too long - /// * we are unable to parse some aspect of the request - /// * the request is an unsupported http version - /// * we cannot make sense of the headers, such as if there is a - /// `content-length` header as well as a `transfer-encoding: chunked` - /// header. - - pub async fn map( - transport: Transport, - swansong: Swansong, - handler: F, - ) -> Result>> - where - F: FnMut(Conn) -> Fut, - Fut: Future>, - { - Self::map_with_config(DEFAULT_CONFIG, transport, swansong, handler).await - } - - /// read any number of new `Conn`s from the transport and call the - /// provided handler function until either the connection is closed or - /// an upgrade is requested. A return value of Ok(None) indicates a - /// closed connection, while a return value of Ok(Some(upgrade)) - /// represents an upgrade. - /// - /// See the documentation for [`Conn`] for a full example. - /// - /// # Errors - /// - /// This will return an error variant if: - /// - /// * there is an io error when reading from the underlying transport - /// * headers are too long - /// * we are unable to parse some aspect of the request - /// * the request is an unsupported http version - /// * we cannot make sense of the headers, such as if there is a - /// `content-length` header as well as a `transfer-encoding: chunked` - /// header. - pub async fn map_with_config( - http_config: HttpConfig, - transport: Transport, - swansong: Swansong, - handler: F, - ) -> Result>> - where - F: FnMut(Conn) -> Fut, - Fut: Future>, - { - Self::map_with_config_and_shared_state(http_config, transport, swansong, None, handler) - .await - } - - /// read any number of new `Conn`s from the transport and call the - /// provided handler function until either the connection is closed or - /// an upgrade is requested. A return value of Ok(None) indicates a - /// closed connection, while a return value of Ok(Some(upgrade)) - /// represents an upgrade. - /// - /// The `shared_state` `Arc` is available provided to all Conns on this transport if - /// provided. - /// - /// See the documentation for [`Conn`] for a full example. - /// - /// # Errors - /// - /// This will return an error variant if: - /// - /// * there is an io error when reading from the underlying transport - /// * headers are too long - /// * we are unable to parse some aspect of the request - /// * the request is an unsupported http version - /// * we cannot make sense of the headers, such as if there is a - /// `content-length` header as well as a `transfer-encoding: chunked` - /// header. - pub async fn map_with_config_and_shared_state( - http_config: HttpConfig, - transport: Transport, - swansong: Swansong, - shared_state: Option>, - mut handler: F, - ) -> Result>> - where - F: FnMut(Conn) -> Fut, - Fut: Future>, - { - let mut conn = Conn::new_internal( - http_config, - transport, - Vec::with_capacity(http_config.request_buffer_initial_len).into(), - swansong, - shared_state, - ) - .await?; - - loop { - conn = match handler(conn).await.send().await? { - ConnectionStatus::Upgrade(upgrade) => return Ok(Some(upgrade)), - ConnectionStatus::Close => return Ok(None), - ConnectionStatus::Conn(next) => next, - } - } - } - - async fn send(mut self) -> Result> { - let mut output_buffer = Vec::with_capacity(self.http_config.response_buffer_len); - self.write_headers(&mut output_buffer)?; - - let mut bufwriter = BufWriter::new_with_buffer(output_buffer, &mut self.transport); - - if self.method != Method::Head - && !matches!(self.status, Some(Status::NotModified | Status::NoContent)) - { - if let Some(body) = self.response_body.take() { - copy(body, &mut bufwriter, self.http_config.copy_loops_per_yield).await?; - } - } - - bufwriter.flush().await?; - self.after_send.call(true.into()); - self.finish().await - } - /// returns a read-only reference to the [state /// typemap](TypeSet) for this conn /// @@ -232,21 +89,13 @@ where /// returns a mutable reference to the [state /// typemap](TypeSet) for this conn - /// - /// stability note: this is not unlikely to be removed at some - /// point, as this may end up being more of a trillium concern - /// than a `trillium_http` concern pub fn state_mut(&mut self) -> &mut TypeSet { &mut self.state } /// Returns the shared state on this conn, if set - /// - /// stability note: this is not unlikely to be removed at some - /// point, as this may end up being more of a trillium concern - /// than a `trillium_http` concern - pub fn shared_state(&self) -> Option<&TypeSet> { - self.shared_state.as_deref() + pub fn shared_state(&self) -> &TypeSet { + &self.server_config.shared_state } /// returns a reference to the request headers @@ -339,21 +188,6 @@ where self.request_headers.insert(Host, host); } - // pub fn url(&self) -> Result { - // let path = self.path(); - // let host = self.host().unwrap_or_else(|| String::from("_")); - // let method = self.method(); - // if path.starts_with("http://") || path.starts_with("https://") { - // Ok(Url::parse(path)?) - // } else if path.starts_with('/') { - // Ok(Url::parse(&format!("http://{}{}", host, path))?) - // } else if method == &Method::Connect { - // Ok(Url::parse(&format!("http://{}/", path))?) - // } else { - // Err(Error::UnexpectedUriFormat) - // } - // } - /// Sets the response body to anything that is [`impl Into`][Body]. /// /// ``` @@ -482,27 +316,6 @@ where future::poll_once(LivenessFut::new(self)).await.is_some() } - fn needs_100_continue(&self) -> bool { - self.request_body_state == ReceivedBodyState::Start - && self.version != Version::Http1_0 - && self - .request_headers - .eq_ignore_ascii_case(Expect, "100-continue") - } - - #[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)] - fn build_request_body(&mut self) -> ReceivedBody<'_, Transport> { - ReceivedBody::new_with_config( - self.request_content_length().ok().flatten(), - &mut self.buffer, - &mut self.transport, - &mut self.request_body_state, - None, - encoding(&self.request_headers), - &self.http_config, - ) - } - /// returns the [`encoding_rs::Encoding`] for this request, as /// determined from the mime-type charset, if available /// @@ -557,190 +370,7 @@ where /// this to gracefully stop long-running futures and streams /// inside of handler functions pub fn swansong(&self) -> Swansong { - self.swansong.clone() - } - - fn validate_headers(request_headers: &Headers) -> Result<()> { - let content_length = request_headers.has_header(ContentLength); - let transfer_encoding_chunked = - request_headers.eq_ignore_ascii_case(TransferEncoding, "chunked"); - - if content_length && transfer_encoding_chunked { - Err(Error::UnexpectedHeader(ContentLength.into())) - } else { - Ok(()) - } - } - - /// # Create a new `Conn` - /// - /// This function creates a new conn from the provided - /// [`Transport`][crate::transport::Transport], as well as any - /// bytes that have already been read from the transport, and a - /// [`Swansong`] instance that will be used to signal graceful - /// shutdown. - /// - /// # Errors - /// - /// This will return an error variant if: - /// - /// * there is an io error when reading from the underlying transport - /// * headers are too long - /// * we are unable to parse some aspect of the request - /// * the request is an unsupported http version - /// * we cannot make sense of the headers, such as if there is a - /// `content-length` header as well as a `transfer-encoding: chunked` - /// header. - pub async fn new(transport: Transport, bytes: Vec, swansong: Swansong) -> Result { - Self::new_internal(DEFAULT_CONFIG, transport, bytes.into(), swansong, None).await - } - - #[cfg(not(feature = "parse"))] - async fn new_internal( - http_config: HttpConfig, - mut transport: Transport, - mut buffer: Buffer, - swansong: Swansong, - shared_state: Option>, - ) -> Result { - use crate::{HeaderName, HeaderValue}; - use httparse::{Request, EMPTY_HEADER}; - use std::str::FromStr; - - let (head_size, start_time) = - Self::head(&mut transport, &mut buffer, &swansong, &http_config).await?; - - let mut headers = vec![EMPTY_HEADER; http_config.max_headers]; - let mut httparse_req = Request::new(&mut headers); - - let status = httparse_req.parse(&buffer[..]).map_err(|e| match e { - httparse::Error::HeaderName => Error::InvalidHeaderName, - httparse::Error::HeaderValue => Error::InvalidHeaderValue("unknown".into()), - httparse::Error::Status => Error::InvalidStatus, - httparse::Error::TooManyHeaders => Error::HeadersTooLong, - httparse::Error::Version => Error::InvalidVersion, - _ => Error::InvalidHead, - })?; - - if status.is_partial() { - return Err(Error::InvalidHead); - } - - let method = match httparse_req.method { - Some(method) => match method.parse() { - Ok(method) => method, - Err(_) => return Err(Error::UnrecognizedMethod(method.to_string())), - }, - None => return Err(Error::MissingMethod), - }; - - let version = match httparse_req.version { - Some(0) => Version::Http1_0, - Some(1) => Version::Http1_1, - _ => return Err(Error::InvalidVersion), - }; - - let mut request_headers = Headers::new(); - for header in httparse_req.headers { - let header_name = HeaderName::from_str(header.name)?; - let header_value = HeaderValue::from(header.value.to_owned()); - request_headers.append(header_name, header_value); - } - - Self::validate_headers(&request_headers)?; - - let path = httparse_req - .path - .ok_or(Error::RequestPathMissing)? - .to_owned(); - log::trace!("received:\n{method} {path} {version}\n{request_headers}"); - - let mut response_headers = Headers::new(); - response_headers.insert(Server, SERVER); - - buffer.ignore_front(head_size); - - Ok(Self { - transport, - request_headers, - method, - version, - path, - buffer, - response_headers, - status: None, - state: TypeSet::new(), - response_body: None, - request_body_state: ReceivedBodyState::Start, - secure: false, - swansong, - after_send: AfterSend::default(), - start_time, - peer_ip: None, - http_config, - shared_state, - }) - } - - #[cfg(feature = "parse")] - async fn new_internal( - http_config: HttpConfig, - mut transport: Transport, - mut buffer: Buffer, - swansong: Swansong, - shared_state: Option>, - ) -> Result { - let (head_size, start_time) = - Self::head(&mut transport, &mut buffer, &swansong, &http_config).await?; - - let first_line_index = Finder::new(b"\r\n") - .find(&buffer) - .ok_or(Error::InvalidHead)?; - - let mut spaces = memchr::memchr_iter(b' ', &buffer[..first_line_index]); - let first_space = spaces.next().ok_or(Error::MissingMethod)?; - let method = Method::parse(&buffer[0..first_space])?; - let second_space = spaces.next().ok_or(Error::RequestPathMissing)?; - let path = str::from_utf8(&buffer[first_space + 1..second_space]) - .map_err(|_| Error::RequestPathMissing)? - .to_string(); - if path.is_empty() { - return Err(Error::InvalidHead); - } - let version = Version::parse(&buffer[second_space + 1..first_line_index])?; - if !matches!(version, Version::Http1_1 | Version::Http1_0) { - return Err(Error::UnsupportedVersion(version)); - } - - let request_headers = Headers::parse(&buffer[first_line_index + 2..head_size])?; - - Self::validate_headers(&request_headers)?; - - let mut response_headers = Headers::new(); - response_headers.insert(Server, SERVER); - - buffer.ignore_front(head_size); - - Ok(Self { - transport, - request_headers, - method, - version, - path, - buffer, - response_headers, - status: None, - state: TypeSet::new(), - response_body: None, - request_body_state: ReceivedBodyState::Start, - secure: false, - swansong, - after_send: AfterSend::default(), - start_time, - peer_ip: None, - http_config, - shared_state, - }) + self.server_config.swansong.clone() } /// predicate function to indicate whether the connection is @@ -786,7 +416,7 @@ where } } - if self.swansong.state().is_shutting_down() { + if self.server_config.swansong.state().is_shutting_down() { self.response_headers.insert(Connection, "close"); } } @@ -812,229 +442,34 @@ where self.start_time } - async fn send_100_continue(&mut self) -> Result<()> { - log::trace!("sending 100-continue"); - Ok(self - .transport - .write_all(b"HTTP/1.1 100 Continue\r\n\r\n") - .await?) - } - - async fn head( - transport: &mut Transport, - buf: &mut Buffer, - swansong: &Swansong, - http_config: &HttpConfig, - ) -> Result<(usize, Instant)> { - let mut len = 0; - let mut start_with_read = buf.is_empty(); - let mut instant = None; - let finder = Finder::new(b"\r\n\r\n"); - loop { - if len >= http_config.head_max_len { - return Err(Error::HeadersTooLong); - } - - let bytes = if start_with_read { - buf.expand(); - if len == 0 { - swansong - .interrupt(transport.read(buf)) - .await - .ok_or(Error::Closed)?? - } else { - transport.read(&mut buf[len..]).await? - } - } else { - start_with_read = true; - buf.len() - }; - - if instant.is_none() { - instant = Some(Instant::now()); - } - - let search_start = len.max(3) - 3; - let search = finder.find(&buf[search_start..]); - - if let Some(index) = search { - buf.truncate(len + bytes); - return Ok((search_start + index + 4, instant.unwrap())); - } - - len += bytes; - - if bytes == 0 { - return if len == 0 { - Err(Error::Closed) - } else { - Err(Error::InvalidHead) - }; - } - } - } - - async fn next(mut self) -> Result { - if !self.needs_100_continue() || self.request_body_state != ReceivedBodyState::Start { - self.build_request_body().drain().await?; - } - Conn::new_internal( - self.http_config, - self.transport, - self.buffer, - self.swansong, - self.shared_state, - ) - .await - } - - fn should_close(&self) -> bool { - let request_connection = self.request_headers.get_lower(Connection); - let response_connection = self.response_headers.get_lower(Connection); - - match ( - request_connection.as_deref(), - response_connection.as_deref(), - ) { - (Some("keep-alive"), Some("keep-alive")) => false, - (Some("close"), _) | (_, Some("close")) => true, - _ => self.version == Version::Http1_0, - } - } - - fn should_upgrade(&self) -> bool { - (self.method() == Method::Connect && self.status == Some(Status::Ok)) - || self.status == Some(Status::SwitchingProtocols) - } - - async fn finish(self) -> Result> { - if self.should_close() { - Ok(ConnectionStatus::Close) - } else if self.should_upgrade() { - Ok(ConnectionStatus::Upgrade(self.into())) - } else { - match self.next().await { - Err(Error::Closed) => { - log::trace!("connection closed by client"); - Ok(ConnectionStatus::Close) - } - Err(e) => Err(e), - Ok(conn) => Ok(ConnectionStatus::Conn(conn)), - } - } - } - - fn request_content_length(&self) -> Result> { - if self - .request_headers - .eq_ignore_ascii_case(TransferEncoding, "chunked") - { - Ok(None) - } else if let Some(cl) = self.request_headers.get_str(ContentLength) { - cl.parse() - .map(Some) - .map_err(|_| Error::InvalidHeaderValue(ContentLength.into())) - } else { - Ok(Some(0)) - } - } - - fn body_len(&self) -> Option { - match self.response_body { - Some(ref body) => body.len(), - None => Some(0), - } - } - - fn write_headers(&mut self, output_buffer: &mut Vec) -> Result<()> { - use std::io::Write; - let status = self.status().unwrap_or(Status::NotFound); - - write!( - output_buffer, - "{} {} {}\r\n", - self.version, - status as u16, - status.canonical_reason() - )?; - - self.finalize_headers(); - - log::trace!( - "sending:\n{} {}\n{}", - self.version, - status, - &self.response_headers - ); - - for (name, values) in &self.response_headers { - if name.is_valid() { - for value in values { - if value.is_valid() { - write!(output_buffer, "{name}: ")?; - output_buffer.extend_from_slice(value.as_ref()); - write!(output_buffer, "\r\n")?; - } else { - log::error!("skipping invalid header value {value:?} for header {name}"); - } - } - } else { - log::error!("skipping invalid header with name {name:?}"); - } - } - - write!(output_buffer, "\r\n")?; - Ok(()) - } - /// applies a mapping function from one transport to another. This /// is particularly useful for boxing the transport. unless you're /// sure this is what you're looking for, you probably don't want /// to be using this - pub fn map_transport( + pub fn map_transport( self, - f: impl Fn(Transport) -> T, - ) -> Conn { - let Conn { - request_headers, - response_headers, - path, - status, - version, - state, - transport, - buffer, - request_body_state, - secure, - method, - response_body, - swansong, - after_send, - start_time, - peer_ip, - http_config, - shared_state, - } = self; - + f: impl Fn(Transport) -> NewTransport, + ) -> Conn + where + NewTransport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, + { Conn { - request_headers, - response_headers, - method, - response_body, - path, - status, - version, - state, - transport: f(transport), - buffer, - request_body_state, - secure, - swansong, - after_send, - start_time, - peer_ip, - http_config, - shared_state, + server_config: self.server_config, + request_headers: self.request_headers, + response_headers: self.response_headers, + method: self.method, + response_body: self.response_body, + path: self.path, + status: self.status, + version: self.version, + state: self.state, + transport: f(self.transport), + buffer: self.buffer, + request_body_state: self.request_body_state, + secure: self.secure, + after_send: self.after_send, + start_time: self.start_time, + peer_ip: self.peer_ip, } } diff --git a/http/src/conn/implementation.rs b/http/src/conn/implementation.rs new file mode 100644 index 0000000000..8ebe482264 --- /dev/null +++ b/http/src/conn/implementation.rs @@ -0,0 +1,404 @@ +use crate::{ + after_send::AfterSend, conn::ReceivedBodyState, copy, util::encoding, BufWriter, Buffer, Conn, + ConnectionStatus, Error, Headers, KnownHeaderName, Method, ReceivedBody, Result, ServerConfig, + Status, TypeSet, Version, SERVER, +}; +use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use memchr::memmem::Finder; +use std::{sync::Arc, time::Instant}; + +impl Conn +where + Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, +{ + pub(crate) async fn send(mut self) -> Result> { + let mut output_buffer = + Vec::with_capacity(self.server_config.http_config.response_buffer_len); + self.write_headers(&mut output_buffer)?; + + let mut bufwriter = BufWriter::new_with_buffer(output_buffer, &mut self.transport); + + if self.method != Method::Head + && !matches!(self.status, Some(Status::NotModified | Status::NoContent)) + { + if let Some(body) = self.response_body.take() { + copy( + body, + &mut bufwriter, + self.server_config.http_config.copy_loops_per_yield, + ) + .await?; + } + } + + bufwriter.flush().await?; + self.after_send.call(true.into()); + self.finish().await + } + + pub(super) fn needs_100_continue(&self) -> bool { + self.request_body_state == ReceivedBodyState::Start + && self.version != Version::Http1_0 + && self + .request_headers + .eq_ignore_ascii_case(KnownHeaderName::Expect, "100-continue") + } + + #[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)] + pub(super) fn build_request_body(&mut self) -> ReceivedBody<'_, Transport> { + ReceivedBody::new_with_config( + self.request_content_length().ok().flatten(), + &mut self.buffer, + &mut self.transport, + &mut self.request_body_state, + None, + encoding(&self.request_headers), + &self.server_config.http_config, + ) + } + + fn validate_headers(request_headers: &Headers) -> Result<()> { + let content_length = request_headers.has_header(KnownHeaderName::ContentLength); + let transfer_encoding_chunked = + request_headers.eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked"); + + if content_length && transfer_encoding_chunked { + Err(Error::UnexpectedHeader( + KnownHeaderName::ContentLength.into(), + )) + } else { + Ok(()) + } + } + + // /// # Create a new `Conn` + // /// + // /// This function creates a new conn from the provided + // /// [`Transport`][crate::transport::Transport], as well as any + // /// bytes that have already been read from the transport, and a + // /// [`Swansong`] instance that will be used to signal graceful + // /// shutdown. + // /// + // /// # Errors + // /// + // /// This will return an error variant if: + // /// + // /// * there is an io error when reading from the underlying transport + // /// * headers are too long + // /// * we are unable to parse some aspect of the request + // /// * the request is an unsupported http version + // /// * we cannot make sense of the headers, such as if there is a + // /// `content-length` header as well as a `transfer-encoding: chunked` + // /// header. + // pub async fn new(transport: Transport, bytes: Vec, swansong: Swansong) -> Result { + // Self::new_internal(DEFAULT_CONFIG, transport, bytes.into(), swansong, None).await + // } + + #[cfg(not(feature = "parse"))] + pub(crate) async fn new_internal( + server_config: Arc, + mut transport: Transport, + mut buffer: Buffer, + ) -> Result { + use crate::{HeaderName, HeaderValue}; + use httparse::{Request, EMPTY_HEADER}; + use std::str::FromStr; + + let (head_size, start_time) = + Self::head(&mut transport, &mut buffer, &server_config).await?; + + let mut headers = vec![EMPTY_HEADER; server_config.http_config.max_headers]; + let mut httparse_req = Request::new(&mut headers); + + let status = httparse_req.parse(&buffer[..]).map_err(|e| match e { + httparse::Error::HeaderName => Error::InvalidHeaderName, + httparse::Error::HeaderValue => Error::InvalidHeaderValue("unknown".into()), + httparse::Error::Status => Error::InvalidStatus, + httparse::Error::TooManyHeaders => Error::HeadersTooLong, + httparse::Error::Version => Error::InvalidVersion, + _ => Error::InvalidHead, + })?; + + if status.is_partial() { + return Err(Error::InvalidHead); + } + + let method = match httparse_req.method { + Some(method) => match method.parse() { + Ok(method) => method, + Err(_) => return Err(Error::UnrecognizedMethod(method.to_string())), + }, + None => return Err(Error::MissingMethod), + }; + + let version = match httparse_req.version { + Some(0) => Version::Http1_0, + Some(1) => Version::Http1_1, + _ => return Err(Error::InvalidVersion), + }; + + let mut request_headers = Headers::new(); + for header in httparse_req.headers { + let header_name = HeaderName::from_str(header.name)?; + let header_value = HeaderValue::from(header.value.to_owned()); + request_headers.append(header_name, header_value); + } + + Self::validate_headers(&request_headers)?; + + let path = httparse_req + .path + .ok_or(Error::RequestPathMissing)? + .to_owned(); + log::trace!("received:\n{method} {path} {version}\n{request_headers}"); + + let mut response_headers = Headers::new(); + response_headers.insert(KnownHeaderName::Server, SERVER); + + buffer.ignore_front(head_size); + + Ok(Self { + transport, + request_headers, + method, + version, + path, + buffer, + response_headers, + status: None, + state: TypeSet::new(), + response_body: None, + request_body_state: ReceivedBodyState::Start, + secure: false, + after_send: AfterSend::default(), + start_time, + peer_ip: None, + server_config, + }) + } + + #[cfg(feature = "parse")] + pub(crate) async fn new_internal( + server_config: Arc, + mut transport: Transport, + mut buffer: Buffer, + ) -> Result { + let (head_size, start_time) = + Self::head(&mut transport, &mut buffer, &server_config).await?; + + let first_line_index = Finder::new(b"\r\n") + .find(&buffer) + .ok_or(Error::InvalidHead)?; + + let mut spaces = memchr::memchr_iter(b' ', &buffer[..first_line_index]); + let first_space = spaces.next().ok_or(Error::MissingMethod)?; + let method = Method::parse(&buffer[0..first_space])?; + let second_space = spaces.next().ok_or(Error::RequestPathMissing)?; + let path = std::str::from_utf8(&buffer[first_space + 1..second_space]) + .map_err(|_| Error::RequestPathMissing)? + .to_string(); + if path.is_empty() { + return Err(Error::InvalidHead); + } + let version = Version::parse(&buffer[second_space + 1..first_line_index])?; + if !matches!(version, Version::Http1_1 | Version::Http1_0) { + return Err(Error::UnsupportedVersion(version)); + } + + let request_headers = Headers::parse(&buffer[first_line_index + 2..head_size])?; + + Self::validate_headers(&request_headers)?; + + let mut response_headers = Headers::new(); + response_headers.insert(KnownHeaderName::Server, SERVER); + + buffer.ignore_front(head_size); + + Ok(Self { + server_config, + transport, + request_headers, + method, + version, + path, + buffer, + response_headers, + status: None, + state: TypeSet::new(), + response_body: None, + request_body_state: ReceivedBodyState::Start, + secure: false, + after_send: AfterSend::default(), + start_time, + peer_ip: None, + }) + } + + pub(super) async fn send_100_continue(&mut self) -> Result<()> { + log::trace!("sending 100-continue"); + Ok(self + .transport + .write_all(b"HTTP/1.1 100 Continue\r\n\r\n") + .await?) + } + + async fn head( + transport: &mut Transport, + buf: &mut Buffer, + server_config: &ServerConfig, + ) -> Result<(usize, Instant)> { + let mut len = 0; + let mut start_with_read = buf.is_empty(); + let mut instant = None; + let finder = Finder::new(b"\r\n\r\n"); + loop { + if len >= server_config.http_config.head_max_len { + return Err(Error::HeadersTooLong); + } + + let bytes = if start_with_read { + buf.expand(); + if len == 0 { + server_config + .swansong + .interrupt(transport.read(buf)) + .await + .ok_or(Error::Closed)?? + } else { + transport.read(&mut buf[len..]).await? + } + } else { + start_with_read = true; + buf.len() + }; + + if instant.is_none() { + instant = Some(Instant::now()); + } + + let search_start = len.max(3) - 3; + let search = finder.find(&buf[search_start..]); + + if let Some(index) = search { + buf.truncate(len + bytes); + return Ok((search_start + index + 4, instant.unwrap())); + } + + len += bytes; + + if bytes == 0 { + return if len == 0 { + Err(Error::Closed) + } else { + Err(Error::InvalidHead) + }; + } + } + } + + async fn next(mut self) -> Result { + if !self.needs_100_continue() || self.request_body_state != ReceivedBodyState::Start { + self.build_request_body().drain().await?; + } + Conn::new_internal(self.server_config, self.transport, self.buffer).await + } + + fn should_close(&self) -> bool { + let request_connection = self.request_headers.get_lower(KnownHeaderName::Connection); + let response_connection = self.response_headers.get_lower(KnownHeaderName::Connection); + + match ( + request_connection.as_deref(), + response_connection.as_deref(), + ) { + (Some("keep-alive"), Some("keep-alive")) => false, + (Some("close"), _) | (_, Some("close")) => true, + _ => self.version == Version::Http1_0, + } + } + + fn should_upgrade(&self) -> bool { + (self.method() == Method::Connect && self.status == Some(Status::Ok)) + || self.status == Some(Status::SwitchingProtocols) + } + + async fn finish(self) -> Result> { + if self.should_close() { + Ok(ConnectionStatus::Close) + } else if self.should_upgrade() { + Ok(ConnectionStatus::Upgrade(self.into())) + } else { + match self.next().await { + Err(Error::Closed) => { + log::trace!("connection closed by client"); + Ok(ConnectionStatus::Close) + } + Err(e) => Err(e), + Ok(conn) => Ok(ConnectionStatus::Conn(conn)), + } + } + } + + fn request_content_length(&self) -> Result> { + if self + .request_headers + .eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked") + { + Ok(None) + } else if let Some(cl) = self.request_headers.get_str(KnownHeaderName::ContentLength) { + cl.parse() + .map(Some) + .map_err(|_| Error::InvalidHeaderValue(KnownHeaderName::ContentLength.into())) + } else { + Ok(Some(0)) + } + } + + pub(super) fn body_len(&self) -> Option { + match self.response_body { + Some(ref body) => body.len(), + None => Some(0), + } + } + + fn write_headers(&mut self, output_buffer: &mut Vec) -> Result<()> { + use std::io::Write; + let status = self.status().unwrap_or(Status::NotFound); + + write!( + output_buffer, + "{} {} {}\r\n", + self.version, + status as u16, + status.canonical_reason() + )?; + + self.finalize_headers(); + + log::trace!( + "sending:\n{} {}\n{}", + self.version, + status, + &self.response_headers + ); + + for (name, values) in &self.response_headers { + if name.is_valid() { + for value in values { + if value.is_valid() { + write!(output_buffer, "{name}: ")?; + output_buffer.extend_from_slice(value.as_ref()); + write!(output_buffer, "\r\n")?; + } else { + log::error!("skipping invalid header value {value:?} for header {name}"); + } + } + } else { + log::error!("skipping invalid header with name {name:?}"); + } + } + + write!(output_buffer, "\r\n")?; + Ok(()) + } +} diff --git a/http/src/lib.rs b/http/src/lib.rs index a63ddd2d4f..f4a13940b2 100644 --- a/http/src/lib.rs +++ b/http/src/lib.rs @@ -24,57 +24,53 @@ //! usable interface on top of `trillium_http`, at very little cost. //! //! ``` -//! # fn main() -> trillium_http::Result<()> { smol::block_on(async { -//! use async_net::{TcpListener, TcpStream}; -//! use futures_lite::StreamExt; -//! use trillium_http::{Conn, Result, Swansong}; +//! fn main() -> trillium_http::Result<()> { +//! smol::block_on(async { +//! use async_net::TcpListener; +//! use futures_lite::StreamExt; +//! use std::sync::Arc; +//! use trillium_http::ServerConfig; //! -//! let swansong = Swansong::new(); -//! let listener = TcpListener::bind(("localhost", 0)).await?; -//! let port = listener.local_addr()?.port(); +//! let server_config = Arc::new(ServerConfig::default()); +//! let listener = TcpListener::bind(("localhost", 0)).await?; +//! let local_addr = listener.local_addr().unwrap(); +//! let server_handle = smol::spawn({ +//! let server_config = server_config.clone(); +//! async move { +//! let mut incoming = server_config.swansong().interrupt(listener.incoming()); //! -//! let server_swansong = swansong.clone(); -//! let server_handle = smol::spawn(async move { -//! let mut incoming = server_swansong.interrupt(listener.incoming()); +//! while let Some(Ok(stream)) = incoming.next().await { +//! smol::spawn(server_config.clone().run(stream, |mut conn| async move { +//! conn.set_response_body("hello world"); +//! conn.set_status(200); +//! conn +//! })) +//! .detach() +//! } +//! } +//! }); //! -//! while let Some(Ok(stream)) = incoming.next().await { -//! let swansong = server_swansong.clone(); -//! smol::spawn(Conn::map( -//! stream, -//! swansong, -//! |mut conn: Conn| async move { -//! conn.set_response_body("hello world"); -//! conn.set_status(200); -//! conn -//! }, -//! )) -//! .detach() -//! } +//! // this example uses the trillium client +//! // any other http client would work here too +//! let client = trillium_client::Client::new(trillium_smol::ClientConfig::default()) +//! .with_base(local_addr); +//! let mut client_conn = client.get("/").await?; //! -//! Result::Ok(()) -//! }); +//! assert_eq!(client_conn.status().unwrap(), 200); +//! assert_eq!( +//! client_conn.response_headers().get_str("content-length"), +//! Some("11") +//! ); +//! assert_eq!( +//! client_conn.response_body().read_string().await?, +//! "hello world" +//! ); //! -//! // this example uses the trillium client -//! // any other http client would work here too -//! -//! let url = format!("http://localhost:{}/", port); -//! let client = trillium_client::Client::new(trillium_smol::ClientConfig::default()); -//! let mut client_conn = client.get(&*url).await?; -//! -//! assert_eq!(client_conn.status().unwrap(), 200); -//! assert_eq!( -//! client_conn.response_headers().get_str("content-length"), -//! Some("11") -//! ); -//! assert_eq!( -//! client_conn.response_body().read_string().await?, -//! "hello world" -//! ); -//! -//! swansong.shut_down(); // stop the server after one request -//! server_handle.await?; // wait for the server to shut down -//! // -//! # Result::Ok(()) }) } +//! server_config.shut_down().await; // stop the server after one request +//! server_handle.await; // wait for the server to shut down +//! Ok(()) +//! }) +//! } //! ``` mod received_body; @@ -160,3 +156,5 @@ pub use copy::copy; pub(crate) use copy::copy; mod liveness; +mod server_config; +pub use server_config::ServerConfig; diff --git a/http/src/received_body.rs b/http/src/received_body.rs index a647b7d268..ce35b27261 100644 --- a/http/src/received_body.rs +++ b/http/src/received_body.rs @@ -361,7 +361,7 @@ impl<'conn, Transport> Debug for ReceivedBody<'conn, Transport> { f.debug_struct("RequestBody") .field("state", &*self.state) .field("content_length", &self.content_length) - .field("buffer", &"..") + .field("buffer", &format_args!("..")) .field("on_completion", &self.on_completion.is_some()) .finish() } diff --git a/http/src/received_body/chunked.rs b/http/src/received_body/chunked.rs index 55a1662cb1..9a8056e82d 100644 --- a/http/src/received_body/chunked.rs +++ b/http/src/received_body/chunked.rs @@ -171,7 +171,8 @@ mod tests { use crate::{http_config::DEFAULT_CONFIG, Buffer, HttpConfig}; use encoding_rs::UTF_8; use futures_lite::{io::Cursor, AsyncRead, AsyncReadExt}; - use trillium_testing::block_on; + use test_harness::test; + use trillium_testing::harness; #[track_caller] fn assert_decoded( @@ -244,22 +245,20 @@ mod tests { decode_with_config(input, poll_size, &DEFAULT_CONFIG).await } - #[test] - fn test_full_decode() { - block_on(async { - for size in 1..50 { - let input = "5\r\n12345\r\n1\r\na\r\n2\r\nbc\r\n3\r\ndef\r\n0\r\n"; - let output = decode(input.into(), size).await.unwrap(); - assert_eq!(output, "12345abcdef", "size: {size}"); - - let input = "7\r\nMozilla\r\n9\r\nDeveloper\r\n7\r\nNetwork\r\n0\r\n\r\n"; - let output = decode(input.into(), size).await.unwrap(); - assert_eq!(output, "MozillaDeveloperNetwork", "size: {size}"); - - assert!(decode(String::new(), size).await.is_err()); - assert!(decode("fffffffffffffff0\r\n".into(), size).await.is_err()); - } - }); + #[test(harness)] + async fn test_full_decode() { + for size in 1..50 { + let input = "5\r\n12345\r\n1\r\na\r\n2\r\nbc\r\n3\r\ndef\r\n0\r\n"; + let output = decode(input.into(), size).await.unwrap(); + assert_eq!(output, "12345abcdef", "size: {size}"); + + let input = "7\r\nMozilla\r\n9\r\nDeveloper\r\n7\r\nNetwork\r\n0\r\n\r\n"; + let output = decode(input.into(), size).await.unwrap(); + assert_eq!(output, "MozillaDeveloperNetwork", "size: {size}"); + + assert!(decode(String::new(), size).await.is_err()); + assert!(decode("fffffffffffffff0\r\n".into(), size).await.is_err()); + } } async fn build_chunked_body(input: String) -> String { @@ -276,49 +275,46 @@ mod tests { String::from_utf8(output).unwrap() } - #[test] - fn test_read_buffer_short() { - block_on(async { - let input = "test ".repeat(50); - let chunked = build_chunked_body(input.clone()).await; - - for size in 1..10 { - assert_eq!( - &decode(chunked.clone(), size).await.unwrap(), - &input, - "size: {size}" - ); - } - }); + #[test(harness)] + async fn test_read_buffer_short() { + let input = "test ".repeat(50); + let chunked = build_chunked_body(input.clone()).await; + + for size in 1..10 { + assert_eq!( + &decode(chunked.clone(), size).await.unwrap(), + &input, + "size: {size}" + ); + } } - #[test] - fn test_max_len() { - block_on(async { - let input = build_chunked_body("test ".repeat(10)).await; - - for size in 4..10 { - assert!( - decode_with_config( - input.clone(), - size, - &HttpConfig::default().with_received_body_max_len(5) - ) + #[test(harness)] + async fn test_max_len() { + let input = build_chunked_body("test ".repeat(10)).await; + + for size in 4..10 { + assert!( + decode_with_config( + input.clone(), + size, + &HttpConfig::default().with_received_body_max_len(5) + ) + .await + .is_err() + ); + + assert!( + decode_with_config(input.clone(), size, &HttpConfig::default()) .await - .is_err() - ); - - assert!( - decode_with_config(input.clone(), size, &HttpConfig::default()) - .await - .is_ok() - ); - } - }); + .is_ok() + ); + } } #[test] fn test_chunk_start() { + let _ = env_logger::builder().is_test(true).try_init(); assert_decoded((0, "5\r\n12345\r\n"), (Some(0), "12345", "")); assert_decoded((0, "F\r\n1"), (Some(14 + 2), "1", "")); assert_decoded((0, "5\r\n123"), (Some(2 + 2), "123", "")); @@ -343,6 +339,8 @@ mod tests { #[test] fn test_chunk_start_with_ext() { + let _ = env_logger::builder().is_test(true).try_init(); + assert_decoded((0, "5;abcdefg\r\n12345\r\n"), (Some(0), "12345", "")); assert_decoded((0, "F;aaa\taaaaa\taaa aaa\r\n1"), (Some(14 + 2), "1", "")); assert_decoded((0, "5;;;;;;;;;;;;;;;;\r\n123"), (Some(2 + 2), "123", "")); @@ -368,63 +366,61 @@ mod tests { assert_decoded((7, "hello\r\n0;\r\n\r\n"), (None, "hello", "")); } - #[test] - fn read_string_and_read_bytes() { - block_on(async { - let content = build_chunked_body("test ".repeat(100)).await; - assert_eq!( - new_with_config(content.clone(), &DEFAULT_CONFIG) - .read_string() - .await - .unwrap() - .len(), - 500 - ); - - assert_eq!( - new_with_config(content.clone(), &DEFAULT_CONFIG) - .read_bytes() - .await - .unwrap() - .len(), - 500 - ); - - assert!( - new_with_config( - content.clone(), - &DEFAULT_CONFIG.with_received_body_max_len(400) - ) + #[test(harness)] + async fn read_string_and_read_bytes() { + let content = build_chunked_body("test ".repeat(100)).await; + assert_eq!( + new_with_config(content.clone(), &DEFAULT_CONFIG) .read_string() .await - .is_err() - ); + .unwrap() + .len(), + 500 + ); - assert!( - new_with_config( - content.clone(), - &DEFAULT_CONFIG.with_received_body_max_len(400) - ) + assert_eq!( + new_with_config(content.clone(), &DEFAULT_CONFIG) .read_bytes() .await - .is_err() - ); + .unwrap() + .len(), + 500 + ); - assert!( - new_with_config(content.clone(), &DEFAULT_CONFIG) - .with_max_len(400) - .read_bytes() - .await - .is_err() - ); + assert!( + new_with_config( + content.clone(), + &DEFAULT_CONFIG.with_received_body_max_len(400) + ) + .read_string() + .await + .is_err() + ); - assert!( - new_with_config(content.clone(), &DEFAULT_CONFIG) - .with_max_len(400) - .read_string() - .await - .is_err() - ); - }); + assert!( + new_with_config( + content.clone(), + &DEFAULT_CONFIG.with_received_body_max_len(400) + ) + .read_bytes() + .await + .is_err() + ); + + assert!( + new_with_config(content.clone(), &DEFAULT_CONFIG) + .with_max_len(400) + .read_bytes() + .await + .is_err() + ); + + assert!( + new_with_config(content.clone(), &DEFAULT_CONFIG) + .with_max_len(400) + .read_string() + .await + .is_err() + ); } } diff --git a/http/src/server_config.rs b/http/src/server_config.rs new file mode 100644 index 0000000000..52f9e0fc89 --- /dev/null +++ b/http/src/server_config.rs @@ -0,0 +1,118 @@ +use crate::{Conn, ConnectionStatus, HttpConfig, Result, TypeSet, Upgrade}; +use futures_lite::{AsyncRead, AsyncWrite}; +use std::{future::Future, sync::Arc}; +use swansong::{ShutdownCompletion, Swansong}; +/// This struct represents the shared configuration and context for a http server. +/// +/// This currently contains tunable parameters in a [`HttpConfig`], the [`Swansong`] graceful +/// shutdown control interface, and a shared [`TypeSet`] that contains application-specific +/// information about the running server +#[derive(Default, Debug)] +pub struct ServerConfig { + pub(crate) http_config: HttpConfig, + pub(crate) swansong: Swansong, + pub(crate) shared_state: TypeSet, +} +impl AsRef for ServerConfig { + fn as_ref(&self) -> &TypeSet { + &self.shared_state + } +} + +impl AsMut for ServerConfig { + fn as_mut(&mut self) -> &mut TypeSet { + &mut self.shared_state + } +} + +impl AsRef for ServerConfig { + fn as_ref(&self) -> &Swansong { + &self.swansong + } +} + +impl AsRef for ServerConfig { + fn as_ref(&self) -> &HttpConfig { + &self.http_config + } +} + +impl ServerConfig { + /// Modify the [`HttpConfig`] for this server. + pub fn http_config_mut(&mut self) -> &mut HttpConfig { + &mut self.http_config + } + + /// Replace the [`Swansong`] graceful shutdown control interface for this server. + pub fn set_swansong(&mut self, swansong: Swansong) { + self.swansong = swansong; + } + + /// Borrow the [`Swansong`] graceful shutdown control interface for this server. + pub fn swansong(&self) -> &Swansong { + &self.swansong + } + + /// Construct a new `ServerConfig` + pub fn new() -> Self { + Self::default() + } + + /// Borrow the shared state [`TypeSet`] for this server + pub fn shared_state(&self) -> &TypeSet { + &self.shared_state + } + + /// Mutate the shared state [`TypeSet`] for this server. + /// + /// Types added here will be immutably available on all [`Conn`]s handled by this server. + pub fn shared_state_mut(&mut self) -> &mut TypeSet { + &mut self.shared_state + } + + /// Perform HTTP on the provided transport, applying the provided `async Conn -> Conn` handler + /// function for every distinct http request-response. + /// + /// For any given invocation of `ServerConfig::run`, the handler function may run any number of + /// times, depending on whether the connection is reused by the client. + /// + /// This can only be called on an `Arc` because an arc clone is moved into the + /// Conn. + /// + /// # Errors + /// + /// This function will return an [`Error`] if any of the http requests is irrecoverably + /// malformed or otherwise noncompliant. + pub async fn run( + self: Arc, + transport: Transport, + mut handler: Handler, + ) -> Result>> + where + Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + Handler: FnMut(Conn) -> Fut, + Fut: Future>, + { + let _guard = self.swansong.guard(); + let buffer = Vec::with_capacity(self.http_config.request_buffer_initial_len).into(); + + let mut conn = Conn::new_internal(self, transport, buffer).await?; + + loop { + conn = match handler(conn).await.send().await? { + ConnectionStatus::Upgrade(upgrade) => return Ok(Some(upgrade)), + ConnectionStatus::Close => return Ok(None), + ConnectionStatus::Conn(next) => next, + } + } + } + + /// Attempt graceful shutdown of this server. + /// + /// The returned [`ShutdownCompletion`] type can + /// either be awaited in an async context or blocked on with [`ShutdownCompletion::block`] in a + /// blocking context + pub fn shut_down(&self) -> ShutdownCompletion { + self.swansong.shut_down() + } +} diff --git a/http/src/synthetic.rs b/http/src/synthetic.rs index a216ef67f7..bb6c434678 100644 --- a/http/src/synthetic.rs +++ b/http/src/synthetic.rs @@ -1,10 +1,11 @@ use crate::{ after_send::AfterSend, http_config::DEFAULT_CONFIG, received_body::ReceivedBodyState, - transport::Transport, Conn, Headers, KnownHeaderName, Method, Swansong, TypeSet, Version, + transport::Transport, Conn, Headers, KnownHeaderName, Method, ServerConfig, TypeSet, Version, }; use futures_lite::io::{AsyncRead, AsyncWrite, Cursor, Result}; use std::{ pin::Pin, + sync::Arc, task::{Context, Poll}, time::Instant, }; @@ -134,6 +135,7 @@ impl Conn { request_headers.insert(KnownHeaderName::ContentLength, transport.len().to_string()); Self { + server_config: Arc::default(), transport, request_headers, response_headers: Headers::new(), @@ -146,15 +148,24 @@ impl Conn { buffer: Vec::with_capacity(DEFAULT_CONFIG.request_buffer_initial_len).into(), request_body_state: ReceivedBodyState::Start, secure: false, - swansong: Swansong::new(), after_send: AfterSend::default(), start_time: Instant::now(), peer_ip: None, - http_config: DEFAULT_CONFIG, - shared_state: None, } } + /// use a particular shared server config for this synthetic conn + pub fn set_server_config(&mut self, server_config: Arc) { + self.server_config = server_config; + } + + /// chainable setter for server config + #[must_use] + pub fn with_server_config(mut self, server_config: Arc) -> Self { + self.set_server_config(server_config); + self + } + /// simulate closing the transport pub fn close(&mut self) { self.transport.close(); diff --git a/http/src/upgrade.rs b/http/src/upgrade.rs index 9ee4ecf607..d810bab53c 100644 --- a/http/src/upgrade.rs +++ b/http/src/upgrade.rs @@ -1,4 +1,4 @@ -use crate::{received_body::read_buffered, Buffer, Conn, Headers, Method, Swansong, TypeSet}; +use crate::{received_body::read_buffered, Buffer, Conn, Headers, Method, ServerConfig, TypeSet}; use futures_lite::{AsyncRead, AsyncWrite}; use std::{ fmt::{self, Debug, Formatter}, @@ -6,6 +6,7 @@ use std::{ net::IpAddr, pin::Pin, str, + sync::Arc, task::{Context, Poll}, }; use trillium_macros::AsyncWrite; @@ -37,10 +38,8 @@ pub struct Upgrade { /// already. It is your responsibility to process these bytes /// before reading directly from the transport. pub buffer: Buffer, - /// A [`Swansong`] which can and should be used to gracefully shut - /// down any long running streams or futures associated with this - /// upgrade - pub swansong: Swansong, + /// The [`ServerConfig`] shared for this server + pub server_config: Arc, /// the ip address of the connection, if available pub peer_ip: Option, } @@ -61,7 +60,7 @@ impl Upgrade { transport, buffer, state: TypeSet::new(), - swansong: Swansong::new(), + server_config: Arc::default(), peer_ip: None, } } @@ -112,7 +111,7 @@ impl Upgrade { state: self.state, buffer: self.buffer, request_headers: self.request_headers, - swansong: self.swansong, + server_config: self.server_config, peer_ip: self.peer_ip, } } @@ -125,9 +124,9 @@ impl Debug for Upgrade { .field("path", &self.path) .field("method", &self.method) .field("buffer", &self.buffer) - .field("swansong", &self.swansong) + .field("server_config", &self.server_config) .field("state", &self.state) - .field("transport", &"..") + .field("transport", &format_args!("..")) .field("peer_ip", &self.peer_ip) .finish() } @@ -142,7 +141,7 @@ impl From> for Upgrade { state, transport, buffer, - swansong, + server_config, peer_ip, .. } = conn; @@ -154,7 +153,7 @@ impl From> for Upgrade { state, transport, buffer, - swansong, + server_config, peer_ip, } } diff --git a/http/tests/corpus.rs b/http/tests/corpus.rs index f657bc1469..8be831aeee 100644 --- a/http/tests/corpus.rs +++ b/http/tests/corpus.rs @@ -1,8 +1,8 @@ use indoc::formatdoc; use pretty_assertions::assert_str_eq; -use std::{env, net::Shutdown, path::PathBuf}; +use std::{env, net::Shutdown, path::PathBuf, sync::Arc}; use test_harness::test; -use trillium_http::{Conn, KnownHeaderName, Swansong}; +use trillium_http::{Conn, KnownHeaderName, ServerConfig, Swansong}; use trillium_testing::{harness, RuntimeTrait, TestTransport}; const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; @@ -44,7 +44,6 @@ async fn handler(mut conn: Conn) -> Conn { #[test(harness)] async fn corpus_test() { - env_logger::init(); let runtime = trillium_testing::runtime(); let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/corpus"); let filter = env::var("CORPUS_TEST_FILTER").unwrap_or_default(); @@ -69,9 +68,10 @@ async fn corpus_test() { let (client, server) = TestTransport::new(); let swansong = Swansong::new(); + let server_config = Arc::new(ServerConfig::new()); let res = runtime.spawn({ - let swansong = swansong.clone(); - async move { Conn::map(server, swansong, handler).await } + let server_config = server_config.clone(); + async move { server_config.run(server, handler).await } }); client.write_all(request); diff --git a/http/tests/one_hundred_continue.rs b/http/tests/one_hundred_continue.rs index 87f8bb869f..dda282f811 100644 --- a/http/tests/one_hundred_continue.rs +++ b/http/tests/one_hundred_continue.rs @@ -1,7 +1,8 @@ use indoc::{formatdoc, indoc}; use pretty_assertions::assert_eq; +use std::sync::Arc; use test_harness::test; -use trillium_http::{Conn, KnownHeaderName, Swansong, SERVER}; +use trillium_http::{Conn, KnownHeaderName, ServerConfig, SERVER}; use trillium_testing::{harness, RuntimeTrait, TestResult, TestTransport}; const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; @@ -21,7 +22,8 @@ async fn handler(mut conn: Conn) -> Conn { async fn one_hundred_continue() -> TestResult { let (client, server) = TestTransport::new(); let runtime = trillium_testing::runtime(); - let handle = runtime.spawn(async move { Conn::map(server, Swansong::new(), handler).await }); + let server_config = Arc::new(ServerConfig::default()); + let handle = runtime.spawn(server_config.run(server, handler)); client.write_all(indoc! {" POST / HTTP/1.1\r @@ -57,7 +59,8 @@ async fn one_hundred_continue() -> TestResult { async fn one_hundred_continue_http_one_dot_zero() -> TestResult { let (client, server) = TestTransport::new(); let runtime = trillium_testing::runtime(); - let handle = runtime.spawn(async move { Conn::map(server, Swansong::new(), handler).await }); + let server_config = Arc::new(ServerConfig::default()); + let handle = runtime.spawn(server_config.run(server, handler)); client.write_all(indoc! { " POST / HTTP/1.0\r diff --git a/http/tests/unsafe_headers.rs b/http/tests/unsafe_headers.rs index b4c3187e5f..ccf809ea9b 100644 --- a/http/tests/unsafe_headers.rs +++ b/http/tests/unsafe_headers.rs @@ -1,8 +1,8 @@ use indoc::{formatdoc, indoc}; use pretty_assertions::assert_eq; -use swansong::Swansong; +use std::sync::Arc; use test_harness::test; -use trillium_http::{Conn, KnownHeaderName, SERVER}; +use trillium_http::{Conn, KnownHeaderName, ServerConfig, SERVER}; use trillium_testing::{harness, RuntimeTrait, TestResult, TestTransport}; const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; @@ -24,7 +24,8 @@ async fn handler(mut conn: Conn) -> Conn { async fn bad_headers() -> TestResult { let (client, server) = TestTransport::new(); let runtime = trillium_testing::runtime(); - let handle = runtime.spawn(async move { Conn::map(server, Swansong::new(), handler).await }); + let server_config = Arc::new(ServerConfig::default()); + let handle = runtime.spawn(server_config.run(server, handler)); client.write_all(indoc! {" GET / HTTP/1.1\r diff --git a/http/tests/use_cases.rs b/http/tests/use_cases.rs index 9d63f1a994..e5dfe6dc48 100644 --- a/http/tests/use_cases.rs +++ b/http/tests/use_cases.rs @@ -3,7 +3,7 @@ use std::{future::Future, marker::PhantomData, sync::Arc}; use test_harness::test; use trillium_client::{Client, Connector, Url}; -use trillium_http::{Conn, KnownHeaderName}; +use trillium_http::{Conn, KnownHeaderName, ServerConfig}; use trillium_testing::{harness, Runtime, TestResult, TestTransport}; #[test(harness)] @@ -22,6 +22,7 @@ pub struct ServerConnector { handler: Arc, fut: PhantomData, runtime: Runtime, + server_config: Arc, } impl ServerConnector @@ -33,6 +34,7 @@ where Self { handler: Arc::new(handler), fut: PhantomData, + server_config: ServerConfig::default().into(), runtime: trillium_testing::runtime().into(), } } @@ -50,12 +52,10 @@ where let (client_transport, server_transport) = TestTransport::new(); let handler = self.handler.clone(); + let server_config = self.server_config.clone(); - self.runtime.spawn(async move { - Conn::map(server_transport, Default::default(), &*handler) - .await - .unwrap(); - }); + self.runtime + .spawn(async move { server_config.run(server_transport, &*handler).await }); Ok(client_transport) } diff --git a/logger/Cargo.toml b/logger/Cargo.toml index 586f9eefb6..a11c98f999 100644 --- a/logger/Cargo.toml +++ b/logger/Cargo.toml @@ -16,6 +16,7 @@ log = "0.4.20" size = "0.4.1" time = { version = "0.3.31", features = ["local-offset", "formatting", "macros"] } trillium = { path = "../trillium", version = "0.2.20" } +url = "2.5.0" [dev-dependencies] access_log_parser = "0.8.0" diff --git a/logger/src/lib.rs b/logger/src/lib.rs index ea3b5e25b2..4a361a5da6 100644 --- a/logger/src/lib.rs +++ b/logger/src/lib.rs @@ -8,7 +8,11 @@ //! Welcome to the trillium logger! pub use crate::formatters::{apache_combined, apache_common, dev_formatter}; -use std::{fmt::Display, io::IsTerminal, sync::Arc}; +use std::{ + fmt::{Display, Write}, + io::IsTerminal, + sync::Arc, +}; use trillium::{Conn, Handler, Info}; /// Components with which common log formats can be constructed pub mod formatters; @@ -241,6 +245,21 @@ impl Logger { } } +/// An easily-named `Arc` that is stored in trillium shared state +#[derive(Clone)] +pub struct LogTarget(Arc); +impl Targetable for LogTarget { + fn write(&self, data: String) { + self.0.write(data); + } +} +impl LogTarget { + /// Emit a log message to the logging backend + pub fn write(&self, data: String) { + self.0.write(data); + } +} + struct LoggerWasRun; impl Handler for Logger @@ -248,18 +267,25 @@ where F: LogFormatter, { async fn init(&mut self, info: &mut Info) { - self.target.write(format!( - " -🌱🦀🌱 {} started -Listening at {}{} + let mut string = "\nTrillium started\n".to_string(); + + if let Some(url) = info.state::() { + writeln!(string, "✾ Listening at {}", url.as_str()).unwrap(); + } + + if let Some(tcp) = info.tcp_socket_addr() { + writeln!(string, "✾ Bound as tcp://{tcp}").unwrap(); + } + + #[cfg(unix)] + if let Some(unix) = info.unix_socket_addr().and_then(|unix| unix.as_pathname()) { + writeln!(string, "✾ Bound as unix://{}", unix.display()).unwrap(); + } + + writeln!(string, "Control-c to quit").unwrap(); -Control-C to quit", - info.server_description(), - info.listener_description(), - info.tcp_socket_addr() - .map(|s| format!(" (bound as tcp://{s})")) - .unwrap_or_default(), - )); + self.target.write(string); + info.insert_state(LogTarget(Arc::clone(&self.target))); } async fn run(&self, conn: Conn) -> Conn { @@ -267,7 +293,7 @@ Control-C to quit", } async fn before_send(&self, mut conn: Conn) -> Conn { - if conn.state::().is_some() { + if conn.as_ref().contains::() { let target = self.target.clone(); let output = self.format.format(&conn, self.color_mode.is_enabled()); conn.inner_mut() diff --git a/macros/tests/derive.rs b/macros/tests/derive.rs index 7063cf844f..aa95fed3c0 100644 --- a/macros/tests/derive.rs +++ b/macros/tests/derive.rs @@ -20,7 +20,7 @@ fn full_lifecycle() { async fn init(&mut self, info: &mut Info) { self.init = true; - *info.server_description_mut() = "inner handler took over".into(); + info.insert_state("inner handler took over"); } async fn before_send(&self, conn: Conn) -> Conn { @@ -40,7 +40,7 @@ fn full_lifecycle() { let mut handler = OuterHandler(InnerHandler { init: false }); handler.init(&mut info).await; - assert_eq!(info.server_description(), "inner handler took over"); + assert_eq!(info.state::<&str>().unwrap(), &"inner handler took over"); assert!(handler.0.init); assert_ok!(get("/").run_async(&handler).await, "run", "before-send" => "before-send"); assert_eq!(handler.name(), "OuterHandler (inner handler)"); diff --git a/native-tls/src/client.rs b/native-tls/src/client.rs index 2e4c8722a7..ec9bc5cdf1 100644 --- a/native-tls/src/client.rs +++ b/native-tls/src/client.rs @@ -43,7 +43,7 @@ impl Debug for NativeTlsConfig { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("NativeTlsConfig") .field("tcp_config", &self.tcp_config) - .field("tls_connector", &"..") + .field("tls_connector", &format_args!("..")) .finish() } } diff --git a/rustls/src/client.rs b/rustls/src/client.rs index c61cee44d9..b136dc3411 100644 --- a/rustls/src/client.rs +++ b/rustls/src/client.rs @@ -97,7 +97,7 @@ impl RustlsConfig { impl Debug for RustlsConfig { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("RustlsConfig") - .field("rustls_config", &"..") + .field("rustls_config", &format_args!("..")) .field("tcp_config", &self.tcp_config) .finish() } diff --git a/server-common/Cargo.toml b/server-common/Cargo.toml index d97e0154de..515cf068a3 100644 --- a/server-common/Cargo.toml +++ b/server-common/Cargo.toml @@ -14,6 +14,7 @@ categories = ["web-programming::http-server", "web-programming"] async-channel = "2.2.0" async_cell = "0.2.2" futures-lite = "2.1.0" +listenfd = "1.0.1" log = "0.4.20" pin-project-lite = "0.2.13" swansong = "0.3.0" diff --git a/server-common/src/acceptor.rs b/server-common/src/acceptor.rs index ed29e7c935..862c911d42 100644 --- a/server-common/src/acceptor.rs +++ b/server-common/src/acceptor.rs @@ -22,6 +22,11 @@ where &self, input: Input, ) -> impl Future> + Send; + + /// should conns be treated as secure? + fn is_secure(&self) -> bool { + true + } } impl Acceptor for () @@ -34,4 +39,8 @@ where async fn accept(&self, input: Input) -> Result { Ok(input) } + + fn is_secure(&self) -> bool { + false + } } diff --git a/server-common/src/config.rs b/server-common/src/config.rs index cacffd0dae..2ddf86cbf9 100644 --- a/server-common/src/config.rs +++ b/server-common/src/config.rs @@ -1,12 +1,10 @@ -use crate::{Acceptor, RuntimeTrait, Server, ServerHandle}; +use crate::{running_config::RunningConfig, Acceptor, RuntimeTrait, Server, ServerHandle}; use async_cell::sync::AsyncCell; -use std::{ - cell::OnceCell, - marker::PhantomData, - net::SocketAddr, - sync::{Arc, RwLock}, -}; -use trillium::{Handler, HttpConfig, Info, Swansong}; +use futures_lite::StreamExt; +use std::{cell::OnceCell, net::SocketAddr, pin::pin, sync::Arc}; +use trillium::{Handler, HttpConfig, Info, Swansong, TypeSet}; +use trillium_http::ServerConfig; +use url::Url; /// # Primary entrypoint for configuring and running a trillium server /// @@ -58,17 +56,15 @@ use trillium::{Handler, HttpConfig, Info, Swansong}; #[derive(Debug)] pub struct Config { pub(crate) acceptor: AcceptorType, - pub(crate) port: Option, + pub(crate) binding: Option, pub(crate) host: Option, + pub(crate) server_config_cell: Arc>>, + pub(crate) max_connections: Option, pub(crate) nodelay: bool, - pub(crate) swansong: Swansong, + pub(crate) port: Option, pub(crate) register_signals: bool, - pub(crate) max_connections: Option, - pub(crate) info: Arc>>, - pub(crate) binding: RwLock>, - pub(crate) server: PhantomData, - pub(crate) http_config: HttpConfig, pub(crate) runtime: ServerType::Runtime, + pub(crate) server_config: ServerConfig, } impl Config @@ -82,8 +78,8 @@ where /// outside of trillium's web server. For applications that embed a /// trillium server inside of an already-running async runtime, use /// [`Config::run_async`] - pub fn run(self, h: H) { - ServerType::run(self, h) + pub fn run(self, handler: impl Handler) { + self.runtime.clone().block_on(self.run_async(handler)); } /// Runs the provided handler with this config, in an @@ -91,10 +87,78 @@ where /// for an application that needs to spawn async tasks that are /// unrelated to the trillium application. If you do not need to spawn /// other tasks, [`Config::run`] is the preferred entrypoint - pub async fn run_async(self, handler: impl Handler) { - let swansong = self.swansong.clone(); - ServerType::run_async(self, handler).await; - swansong.shut_down().await; + pub async fn run_async(self, mut handler: impl Handler) { + let Self { + runtime, + acceptor, + max_connections, + nodelay, + binding, + host, + port, + register_signals, + server_config, + server_config_cell, + } = self; + let host = host + .or_else(|| std::env::var("HOST").ok()) + .unwrap_or_else(|| "localhost".into()); + let port = port + .or_else(|| { + std::env::var("PORT") + .ok() + .map(|x| x.parse().expect("PORT must be an unsigned integer")) + }) + .unwrap_or(8080); + + let listener = binding + .inspect(|_| log::debug!("taking prebound listener")) + .unwrap_or_else(|| ServerType::from_host_and_port(&host, port)); + + let swansong = server_config.swansong().clone(); + + let mut info = Info::from(server_config) + .with_state(runtime.clone().into()) + .with_state(runtime.clone()); + listener.init(&mut info); + insert_url(info.as_mut(), acceptor.is_secure()); + handler.init(&mut info).await; + + let server_config = Arc::new(ServerConfig::from(info)); + server_config_cell.set(server_config.clone()); + + if register_signals { + let runtime = runtime.clone(); + runtime.clone().spawn(async move { + let mut signals = pin!(runtime.hook_signals([2, 3, 15])); + while signals.next().await.is_some() { + let guard_count = swansong.guard_count(); + if swansong.state().is_shutting_down() { + eprintln!( + "\nSecond interrupt, shutting down harshly (dropping {guard_count} \ + guards)" + ); + std::process::exit(1); + } else { + println!( + "\nShutting down gracefully. Waiting for {guard_count} shutdown \ + guards to drop.\nControl-c again to force." + ); + swansong.shut_down(); + } + } + }); + } + + let running_config = Arc::new(RunningConfig { + acceptor, + max_connections, + server_config, + runtime, + nodelay, + }); + + running_config.run_async(listener, handler).await; } /// Spawns the server onto the async runtime, returning a @@ -111,9 +175,9 @@ where /// when spawning the server onto a runtime. pub fn handle(&self) -> ServerHandle { ServerHandle { - swansong: self.swansong.clone(), - info: self.info.clone(), - received_info: OnceCell::new(), + swansong: self.server_config.swansong().clone(), + server_config: self.server_config_cell.clone(), + received_server_config: OnceCell::new(), runtime: self.runtime().into(), } } @@ -180,20 +244,18 @@ where host: self.host, port: self.port, nodelay: self.nodelay, - server: PhantomData, - swansong: self.swansong, register_signals: self.register_signals, max_connections: self.max_connections, - info: self.info, + server_config_cell: self.server_config_cell, + server_config: self.server_config, binding: self.binding, - http_config: self.http_config, runtime: self.runtime, } } /// use the specific [`Swansong`] provided pub fn with_swansong(mut self, swansong: Swansong) -> Self { - self.swansong = swansong; + self.server_config.set_swansong(swansong); self } @@ -209,7 +271,7 @@ where /// /// See [`HttpConfig`] for documentation pub fn with_http_config(mut self, http_config: HttpConfig) -> Self { - self.http_config = http_config; + *self.server_config.http_config_mut() = http_config; self } @@ -239,21 +301,28 @@ where ); } - self.binding = RwLock::new(Some(server.into())); + self.binding = Some(server.into()); self } fn has_binding(&self) -> bool { - self.binding - .read() - .as_deref() - .map_or(false, Option::is_some) + self.binding.is_some() } /// retrieve the runtime pub fn runtime(&self) -> ServerType::Runtime { self.runtime.clone() } + + /// return the configured port + pub fn port(&self) -> Option { + self.port + } + + /// return the configured host + pub fn host(&self) -> Option<&str> { + self.host.as_deref() + } } impl Config { @@ -282,15 +351,28 @@ impl Default for Config { acceptor: (), port: None, host: None, - server: PhantomData, nodelay: false, - swansong: Swansong::new(), register_signals: cfg!(unix), max_connections, - info: AsyncCell::shared(), - binding: RwLock::new(None), - http_config: HttpConfig::default(), + server_config_cell: AsyncCell::shared(), + binding: None, runtime: ServerType::runtime(), + server_config: Default::default(), } } } + +fn insert_url(state: &mut TypeSet, secure: bool) -> Option<()> { + let socket_addr = state.get::().copied()?; + let vacant_entry = state.entry::().into_vacant()?; + let scheme = if secure { "https" } else { "http" }; + let url = Url::parse(&if socket_addr.ip().is_loopback() { + format!("{scheme}://localhost:{}/", socket_addr.port()) + } else { + format!("{scheme}://{socket_addr}/") + }) + .ok()?; + + vacant_entry.insert(url); + Some(()) +} diff --git a/server-common/src/config_ext.rs b/server-common/src/config_ext.rs deleted file mode 100644 index 491cdeadd4..0000000000 --- a/server-common/src/config_ext.rs +++ /dev/null @@ -1,232 +0,0 @@ -use crate::{Acceptor, Config, Server, Transport}; -use futures_lite::prelude::*; -use std::{ - io::ErrorKind, - net::{SocketAddr, TcpListener, ToSocketAddrs}, - sync::Arc, -}; -use trillium::Handler; -use trillium_http::{transport::BoxedTransport, Error, Swansong, SERVICE_UNAVAILABLE}; -/// # Server-implementer interfaces to Config -/// -/// These functions are intended for use by authors of trillium servers, -/// and should not be necessary to build an application. Please open -/// an issue if you find yourself using this trait directly in an -/// application. - -pub trait ConfigExt -where - ServerType: Server, -{ - /// resolve a port for this application, either directly - /// configured, from the environmental variable `PORT`, or a default - /// of `8080` - fn port(&self) -> u16; - - /// resolve the host for this application, either directly from - /// configuration, from the `HOST` env var, or `"localhost"` - fn host(&self) -> String; - - /// use the [`ConfigExt::port`] and [`ConfigExt::host`] to resolve - /// a vec of potential socket addrs - fn socket_addrs(&self) -> Vec; - - /// returns whether this server should register itself for - /// operating system signals. this flag does nothing aside from - /// communicating to the server implementer that this is - /// desired. defaults to true on `cfg(unix)` systems, and false - /// elsewhere. - fn should_register_signals(&self) -> bool; - - /// returns whether the server should set TCP_NODELAY on the - /// TcpListener, if that is applicable - fn nodelay(&self) -> bool; - - /// returns a clone of the [`Swansong`] associated with - /// this server, to be used in conjunction with signals or other - /// service interruption methods - fn swansong(&self) -> Swansong; - - /// returns the tls acceptor for this server - fn acceptor(&self) -> &AcceptorType; - - /// waits for all requests to complete - fn graceful_shutdown(&self) -> impl Future + Send; - - /// apply the provided handler to the transport, using - /// [`trillium_http`]'s http implementation. this is the default inner - /// loop for most trillium servers - fn handle_stream( - self: Arc, - stream: ServerType::Transport, - handler: impl Handler, - ) -> impl Future + Send; - - /// builds any type that is TryFrom and - /// configures it for use. most trillium servers should use this if - /// possible instead of using [`ConfigExt::port`], - /// [`ConfigExt::host`], or [`ConfigExt::socket_addrs`]. - /// - /// this function also contains logic that sets nonblocking to - /// true and on unix systems will build a tcp listener from the - /// `LISTEN_FD` env var. - fn build_listener(&self) -> Listener - where - Listener: TryFrom, - >::Error: std::fmt::Debug; - - /// determines if the server is currently responding to more than - /// the maximum number of connections set by - /// `Config::with_max_connections`. - fn over_capacity(&self) -> bool; -} - -impl ConfigExt - for Config -where - ServerType: Server + Send + ?Sized, - AcceptorType: Acceptor<::Transport>, -{ - fn port(&self) -> u16 { - self.port - .or_else(|| std::env::var("PORT").ok().and_then(|p| p.parse().ok())) - .unwrap_or(8080) - } - - fn host(&self) -> String { - self.host - .as_ref() - .map(String::from) - .or_else(|| std::env::var("HOST").ok()) - .unwrap_or_else(|| String::from("localhost")) - } - - fn socket_addrs(&self) -> Vec { - (self.host(), self.port()) - .to_socket_addrs() - .unwrap() - .collect() - } - - fn should_register_signals(&self) -> bool { - self.register_signals - } - - fn nodelay(&self) -> bool { - self.nodelay - } - - fn swansong(&self) -> Swansong { - self.swansong.clone() - } - - fn acceptor(&self) -> &AcceptorType { - &self.acceptor - } - - async fn graceful_shutdown(&self) { - self.swansong.shut_down().await - } - - async fn handle_stream( - self: Arc, - mut stream: ServerType::Transport, - handler: impl Handler, - ) { - if self.over_capacity() { - let mut byte = [0u8]; // wait for the client to start requesting - trillium::log_error!(stream.read(&mut byte).await); - trillium::log_error!(stream.write_all(SERVICE_UNAVAILABLE).await); - return; - } - - let guard = self.swansong.guard(); - - trillium::log_error!(stream.set_nodelay(self.nodelay)); - - let peer_ip = stream.peer_addr().ok().flatten().map(|addr| addr.ip()); - - let transport = match self.acceptor.accept(stream).await { - Ok(stream) => stream, - Err(e) => { - log::error!("acceptor error: {:?}", e); - return; - } - }; - - let handler = &handler; - let result = trillium_http::Conn::map_with_config( - self.http_config, - transport, - self.swansong.clone(), - |mut conn| async { - conn.set_peer_ip(peer_ip); - let conn = handler.run(conn.into()).await; - let conn = handler.before_send(conn).await; - - conn.into_inner() - }, - ) - .await; - - match result { - Ok(Some(upgrade)) => { - let upgrade = upgrade.map_transport(BoxedTransport::new); - if handler.has_upgrade(&upgrade) { - log::debug!("upgrading..."); - handler.upgrade(upgrade).await; - } else { - log::error!("upgrade specified but no upgrade handler provided"); - } - } - - Err(Error::Closed) | Ok(None) => { - log::debug!("closing connection"); - } - - Err(Error::Io(e)) - if e.kind() == ErrorKind::ConnectionReset || e.kind() == ErrorKind::BrokenPipe => - { - log::debug!("closing connection"); - } - - Err(e) => { - log::error!("http error: {:?}", e); - } - }; - - drop(guard); - } - - fn build_listener(&self) -> Listener - where - Listener: TryFrom, - >::Error: std::fmt::Debug, - { - #[cfg(unix)] - let listener = { - use std::os::unix::prelude::FromRawFd; - - if let Some(fd) = std::env::var("LISTEN_FD") - .ok() - .and_then(|fd| fd.parse().ok()) - { - log::debug!("using fd {} from LISTEN_FD", fd); - unsafe { TcpListener::from_raw_fd(fd) } - } else { - TcpListener::bind((self.host(), self.port())).unwrap() - } - }; - - #[cfg(not(unix))] - let listener = TcpListener::bind((self.host(), self.port())).unwrap(); - - listener.set_nonblocking(true).unwrap(); - listener.try_into().unwrap() - } - - fn over_capacity(&self) -> bool { - self.max_connections - .map_or(false, |m| self.swansong.guard_count() >= m) - } -} diff --git a/server-common/src/lib.rs b/server-common/src/lib.rs index 71f66b1acf..9b06db112e 100644 --- a/server-common/src/lib.rs +++ b/server-common/src/lib.rs @@ -1,3 +1,4 @@ +#![forbid(unsafe_code)] #![deny( clippy::dbg_macro, missing_copy_implementations, @@ -28,9 +29,6 @@ pub use url::{self, Url}; mod config; pub use config::Config; -mod config_ext; -pub use config_ext::ConfigExt; - mod server; pub use server::Server; @@ -52,3 +50,5 @@ pub use swansong::Swansong; mod runtime; pub use runtime::{DroppableFuture, Runtime, RuntimeTrait}; + +mod running_config; diff --git a/server-common/src/running_config.rs b/server-common/src/running_config.rs new file mode 100644 index 0000000000..ec2267c586 --- /dev/null +++ b/server-common/src/running_config.rs @@ -0,0 +1,131 @@ +use crate::{Acceptor, ArcHandler, RuntimeTrait, Server}; +use futures_lite::{AsyncReadExt, AsyncWriteExt}; +use std::{io::ErrorKind, sync::Arc}; +use trillium::Handler; +use trillium_http::{ + transport::{BoxedTransport, Transport}, + Error, ServerConfig, SERVICE_UNAVAILABLE, +}; + +#[derive(Debug)] +pub struct RunningConfig { + pub(crate) acceptor: AcceptorType, + pub(crate) max_connections: Option, + pub(crate) nodelay: bool, + pub(crate) runtime: ServerType::Runtime, + pub(crate) server_config: Arc, +} + +impl::Transport>> RunningConfig { + pub(crate) async fn run_async(self: Arc, mut listener: S, handler: impl Handler) { + let swansong = self.server_config.as_ref().swansong(); + let runtime = self.runtime.clone(); + let handler = ArcHandler::new(handler); + while let Some(transport) = swansong.interrupt(listener.accept()).await { + match transport { + Ok(stream) => { + runtime.spawn( + Arc::clone(&self).handle_stream(stream, ArcHandler::clone(&handler)), + ); + } + Err(e) => log::error!("tcp error: {}", e), + } + } + + self.server_config.swansong().shut_down().await; + listener.clean_up().await; + } + + async fn handle_stream(self: Arc, mut stream: S::Transport, handler: impl Handler) { + if self.over_capacity() { + let mut byte = [0u8]; // wait for the client to start requesting + trillium::log_error!(stream.read(&mut byte).await); + trillium::log_error!(stream.write_all(SERVICE_UNAVAILABLE).await); + return; + } + + trillium::log_error!(stream.set_nodelay(self.nodelay)); + + let peer_ip = stream.peer_addr().ok().flatten().map(|addr| addr.ip()); + + let transport = match self.acceptor.accept(stream).await { + Ok(stream) => stream, + Err(e) => { + log::error!("acceptor error: {:?}", e); + return; + } + }; + + let handler = &handler; + + let result = self + .server_config + .clone() + .run(transport, |mut conn| async { + conn.set_peer_ip(peer_ip); + let conn = handler.run(conn.into()).await; + let conn = handler.before_send(conn).await; + + conn.into_inner() + }) + .await; + + match result { + Ok(Some(upgrade)) => { + let upgrade = upgrade.map_transport(BoxedTransport::new); + if handler.has_upgrade(&upgrade) { + log::debug!("upgrading..."); + handler.upgrade(upgrade).await; + } else { + log::error!("upgrade specified but no upgrade handler provided"); + } + } + + Err(Error::Closed) | Ok(None) => { + log::debug!("closing connection"); + } + + Err(Error::Io(e)) + if e.kind() == ErrorKind::ConnectionReset || e.kind() == ErrorKind::BrokenPipe => + { + log::debug!("closing connection"); + } + + Err(e) => { + log::error!("http error: {:?}", e); + } + }; + } + + // fn build_listener(&self) -> Listener + // where + // Listener: TryFrom, + // >::Error: std::fmt::Debug, + // { + // #[cfg(unix)] + // let listener = { + // use std::os::unix::prelude::FromRawFd; + + // if let Some(fd) = std::env::var("LISTEN_FD") + // .ok() + // .and_then(|fd| fd.parse().ok()) + // { + // log::debug!("using fd {} from LISTEN_FD", fd); + // unsafe { TcpListener::from_raw_fd(fd) } + // } else { + // TcpListener::bind((self.host(), self.port())).unwrap() + // } + // }; + + // #[cfg(not(unix))] + // let listener = TcpListener::bind((self.host(), self.port())).unwrap(); + + // listener.set_nonblocking(true).unwrap(); + // listener.try_into().unwrap() + // } + + fn over_capacity(&self) -> bool { + self.max_connections + .map_or(false, |m| self.server_config.swansong().guard_count() >= m) + } +} diff --git a/server-common/src/runtime.rs b/server-common/src/runtime.rs index 52c39b0c7f..2a609162b2 100644 --- a/server-common/src/runtime.rs +++ b/server-common/src/runtime.rs @@ -22,16 +22,20 @@ pub struct Runtime(Arc); impl Debug for Runtime { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_tuple("Runtime").field(&"..").finish() + f.debug_tuple("Runtime").field(&format_args!("..")).finish() } } impl Runtime { /// Construct a new type-erased runtime object from any [`RuntimeTrait`] implementation. - /// - /// Prefer using [`from`][From::from]/[`into`][Into::into] if you don't have a concrete - /// `RuntimeTrait` in order to avoid double-arc-ing a Runtime. pub fn new(runtime: impl RuntimeTrait) -> Self { + runtime.into() // we avoid re-arcing a Runtime by using Into::into + } + + // in order to avoid re-arcing Runtime in new / into, we use this to actually construct the + // Runtime within From implementations on the runtime trait type + #[doc(hidden)] + pub fn from_trait_impl(runtime: impl RuntimeTrait) -> Self { Self(Arc::new(runtime)) } @@ -117,4 +121,11 @@ impl RuntimeTrait for Runtime { })); receive.recv().unwrap() } + + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + self.0.hook_signals(signals.into_iter().collect()) + } } diff --git a/server-common/src/runtime/object_safe_runtime.rs b/server-common/src/runtime/object_safe_runtime.rs index 5fe22c4dab..42b848952f 100644 --- a/server-common/src/runtime/object_safe_runtime.rs +++ b/server-common/src/runtime/object_safe_runtime.rs @@ -19,6 +19,8 @@ pub(super) trait ObjectSafeRuntime: Send + Sync + 'static { where 'runtime: 'fut, Self: 'fut; + + fn hook_signals(&self, signals: Vec) -> Pin + Send + 'static>>; } impl ObjectSafeRuntime for R @@ -54,4 +56,8 @@ where { RuntimeTrait::block_on(self, fut) } + + fn hook_signals(&self, signals: Vec) -> Pin + Send + 'static>> { + Box::pin(RuntimeTrait::hook_signals(self, signals)) + } } diff --git a/server-common/src/runtime/runtime_trait.rs b/server-common/src/runtime/runtime_trait.rs index 76f251484c..63d834db0e 100644 --- a/server-common/src/runtime/runtime_trait.rs +++ b/server-common/src/runtime/runtime_trait.rs @@ -51,4 +51,13 @@ pub trait RuntimeTrait: Into + Clone + Send + Sync + 'static { None }) } + + /// trap and return a [`Stream`] of signals that match the provided signals + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + let _ = signals; + futures_lite::stream::empty() + } } diff --git a/server-common/src/server.rs b/server-common/src/server.rs index a817decfb3..4ec313251f 100644 --- a/server-common/src/server.rs +++ b/server-common/src/server.rs @@ -1,6 +1,9 @@ -use crate::{Acceptor, ArcHandler, Config, ConfigExt, RuntimeTrait, Swansong, Transport}; -use std::{future::Future, io::Result, sync::Arc}; -use trillium::{Handler, Info}; +use crate::{RuntimeTrait, Swansong, Transport}; +use listenfd::ListenFd; +#[cfg(unix)] +use std::os::unix::net::UnixListener; +use std::{future::Future, io::Result, net::TcpListener}; +use trillium::Info; /// The server trait, for standard network-based server implementations. pub trait Server: Sized + Send + Sync + 'static { @@ -12,16 +15,15 @@ pub trait Server: Sized + Send + Sync + 'static { /// The [`RuntimeTrait`] for this `Server`. type Runtime: RuntimeTrait; - /// The description of this server, to be appended to the Info and potentially logged. - const DESCRIPTION: &'static str; - /// Asynchronously return a single `Self::Transport` from a /// `Self::Listener`. Must be implemented. fn accept(&mut self) -> impl Future> + Send; /// Build an [`Info`] from the Self::Listener type. See [`Info`] /// for more details. - fn info(&self) -> Info; + fn init(&self, info: &mut Info) { + let _ = info; + } /// After the server has shut down, perform any housekeeping, eg /// unlinking a unix socket. @@ -36,61 +38,59 @@ pub trait Server: Sized + Send + Sync + 'static { /// is described elsewhere. To override the default logic, server /// implementations could potentially implement this directly. To /// use this default logic, implement - /// [`Server::listener_from_tcp`] and - /// [`Server::listener_from_unix`]. - #[cfg(unix)] - fn build_listener(config: &Config) -> Self - where - A: Acceptor, - { - if let Some(listener) = config.binding.write().unwrap().take() { - log::debug!("taking prebound listener"); - return listener; + /// [`Server::from_tcp`] and + /// [`Server::from_unix`]. + fn from_host_and_port(host: &str, port: u16) -> Self { + #[cfg(unix)] + if host.starts_with(|c| c == '/' || c == '.' || c == '~') { + log::debug!("using unix listener at {host}"); + return UnixListener::bind(host) + .inspect(|unix_listener| { + log::debug!("listening at {:?}", unix_listener.local_addr().unwrap()); + }) + .map(Self::from_unix) + .unwrap(); } - use std::os::unix::prelude::FromRawFd; - let host = config.host(); - if host.starts_with(|c| c == '/' || c == '.' || c == '~') { - Self::listener_from_unix(std::os::unix::net::UnixListener::bind(host).unwrap()) - } else { - let tcp_listener = if let Some(fd) = std::env::var("LISTEN_FD") - .ok() - .and_then(|fd| fd.parse().ok()) - { - log::debug!("using fd {} from LISTEN_FD", fd); - unsafe { std::net::TcpListener::from_raw_fd(fd) } - } else { - std::net::TcpListener::bind((host, config.port())).unwrap() - }; + let mut listen_fd = ListenFd::from_env(); - tcp_listener.set_nonblocking(true).unwrap(); - Self::listener_from_tcp(tcp_listener) + #[cfg(unix)] + if let Ok(Some(unix_listener)) = listen_fd.take_unix_listener(0) { + log::debug!( + "using unix listener from systemfd environment {:?}", + unix_listener.local_addr().unwrap() + ); + return Self::from_unix(unix_listener); } - } - /// Build a listener from the config. The default logic for this - /// is described elsewhere. To override the default logic, server - /// implementations could potentially implement this directly. To - /// use this default logic, implement [`Server::listener_from_tcp`] - #[cfg(not(unix))] - fn build_listener(config: &Config) -> Self - where - A: Acceptor, - { - if let Some(listener) = config.binding.write().unwrap().take() { - log::debug!("taking prebound listener"); - return listener; - } + let tcp_listener = listen_fd + .take_tcp_listener(0) + .ok() + .flatten() + .inspect(|tcp_listener| { + log::debug!( + "using tcp listener from systemfd environment, listening at {:?}", + tcp_listener.local_addr() + ) + }) + .unwrap_or_else(|| { + log::debug!("using tcp listener at {host}:{port}"); + TcpListener::bind((host, port)) + .inspect(|tcp_listener| { + log::debug!("listening at {:?}", tcp_listener.local_addr().unwrap()) + }) + .unwrap() + }); - let tcp_listener = std::net::TcpListener::bind((config.host(), config.port())).unwrap(); tcp_listener.set_nonblocking(true).unwrap(); - Self::listener_from_tcp(tcp_listener) + Self::from_tcp(tcp_listener) } /// Build a Self::Listener from a tcp listener. This is called by /// the [`Server::build_listener`] default implementation, and /// is mandatory if the default implementation is used. - fn listener_from_tcp(_tcp: std::net::TcpListener) -> Self { + fn from_tcp(tcp_listener: TcpListener) -> Self { + let _ = tcp_listener; unimplemented!() } @@ -98,7 +98,8 @@ pub trait Server: Sized + Send + Sync + 'static { /// the [`Server::build_listener`] default implementation. You /// will want to tag an implementation of this with #[cfg(unix)]. #[cfg(unix)] - fn listener_from_unix(_tcp: std::os::unix::net::UnixListener) -> Self { + fn from_unix(unix_listener: UnixListener) -> Self { + let _ = unix_listener; unimplemented!() } @@ -108,58 +109,4 @@ pub trait Server: Sized + Send + Sync + 'static { fn handle_signals(_swansong: Swansong) -> impl Future + Send { async {} } - - /// Run a trillium application from a sync context - fn run(config: Config, handler: H) - where - A: Acceptor, - H: Handler, - { - config - .runtime - .clone() - .block_on(async move { Self::run_async(config, handler).await }); - } - - /// Run a trillium application from an async context. The default - /// implementation of this method contains the core logic of this - /// Trait. - fn run_async(config: Config, mut handler: H) -> impl Future + Send - where - A: Acceptor, - H: Handler, - { - async move { - let runtime = config.runtime.clone(); - if config.should_register_signals() { - #[cfg(unix)] - runtime.spawn(Self::handle_signals(config.swansong())); - - #[cfg(not(unix))] - log::error!("signals handling not supported on windows yet"); - } - let mut listener = Self::build_listener(&config); - let mut info = Self::info(&listener); - info.server_description_mut().push_str(Self::DESCRIPTION); - handler.init(&mut info).await; - config.info.set(Arc::new(info)); - let config = Arc::new(config); - let handler = ArcHandler::new(handler); - let swansong = &config.swansong; - - while let Some(transport) = swansong.interrupt(Self::accept(&mut listener)).await { - match transport { - Ok(stream) => { - let config = Arc::clone(&config); - let handler = ArcHandler::clone(&handler); - runtime.spawn(config.handle_stream(stream, handler)); - } - Err(e) => log::error!("tcp error: {}", e), - } - } - - config.graceful_shutdown().await; - Self::clean_up(listener).await; - } - } } diff --git a/server-common/src/server_handle.rs b/server-common/src/server_handle.rs index 5110155575..5a750c9f08 100644 --- a/server-common/src/server_handle.rs +++ b/server-common/src/server_handle.rs @@ -1,8 +1,8 @@ use crate::Runtime; use async_cell::sync::AsyncCell; -use std::{cell::OnceCell, future::IntoFuture, sync::Arc}; +use std::{cell::OnceCell, future::IntoFuture, net::SocketAddr, sync::Arc}; use swansong::{ShutdownCompletion, Swansong}; -use trillium::Info; +use trillium_http::ServerConfig; /// A handle for a spawned trillium server. Returned by /// [`Config::handle`][crate::Config::handle] and @@ -10,19 +10,48 @@ use trillium::Info; #[derive(Clone, Debug)] pub struct ServerHandle { pub(crate) swansong: Swansong, - pub(crate) info: Arc>>, - pub(crate) received_info: OnceCell>, + pub(crate) server_config: Arc>>, + pub(crate) received_server_config: OnceCell>, pub(crate) runtime: Runtime, } +#[derive(Debug)] +pub struct BoundInfo(Arc); + +impl BoundInfo { + /// Borrow a type from the [`TypeSet`] on this `BoundInfo`. + pub fn state(&self) -> Option<&T> { + self.0.shared_state().get() + } + + /// Returns the `local_addr` of a bound tcp listener, if such a thing exists for this server + pub fn tcp_socket_addr(&self) -> Option<&SocketAddr> { + self.state() + } + + pub fn url(&self) -> Option<&url::Url> { + self.state() + } + + /// Returns the `local_addr` of a bound unix listener, if such a thing exists for this server + #[cfg(unix)] + pub fn unix_socket_addr(&self) -> Option<&std::os::unix::net::SocketAddr> { + self.state() + } +} + impl ServerHandle { /// await server start and retrieve the server's [`Info`] - pub async fn info(&self) -> &Info { - if let Some(info) = self.received_info.get() { - return info; + pub async fn info(&self) -> BoundInfo { + if let Some(server_config) = self.received_server_config.get().cloned() { + return BoundInfo(server_config); } - let arc_info = self.info.get().await; - self.received_info.get_or_init(|| arc_info) + let arc_server_config = self.server_config.get().await; + let server_config = self + .received_server_config + .get_or_init(|| arc_server_config); + + BoundInfo(Arc::clone(server_config)) } /// stop server and return a future that can be awaited for it to shut down gracefully diff --git a/smol/examples/smol.rs b/smol/examples/smol.rs index 2d32e8ff1e..180617b5ce 100644 --- a/smol/examples/smol.rs +++ b/smol/examples/smol.rs @@ -1,11 +1,11 @@ use std::time::Duration; use trillium::{Conn, Handler}; use trillium_logger::Logger; -use trillium_smol::SmolRuntime; +use trillium_server_common::Runtime; pub fn app() -> impl Handler { (Logger::new(), |conn: Conn| async move { - let runtime = SmolRuntime::default(); + let runtime = conn.shared_state::().cloned().unwrap(); let response = runtime .clone() .spawn(async move { @@ -23,13 +23,3 @@ pub fn main() { env_logger::init(); trillium_smol::run(app()); } - -#[cfg(test)] -mod tests { - use trillium_testing::prelude::*; - #[test] - fn test() { - let app = super::app(); - assert_ok!(get("/").on(&app), "successfully spawned a task"); - } -} diff --git a/smol/src/runtime.rs b/smol/src/runtime.rs index 62f3595b21..20d26f0bf1 100644 --- a/smol/src/runtime.rs +++ b/smol/src/runtime.rs @@ -54,6 +54,14 @@ impl RuntimeTrait for SmolRuntime { fn block_on(&self, fut: Fut) -> Fut::Output { async_global_executor::block_on(fut) } + + #[cfg(unix)] + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + signal_hook_async_std::Signals::new(signals).unwrap() + } } impl SmolRuntime { @@ -105,6 +113,6 @@ impl SmolRuntime { impl From for Runtime { fn from(value: SmolRuntime) -> Self { - Runtime::new(value) + Runtime::from_trait_impl(value) } } diff --git a/smol/src/server/tcp.rs b/smol/src/server/tcp.rs index 3c3e07e104..3931d77343 100644 --- a/smol/src/server/tcp.rs +++ b/smol/src/server/tcp.rs @@ -1,8 +1,8 @@ use crate::{SmolRuntime, SmolTransport}; use async_net::{TcpListener, TcpStream}; -use std::{convert::TryInto, env, io::Result, net}; +use std::{convert::TryInto, io::Result, net}; use trillium::Info; -use trillium_server_common::{Server, Url}; +use trillium_server_common::Server; #[derive(Debug)] pub struct SmolTcpServer(TcpListener); @@ -16,29 +16,18 @@ impl Server for SmolTcpServer { type Runtime = SmolRuntime; type Transport = SmolTransport; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); - async fn accept(&mut self) -> Result { self.0.accept().await.map(|(t, _)| t.into()) } - fn listener_from_tcp(tcp: net::TcpListener) -> Self { + fn from_tcp(tcp: net::TcpListener) -> Self { Self(tcp.try_into().unwrap()) } - fn info(&self) -> Info { - let local_addr = self.0.local_addr().unwrap(); - let mut info = Info::from(local_addr); - if let Ok(url) = Url::parse(&format!("http://{local_addr}")) { - info.state_mut().insert(url); + fn init(&self, info: &mut Info) { + if let Ok(socket_addr) = self.0.local_addr() { + info.insert_state(socket_addr); } - info } fn runtime() -> Self::Runtime { diff --git a/smol/src/server/unix.rs b/smol/src/server/unix.rs index 6647213c56..d26bd94ee6 100644 --- a/smol/src/server/unix.rs +++ b/smol/src/server/unix.rs @@ -3,12 +3,11 @@ use async_net::{ unix::{UnixListener, UnixStream}, TcpListener, TcpStream, }; -use futures_lite::prelude::*; -use std::{env, io::Result}; +use std::io::Result; use trillium::{log_error, Info}; use trillium_server_common::{ Binding::{self, *}, - Server, Swansong, Url, + Server, }; #[derive(Debug, Clone)] @@ -29,35 +28,10 @@ impl Server for SmolServer { type Runtime = SmolRuntime; type Transport = Binding, SmolTransport>; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); - fn runtime() -> Self::Runtime { SmolRuntime::default() } - async fn handle_signals(swansong: Swansong) { - use signal_hook::consts::signal::*; - use signal_hook_async_std::Signals; - - let signals = Signals::new([SIGINT, SIGTERM, SIGQUIT]).unwrap(); - let mut signals = signals.fuse(); - while signals.next().await.is_some() { - if swansong.state().is_shutting_down() { - eprintln!("\nSecond interrupt, shutting down harshly"); - std::process::exit(1); - } else { - println!("\nShutting down gracefully.\nControl-C again to force."); - swansong.shut_down(); - } - } - } - async fn accept(&mut self) -> Result { match &self.0 { Tcp(t) => t.accept().await.map(|(t, _)| Tcp(SmolTransport::from(t))), @@ -65,25 +39,26 @@ impl Server for SmolServer { } } - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { + fn from_tcp(tcp: std::net::TcpListener) -> Self { Self(Tcp(tcp.try_into().unwrap())) } - fn listener_from_unix(tcp: std::os::unix::net::UnixListener) -> Self { + fn from_unix(tcp: std::os::unix::net::UnixListener) -> Self { Self(Unix(tcp.try_into().unwrap())) } - fn info(&self) -> Info { + fn init(&self, info: &mut Info) { match &self.0 { Tcp(t) => { - let local_addr = t.local_addr().unwrap(); - let mut info = Info::from(local_addr); - if let Ok(url) = Url::parse(&format!("http://{local_addr}")) { - info.state_mut().insert(url); + if let Ok(socket_addr) = t.local_addr() { + info.insert_state(socket_addr); + } + } + Unix(u) => { + if let Ok(socket_addr) = u.local_addr() { + info.insert_state(socket_addr); } - info } - Unix(u) => u.local_addr().unwrap().into(), } } diff --git a/static/src/handler.rs b/static/src/handler.rs index 76289a64ed..5a94f7fabe 100644 --- a/static/src/handler.rs +++ b/static/src/handler.rs @@ -78,7 +78,7 @@ impl StaticFileHandler { /// use trillium_testing::prelude::*; /// /// let mut handler = StaticFileHandler::new(crate_relative_path!("examples/files")); - /// # handler.init(&mut "testing".into()).await; + /// # init(&mut handler); /// /// assert_not_handled!(get("/").run_async(&handler).await); // no index file configured /// @@ -122,7 +122,7 @@ impl StaticFileHandler { /// /// let mut handler = StaticFileHandler::new(crate_relative_path!("examples/files")) /// .with_index_file("index.html"); - /// # handler.init(&mut "testing".into()).await; + /// # init(&mut handler); /// /// use trillium_testing::prelude::*; /// assert_ok!( diff --git a/static/src/lib.rs b/static/src/lib.rs index 19b3dc7ef6..9c333be47d 100644 --- a/static/src/lib.rs +++ b/static/src/lib.rs @@ -29,10 +29,9 @@ //! // └── subdir_with_no_index //! // └── plaintext.txt //! -//! use trillium::Handler; //! use trillium_testing::prelude::*; //! -//! handler.init(&mut "testing".into()).await; +//! init(&mut handler).await; //! //! assert_ok!( //! get("/").run_async(&handler).await, @@ -64,6 +63,8 @@ //! let plaintext_index = StaticFileHandler::new(crate_relative_path!("examples/files")) //! .with_index_file("plaintext.txt"); //! +//! init(&mut handler).await; +//! //! assert_not_handled!(get("/").run_async(&plaintext_index).await); //! assert_not_handled!(get("/subdir").run_async(&plaintext_index).await); //! assert_ok!( diff --git a/tera/src/tera_handler.rs b/tera/src/tera_handler.rs index d3470ebdcc..642174dd44 100644 --- a/tera/src/tera_handler.rs +++ b/tera/src/tera_handler.rs @@ -20,13 +20,13 @@ impl From<&str> for TeraHandler { impl From<&String> for TeraHandler { fn from(dir: &String) -> Self { - (**dir).into() + Tera::new(&dir).unwrap().into() } } impl From for TeraHandler { fn from(dir: String) -> Self { - dir.into() + Tera::new(&dir).unwrap().into() } } diff --git a/testing/Cargo.toml b/testing/Cargo.toml index 7ae2fb3c23..f2c229cf47 100644 --- a/testing/Cargo.toml +++ b/testing/Cargo.toml @@ -32,6 +32,7 @@ trillium-macros = { version = "0.0.6", path = "../macros" } dashmap = "5.5.3" once_cell = "1.19.0" fastrand = "2.0.1" +env_logger = "0.11.3" log = "0.4.21" [dependencies.trillium-smol] diff --git a/testing/src/lib.rs b/testing/src/lib.rs index 2af93c8f14..0954eafe7d 100644 --- a/testing/src/lib.rs +++ b/testing/src/lib.rs @@ -60,7 +60,7 @@ mod assertions; mod test_transport; -use std::{future::Future, process::Termination}; +use std::{future::Future, process::Termination, sync::Arc}; pub use test_transport::TestTransport; mod test_conn; @@ -76,7 +76,9 @@ pub mod prelude { pub use trillium::{Conn, Method, Status}; } +use trillium::{Handler, Info}; pub use trillium::{Method, Status}; +use trillium_http::ServerConfig; pub use url::Url; /// runs the future to completion on the current thread @@ -85,9 +87,12 @@ pub fn block_on(fut: Fut) -> Fut::Output { } /// initialize a handler -pub fn init(handler: &mut impl trillium::Handler) { - let mut info = "testing".into(); - block_on(async move { handler.init(&mut info).await }) +pub async fn init(handler: &mut impl Handler) -> Arc { + let mut info = Info::from(ServerConfig::default()); + info.insert_state(runtime()); + info.insert_state(runtime().into()); + handler.init(&mut info).await; + Arc::new(info.into()) } // these exports are used by macros @@ -185,5 +190,18 @@ where Fut: Future, Output: Termination, { + let _ = env_logger::builder().is_test(true).try_init(); block_on(test()) } + +/// a harness that includes the runtime +#[track_caller] +pub fn with_runtime(test: F) -> Output +where + F: FnOnce(Runtime) -> Fut, + Fut: Future, + Output: Termination, +{ + let runtime = runtime(); + runtime.clone().block_on(test(runtime.into())) +} diff --git a/testing/src/runtimeless/runtime.rs b/testing/src/runtimeless/runtime.rs index 98c50fc911..76643a2e63 100644 --- a/testing/src/runtimeless/runtime.rs +++ b/testing/src/runtimeless/runtime.rs @@ -51,7 +51,7 @@ impl RuntimeTrait for RuntimelessRuntime { } impl From for Runtime { fn from(value: RuntimelessRuntime) -> Self { - Runtime::new(value) + Runtime::from_trait_impl(value) } } impl RuntimelessRuntime { diff --git a/testing/src/runtimeless/server.rs b/testing/src/runtimeless/server.rs index 2b31dc046a..ef1c86d48e 100644 --- a/testing/src/runtimeless/server.rs +++ b/testing/src/runtimeless/server.rs @@ -3,7 +3,7 @@ use crate::{RuntimelessRuntime, TestTransport}; use async_channel::Receiver; use std::io::{Error, ErrorKind, Result}; use trillium::Info; -use trillium_server_common::{Acceptor, Config, ConfigExt, Server}; +use trillium_server_common::Server; use url::Url; /// A [`Server`] for testing that does not depend on any runtime @@ -30,8 +30,6 @@ impl Server for RuntimelessServer { type Runtime = RuntimelessRuntime; type Transport = TestTransport; - const DESCRIPTION: &'static str = "test server"; - fn runtime() -> Self::Runtime { RuntimelessRuntime::default() } @@ -43,29 +41,24 @@ impl Server for RuntimelessServer { .map_err(|e| Error::new(ErrorKind::Other, e.to_string())) } - fn build_listener(config: &Config) -> Self - where - A: Acceptor, - { - let mut port = config.port(); - let host = config.host(); + fn from_host_and_port(host: &str, mut port: u16) -> Self { if port == 0 { loop { port = fastrand::u16(..); - if !SERVERS.contains_key(&(host.clone(), port)) { + if !SERVERS.contains_key(&(host.to_string(), port)) { break; } } } let entry = SERVERS - .entry((host.clone(), port)) + .entry((host.to_string(), port)) .or_insert_with(async_channel::unbounded); let (_, channel) = entry.value(); Self { - host, + host: host.to_string(), channel: channel.clone(), port, } @@ -75,10 +68,7 @@ impl Server for RuntimelessServer { SERVERS.remove(&(self.host, self.port)); } - fn info(&self) -> Info { - let mut info = Info::from(&*format!("{}:{}", &self.host, &self.port)); - info.state_mut() - .insert(Url::parse(&format!("http://{}:{}", &self.host, self.port)).unwrap()); - info + fn init(&self, info: &mut Info) { + info.insert_state(Url::parse(&format!("http://{}:{}", &self.host, self.port)).unwrap()); } } diff --git a/testing/src/server_connector.rs b/testing/src/server_connector.rs index d4a66f0de6..d52e4ffbe8 100644 --- a/testing/src/server_connector.rs +++ b/testing/src/server_connector.rs @@ -1,7 +1,7 @@ use crate::{RuntimeType, TestTransport}; use std::{io, sync::Arc}; use trillium::Handler; -use trillium_http::Conn; +use trillium_http::ServerConfig; use trillium_server_common::Connector; use url::Url; @@ -10,6 +10,7 @@ use url::Url; pub struct ServerConnector { handler: Arc, runtime: RuntimeType, + server_config: Arc, } impl ServerConnector { @@ -18,27 +19,36 @@ impl ServerConnector { Self { handler: Arc::new(handler), runtime: RuntimeType::default(), + server_config: Arc::default(), } } + /// use a specific server config + pub fn with_server_config(mut self, server_config: ServerConfig) -> Self { + self.server_config = Arc::new(server_config); + self + } + /// opens a new connection to this virtual server, returning the client transport pub async fn connect(&self, secure: bool) -> TestTransport { let (client_transport, server_transport) = TestTransport::new(); let handler = Arc::clone(&self.handler); + let server_config = Arc::clone(&self.server_config); self.runtime.spawn(async move { - Conn::map(server_transport, Default::default(), |mut conn| { - let handler = Arc::clone(&handler); - async move { - conn.set_secure(secure); - let conn = handler.run(conn.into()).await; - let conn = handler.before_send(conn).await; - conn.into_inner() - } - }) - .await - .unwrap(); + server_config + .run(server_transport, |mut conn| { + let handler = Arc::clone(&handler); + async move { + conn.set_secure(secure); + let conn = handler.run(conn.into()).await; + let conn = handler.before_send(conn).await; + conn.into_inner() + } + }) + .await + .unwrap(); }); client_transport diff --git a/testing/src/test_conn.rs b/testing/src/test_conn.rs index 400bc3aceb..8ae80b376e 100644 --- a/testing/src/test_conn.rs +++ b/testing/src/test_conn.rs @@ -2,9 +2,10 @@ use std::{ fmt::Debug, net::IpAddr, ops::{Deref, DerefMut}, + sync::Arc, }; use trillium::{Conn, Handler, HeaderName, HeaderValues, Method}; -use trillium_http::{Conn as HttpConn, Synthetic}; +use trillium_http::{Conn as HttpConn, ServerConfig, Synthetic}; type SyntheticConn = HttpConn; @@ -31,6 +32,17 @@ impl TestConn { Self(HttpConn::new_synthetic(method.try_into().unwrap(), path.into(), body).into()) } + /// assigns a shared server config to this test conn + pub fn with_server_config(self, server_config: Arc) -> Self { + let inner = self + .0 + .into_inner::() + .with_server_config(server_config) + .into(); + + Self(inner) + } + /// chainable constructor to append a request header to the TestConn /// ``` /// use trillium_testing::TestConn; @@ -96,7 +108,7 @@ impl TestConn { /// use trillium_testing::prelude::*; /// /// async fn handler(conn: Conn) -> Conn { - /// conn.ok("hello trillium") + /// conn.ok("hello trillium") /// } /// /// let conn = get("/").run(&handler); @@ -113,12 +125,12 @@ impl TestConn { /// use trillium_testing::prelude::*; /// /// async fn handler(conn: Conn) -> Conn { - /// conn.ok("hello trillium") + /// conn.ok("hello trillium") /// } /// /// block_on(async move { - /// let conn = get("/").run_async(&handler).await; - /// assert_ok!(conn, "hello trillium", "content-length" => "14"); + /// let conn = get("/").run_async(&handler).await; + /// assert_ok!(conn, "hello trillium", "content-length" => "14"); /// }); /// ``` pub async fn run_async(self, handler: &impl Handler) -> Self { diff --git a/testing/src/with_server.rs b/testing/src/with_server.rs index 64f6c16a53..f1d3350813 100644 --- a/testing/src/with_server.rs +++ b/testing/src/with_server.rs @@ -24,7 +24,7 @@ where runtime.block_on(async move { let handle = config.spawn(handler); let info = handle.info().await; - let url = info.state().get::().cloned().unwrap_or_else(|| { + let url = info.state().cloned().unwrap_or_else(|| { let port = info.tcp_socket_addr().map(|t| t.port()).unwrap_or(0); format!("http://localhost:{port}").parse().unwrap() }); diff --git a/tokio/src/runtime.rs b/tokio/src/runtime.rs index da1130bacb..6e21cd2264 100644 --- a/tokio/src/runtime.rs +++ b/tokio/src/runtime.rs @@ -55,6 +55,14 @@ impl RuntimeTrait for TokioRuntime { Inner::Owned(runtime) => runtime.block_on(fut), } } + + #[cfg(unix)] + fn hook_signals( + &self, + signals: impl IntoIterator, + ) -> impl Stream + Send + 'static { + signal_hook_tokio::Signals::new(signals).unwrap() + } } impl TokioRuntime { @@ -112,6 +120,6 @@ impl TokioRuntime { impl From for Runtime { fn from(value: TokioRuntime) -> Self { - Runtime::new(value) + Runtime::from_trait_impl(value) } } diff --git a/tokio/src/server/tcp.rs b/tokio/src/server/tcp.rs index 7461932c86..29e4092967 100644 --- a/tokio/src/server/tcp.rs +++ b/tokio/src/server/tcp.rs @@ -19,14 +19,6 @@ impl Server for TokioServer { type Runtime = TokioRuntime; type Transport = TokioTransport>; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); - async fn accept(&mut self) -> io::Result { self.0 .accept() @@ -34,12 +26,14 @@ impl Server for TokioServer { .map(|(t, _)| TokioTransport(Compat::new(t))) } - fn listener_from_tcp(tcp: net::TcpListener) -> Self { + fn from_tcp(tcp: net::TcpListener) -> Self { Self(tcp.try_into().unwrap()) } - fn info(&self) -> Info { - self.0.local_addr().unwrap().into() + fn init(&self, info: &mut Info) { + if let Ok(socket_addr) = self.0.local_addr() { + info.insert_state(socket_addr); + } } fn runtime() -> Self::Runtime { diff --git a/tokio/src/server/unix.rs b/tokio/src/server/unix.rs index 674da9a4c3..a1208c8ad0 100644 --- a/tokio/src/server/unix.rs +++ b/tokio/src/server/unix.rs @@ -5,7 +5,7 @@ use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; use trillium::{log_error, Info}; use trillium_server_common::{ Binding::{self, *}, - Server, Swansong, + Server, }; /// Tcp/Unix Trillium server adapter for Tokio @@ -28,31 +28,6 @@ impl Server for TokioServer { type Runtime = TokioRuntime; type Transport = Binding>, TokioTransport>>; - const DESCRIPTION: &'static str = concat!( - " (", - env!("CARGO_PKG_NAME"), - " v", - env!("CARGO_PKG_VERSION"), - ")" - ); - - async fn handle_signals(swansong: Swansong) { - use signal_hook::consts::signal::*; - use signal_hook_tokio::Signals; - use tokio_stream::StreamExt; - let signals = Signals::new([SIGINT, SIGTERM, SIGQUIT]).unwrap(); - let mut signals = signals.fuse(); - while signals.next().await.is_some() { - if swansong.state().is_shutting_down() { - eprintln!("\nSecond interrupt, shutting down harshly"); - std::process::exit(1); - } else { - println!("\nShutting down gracefully.\nControl-C again to force."); - swansong.shut_down(); - } - } - } - async fn accept(&mut self) -> Result { match &mut self.0 { Tcp(t) => t @@ -67,19 +42,20 @@ impl Server for TokioServer { } } - fn info(&self) -> Info { + fn init(&self, info: &mut Info) { match &self.0 { - Tcp(t) => t.local_addr().unwrap().into(), - Unix(u) => (*format!("{:?}", u.local_addr().unwrap())).into(), - } - } - - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { - Self(Tcp(tcp.try_into().unwrap())) - } + Tcp(t) => { + if let Ok(socket_addr) = t.local_addr() { + info.insert_state(socket_addr); + } + } - fn listener_from_unix(unix: std::os::unix::net::UnixListener) -> Self { - Self(Unix(unix.try_into().unwrap())) + Unix(u) => { + if let Ok(socket_addr) = u.local_addr() { + info.insert_state(socket_addr); + } + } + } } async fn clean_up(self) { @@ -93,6 +69,14 @@ impl Server for TokioServer { } } + fn from_tcp(tcp_listener: std::net::TcpListener) -> Self { + TcpListener::from_std(tcp_listener).unwrap().into() + } + + fn from_unix(unix_listener: std::os::unix::net::UnixListener) -> Self { + UnixListener::from_std(unix_listener).unwrap().into() + } + fn runtime() -> Self::Runtime { TokioRuntime::default() } diff --git a/trillium/Cargo.toml b/trillium/Cargo.toml index 7c93497014..4605b426e7 100644 --- a/trillium/Cargo.toml +++ b/trillium/Cargo.toml @@ -25,10 +25,11 @@ futures-lite = "2.1.0" async-channel = "2.1.1" async-io = "2.3.1" fastrand = "2.0.1" -test-harness = "0.2.0" trillium-smol = { path = "../smol" } trillium-testing = { path = "../testing" } env_logger = "0.11.3" +trillium-client = { path = "../client" } +test-harness = "0.2.0" [package.metadata.cargo-udeps.ignore] development = ["trillium-testing"] diff --git a/trillium/examples/state.rs b/trillium/examples/state.rs index c3ef379429..6d91bb8475 100644 --- a/trillium/examples/state.rs +++ b/trillium/examples/state.rs @@ -37,13 +37,26 @@ mod conn_counter { } use conn_counter::{ConnCounterConnExt, ConnCounterHandler}; -use trillium::{Conn, Handler}; +use std::time::Instant; +use trillium::{Conn, Handler, Init}; + +struct ServerStart(Instant); fn handler() -> impl Handler { - (ConnCounterHandler::new(), |conn: Conn| async move { - let conn_number = conn.conn_number(); - conn.ok(format!("conn number was {conn_number}")) - }) + ( + Init::new(|info| async move { info.with_state(ServerStart(Instant::now())) }), + ConnCounterHandler::new(), + |conn: Conn| async move { + let uptime = conn + .shared_state() + .map(|ServerStart(instant)| instant.elapsed()) + .unwrap_or_default(); + let conn_number = conn.conn_number(); + conn.ok(format!( + "conn number was {conn_number}, server has been up {uptime:?}" + )) + }, + ) } fn main() { diff --git a/trillium/src/conn.rs b/trillium/src/conn.rs index 9aacc1cb61..d204f825bf 100644 --- a/trillium/src/conn.rs +++ b/trillium/src/conn.rs @@ -21,22 +21,22 @@ use trillium_http::{ /// that is named `with_{attribute}` will take ownership of the conn, set /// the attribute and return the conn, enabling chained calls like: /// -/// ```rust +/// ``` /// struct MyState(&'static str); /// async fn handler(mut conn: trillium::Conn) -> trillium::Conn { -/// conn.with_response_header("content-type", "text/plain") -/// .with_state(MyState("hello")) -/// .with_body("hey there") -/// .with_status(418) +/// conn.with_response_header("content-type", "text/plain") +/// .with_state(MyState("hello")) +/// .with_body("hey there") +/// .with_status(418) /// } /// /// use trillium_testing::prelude::*; /// /// assert_response!( -/// get("/").on(&handler), -/// Status::ImATeapot, -/// "hey there", -/// "content-type" => "text/plain" +/// get("/").on(&handler), +/// Status::ImATeapot, +/// "hey there", +/// "content-type" => "text/plain" /// ); /// ``` /// @@ -62,6 +62,7 @@ use trillium_http::{ /// so that application code can be written without transport /// generics. See [`Transport`](trillium_http::transport::Transport) for further /// reading on this. +/// pub struct Conn { inner: trillium_http::Conn, @@ -217,14 +218,6 @@ impl Conn { self.inner.state().get() } - /// Attempts to receive a &T from the shared state set - /// - /// Note that shared state may not currently be mutated after server start, so there is no - /// `shared_state_mut` or `shared_state_entry` - pub fn shared_state(&self) -> Option<&T> { - self.inner.shared_state().and_then(TypeSet::get) - } - /// Attempts to retrieve a &mut T from the state set pub fn state_mut(&mut self) -> Option<&mut T> { self.inner.state_mut().get_mut() @@ -252,14 +245,17 @@ impl Conn { self.inner.state_mut().take() } - /// Returns an [`Entry`] type that represents the presence or absence of a type in this state. - /// - /// Use this for chainable combinators like [`Entry::or_default`], [`Entry::or_insert`], - /// [`Entry::or_insert_with`], and [`Entry::and_modify`] as well as matching on it as an enum. + /// Returns an [`Entry`] for the state typeset that can be used with functions like + /// [`Entry::or_insert`], [`Entry::or_insert_with`], [`Entry::and_modify`], and others. pub fn state_entry(&mut self) -> Entry<'_, T> { self.inner.state_mut().entry() } + /// Attempts to borrow a T from the immutable shared state set + pub fn shared_state(&self) -> Option<&T> { + self.inner.shared_state().get() + } + /// Returns a [`ReceivedBody`] that references this `Conn`. The `Conn` /// retains all data and holds the singular transport, but the /// [`ReceivedBody`] provides an interface to read body content. @@ -327,6 +323,7 @@ impl Conn { /// /// assert_eq!(conn.method(), Method::Get); /// ``` + /// pub fn method(&self) -> Method { self.inner.method() } @@ -409,6 +406,7 @@ impl Conn { /// [`serde_qs`](https://docs.rs/serde_qs/) /// [`querystring`](https://docs.rs/querystring/) /// [`serde_querystring`](https://docs.rs/serde-querystring/latest/serde_querystring/) + /// pub fn querystring(&self) -> &str { self.inner.querystring() } @@ -457,7 +455,7 @@ impl Conn { self.inner.is_secure() } - /// The [`Instant`][std::time::Instant] that the first header bytes for this conn were + /// The [`Instant`] that the first header bytes for this conn were /// received, before any processing or parsing has been performed. pub fn start_time(&self) -> std::time::Instant { self.inner.start_time() @@ -494,7 +492,7 @@ impl Conn { /// /// # Panics /// - /// This will panic if you attempt to downcast to the wrong `Transport` type. + /// This will panic if you attempt to downcast to the wrong Transport type. pub fn into_inner(self) -> trillium_http::Conn { self.inner.map_transport(|t| { *t.downcast() diff --git a/trillium/src/handler.rs b/trillium/src/handler.rs index 260c3f6f16..cdb3216745 100644 --- a/trillium/src/handler.rs +++ b/trillium/src/handler.rs @@ -57,108 +57,72 @@ use std::{borrow::Cow, future::Future}; /// `run` is the only trait function that needs to be implemented. pub trait Handler: Send + Sync + 'static { - /// Executes this handler, performing any modifications to the - /// Conn that are desired. - fn run(&self, conn: Conn) -> impl Future + Send; - - /// Performs one-time async set up on a mutable borrow of the - /// Handler before the server starts accepting requests. This - /// allows a Handler to be defined in synchronous code but perform - /// async setup such as establishing a database connection or - /// fetching some state from an external source. This is optional, - /// and chances are high that you do not need this. - /// - /// It also receives a mutable borrow of the [`Info`] that represents - /// the current connection. + /// Executes this handler, performing any modifications to the Conn that are desired. + fn run(&self, conn: Conn) -> impl Future + Send { + async { conn } + } + + /// Performs one-time async set up on a mutable borrow of the Handler before the server starts + /// accepting requests. This allows a Handler to be defined in synchronous code but perform + /// async setup such as establishing a database connection or fetching some state from an + /// external source. This is optional, and chances are high that you do not need this. /// - /// **stability note:** This may go away at some point. Please open an - /// issue if you have a use case which requires it. - fn init(&mut self, _info: &mut Info) -> impl Future + Send { + /// It also receives a mutable borrow of the [`Info`] that represents the current connection. + fn init(&mut self, info: &mut Info) -> impl Future + Send { + let _ = info; std::future::ready(()) } - /// Performs any final modifications to this conn after all handlers - /// have been run. Although this is a slight deviation from the simple - /// conn->conn->conn chain represented by most Handlers, it provides - /// an easy way for libraries to effectively inject a second handler - /// into a response chain. This is useful for loggers that need to - /// record information both before and after other handlers have run, - /// as well as database transaction handlers and similar library code. + /// Performs any final modifications to this conn after all handlers have been run. Although + /// this is a slight deviation from the simple conn->conn->conn chain represented by most + /// Handlers, it provides an easy way for libraries to effectively inject a second handler into + /// a response chain. This is useful for loggers that need to record information both before and + /// after other handlers have run, as well as database transaction handlers and similar library + /// code. /// - /// ❗IMPORTANT NOTE FOR LIBRARY AUTHORS:** Please note that this - /// will run __whether or not the conn has was halted before - /// [`Handler::run`] was called on a given conn__. This means that if - /// you want to make your `before_send` callback conditional on - /// whether `run` was called, you need to put a unit type into the - /// conn's state and check for that. + /// **❗IMPORTANT NOTE FOR LIBRARY AUTHORS:** Please note that this will run __whether or not + /// the conn has was halted before [`Handler::run`] was called on a given conn__. This means + /// that if you want to make your `before_send` callback conditional on whether `run` was + /// called, you need to put a unit type into the conn's state and check for that. /// - /// stability note: I don't love this for the exact reason that it - /// breaks the simplicity of the conn->conn->model, but it is - /// currently the best compromise between that simplicity and - /// convenience for the application author, who should not have to add - /// two Handlers to achieve an "around" effect. + /// stability note: I don't love this for the exact reason that it breaks the simplicity of the + /// conn->conn->model, but it is currently the best compromise between that simplicity and + /// convenience for the application author, who should not have to add two Handlers to achieve + /// an "around" effect. fn before_send(&self, conn: Conn) -> impl Future + Send { std::future::ready(conn) } - /// predicate function answering the question of whether this Handler - /// would like to take ownership of the negotiated Upgrade. If this - /// returns true, you must implement [`Handler::upgrade`]. The first - /// handler that responds true to this will receive ownership of the - /// [`trillium::Upgrade`][crate::Upgrade] in a subsequent call to [`Handler::upgrade`] + /// predicate function answering the question of whether this Handler would like to take + /// ownership of the negotiated Upgrade. If this returns true, you must implement + /// [`Handler::upgrade`]. The first handler that responds true to this will receive + /// ownership of the [`trillium::Upgrade`][crate::Upgrade] in a subsequent call to + /// [`Handler::upgrade`] fn has_upgrade(&self, upgrade: &Upgrade) -> bool { let _ = upgrade; false } - /// This will only be called if the handler reponds true to - /// [`Handler::has_upgrade`] and will only be called once for this - /// upgrade. There is no return value, and this function takes - /// exclusive ownership of the underlying transport once this is - /// called. You can downcast the transport to whatever the source - /// transport type is and perform any non-http protocol communication - /// that has been negotiated. You probably don't want this unless - /// you're implementing something like websockets. Please note that - /// for many transports such as `TcpStreams`, dropping the transport - /// (and therefore the Upgrade) will hang up / disconnect. + /// This will only be called if the handler reponds true to [`Handler::has_upgrade`] and will + /// only be called once for this upgrade. There is no return value, and this function takes + /// exclusive ownership of the underlying transport once this is called. You can downcast + /// the transport to whatever the source transport type is and perform any non-http protocol + /// communication that has been negotiated. You probably don't want this unless you're + /// implementing something like websockets. Please note that for many transports such as + /// `TcpStreams`, dropping the transport (and therefore the Upgrade) will hang up / + /// disconnect. fn upgrade(&self, upgrade: Upgrade) -> impl Future + Send { let _ = upgrade; async { unimplemented!("if has_upgrade returns true, you must also implement upgrade") } } - /// Customize the name of your handler. This is used in Debug - /// implementations. The default is the type name of this handler. + /// Customize the name of your handler. This is used in Debug implementations. The default is + /// the type name of this handler. fn name(&self) -> Cow<'static, str> { std::any::type_name::().into() } } -// impl Handler for Box { -// async fn run(&self, conn: Conn) -> Conn { -// self.as_ref().run(conn).await -// } - -// async fn init(&mut self, info: &mut Info) { -// self.as_mut().init(info).await; -// } - -// async fn before_send(&self, conn: Conn) -> Conn { -// self.as_ref().before_send(conn).await -// } - -// fn name(&self) -> Cow<'static, str> { -// self.as_ref().name() -// } - -// fn has_upgrade(&self, upgrade: &Upgrade) -> bool { -// self.as_ref().has_upgrade(upgrade) -// } - -// async fn upgrade(&self, upgrade: Upgrade) { -// self.as_ref().upgrade(upgrade).await; -// } -// } - impl Handler for Status { async fn run(&self, conn: Conn) -> Conn { conn.with_status(*self) diff --git a/trillium/src/info.rs b/trillium/src/info.rs index 7506bf5eb5..c2659598dc 100644 --- a/trillium/src/info.rs +++ b/trillium/src/info.rs @@ -1,140 +1,78 @@ -use std::{ - fmt::{Display, Formatter, Result}, - net::SocketAddr, -}; -use trillium_http::TypeSet; - -const DEFAULT_SERVER_DESCRIPTION: &str = concat!("trillium v", env!("CARGO_PKG_VERSION")); +use std::net::SocketAddr; +use trillium_http::{type_set::entry::Entry, ServerConfig, Swansong, TypeSet}; /// This struct represents information about the currently connected /// server. /// /// It is passed to [`Handler::init`](crate::Handler::init). -#[derive(Debug)] -pub struct Info { - server_description: String, - listener_description: String, - tcp_socket_addr: Option, - state: TypeSet, -} - -impl Default for Info { - fn default() -> Self { - Self { - server_description: DEFAULT_SERVER_DESCRIPTION.into(), - listener_description: String::new(), - tcp_socket_addr: None, - state: TypeSet::new(), - } +#[derive(Debug, Default)] +pub struct Info(ServerConfig); +impl From for Info { + fn from(value: ServerConfig) -> Self { + Self(value) } } - -impl Display for Info { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - f.write_fmt(format_args!( - "{} listening on {}", - self.server_description(), - self.listener_description(), - )) +impl From for ServerConfig { + fn from(value: Info) -> Self { + value.0 } } -impl Info { - /// Returns a user-displayable description of the server. This - /// might be a string like "trillium x.y.z (trillium-tokio x.y.z)" or "my - /// special application". - pub fn server_description(&self) -> &str { - &self.server_description +impl AsRef for Info { + fn as_ref(&self) -> &TypeSet { + self.0.as_ref() } - - /// Returns a user-displayable string description of the location - /// or port the listener is bound to, potentially as a url. Do not - /// rely on the format of this string, as it will vary between - /// server implementations and is intended for user - /// display. Instead, use [`Info::tcp_socket_addr`] for any - /// processing. - pub fn listener_description(&self) -> &str { - &self.listener_description +} +impl AsMut for Info { + fn as_mut(&mut self) -> &mut TypeSet { + self.0.as_mut() } +} +impl Info { /// Returns the `local_addr` of a bound tcp listener, if such a /// thing exists for this server - pub const fn tcp_socket_addr(&self) -> Option<&SocketAddr> { - self.tcp_socket_addr.as_ref() - } - - /// obtain a mutable borrow of the server description, suitable - /// for appending information or replacing it - pub fn server_description_mut(&mut self) -> &mut String { - &mut self.server_description - } - - /// obtain a mutable borrow of the listener description, suitable - /// for appending information or replacing it - pub fn listener_description_mut(&mut self) -> &mut String { - &mut self.listener_description + pub fn tcp_socket_addr(&self) -> Option<&SocketAddr> { + self.state() } - /// borrow the [`TypeSet`] on this `Info`. This can be useful for passing initialization data - /// between handlers - #[allow(clippy::missing_const_for_fn)] // Info isn't useful in a const context - pub fn state(&self) -> &TypeSet { - &self.state + /// Returns the `local_addr` of a bound unix listener, if such a + /// thing exists for this server + #[cfg(unix)] + pub fn unix_socket_addr(&self) -> Option<&std::os::unix::net::SocketAddr> { + self.state() } - /// attempt to mutably borrow the [`TypeSet`] on this `Info`. - pub fn state_mut(&mut self) -> &mut TypeSet { - &mut self.state + /// Borrow a type from the shared state [`TypeSet`] on this `Info`. + pub fn state(&self) -> Option<&T> { + self.0.shared_state().get() } -} -impl AsRef for Info { - fn as_ref(&self) -> &TypeSet { - self.state() + /// Insert a type into the shared state typeset, returning the previous value if any + pub fn insert_state(&mut self, value: T) -> Option { + self.0.shared_state_mut().insert(value) } -} -impl AsMut for Info { - fn as_mut(&mut self) -> &mut TypeSet { - self.state_mut() + /// Mutate a type in the shared state typeset + pub fn state_mut(&mut self) -> Option<&mut T> { + self.0.shared_state_mut().get_mut() } -} -impl From<&str> for Info { - fn from(description: &str) -> Self { - Self { - server_description: String::from(DEFAULT_SERVER_DESCRIPTION), - listener_description: String::from(description), - ..Self::default() - } + /// Returns an [`Entry`] into the shared state typeset. + pub fn state_entry(&mut self) -> Entry<'_, T> { + self.0.shared_state_mut().entry() } -} - -impl From for Info { - fn from(socket_addr: SocketAddr) -> Self { - let mut info = Self { - server_description: String::from(DEFAULT_SERVER_DESCRIPTION), - listener_description: socket_addr.to_string(), - tcp_socket_addr: Some(socket_addr), - ..Self::default() - }; - info.state_mut().insert(socket_addr); - info + /// chainable interface to insert a type into the shared state typeset + #[must_use] + pub fn with_state(mut self, value: T) -> Self { + self.insert_state(value); + self } -} - -#[cfg(unix)] -impl From for Info { - fn from(s: std::os::unix::net::SocketAddr) -> Self { - let mut info = Self { - server_description: String::from(DEFAULT_SERVER_DESCRIPTION), - listener_description: format!("{s:?}"), - ..Self::default() - }; - info.state_mut().insert(s); - info + /// Borrow the [`Swansong`] graceful shutdown interface for this server + pub fn swansong(&self) -> &Swansong { + self.0.swansong() } } diff --git a/trillium/src/init.rs b/trillium/src/init.rs new file mode 100644 index 0000000000..cc96fe6cb8 --- /dev/null +++ b/trillium/src/init.rs @@ -0,0 +1,88 @@ +use crate::{Conn, Handler, Info}; +use std::{future::Future, mem}; + +/// Provides support for asynchronous initialization of a handler after +/// the server is started. +/// +/// ``` +/// use trillium::{Conn, Init, State}; +/// +/// #[derive(Debug, Clone)] +/// struct MyDatabaseConnection(String); +/// impl MyDatabaseConnection { +/// async fn connect(uri: &str) -> std::io::Result { +/// Ok(Self(uri.into())) +/// } +/// +/// async fn query(&self, query: &str) -> String { +/// format!("you queried `{}` against {}", query, &self.0) +/// } +/// } +/// +/// let mut handler = ( +/// Init::new(|mut info| async move { +/// let db = MyDatabaseConnection::connect("db://db").await.expect("1"); +/// info.with_state(db) +/// }), +/// |conn: Conn| async move { +/// dbg!(&conn); +/// let db = conn.shared_state::().expect("2"); +/// let response = db.query("select * from users limit 1").await; +/// conn.ok(response) +/// }, +/// ); +/// +/// use trillium_testing::prelude::*; +/// +/// block_on(async move { +/// let server_config = init(&mut handler).await; +/// assert_ok!( +/// get("/") +/// .with_server_config(server_config) +/// .run_async(&handler) +/// .await, +/// "you queried `select * from users limit 1` against db://db" +/// ); +/// }); +/// ``` +#[derive(Debug)] +pub struct Init(Option); + +impl Init +where + F: FnOnce(Info) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + /// Constructs a new Init handler with an async function that receives and returns [`Info`]. + #[must_use] + pub const fn new(init: F) -> Self { + Self(Some(init)) + } +} + +impl Handler for Init +where + F: FnOnce(Info) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + async fn run(&self, conn: Conn) -> Conn { + conn + } + + async fn init(&mut self, info: &mut Info) { + if let Some(init) = self.0.take() { + *info = init(mem::take(info)).await; + } else { + log::warn!("called init more than once"); + } + } +} + +/// alias for [`Init::new`] +pub const fn init(init: F) -> Init +where + F: FnOnce(Info) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + Init::new(init) +} diff --git a/trillium/src/lib.rs b/trillium/src/lib.rs index 36d27a9034..cbe43889d6 100644 --- a/trillium/src/lib.rs +++ b/trillium/src/lib.rs @@ -54,3 +54,6 @@ pub use info::Info; mod boxed_handler; pub use boxed_handler::BoxedHandler; + +mod init; +pub use init::{init, Init}; diff --git a/trillium/src/shared_state.rs b/trillium/src/shared_state.rs new file mode 100644 index 0000000000..a3adfa93a2 --- /dev/null +++ b/trillium/src/shared_state.rs @@ -0,0 +1,36 @@ +use crate::{Handler, Info}; +use std::{any::type_name, borrow::Cow}; + +/// This handler populates a type into the immutable server-shared state type-set. Note that unlike +/// [`State`], this handler does not require [`Clone`], as the single allocation provided to the +/// constructor is held in an Arc and shared with every Conn. +/// +#[derive(Debug)] +pub struct SharedState(Option); +impl SharedState +where + T: Send + Sync + 'static, +{ + /// Constructs a new State handler from any `Clone` + `Send` + `Sync` + + /// `'static` + pub const fn new(t: T) -> Self { + Self(Some(t)) + } +} + +/// Constructs a new [`SharedState`] handler from any Send + Sync + +/// 'static. Alias for [`SharedState::new`] +#[allow(clippy::missing_const_for_fn)] +pub fn shared_state(t: T) -> SharedState { + SharedState::new(t) +} + +impl Handler for SharedState { + async fn init(&mut self, info: &mut Info) { + info.insert_state(self.0.take().unwrap()); + } + + fn name(&self) -> Cow<'static, str> { + format!("SharedState<{}>", type_name::()).into() + } +} diff --git a/trillium/src/state.rs b/trillium/src/state.rs index 3a5f293807..8b6a13dea8 100644 --- a/trillium/src/state.rs +++ b/trillium/src/state.rs @@ -1,5 +1,5 @@ use crate::{Conn, Handler}; -use std::fmt::{self, Debug, Formatter}; +use std::fmt::Debug; /// # A handler for sharing state across an application. /// @@ -52,27 +52,8 @@ use std::fmt::{self, Debug, Formatter}; /// for your application. There will be one clones of the contained T type /// in memory for each http connection, and any locks should be held as /// briefly as possible so as to minimize impact on other conns. -/// -/// **Stability note:** This is a common enough pattern that it currently -/// exists in the public api, but may be removed at some point for -/// simplicity. - -pub struct State(T); - -impl Debug for State { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_tuple("State").field(&self.0).finish() - } -} - -impl Default for State -where - T: Default + Clone + Send + Sync + 'static, -{ - fn default() -> Self { - Self::new(T::default()) - } -} +#[derive(Debug, Default)] +pub struct State(T); impl State where @@ -80,16 +61,14 @@ where { /// Constructs a new State handler from any `Clone` + `Send` + `Sync` + /// `'static` - #[allow(clippy::missing_const_for_fn)] - pub fn new(t: T) -> Self { + pub const fn new(t: T) -> Self { Self(t) } } /// Constructs a new [`State`] handler from any Clone + Send + Sync + /// 'static. Alias for [`State::new`] -#[allow(clippy::missing_const_for_fn)] -pub fn state(t: T) -> State { +pub const fn state(t: T) -> State { State::new(t) } diff --git a/trillium/tests/init.rs b/trillium/tests/init.rs new file mode 100644 index 0000000000..bdef474c20 --- /dev/null +++ b/trillium/tests/init.rs @@ -0,0 +1,56 @@ +use std::io; +use test_harness::test; +use trillium::Handler; +use trillium_client::Client; +use trillium_http::ServerConfig; +use trillium_testing::{harness, ServerConnector, TestResult}; + +async fn test_client(mut handler: impl Handler) -> Client { + let mut info = ServerConfig::default().into(); + handler.init(&mut info).await; + let connector = ServerConnector::new(handler).with_server_config(info.into()); + Client::new(connector).with_base("http://test.host") +} + +#[test(harness)] +async fn init_doctest() -> TestResult { + use trillium::{Conn, Init}; + + #[derive(Debug, Clone)] + struct MyDatabaseConnection(&'static str); + impl MyDatabaseConnection { + async fn connect(uri: &'static str) -> io::Result { + Ok(Self(uri)) + } + + async fn query(&self, query: &str) -> String { + format!("you queried `{}` against {}", query, &self.0) + } + } + + let client = test_client(( + Init::new(|info| async move { + let db = MyDatabaseConnection::connect("mydatabase://...") + .await + .unwrap(); + info.with_state(db) + }), + |conn: Conn| async move { + let Some(db) = conn.shared_state::() else { + return conn.with_status(500); + }; + let response = db.query("select * from users limit 1").await; + conn.ok(response) + }, + )) + .await; + + let mut conn = client.get("/").await?; + + assert_eq!( + conn.response_body().read_string().await?, + "you queried `select * from users limit 1` against mydatabase://..." + ); + + Ok(()) +} diff --git a/trillium/tests/liveness.rs b/trillium/tests/liveness.rs index f6312f2313..5bb752ce48 100644 --- a/trillium/tests/liveness.rs +++ b/trillium/tests/liveness.rs @@ -23,11 +23,8 @@ async fn infinitely_pending_task() -> TestResult { }); let info = handle.info().await; - - let url = format!("http://{}", info.listener_description()) - .parse() - .unwrap(); - let mut client = connector.connect(&url).await?; + let url = info.url().unwrap(); + let mut client = connector.connect(url).await?; client .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") @@ -48,7 +45,6 @@ async fn infinitely_pending_task() -> TestResult { #[test(harness)] async fn is_disconnected() -> TestResult { - let _ = env_logger::builder().is_test(true).try_init(); let connector = ArcedConnector::new(client_config()); let (delay_sender, delay_receiver) = async_channel::unbounded(); let (disconnected_sender, disconnected_receiver) = async_channel::unbounded(); @@ -71,10 +67,8 @@ async fn is_disconnected() -> TestResult { let info = handle.info().await; let runtime = handle.runtime(); - let url = format!("http://{}", info.listener_description()) - .parse() - .unwrap(); - let mut client = connector.connect(&url).await?; + let url = info.url().unwrap(); + let mut client = connector.connect(url).await?; client .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") @@ -88,7 +82,7 @@ async fn is_disconnected() -> TestResult { assert!(s.starts_with("HTTP/1.1 200 OK\r\n")); client.close().await?; - let mut client = connector.connect(&url).await?; + let mut client = connector.connect(url).await?; client .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") .await?; diff --git a/websockets/src/bidirectional_stream.rs b/websockets/src/bidirectional_stream.rs index 1ec60c0b8d..9ca77f6b57 100644 --- a/websockets/src/bidirectional_stream.rs +++ b/websockets/src/bidirectional_stream.rs @@ -23,7 +23,7 @@ impl Debug for BidirectionalStream { None => "None", }, ) - .field("outbound", &"..") + .field("outbound", &format_args!("..")) .finish() } } diff --git a/websockets/src/websocket_connection.rs b/websockets/src/websocket_connection.rs index 780f2d584a..4c7b8f982a 100644 --- a/websockets/src/websocket_connection.rs +++ b/websockets/src/websocket_connection.rs @@ -10,11 +10,12 @@ use futures_util::{ use std::{ net::IpAddr, pin::Pin, + sync::Arc, task::{Context, Poll}, }; use swansong::{Interrupt, Swansong}; use trillium::{Headers, Method, TypeSet, Upgrade}; -use trillium_http::{transport::BoxedTransport, type_set::entry::Entry}; +use trillium_http::{transport::BoxedTransport, type_set::entry::Entry, ServerConfig}; /// A struct that represents an specific websocket connection. /// @@ -32,7 +33,7 @@ pub struct WebSocketConn { method: Method, state: TypeSet, peer_ip: Option, - swansong: Swansong, + server_config: Arc, sink: SplitSink, stream: Option, } @@ -75,7 +76,7 @@ impl WebSocketConn { state, buffer, transport, - swansong, + server_config, peer_ip, .. } = upgrade; @@ -88,7 +89,7 @@ impl WebSocketConn { let (sink, stream) = wss.split(); let stream = Some(WStream { - stream: swansong.interrupt(stream), + stream: server_config.swansong().interrupt(stream), }); Self { @@ -99,13 +100,13 @@ impl WebSocketConn { peer_ip, sink, stream, - swansong, + server_config, } } /// retrieve a clone of the server's [`Swansong`] pub fn swansong(&self) -> Swansong { - self.swansong.clone() + self.server_config.swansong().clone() } /// close the websocket connection gracefully @@ -168,6 +169,12 @@ impl WebSocketConn { self.state.insert(state) } + /// Returns an [`Entry`] for the state typeset that can be used with functions like + /// [`Entry::or_insert`], [`Entry::or_insert_with`], [`Entry::and_modify`], and others. + pub fn state_entry(&mut self) -> Entry<'_, T> { + self.state.entry() + } + /// take some type T out of the state set that has been /// accumulated by trillium handlers run on the [`trillium::Conn`] /// before it became a websocket. see [`trillium::Conn::take_state`] @@ -176,12 +183,6 @@ impl WebSocketConn { self.state.take() } - /// Returns an [`Entry`] for the state typeset that can be used with functions like - /// [`Entry::or_insert`], [`Entry::or_insert_with`], [`Entry::and_modify`], and others. - pub fn state_entry(&mut self) -> Entry<'_, T> { - self.state.entry() - } - /// take the inbound Message stream from this conn pub fn take_inbound_stream(&mut self) -> Option> { self.stream.take()