diff --git a/src/cookies/middleware.rs b/src/cookies/middleware.rs index 7220e731a..633b35acb 100644 --- a/src/cookies/middleware.rs +++ b/src/cookies/middleware.rs @@ -37,7 +37,7 @@ impl CookiesMiddleware { impl Middleware for CookiesMiddleware { fn handle<'a>( &'a self, - mut ctx: Request, + mut ctx: Request, next: Next<'a, State>, ) -> BoxFuture<'a, crate::Result> { Box::pin(async move { @@ -117,7 +117,7 @@ impl LazyJar { } impl CookieData { - pub(crate) fn from_request(req: &Request) -> Self { + pub(crate) fn from_request(req: &Request) -> Self { let jar = if let Some(cookie_headers) = req.header(&headers::COOKIE) { let mut jar = CookieJar::new(); for cookie_header in cookie_headers { diff --git a/src/endpoint.rs b/src/endpoint.rs index f97b9da17..0971505cb 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -46,7 +46,7 @@ use crate::{Middleware, Request, Response}; /// Tide routes will also accept endpoints with `Fn` signatures of this form, but using the `async` keyword has better ergonomics. pub trait Endpoint: Send + Sync + 'static { /// Invoke the endpoint within the given context - fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, crate::Result>; + fn call<'a>(&'a self, req: Request, state: State) -> BoxFuture<'a, crate::Result>; } pub(crate) type DynEndpoint = dyn Endpoint; @@ -54,11 +54,11 @@ pub(crate) type DynEndpoint = dyn Endpoint; impl Endpoint for F where State: Send + Sync + 'static, - F: Send + Sync + 'static + Fn(Request) -> Fut, + F: Send + Sync + 'static + Fn(Request) -> Fut, Fut: Future> + Send + 'static, Res: Into, { - fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, crate::Result> { + fn call<'a>(&'a self, req: Request, _: State) -> BoxFuture<'a, crate::Result> { let fut = (self)(req); Box::pin(async move { let res = fut.await?; @@ -67,6 +67,22 @@ where } } +impl Endpoint for F +where + State: Send + Sync + 'static, + F: Send + Sync + 'static + Fn(Request, State) -> Fut, + Fut: Future> + Send + 'static, + Res: Into, +{ + fn call<'a>(&'a self, req: Request, state: State) -> BoxFuture<'a, crate::Result> { + let fut = (self)(req, state); + Box::pin(async move { + let res = fut.await?; + Ok(res.into()) + }) + } +} + pub struct MiddlewareEndpoint { endpoint: E, middleware: Vec>>, @@ -109,7 +125,7 @@ where State: Send + Sync + 'static, E: Endpoint, { - fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, crate::Result> { + fn call<'a>(&'a self, req: Request, _: State) -> BoxFuture<'a, crate::Result> { let next = Next { endpoint: &self.endpoint, next_middleware: &self.middleware, diff --git a/src/fs/serve_dir.rs b/src/fs/serve_dir.rs index 953f07529..242fe0fa4 100644 --- a/src/fs/serve_dir.rs +++ b/src/fs/serve_dir.rs @@ -23,7 +23,7 @@ impl Endpoint for ServeDir where State: Send + Sync + 'static, { - fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, Result> { + fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, Result> { let path = req.url().path(); let path = path.trim_start_matches(&self.prefix); let path = path.trim_start_matches('/'); @@ -81,7 +81,7 @@ mod test { }) } - fn request(path: &str) -> crate::Request<()> { + fn request(path: &str) -> crate::Request { let request = crate::http::Request::get( crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(), ); diff --git a/src/log/middleware.rs b/src/log/middleware.rs index 5b8d787c7..a544cf8b2 100644 --- a/src/log/middleware.rs +++ b/src/log/middleware.rs @@ -27,7 +27,7 @@ impl LogMiddleware { /// Log a request and a response. async fn log<'a, State: Send + Sync + 'static>( &'a self, - ctx: Request, + ctx: Request, next: Next<'a, State>, ) -> crate::Result { let path = ctx.url().path().to_owned(); @@ -78,7 +78,7 @@ impl LogMiddleware { impl Middleware for LogMiddleware { fn handle<'a>( &'a self, - ctx: Request, + ctx: Request, next: Next<'a, State>, ) -> BoxFuture<'a, crate::Result> { Box::pin(async move { self.log(ctx, next).await }) diff --git a/src/middleware.rs b/src/middleware.rs index 138a9dd31..708b7ef32 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -11,7 +11,7 @@ pub trait Middleware: Send + Sync + 'static { /// Asynchronously handle the request, and return a response. fn handle<'a>( &'a self, - request: Request, + request: Request, next: Next<'a, State>, ) -> BoxFuture<'a, crate::Result>; @@ -26,17 +26,35 @@ where F: Send + Sync + 'static - + for<'a> Fn(Request, Next<'a, State>) -> BoxFuture<'a, crate::Result>, + + for<'a> Fn(Request, Next<'a, State>) -> BoxFuture<'a, crate::Result>, { fn handle<'a>( &'a self, - req: Request, + req: Request, + _: State, next: Next<'a, State>, ) -> BoxFuture<'a, crate::Result> { (self)(req, next) } } +impl Middleware for F +where + F: Send + + Sync + + 'static + + for<'a> Fn(Request, State, Next<'a, State>) -> BoxFuture<'a, crate::Result>, +{ + fn handle<'a>( + &'a self, + req: Request, + state: State, + next: Next<'a, State>, + ) -> BoxFuture<'a, crate::Result> { + (self)(req, state, next) + } +} + /// The remainder of a middleware chain, including the endpoint. #[allow(missing_debug_implementations)] pub struct Next<'a, State> { @@ -47,7 +65,7 @@ pub struct Next<'a, State> { impl<'a, State: Send + Sync + 'static> Next<'a, State> { /// Asynchronously execute the remaining middleware chain. #[must_use] - pub fn run(mut self, req: Request) -> BoxFuture<'a, Response> { + pub fn run(mut self, req: Request) -> BoxFuture<'a, Response> { Box::pin(async move { if let Some((current, next)) = self.next_middleware.split_first() { self.next_middleware = next; diff --git a/src/redirect.rs b/src/redirect.rs index e83301686..49bacc5bc 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -91,7 +91,7 @@ where State: Send + Sync + 'static, T: AsRef + Send + Sync + 'static, { - fn call<'a>(&'a self, _req: Request) -> BoxFuture<'a, crate::Result> { + fn call<'a>(&'a self, _req: Request) -> BoxFuture<'a, crate::Result> { let res = self.into(); Box::pin(async move { Ok(res) }) } diff --git a/src/request.rs b/src/request.rs index 332121c82..7ed34ae49 100644 --- a/src/request.rs +++ b/src/request.rs @@ -4,7 +4,7 @@ use route_recognizer::Params; use std::ops::Index; use std::pin::Pin; -use std::{fmt, str::FromStr, sync::Arc}; +use std::{fmt, str::FromStr}; use crate::cookies::CookieData; use crate::http::cookies::Cookie; @@ -20,8 +20,7 @@ use crate::Response; /// Requests also provide *extensions*, a type map primarily used for low-level /// communication between middleware and endpoints. #[derive(Debug)] -pub struct Request { - pub(crate) state: Arc, +pub struct Request { pub(crate) req: http::Request, pub(crate) route_params: Vec, } @@ -43,15 +42,13 @@ impl fmt::Display for ParamError { impl std::error::Error for ParamError {} -impl Request { +impl Request { /// Create a new `Request`. pub(crate) fn new( - state: Arc, req: http_types::Request, route_params: Vec, ) -> Self { Self { - state, req, route_params, } @@ -266,12 +263,6 @@ impl Request { self.req.ext_mut().insert(val) } - #[must_use] - /// Access application scoped state. - pub fn state(&self) -> &State { - &self.state - } - /// Extract and parse a route parameter by name. /// /// Returns the results of parsing the parameter according to the inferred @@ -524,31 +515,31 @@ impl Request { } } -impl AsRef for Request { +impl AsRef for Request { fn as_ref(&self) -> &http::Request { &self.req } } -impl AsMut for Request { +impl AsMut for Request { fn as_mut(&mut self) -> &mut http::Request { &mut self.req } } -impl AsRef for Request { +impl AsRef for Request { fn as_ref(&self) -> &http::Headers { self.req.as_ref() } } -impl AsMut for Request { +impl AsMut for Request { fn as_mut(&mut self) -> &mut http::Headers { self.req.as_mut() } } -impl Read for Request { +impl Read for Request { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -558,7 +549,7 @@ impl Read for Request { } } -impl Into for Request { +impl Into for Request { fn into(self) -> http::Request { self.req } @@ -566,7 +557,7 @@ impl Into for Request { // NOTE: From cannot be implemented for this conversion because `State` needs to // be constrained by a type. -impl Into for Request { +impl Into for Request { fn into(mut self) -> Response { let mut res = Response::new(StatusCode::Ok); res.set_body(self.take_body()); @@ -574,7 +565,7 @@ impl Into for Request { } } -impl IntoIterator for Request { +impl IntoIterator for Request { type Item = (HeaderName, HeaderValues); type IntoIter = http_types::headers::IntoIter; @@ -585,7 +576,7 @@ impl IntoIterator for Request { } } -impl<'a, State> IntoIterator for &'a Request { +impl<'a> IntoIterator for &'a Request { type Item = (&'a HeaderName, &'a HeaderValues); type IntoIter = http_types::headers::Iter<'a>; @@ -595,7 +586,7 @@ impl<'a, State> IntoIterator for &'a Request { } } -impl<'a, State> IntoIterator for &'a mut Request { +impl<'a> IntoIterator for &'a mut Request { type Item = (&'a HeaderName, &'a mut HeaderValues); type IntoIter = http_types::headers::IterMut<'a>; @@ -605,7 +596,7 @@ impl<'a, State> IntoIterator for &'a mut Request { } } -impl Index for Request { +impl Index for Request { type Output = HeaderValues; /// Returns a reference to the value corresponding to the supplied name. @@ -619,7 +610,7 @@ impl Index for Request { } } -impl Index<&str> for Request { +impl Index<&str> for Request { type Output = HeaderValues; /// Returns a reference to the value corresponding to the supplied name. diff --git a/src/route.rs b/src/route.rs index c1df54f0c..db99e8054 100644 --- a/src/route.rs +++ b/src/route.rs @@ -279,9 +279,8 @@ where State: Send + Sync + 'static, E: Endpoint, { - fn call<'a>(&'a self, req: crate::Request) -> BoxFuture<'a, crate::Result> { + fn call<'a>(&'a self, req: crate::Request, _: State) -> BoxFuture<'a, crate::Result> { let crate::Request { - state, mut req, route_params, } = req; @@ -290,7 +289,6 @@ where req.url_mut().set_path(&rest); self.0.call(crate::Request { - state, req, route_params, }) diff --git a/src/router.rs b/src/router.rs index 3b04d2a0b..ddbcfb12c 100644 --- a/src/router.rs +++ b/src/router.rs @@ -82,14 +82,14 @@ impl Router { } } -fn not_found_endpoint( - _req: Request, +fn not_found_endpoint( + _req: Request, ) -> BoxFuture<'static, crate::Result> { Box::pin(async { Ok(Response::new(StatusCode::NotFound)) }) } -fn method_not_allowed( - _req: Request, +fn method_not_allowed( + _req: Request, ) -> BoxFuture<'static, crate::Result> { Box::pin(async { Ok(Response::new(StatusCode::MethodNotAllowed)) }) } diff --git a/src/security/cors.rs b/src/security/cors.rs index 50f4a06d3..29ddf62d3 100644 --- a/src/security/cors.rs +++ b/src/security/cors.rs @@ -134,7 +134,7 @@ impl CorsMiddleware { } impl Middleware for CorsMiddleware { - fn handle<'a>(&'a self, req: Request, next: Next<'a, State>) -> BoxFuture<'a, Result> { + fn handle<'a>(&'a self, req: Request, next: Next<'a, State>) -> BoxFuture<'a, Result> { Box::pin(async move { // TODO: how should multiple origin values be handled? let origins = req.header(&headers::ORIGIN).cloned(); diff --git a/src/server.rs b/src/server.rs index e5f52c208..77380521f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -442,7 +442,7 @@ impl Clone for Server { impl Endpoint for Server { - fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, crate::Result> { + fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, crate::Result> { let Request { req, mut route_params, diff --git a/src/sse/endpoint.rs b/src/sse/endpoint.rs index dac364490..d957c4ff8 100644 --- a/src/sse/endpoint.rs +++ b/src/sse/endpoint.rs @@ -15,7 +15,7 @@ use std::sync::Arc; pub fn endpoint(handler: F) -> SseEndpoint where State: Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { SseEndpoint { @@ -30,7 +30,7 @@ where pub struct SseEndpoint where State: Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { handler: Arc, @@ -41,10 +41,10 @@ where impl Endpoint for SseEndpoint where State: Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { - fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, Result> { + fn call<'a>(&'a self, req: Request) -> BoxFuture<'a, Result> { let handler = self.handler.clone(); Box::pin(async move { let (sender, encoder) = async_sse::encode(); diff --git a/src/sse/upgrade.rs b/src/sse/upgrade.rs index 9eff4b864..74667bcc6 100644 --- a/src/sse/upgrade.rs +++ b/src/sse/upgrade.rs @@ -9,10 +9,10 @@ use async_std::io::BufReader; use async_std::task; /// Upgrade an existing HTTP connection to an SSE connection. -pub fn upgrade(req: Request, handler: F) -> Response +pub fn upgrade(req: Request, handler: F) -> Response where State: Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { let (sender, encoder) = async_sse::encode(); diff --git a/src/utils.rs b/src/utils.rs index 0a43b6816..5b0577968 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -31,12 +31,12 @@ pub struct Before(pub F); impl Middleware for Before where State: Send + Sync + 'static, - F: Fn(Request) -> Fut + Send + Sync + 'static, - Fut: std::future::Future> + Send + Sync, + F: Fn(Request) -> Fut + Send + Sync + 'static, + Fut: std::future::Future + Send + Sync, { fn handle<'a>( &'a self, - request: Request, + request: Request, next: Next<'a, State>, ) -> BoxFuture<'a, crate::Result> { Box::pin(async move { @@ -45,6 +45,24 @@ where }) } } +impl Middleware for Before +where + State: Send + Sync + 'static, + F: Fn(Request, State) -> Fut + Send + Sync + 'static, + Fut: std::future::Future + Send + Sync, +{ + fn handle<'a>( + &'a self, + request: Request, + state: State, + next: Next<'a, State>, + ) -> BoxFuture<'a, crate::Result> { + Box::pin(async move { + let request = (self.0)(request).await; + Ok(next.run(request, state).await) + }) + } +} /// Define a middleware that operates on outgoing responses. /// @@ -75,7 +93,7 @@ where { fn handle<'a>( &'a self, - request: Request, + request: Request, next: Next<'a, State>, ) -> BoxFuture<'a, crate::Result> { Box::pin(async move { @@ -84,3 +102,21 @@ where }) } } +impl Middleware for After +where + State: Send + Sync + 'static, + F: Fn(Response, State) -> Fut + Send + Sync + 'static, + Fut: std::future::Future + Send + Sync, +{ + fn handle<'a>( + &'a self, + request: Request, + state: State, + next: Next<'a, State>, + ) -> BoxFuture<'a, crate::Result> { + Box::pin(async move { + let response = next.run(request).await; + (self.0)(response, state).await + }) + } +}