Skip to content

Commit

Permalink
feat(transport): Support timeouts with "grpc-timeout" header (#606)
Browse files Browse the repository at this point in the history
* transport: Support timeouts with "grpc-timeout" header

* Apply suggestions from code review

Co-authored-by: Lucio Franco <[email protected]>

* Timeout -> GrpcTimeout and export TimeoutExpired

* Clean up imports

* Give header name a more proper home

* Add fuzz tests for parsing header value into `grpc-timeout`

* Map `TimeoutExpired` to `cancelled` status

* Recover from timeout errors in the service

* Refactor tests

* Fix CI

* Fix CI, again

Co-authored-by: Lucio Franco <[email protected]>
  • Loading branch information
davidpdrsn and LucioFranco authored Apr 29, 2021
1 parent 4926c60 commit 9ff4f7b
Show file tree
Hide file tree
Showing 13 changed files with 494 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ jobs:

env:
RUSTFLAGS: "-D warnings"
# run a lot of quickcheck iterations
QUICKCHECK_TESTS: 1000

steps:
- uses: hecrj/setup-rust-action@master
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ bytes = "1.0"

[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] }
tokio-stream = { version = "0.1.5", features = ["net"] }

[build-dependencies]
tonic-build = { path = "../../tonic-build" }
92 changes: 92 additions & 0 deletions tests/integration_tests/tests/timeout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::{net::SocketAddr, time::Duration};
use tokio::net::TcpListener;
use tonic::{transport::Server, Code, Request, Response, Status};

#[tokio::test]
async fn cancelation_on_timeout() {
let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await;

let mut client = test_client::TestClient::connect(format!("http://{}", addr))
.await
.unwrap();

let mut req = Request::new(Input {});
req.metadata_mut()
// 500 ms
.insert("grpc-timeout", "500m".parse().unwrap());

let res = client.unary_call(req).await;

let err = res.unwrap_err();
assert!(err.message().contains("Timeout expired"));
assert_eq!(err.code(), Code::Cancelled);
}

#[tokio::test]
async fn picks_server_timeout_if_thats_sorter() {
let addr = run_service_in_background(Duration::from_secs(1), Duration::from_millis(100)).await;

let mut client = test_client::TestClient::connect(format!("http://{}", addr))
.await
.unwrap();

let mut req = Request::new(Input {});
req.metadata_mut()
// 10 hours
.insert("grpc-timeout", "10H".parse().unwrap());

let res = client.unary_call(req).await;
let err = res.unwrap_err();
assert!(err.message().contains("Timeout expired"));
assert_eq!(err.code(), Code::Cancelled);
}

#[tokio::test]
async fn picks_client_timeout_if_thats_sorter() {
let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await;

let mut client = test_client::TestClient::connect(format!("http://{}", addr))
.await
.unwrap();

let mut req = Request::new(Input {});
req.metadata_mut()
// 100 ms
.insert("grpc-timeout", "100m".parse().unwrap());

let res = client.unary_call(req).await;
let err = res.unwrap_err();
assert!(err.message().contains("Timeout expired"));
assert_eq!(err.code(), Code::Cancelled);
}

async fn run_service_in_background(latency: Duration, server_timeout: Duration) -> SocketAddr {
struct Svc {
latency: Duration,
}

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, _req: Request<Input>) -> Result<Response<Output>, Status> {
tokio::time::sleep(self.latency).await;
Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::new(Svc { latency });

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

tokio::spawn(async move {
Server::builder()
.timeout(server_timeout)
.add_service(svc)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
.await
.unwrap();
});

addr
}
7 changes: 5 additions & 2 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ transport = [
"tokio",
"tower",
"tracing-futures",
"tokio/macros"
"tokio/macros",
"tokio/time",
]
tls = ["transport", "tokio-rustls"]
tls-roots = ["tls", "rustls-native-certs"]
Expand Down Expand Up @@ -68,7 +69,7 @@ h2 = { version = "0.3", optional = true }
hyper = { version = "0.14.2", features = ["full"], optional = true }
tokio = { version = "1.0.1", features = ["net"], optional = true }
tokio-stream = "0.1"
tower = { version = "0.4.4", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
tracing-futures = { version = "0.2", optional = true }

# rustls
Expand All @@ -80,6 +81,8 @@ tokio = { version = "1.0", features = ["rt", "macros"] }
static_assertions = "1.0"
rand = "0.8"
bencher = "0.1.5"
quickcheck = "1.0"
quickcheck_macros = "1.0"

[package.metadata.docs.rs]
all-features = true
Expand Down
6 changes: 4 additions & 2 deletions tonic/src/metadata/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,17 @@ pub struct OccupiedEntry<'a, VE: ValueEncoding> {
phantom: PhantomData<VE>,
}

#[cfg(feature = "transport")]
pub(crate) const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout";

// ===== impl MetadataMap =====

