Skip to content

Commit

Permalink
Recover from timeout errors in the service
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed Apr 28, 2021
1 parent 3c5d8c5 commit a7e8f61
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 12 deletions.
48 changes: 44 additions & 4 deletions tests/integration_tests/tests/timeout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async fn cancelation_on_timeout() {
}

#[tokio::test]
async fn picks_the_shortest_timeout() {
async fn picks_server_timeout_if_thats_sorter() {
struct Svc;

#[tonic::async_trait]
Expand Down Expand Up @@ -80,10 +80,50 @@ async fn picks_the_shortest_timeout() {
// 10 hours
.insert("grpc-timeout", "10H".parse().unwrap());

// TODO(david): for some reason this fails with "h2 protocol error: protocol error: unexpected
// internal error encountered". Seems to be happening on `master` as well. Bug?
let res = client.unary_call(req).await;
dbg!(&res);
let err = res.unwrap_err();
assert!(err.message().contains("Timeout expired"));
assert_eq!(err.code(), Code::Cancelled);
}

#[tokio::test]
async fn picks_client_timeout_if_thats_sorter() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, _req: Request<Input>) -> Result<Response<Output>, Status> {
// Wait for a time longer than the timeout
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::new(Svc);

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

tokio::spawn(async move {
Server::builder()
.timeout(Duration::from_secs(9001))
.add_service(svc)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
.await
.unwrap();
});

let mut client = test_client::TestClient::connect(format!("http://{}", addr))
.await
.unwrap();

let mut req = Request::new(Input {});
req.metadata_mut()
// 100 ms
.insert("grpc-timeout", "100m".parse().unwrap());

let res = client.unary_call(req).await;
let err = res.unwrap_err();
assert!(err.message().contains("Timeout expired"));
assert_eq!(err.code(), Code::Cancelled);
}
2 changes: 1 addition & 1 deletion tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ h2 = { version = "0.3", optional = true }
hyper = { version = "0.14.2", features = ["full"], optional = true }
tokio = { version = "1.0.1", features = ["net"], optional = true }
tokio-stream = "0.1"
tower = { version = "0.4.4", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
tracing-futures = { version = "0.2", optional = true }

# rustls
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ impl Status {
Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string()))
}

fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
let mut cause = Some(err);

while let Some(err) = cause {
Expand Down
12 changes: 6 additions & 6 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod conn;
mod incoming;
mod recover_error;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
Expand All @@ -21,8 +22,9 @@ pub(crate) use tokio_rustls::server::TlsStream;
#[cfg(feature = "tls")]
use crate::transport::Error;

use self::recover_error::RecoverError;
use super::{
service::{Or, Routes, ServerIo},
service::{GrpcTimeout, Or, Routes, ServerIo},
BoxFuture,
};
use crate::{body::BoxBody, request::ConnectionInfo};
Expand All @@ -42,10 +44,7 @@ use std::{
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::{
limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service,
ServiceBuilder,
};
use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder};
use tracing_futures::{Instrument, Instrumented};

type BoxService = tower::util::BoxService<Request<Body>, Response<BoxBody>, crate::Error>;
Expand Down Expand Up @@ -655,8 +654,9 @@ where

Box::pin(async move {
let svc = ServiceBuilder::new()
.layer_fn(RecoverError::new)
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.option_layer(timeout.map(TimeoutLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, timeout))
.service(svc);

let svc = BoxService::new(Svc {
Expand Down
75 changes: 75 additions & 0 deletions tonic/src/transport/server/recover_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use crate::{body::BoxBody, Status};
use futures_util::ready;
use http::Response;
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::Service;

/// Middleware that attempts to recover from service errors by turning them into a response built
/// from the `Status`.
#[derive(Debug, Clone)]
pub(crate) struct RecoverError<S> {
inner: S,
}

impl<S> RecoverError<S> {
pub(crate) fn new(inner: S) -> Self {
Self { inner }
}
}

impl<S, R> Service<R> for RecoverError<S>
where
S: Service<R, Response = Response<BoxBody>>,
S::Error: Into<crate::Error>,
{
type Response = Response<BoxBody>;
type Error = crate::Error;
type Future = ResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}

fn call(&mut self, req: R) -> Self::Future {
ResponseFuture {
inner: self.inner.call(req),
}
}
}

#[pin_project]
pub(crate) struct ResponseFuture<F> {
#[pin]
inner: F,
}

impl<F, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<BoxBody>, E>>,
E: Into<crate::Error>,
{
type Output = Result<Response<BoxBody>, crate::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let result: Result<Response<BoxBody>, crate::Error> =
ready!(self.project().inner.poll(cx)).map_err(Into::into);

match result {
Ok(res) => Poll::Ready(Ok(res)),
Err(err) => {
if let Some(status) = Status::try_from_error(&*err) {
let mut res = Response::new(BoxBody::empty());
status.add_header(res.headers_mut()).unwrap();
Poll::Ready(Ok(res))
} else {
Poll::Ready(Err(err))
}
}
}
}
}
1 change: 1 addition & 0 deletions tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub(crate) use self::add_origin::AddOrigin;
pub(crate) use self::connection::Connection;
pub(crate) use self::connector::connector;
pub(crate) use self::discover::DynamicServiceStream;
pub(crate) use self::grpc_timeout::GrpcTimeout;
pub(crate) use self::io::ServerIo;
pub(crate) use self::router::{Or, Routes};
#[cfg(feature = "tls")]
Expand Down

0 comments on commit a7e8f61

Please sign in to comment.