Skip to content

Commit

Permalink
fix(codec): Fix streaming reponses w/ many status (#689)
Browse files Browse the repository at this point in the history
Closes #681
  • Loading branch information
LucioFranco authored Jul 1, 2021
1 parent 2b60a00 commit 737ace3
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 6 deletions.
1 change: 1 addition & 0 deletions tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ futures = "0.3"
tower = { version = "0.4", features = [] }
http-body = "0.4"
http = "0.2"
tracing-subscriber = "0.2"

[build-dependencies]
tonic-build = { path = "../../tonic-build" }
1 change: 1 addition & 0 deletions tests/integration_tests/build.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
fn main() {
tonic_build::compile_protos("proto/test.proto").unwrap();
tonic_build::compile_protos("proto/stream.proto").unwrap();
}
10 changes: 10 additions & 0 deletions tests/integration_tests/proto/stream.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
syntax = "proto3";

package stream;

service TestStream {
rpc StreamCall(InputStream) returns (stream OutputStream);
}

message InputStream {}
message OutputStream {}
1 change: 1 addition & 0 deletions tests/integration_tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod pb {
tonic::include_proto!("test");
tonic::include_proto!("stream");
}
63 changes: 62 additions & 1 deletion tests/integration_tests/tests/status.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use bytes::Bytes;
use futures_util::FutureExt;
use integration_tests::pb::{test_client, test_server, Input, Output};
use integration_tests::pb::{
test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output,
OutputStream,
};
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::metadata::{MetadataMap, MetadataValue};
Expand Down Expand Up @@ -117,3 +120,61 @@ async fn status_with_metadata() {

jh.await.unwrap();
}

type Stream<T> = std::pin::Pin<
Box<dyn futures::Stream<Item = std::result::Result<T, Status>> + Send + Sync + 'static>,
>;

#[tokio::test]
async fn status_from_server_stream() {
trace_init();

struct Svc;

#[tonic::async_trait]
impl test_stream_server::TestStream for Svc {
type StreamCallStream = Stream<OutputStream>;

async fn stream_call(
&self,
_: Request<InputStream>,
) -> Result<Response<Self::StreamCallStream>, Status> {
let s = futures::stream::iter(vec![
Err::<OutputStream, _>(Status::unavailable("foo")),
Err::<OutputStream, _>(Status::unavailable("bar")),
]);
Ok(Response::new(Box::pin(s) as Self::StreamCallStream))
}
}

let svc = test_stream_server::TestStreamServer::new(Svc);

tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve("127.0.0.1:1339".parse().unwrap())
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let mut client = test_stream_client::TestStreamClient::connect("http://127.0.0.1:1339")
.await
.unwrap();

let mut stream = client
.stream_call(InputStream {})
.await
.unwrap()
.into_inner();

assert_eq!(stream.message().await.unwrap_err().message(), "foo");
assert_eq!(stream.message().await.unwrap(), None);
}

fn trace_init() {
let _ = tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
}
6 changes: 5 additions & 1 deletion tonic/src/codec/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,11 @@ impl<T> Stream for Streaming<T> {
match ready!(Pin::new(&mut self.body).poll_trailers(cx)) {
Ok(trailer) => {
if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) {
return Some(Err(e)).into();
if let Some(e) = e {
return Some(Err(e)).into();
} else {
return Poll::Ready(None);
}
} else {
self.trailers = trailer.map(MetadataMap::from_headers);
}
Expand Down
11 changes: 10 additions & 1 deletion tonic/src/codec/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub(crate) struct EncodeBody<S> {
inner: S,
error: Option<Status>,
role: Role,
is_end_stream: bool,
}

impl<S> EncodeBody<S>
Expand All @@ -99,6 +100,7 @@ where
inner,
error: None,
role: Role::Client,
is_end_stream: false,
}
}

Expand All @@ -107,6 +109,7 @@ where
inner,
error: None,
role: Role::Server,
is_end_stream: false,
}
}
}
Expand All @@ -119,7 +122,7 @@ where
type Error = Status;

fn is_end_stream(&self) -> bool {
false
self.is_end_stream
}

fn poll_data(
Expand Down Expand Up @@ -148,7 +151,13 @@ where
Role::Client => Poll::Ready(Ok(None)),
Role::Server => {
let self_proj = self.project();

if *self_proj.is_end_stream {
return Poll::Ready(Ok(None));
}

let status = if let Some(status) = self_proj.error.take() {
*self_proj.is_end_stream = true;
status
} else {
Status::new(Code::Ok, "")
Expand Down
13 changes: 10 additions & 3 deletions tonic/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,13 +657,13 @@ impl Error for Status {
pub(crate) fn infer_grpc_status(
trailers: Option<&HeaderMap>,
status_code: http::StatusCode,
) -> Result<(), Status> {
) -> Result<(), Option<Status>> {
if let Some(trailers) = trailers {
if let Some(status) = Status::from_header_map(&trailers) {
if status.code() == Code::Ok {
return Ok(());
} else {
return Err(status);
return Err(status.into());
}
}
}
Expand All @@ -678,6 +678,13 @@ pub(crate) fn infer_grpc_status(
| http::StatusCode::BAD_GATEWAY
| http::StatusCode::SERVICE_UNAVAILABLE
| http::StatusCode::GATEWAY_TIMEOUT => Code::Unavailable,
// We got a 200 but no trailers, we can infer that this request is finished.
//
// This can happen when a streaming response sends two Status but
// gRPC requires that we end the stream after the first status.
//
// https://github.com/hyperium/tonic/issues/681
http::StatusCode::OK => return Err(None),
_ => Code::Unknown,
};

Expand All @@ -686,7 +693,7 @@ pub(crate) fn infer_grpc_status(
status_code.as_u16(),
);
let status = Status::new(code, msg);
Err(status)
Err(status.into())
}

// ===== impl Code =====
Expand Down

0 comments on commit 737ace3

Please sign in to comment.