impl MetadataMap {
// Headers reserved by the gRPC protocol.
pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 8] = [
pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 7] = [
"te",
"user-agent",
"content-type",
"grpc-timeout",
"grpc-message",
"grpc-encoding",
"grpc-message-type",
Expand Down
3 changes: 3 additions & 0 deletions tonic/src/metadata/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub use self::value::AsciiMetadataValue;
pub use self::value::BinaryMetadataValue;
pub use self::value::MetadataValue;

#[cfg(feature = "transport")]
pub(crate) use self::map::GRPC_TIMEOUT_HEADER;

/// The metadata::errors module contains types for errors that can occur
/// while handling gRPC custom metadata.
pub mod errors {
Expand Down
6 changes: 5 additions & 1 deletion tonic/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ impl Status {
Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string()))
}

fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
let mut cause = Some(err);

while let Some(err) = cause {
Expand All @@ -331,6 +331,10 @@ impl Status {
if let Some(h2) = err.downcast_ref::<h2::Error>() {
return Some(Status::from_h2_error(h2));
}

if let Some(timeout) = err.downcast_ref::<crate::transport::TimeoutExpired>() {
return Some(Status::cancelled(timeout.to_string()));
}
}

cause = err.source();
Expand Down
2 changes: 2 additions & 0 deletions tonic/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ pub use self::channel::{Channel, Endpoint};
pub use self::error::Error;
#[doc(inline)]
pub use self::server::{NamedService, Server};
#[doc(inline)]
pub use self::service::TimeoutExpired;
pub use self::tls::{Certificate, Identity};
pub use hyper::{Body, Uri};

Expand Down
12 changes: 6 additions & 6 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod conn;
mod incoming;
mod recover_error;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
Expand All @@ -21,8 +22,9 @@ pub(crate) use tokio_rustls::server::TlsStream;
#[cfg(feature = "tls")]
use crate::transport::Error;

use self::recover_error::RecoverError;
use super::{
service::{Or, Routes, ServerIo},
service::{GrpcTimeout, Or, Routes, ServerIo},
BoxFuture,
};
use crate::{body::BoxBody, request::ConnectionInfo};
Expand All @@ -42,10 +44,7 @@ use std::{
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::{
limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service,
ServiceBuilder,
};
use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder};
use tracing_futures::{Instrument, Instrumented};

type BoxService = tower::util::BoxService<Request<Body>, Response<BoxBody>, crate::Error>;
Expand Down Expand Up @@ -655,8 +654,9 @@ where

Box::pin(async move {
let svc = ServiceBuilder::new()
.layer_fn(RecoverError::new)
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.option_layer(timeout.map(TimeoutLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, timeout))
.service(svc);

let svc = BoxService::new(Svc {
Expand Down
75 changes: 75 additions & 0 deletions tonic/src/transport/server/recover_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use crate::{body::BoxBody, Status};
use futures_util::ready;
use http::Response;
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::Service;

/// Middleware that attempts to recover from service errors by turning them into a response built
/// from the `Status`.
#[derive(Debug, Clone)]
pub(crate) struct RecoverError<S> {
inner: S,
}

impl<S> RecoverError<S> {
pub(crate) fn new(inner: S) -> Self {
Self { inner }
}
}

impl<S, R> Service<R> for RecoverError<S>
where
S: Service<R, Response = Response<BoxBody>>,
S::Error: Into<crate::Error>,
{
type Response = Response<BoxBody>;
type Error = crate::Error;
type Future = ResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}

fn call(&mut self, req: R) -> Self::Future {
ResponseFuture {
inner: self.inner.call(req),
}
}
}

#[pin_project]
pub(crate) struct ResponseFuture<F> {
#[pin]
inner: F,
}

impl<F, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<BoxBody>, E>>,
E: Into<crate::Error>,
{
type Output = Result<Response<BoxBody>, crate::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let result: Result<Response<BoxBody>, crate::Error> =
ready!(self.project().inner.poll(cx)).map_err(Into::into);

match result {
Ok(res) => Poll::Ready(Ok(res)),
Err(err) => {
if let Some(status) = Status::try_from_error(&*err) {
let mut res = Response::new(BoxBody::empty());
status.add_header(res.headers_mut()).unwrap();
Poll::Ready(Ok(res))
} else {
Poll::Ready(Err(err))
}
}
}
}
}
5 changes: 2 additions & 3 deletions tonic/src/transport/service/connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::super::BoxFuture;
use super::{reconnect::Reconnect, AddOrigin, UserAgent};
use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent};
use crate::{body::BoxBody, transport::Endpoint};
use http::Uri;
use hyper::client::conn::Builder;
Expand All @@ -14,7 +14,6 @@ use tower::load::Load;
use tower::{
layer::Layer,
limit::{concurrency::ConcurrencyLimitLayer, rate::RateLimitLayer},
timeout::TimeoutLayer,
util::BoxService,
ServiceBuilder, ServiceExt,
};
Expand Down Expand Up @@ -53,7 +52,7 @@ impl Connection {
let stack = ServiceBuilder::new()
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.option_layer(endpoint.timeout.map(TimeoutLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout))
.option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
.option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
.into_inner();
Expand Down
Loading

0 comments on commit 9ff4f7b

Please sign in to comment.