Skip to content

Commit

Permalink
Re-enabled diesel middleware and example
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinastone committed Jan 4, 2020
1 parent 12e339c commit 48a4c9c
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 75 deletions.
5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ members = [

## Middleware
"middleware/template",
"middleware/diesel",
# TODO: Re-enable middleware when their dependencies are updated
# "middleware/diesel",
# "middleware/jwt",

## Examples (these crates are not published)
Expand Down Expand Up @@ -66,8 +66,7 @@ members = [
"examples/static_assets",

# diesel
# TODO: Re-enable when the middleware is updated
# "examples/diesel",
"examples/diesel",

# openssl
# TODO: Re-enable when this example is updated
Expand Down
1 change: 1 addition & 0 deletions examples/diesel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ gotham = { path = "../../gotham/"}
gotham_derive = { path = "../../gotham_derive/" }
gotham_middleware_diesel = { path = "../../middleware/diesel"}
hyper = "0.13.1"
failure = "0.1"
futures = "0.3.1"
mime = "0.3"
log = "0.4"
Expand Down
80 changes: 44 additions & 36 deletions examples/diesel/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ use gotham::pipeline::{new_pipeline, single::single_pipeline};
use gotham::router::{builder::*, Router};
use gotham::state::{FromState, State};
use gotham_middleware_diesel::DieselMiddleware;
use hyper::{Body, StatusCode};
use hyper::{body, Body, StatusCode};
use serde_derive::Serialize;
use std::pin::Pin;
use std::str::from_utf8;

mod models;
Expand Down Expand Up @@ -45,43 +46,48 @@ struct RowsUpdated {

fn create_product_handler(mut state: State) -> Pin<Box<HandlerFuture>> {
let repo = Repo::borrow_from(&state).clone();
extract_json::<NewProduct>(&mut state)
.and_then(move |product| {
repo.run(move |conn| {
// Insert the `NewProduct` in the DB
async move {
let product = match extract_json::<NewProduct>(&mut state).await {
Ok(product) => product,
Err(e) => return Err((state, e)),
};

let rows = match repo
.run(move |conn| {
diesel::insert_into(products::table)
.values(&product)
.execute(&conn)
})
.map_err(|e| e.into_handler_error())
})
.then(|result| match result {
Ok(rows) => {
let body = serde_json::to_string(&RowsUpdated { rows })
.expect("Failed to serialise to json");
let res =
create_response(&state, StatusCode::CREATED, mime::APPLICATION_JSON, body);
future::ok((state, res))
}
Err(e) => future::err((state, e)),
})
.boxed()
.await
{
Ok(rows) => rows,
Err(e) => return Err((state, e.into_handler_error())),
};

let body =
serde_json::to_string(&RowsUpdated { rows }).expect("Failed to serialise to json");
let res = create_response(&state, StatusCode::CREATED, mime::APPLICATION_JSON, body);
Ok((state, res))
}
.boxed()
}

fn get_products_handler(state: State) -> Pin<Box<HandlerFuture>> {
use crate::schema::products::dsl::*;

let repo = Repo::borrow_from(&state).clone();
repo.run(move |conn| products.load::<Product>(&conn))
.then(|result| match result {
async move {
let result = repo.run(move |conn| products.load::<Product>(&conn)).await;
match result {
Ok(users) => {
let body = serde_json::to_string(&users).expect("Failed to serialize users.");
let res = create_response(&state, StatusCode::OK, mime::APPLICATION_JSON, body);
future::ok((state, res))
Ok((state, res))
}
Err(e) => future::err((state, e.into_handler_error())),
})
.boxed()
Err(e) => Err((state, e.into_handler_error())),
}
}
.boxed()
}

fn router(repo: Repo) -> Router {
Expand All @@ -103,19 +109,17 @@ where
e.into_handler_error().with_status(StatusCode::BAD_REQUEST)
}

fn extract_json<T>(state: &mut State) -> impl Future<Output = Result<T, HandlerError>>
async fn extract_json<T>(state: &mut State) -> Result<T, HandlerError>
where
T: serde::de::DeserializeOwned,
{
Body::take_from(state)
.concat2()
let body = body::to_bytes(Body::take_from(state))
.map_err(bad_request)
.await?;
let b = body.to_vec();
from_utf8(&b)
.map_err(bad_request)
.and_then(|body| {
let b = body.to_vec();
from_utf8(&b)
.map_err(bad_request)
.and_then(|s| serde_json::from_str::<T>(s).map_err(bad_request))
})
.and_then(|s| serde_json::from_str::<T>(s).map_err(bad_request))
}

/// Start a server and use a `Router` to dispatch requests
Expand All @@ -135,11 +139,11 @@ fn main() {
#[cfg(test)]
mod tests {
use super::*;
use failure::Fail;
use gotham::test::TestServer;
use gotham_middleware_diesel::Repo;
use hyper::StatusCode;
use std::str;
use tokio::runtime;

static DATABASE_URL: &str = ":memory:";

Expand All @@ -152,7 +156,9 @@ mod tests {
#[test]
fn get_empty_products() {
let repo = Repo::with_test_transactions(DATABASE_URL);
runtime::run(repo.run(|conn| embedded_migrations::run(&conn).map_err(|_| ())));
let mut runtime = tokio::runtime::Runtime::new().unwrap();
let _ = runtime
.block_on(repo.run(|conn| embedded_migrations::run(&conn).map_err(|e| e.compat())));
let test_server = TestServer::new(router(repo)).unwrap();
let response = test_server
.client()
Expand All @@ -171,7 +177,9 @@ mod tests {
#[test]
fn create_and_retrieve_product() {
let repo = Repo::with_test_transactions(DATABASE_URL);
runtime::run(repo.run(|conn| embedded_migrations::run(&conn).map_err(|_| ())));
let mut runtime = tokio::runtime::Runtime::new().unwrap();
let _ = runtime
.block_on(repo.run(|conn| embedded_migrations::run(&conn).map_err(|e| e.compat())));
let test_server = TestServer::new(router(repo)).unwrap();

// First we'll insert something into the DB with a post
Expand Down
2 changes: 1 addition & 1 deletion middleware/diesel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ gotham = "0.5.0-dev"
gotham_derive = "0.5.0-dev"
diesel = { version = "1.3", features = ["r2d2"] }
r2d2 = "0.8"
tokio-threadpool = "0.1"
tokio = { version = "0.2.6", features = ["full"] }
log = "0.4"

[dev-dependencies]
Expand Down
37 changes: 20 additions & 17 deletions middleware/diesel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
//! # use gotham_middleware_diesel::{self, DieselMiddleware};
//! # use diesel::{RunQueryDsl, SqliteConnection};
//! # use hyper::StatusCode;
//! # use futures::future::Future;
//! # use futures::prelude::*;
//! # use gotham::test::TestServer;
//! # use std::pin::Pin;
//!
//! pub type Repo = gotham_middleware_diesel::Repo<SqliteConnection>;
//!
Expand All @@ -37,24 +38,25 @@
//! })
//! }
//!
//! fn handler(state: State) -> Box<HandlerFuture> {
//! fn handler(state: State) -> Pin<Box<HandlerFuture>> {
//! let repo = Repo::borrow_from(&state).clone();
//! // As an example, we perform the query:
//! // `SELECT 1`
//! let f = repo
//! .run(move |conn| {
//! async move {
//! let result = repo.run(move |conn| {
//! diesel::select(diesel::dsl::sql("1"))
//! .load::<i64>(&conn)
//! .map(|v| v.into_iter().next().expect("no results"))
//! }).then(|result| match result {
//! Ok(n) => {
//! let body = format!("result: {}", n);
//! let res = create_response(&state, StatusCode::OK, mime::TEXT_PLAIN, body);
//! Ok((state, res))
//! },
//! Err(e) => Err((state, e.into_handler_error())),
//! });
//! Box::new(f)
//! }).await;
//! match result {
//! Ok(n) => {
//! let body = format!("result: {}", n);
//! let res = create_response(&state, StatusCode::OK, mime::TEXT_PLAIN, body);
//! Ok((state, res))
//! },
//! Err(e) => Err((state, e.into_handler_error())),
//! }
//! }.boxed()
//! }
//!
//! # fn main() {
Expand All @@ -72,10 +74,11 @@
#![doc(test(no_crate_inject, attr(allow(unused_variables), deny(warnings))))]

use diesel::Connection;
use futures::future::{self, Future};
use futures::prelude::*;
use log::{error, trace};
use std::io;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::pin::Pin;
use std::process;

use gotham::handler::HandlerFuture;
Expand Down Expand Up @@ -149,9 +152,9 @@ impl<T> Middleware for DieselMiddleware<T>
where
T: Connection + 'static,
{
fn call<Chain>(self, mut state: State, chain: Chain) -> Box<HandlerFuture>
fn call<Chain>(self, mut state: State, chain: Chain) -> Pin<Box<HandlerFuture>>
where
Chain: FnOnce(State) -> Box<HandlerFuture> + 'static,
Chain: FnOnce(State) -> Pin<Box<HandlerFuture>> + 'static,
Self: Sized,
{
trace!("[{}] pre chain", request_id(&state));
Expand All @@ -163,6 +166,6 @@ where
}
future::ok((state, response))
});
Box::new(f)
f.boxed()
}
}
25 changes: 7 additions & 18 deletions middleware/diesel/src/repo.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use diesel::r2d2::ConnectionManager;
use diesel::Connection;
use futures::future;
use futures::future::{poll_fn, Future};
use gotham_derive::StateData;
use log::error;
use r2d2::{CustomizeConnection, Pool, PooledConnection};
use tokio_threadpool::blocking;
use tokio::task;

/// A database "repository", for running database workloads.
/// Manages a connection pool and running blocking tasks using
Expand Down Expand Up @@ -135,29 +133,20 @@ where
/// Runs the given closure in a way that is safe for blocking IO to the
/// database without blocking the tokio reactor.
/// The closure will be passed a `Connection` from the pool to use.
pub fn run<F, R, E>(&self, f: F) -> impl Future<Item = R, Error = E>
pub async fn run<F, R, E>(&self, f: F) -> Result<R, E>
where
F: FnOnce(PooledConnection<ConnectionManager<T>>) -> Result<R, E>
+ Send
+ std::marker::Unpin
+ 'static,
T: Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let pool = self.connection_pool.clone();
// `tokio_threadpool::blocking` returns a `Poll` which can be converted into a future
// using `poll_fn`.
// `f.take()` allows the borrow checker to be sure `f` is not moved into the inner closure
// multiple times if `poll_fn` is called multple times.
let mut f = Some(f);
poll_fn(move || blocking(|| (f.take().unwrap())(pool.get().unwrap()))).then(
|future_result| match future_result {
Ok(query_result) => match query_result {
Ok(result) => future::ok(result),
Err(error) => future::err(error),
},
Err(e) => panic!("Error running async database task: {:?}", e),
},
)
task::spawn_blocking(move || f(pool.get().unwrap()))

This comment has been minimized.

Copy link
@pksunkara

pksunkara Jan 4, 2020

Contributor

Wow, I didn't realise the solution is this simple. I would blame tokio for not having a proper changelog. 😄

.await
.unwrap_or_else(|e| panic!("Error running async database task: {:?}", e))
}
}

Expand Down

0 comments on commit 48a4c9c

Please sign in to comment.