Skip to content

Commit

Permalink
Server: require State to be Clone
Browse files Browse the repository at this point in the history
Alternative to
#642

This approach is more flexible but requires the user ensure that their
state implements/derives `Clone`, or is wrapped in an `Arc`.

Co-authored-by: Jacob Rothstein <[email protected]>
  • Loading branch information
Fishrock123 and jbr committed Jul 16, 2020
1 parent be1c3a9 commit 50bc628
Show file tree
Hide file tree
Showing 27 changed files with 103 additions and 92 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ route-recognizer = "0.2.0"
logtest = "2.0.0"
async-trait = "0.1.36"
futures-util = "0.3.5"
pin-project-lite = "0.1.7"

[dev-dependencies]
async-std = { version = "1.6.0", features = ["unstable", "attributes"] }
Expand Down
7 changes: 5 additions & 2 deletions examples/graphql.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use async_std::task;
use juniper::{http::graphiql, http::GraphQLRequest, RootNode};
use std::sync::RwLock;
Expand Down Expand Up @@ -37,8 +39,9 @@ impl NewUser {
}
}

#[derive(Clone)]
pub struct State {
users: RwLock<Vec<User>>,
users: Arc<RwLock<Vec<User>>>,
}
impl juniper::Context for State {}

Expand Down Expand Up @@ -96,7 +99,7 @@ async fn handle_graphiql(_: Request<State>) -> tide::Result<impl Into<Response>>
fn main() -> std::io::Result<()> {
task::block_on(async {
let mut app = Server::with_state(State {
users: RwLock::new(Vec::new()),
users: Arc::new(RwLock::new(Vec::new())),
});
app.at("/").get(Redirect::permanent("/graphiql"));
app.at("/graphql").post(handle_graphql);
Expand Down
4 changes: 2 additions & 2 deletions examples/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct User {
name: String,
}

#[derive(Default, Debug)]
#[derive(Clone, Default, Debug)]
struct UserDatabase;
impl UserDatabase {
async fn find_user(&self) -> Option<User> {
Expand Down Expand Up @@ -62,7 +62,7 @@ impl RequestCounterMiddleware {
struct RequestCount(usize);

#[tide::utils::async_trait]
impl<State: Send + Sync + 'static> Middleware<State> for RequestCounterMiddleware {
impl<State: Clone + Send + Sync + 'static> Middleware<State> for RequestCounterMiddleware {
async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> Result {
let count = self.requests_counted.fetch_add(1, Ordering::Relaxed);
tide::log::trace!("request counter", { count: count });
Expand Down
8 changes: 5 additions & 3 deletions examples/upload.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use async_std::{fs::OpenOptions, io};
use tempfile::TempDir;
use tide::prelude::*;
Expand All @@ -6,15 +8,15 @@ use tide::{Body, Request, Response, StatusCode};
#[async_std::main]
async fn main() -> Result<(), std::io::Error> {
tide::log::start();
let mut app = tide::with_state(tempfile::tempdir()?);
let mut app = tide::with_state(Arc::new(tempfile::tempdir()?));

// To test this example:
// $ cargo run --example upload
// $ curl -T ./README.md locahost:8080 # this writes the file to a temp directory
// $ curl localhost:8080/README.md # this reads the file from the same temp directory

app.at(":file")
.put(|req: Request<TempDir>| async move {
.put(|req: Request<Arc<TempDir>>| async move {
let path: String = req.param("file")?;
let fs_path = req.state().path().join(path);

Expand All @@ -33,7 +35,7 @@ async fn main() -> Result<(), std::io::Error> {

Ok(json!({ "bytes": bytes_written }))
})
.get(|req: Request<TempDir>| async move {
.get(|req: Request<Arc<TempDir>>| async move {
let path: String = req.param("file")?;
let fs_path = req.state().path().join(path);

Expand Down
2 changes: 1 addition & 1 deletion src/cookies/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl CookiesMiddleware {
}

#[async_trait]
impl<State: Send + Sync + 'static> Middleware<State> for CookiesMiddleware {
impl<State: Clone + Send + Sync + 'static> Middleware<State> for CookiesMiddleware {
async fn handle(&self, mut ctx: Request<State>, next: Next<'_, State>) -> crate::Result {
let cookie_jar = if let Some(cookie_data) = ctx.ext::<CookieData>() {
cookie_data.content.clone()
Expand Down
8 changes: 4 additions & 4 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::{Middleware, Request, Response};
///
/// Tide routes will also accept endpoints with `Fn` signatures of this form, but using the `async` keyword has better ergonomics.
#[async_trait]
pub trait Endpoint<State: Send + Sync + 'static>: Send + Sync + 'static {
pub trait Endpoint<State: Clone + Send + Sync + 'static>: Send + Sync + 'static {
/// Invoke the endpoint within the given context
async fn call(&self, req: Request<State>) -> crate::Result;
}
Expand All @@ -55,7 +55,7 @@ pub(crate) type DynEndpoint<State> = dyn Endpoint<State>;
#[async_trait]
impl<State, F, Fut, Res> Endpoint<State> for F
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
F: Send + Sync + 'static + Fn(Request<State>) -> Fut,
Fut: Future<Output = Result<Res>> + Send + 'static,
Res: Into<Response> + 'static,
Expand Down Expand Up @@ -93,7 +93,7 @@ impl<E, State> std::fmt::Debug for MiddlewareEndpoint<E, State> {

impl<E, State> MiddlewareEndpoint<E, State>
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
E: Endpoint<State>,
{
pub fn wrap_with_middleware(ep: E, middleware: &[Arc<dyn Middleware<State>>]) -> Self {
Expand All @@ -107,7 +107,7 @@ where
#[async_trait]
impl<E, State> Endpoint<State> for MiddlewareEndpoint<E, State>
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
E: Endpoint<State>,
{
async fn call(&self, req: Request<State>) -> crate::Result {
Expand Down
6 changes: 2 additions & 4 deletions src/fs/serve_dir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl ServeDir {
#[async_trait::async_trait]
impl<State> Endpoint<State> for ServeDir
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
{
async fn call(&self, req: Request<State>) -> Result {
let path = req.url().path();
Expand Down Expand Up @@ -60,8 +60,6 @@ where
mod test {
use super::*;

use async_std::sync::Arc;

use std::fs::{self, File};
use std::io::Write;

Expand All @@ -83,7 +81,7 @@ mod test {
let request = crate::http::Request::get(
crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(),
);
crate::Request::new(Arc::new(()), request, vec![])
crate::Request::new((), request, vec![])
}

#[async_std::test]
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ pub fn new() -> server::Server<()> {
/// use tide::Request;
///
/// /// The shared application state.
/// #[derive(Clone)]
/// struct State {
/// name: String,
/// }
Expand All @@ -279,7 +280,7 @@ pub fn new() -> server::Server<()> {
/// ```
pub fn with_state<State>(state: State) -> server::Server<State>
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
{
Server::with_state(state)
}
Expand Down
4 changes: 2 additions & 2 deletions src/listener/concurrent_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt};
#[derive(Default)]
pub struct ConcurrentListener<State>(Vec<Box<dyn Listener<State>>>);

impl<State: Send + Sync + 'static> ConcurrentListener<State> {
impl<State: Clone + Send + Sync + 'static> ConcurrentListener<State> {
/// creates a new ConcurrentListener
pub fn new() -> Self {
Self(vec![])
Expand Down Expand Up @@ -78,7 +78,7 @@ impl<State: Send + Sync + 'static> ConcurrentListener<State> {
}

#[async_trait::async_trait]
impl<State: Send + Sync + 'static> Listener<State> for ConcurrentListener<State> {
impl<State: Clone + Send + Sync + 'static> Listener<State> for ConcurrentListener<State> {
async fn listen(&mut self, app: Server<State>) -> io::Result<()> {
let mut futures_unordered = FuturesUnordered::new();

Expand Down
4 changes: 2 additions & 2 deletions src/listener/failover_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use async_std::io;
#[derive(Default)]
pub struct FailoverListener<State>(Vec<Box<dyn Listener<State>>>);

impl<State: Send + Sync + 'static> FailoverListener<State> {
impl<State: Clone + Send + Sync + 'static> FailoverListener<State> {
/// creates a new FailoverListener
pub fn new() -> Self {
Self(vec![])
Expand Down Expand Up @@ -80,7 +80,7 @@ impl<State: Send + Sync + 'static> FailoverListener<State> {
}

#[async_trait::async_trait]
impl<State: Send + Sync + 'static> Listener<State> for FailoverListener<State> {
impl<State: Clone + Send + Sync + 'static> Listener<State> for FailoverListener<State> {
async fn listen(&mut self, app: Server<State>) -> io::Result<()> {
for listener in self.0.iter_mut() {
let app = app.clone();
Expand Down
2 changes: 1 addition & 1 deletion src/listener/parsed_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Display for ParsedListener {
}

#[async_trait::async_trait]
impl<State: Send + Sync + 'static> Listener<State> for ParsedListener {
impl<State: Clone + Send + Sync + 'static> Listener<State> for ParsedListener {
async fn listen(&mut self, app: Server<State>) -> io::Result<()> {
match self {
#[cfg(unix)]
Expand Down
4 changes: 2 additions & 2 deletions src/listener/tcp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl TcpListener {
}
}

fn handle_tcp<State: Send + Sync + 'static>(app: Server<State>, stream: TcpStream) {
fn handle_tcp<State: Clone + Send + Sync + 'static>(app: Server<State>, stream: TcpStream) {
task::spawn(async move {
let local_addr = stream.local_addr().ok();
let peer_addr = stream.peer_addr().ok();
Expand All @@ -69,7 +69,7 @@ fn handle_tcp<State: Send + Sync + 'static>(app: Server<State>, stream: TcpStrea
}

#[async_trait::async_trait]
impl<State: Send + Sync + 'static> Listener<State> for TcpListener {
impl<State: Clone + Send + Sync + 'static> Listener<State> for TcpListener {
async fn listen(&mut self, app: Server<State>) -> io::Result<()> {
self.connect().await?;
let listener = self.listener()?;
Expand Down
38 changes: 20 additions & 18 deletions src/listener/to_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use std::net::ToSocketAddrs;
/// # Other implementations
/// See below for additional provided implementations of ToListener.
pub trait ToListener<State: Send + Sync + 'static> {
pub trait ToListener<State: Clone + Send + Sync + 'static> {
type Listener: Listener<State>;
/// Transform self into a
/// [`Listener`](crate::listener::Listener). Unless self is
Expand All @@ -63,7 +63,7 @@ pub trait ToListener<State: Send + Sync + 'static> {
fn to_listener(self) -> io::Result<Self::Listener>;
}

impl<State: Send + Sync + 'static> ToListener<State> for Url {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for Url {
type Listener = ParsedListener;

fn to_listener(self) -> io::Result<Self::Listener> {
Expand Down Expand Up @@ -106,14 +106,14 @@ impl<State: Send + Sync + 'static> ToListener<State> for Url {
}
}

impl<State: Send + Sync + 'static> ToListener<State> for String {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for String {
type Listener = ParsedListener;
fn to_listener(self) -> io::Result<Self::Listener> {
ToListener::<State>::to_listener(self.as_str())
}
}

impl<State: Send + Sync + 'static> ToListener<State> for &str {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for &str {
type Listener = ParsedListener;

fn to_listener(self) -> io::Result<Self::Listener> {
Expand All @@ -133,36 +133,36 @@ impl<State: Send + Sync + 'static> ToListener<State> for &str {
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for async_std::path::PathBuf {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for async_std::path::PathBuf {
type Listener = UnixListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(UnixListener::from_path(self))
}
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for std::path::PathBuf {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for std::path::PathBuf {
type Listener = UnixListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(UnixListener::from_path(self))
}
}

impl<State: Send + Sync + 'static> ToListener<State> for async_std::net::TcpListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for async_std::net::TcpListener {
type Listener = TcpListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(TcpListener::from_listener(self))
}
}

impl<State: Send + Sync + 'static> ToListener<State> for std::net::TcpListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for std::net::TcpListener {
type Listener = TcpListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(TcpListener::from_listener(self))
}
}

impl<State: Send + Sync + 'static> ToListener<State> for (&str, u16) {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for (&str, u16) {
type Listener = TcpListener;

fn to_listener(self) -> io::Result<Self::Listener> {
Expand All @@ -171,65 +171,67 @@ impl<State: Send + Sync + 'static> ToListener<State> for (&str, u16) {
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for async_std::os::unix::net::UnixListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State>
for async_std::os::unix::net::UnixListener
{
type Listener = UnixListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(UnixListener::from_listener(self))
}
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for std::os::unix::net::UnixListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for std::os::unix::net::UnixListener {
type Listener = UnixListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(UnixListener::from_listener(self))
}
}

impl<State: Send + Sync + 'static> ToListener<State> for TcpListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for TcpListener {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for UnixListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for UnixListener {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

impl<State: Send + Sync + 'static> ToListener<State> for ConcurrentListener<State> {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for ConcurrentListener<State> {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

impl<State: Send + Sync + 'static> ToListener<State> for ParsedListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for ParsedListener {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

impl<State: Send + Sync + 'static> ToListener<State> for FailoverListener<State> {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for FailoverListener<State> {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

impl<State: Send + Sync + 'static> ToListener<State> for std::net::SocketAddr {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for std::net::SocketAddr {
type Listener = TcpListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(TcpListener::from_addrs(vec![self]))
}
}

impl<TL: ToListener<State>, State: Send + Sync + 'static> ToListener<State> for Vec<TL> {
impl<TL: ToListener<State>, State: Clone + Send + Sync + 'static> ToListener<State> for Vec<TL> {
type Listener = ConcurrentListener<State>;
fn to_listener(self) -> io::Result<Self::Listener> {
let mut concurrent_listener = ConcurrentListener::new();
Expand Down
Loading

0 comments on commit 50bc628

Please sign in to comment.