Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add auto_encoding and improve send_compressed behavior #2058

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ members = [
"tonic-health",
"tonic-types",
"tonic-reflection",
"tonic-web", # Non-published crates
"tonic-web", # Non-published crates
"examples",
"codegen",
"interop", # Tests
"interop", # Tests
"tests/disable_comments",
"tests/included_service",
"tests/same_name",
Expand All @@ -22,6 +22,7 @@ members = [
"tests/stream_conflict",
"tests/root-crate-path",
"tests/compression",
"tests/various_compression_formats",
"tests/web",
"tests/service_named_result",
"tests/use_arc_self",
Expand Down
13 changes: 13 additions & 0 deletions tests/various_compression_formats/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "various_compression_formats"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
prost = "0.13"
tonic = { path = "../../tonic", features = ["gzip","zstd"]}
tokio = { version = "1.36.2", features = ["macros", "rt-multi-thread"] }

[build-dependencies]
tonic-build = { path = "../../tonic-build" }
4 changes: 4 additions & 0 deletions tests/various_compression_formats/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("proto/proto_box.proto")?;
Ok(())
}
15 changes: 15 additions & 0 deletions tests/various_compression_formats/proto/proto_box.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
syntax = "proto3";

package proto_box;

service ProtoService {
rpc Rpc(Input) returns (Output);
}

message Input {
string data = 1;
}

message Output {
string data = 1;
}
3 changes: 3 additions & 0 deletions tests/various_compression_formats/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod proto_box {
tonic::include_proto!("proto_box");
}
206 changes: 206 additions & 0 deletions tests/various_compression_formats/tests/auto_encoding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
use std::error::Error;

use tokio::net::TcpListener;
use tokio::sync::oneshot;

use tonic::codegen::CompressionEncoding;
use tonic::transport::{server::TcpIncoming, Channel, Server};
use tonic::{Request, Response, Status};

use various_compression_formats::proto_box::{
proto_service_client::ProtoServiceClient,
proto_service_server::{ProtoService, ProtoServiceServer},
Input, Output,
};

const LOCALHOST: &str = "127.0.0.1:0";

#[derive(Default)]
pub struct ServerTest;

#[tonic::async_trait]
impl ProtoService for ServerTest {
async fn rpc(&self, request: Request<Input>) -> Result<Response<Output>, Status> {
println!("Server received request: {:?}", request);

Ok(Response::new(Output {
data: format!("Received: {}", request.into_inner().data),
}))
}
}

struct ClientWrapper {
client: ProtoServiceClient<Channel>,
}

impl ClientWrapper {
async fn new(
address: &str,
accept: Option<CompressionEncoding>,
) -> Result<Self, Box<dyn Error + Send + Sync>> {
let channel = Channel::from_shared(address.to_string())?.connect().await?;
let mut client = ProtoServiceClient::new(channel);

if let Some(encoding) = accept {
client = client.accept_compressed(encoding);
}

Ok(Self { client })
}

async fn send_request(
&mut self,
data: String,
) -> Result<Response<Output>, Box<dyn Error + Send + Sync>> {
let request = Request::new(Input { data });

println!("Client sending request: {:?}", request);

let response = self.client.rpc(request).await?;

println!("Client response headers: {:?}", response.metadata());

Ok(response)
}
}

async fn start_server(
listener: TcpListener,
send: Option<CompressionEncoding>,
auto: bool,
) -> oneshot::Sender<()> {
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let srv = ServerTest::default();
let mut service = ProtoServiceServer::new(srv);

if let Some(encoding) = send {
service = service.send_compressed(encoding);
}

if auto {
service = service.auto_encoding();
}

let server = Server::builder()
.add_service(service)
.serve_with_incoming_shutdown(
TcpIncoming::from_listener(listener, true, None).unwrap(),
async {
shutdown_rx.await.ok();
},
);

tokio::spawn(async move {
server.await.expect("Server crashed");
});

shutdown_tx
}

async fn run_client_test(
address: &str,
client_accept: Option<CompressionEncoding>,
expected_encoding: Option<&str>,
data: &str,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let mut client = ClientWrapper::new(address, client_accept).await?;
let response = client.send_request(data.to_string()).await?;

match expected_encoding {
Some(encoding) => {
let grpc_encoding = response
.metadata()
.get("grpc-encoding")
.expect("Missing 'grpc-encoding' header");
assert_eq!(grpc_encoding, encoding);
}
None => {
assert!(
!response.metadata().contains_key("grpc-encoding"),
"Expected no 'grpc-encoding' header"
);
}
}

Ok(())
}

#[tokio::test]
async fn test_compression_behavior() -> Result<(), Box<dyn Error + Send + Sync>> {
let listener = TcpListener::bind(LOCALHOST).await?;
let address = format!("http://{}", listener.local_addr().unwrap());

// The server is not specified to send data with any compression
let shutdown_tx = start_server(listener, None, false).await;

tokio::time::sleep(std::time::Duration::from_secs(1)).await;

tokio::try_join!(
// Client 1 can only accept gzip encoding or uncompressed,
// so all data must be returned uncompressed
run_client_test(&address, Some(CompressionEncoding::Gzip), None, "Client 1"),
// Client 2 can only accept non-compressed data,
// so all data must be returned uncompressed
run_client_test(&address, None, None, "Client 2")
)?;

shutdown_tx.send(()).unwrap();

let listener = TcpListener::bind(LOCALHOST).await?;
let address = format!("http://{}", listener.local_addr().unwrap());

// The server is specified to send data with zstd compression
let shutdown_tx = start_server(listener, Some(CompressionEncoding::Zstd), false).await;

tokio::time::sleep(std::time::Duration::from_secs(1)).await;

tokio::try_join!(
// Client 3 can only accept zstd encoding or uncompressed,
// so all data must be returned compressed with zstd
run_client_test(
&address,
Some(CompressionEncoding::Zstd),
Some("zstd"),
"Client 3"
),
// Client 4 can only accept Gzip encoding or uncompressed,
// so all data must be returned uncompressed
run_client_test(&address, Some(CompressionEncoding::Gzip), None, "Client 4")
)?;

shutdown_tx.send(()).unwrap();

Ok(())
}

#[tokio::test]
async fn test_auto_encoding_behavior() -> Result<(), Box<dyn Error + Send + Sync>> {
let listener = TcpListener::bind(LOCALHOST).await?;
let address = format!("http://{}", listener.local_addr().unwrap());

// The server returns in the compression format that the client prefers
let shutdown_tx = start_server(listener, Some(CompressionEncoding::Gzip), true).await;

tokio::time::sleep(std::time::Duration::from_secs(1)).await;

tokio::try_join!(
// Client 5 can accept gzip encoding or uncompressed, so all data must be returned compressed with gzip
run_client_test(
&address,
Some(CompressionEncoding::Gzip),
Some("gzip"),
"Client 5"
),
// Client 6 can accept zstd encoding or uncompressed, so all data must be returned compressed with zstd
run_client_test(
&address,
Some(CompressionEncoding::Zstd),
Some("zstd"),
"Client 6"
)
)?;

shutdown_tx.send(()).unwrap();

Ok(())
}
31 changes: 31 additions & 0 deletions tonic-build/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ pub(crate) fn generate_internal<T: Service>(
self.send_compression_encodings.enable(encoding);
self
}

/// Automatically determine the encoding to use based on the request headers.
#[must_use]
pub fn auto_encoding(mut self) -> Self {
self.auto_encoding = true;
self
}
};

let configure_max_message_size_methods = quote! {
Expand Down Expand Up @@ -117,6 +124,7 @@ pub(crate) fn generate_internal<T: Service>(
send_compression_encodings: EnabledCompressionEncodings,
max_decoding_message_size: Option<usize>,
max_encoding_message_size: Option<usize>,
auto_encoding: bool,
}

impl<T> #server_service<T> {
Expand All @@ -131,6 +139,7 @@ pub(crate) fn generate_internal<T: Service>(
send_compression_encodings: Default::default(),
max_decoding_message_size: None,
max_encoding_message_size: None,
auto_encoding: false,
}
}

Expand Down Expand Up @@ -184,6 +193,7 @@ pub(crate) fn generate_internal<T: Service>(
send_compression_encodings: self.send_compression_encodings,
max_decoding_message_size: self.max_decoding_message_size,
max_encoding_message_size: self.max_encoding_message_size,
auto_encoding: self.auto_encoding,
}
}
}
Expand Down Expand Up @@ -473,6 +483,7 @@ fn generate_unary<T: Method>(
let send_compression_encodings = self.send_compression_encodings;
let max_decoding_message_size = self.max_decoding_message_size;
let max_encoding_message_size = self.max_encoding_message_size;
let auto_encoding = self.auto_encoding;
let inner = self.inner.clone();
let fut = async move {
let method = #service_ident(inner);
Expand All @@ -482,6 +493,10 @@ fn generate_unary<T: Method>(
.apply_compression_config(accept_compression_encodings, send_compression_encodings)
.apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);

if auto_encoding {
grpc = grpc.auto_encoding();
}

let res = grpc.unary(method, req).await;
Ok(res)
};
Expand Down Expand Up @@ -540,6 +555,7 @@ fn generate_server_streaming<T: Method>(
let send_compression_encodings = self.send_compression_encodings;
let max_decoding_message_size = self.max_decoding_message_size;
let max_encoding_message_size = self.max_encoding_message_size;
let auto_encoding = self.auto_encoding;
let inner = self.inner.clone();
let fut = async move {
let method = #service_ident(inner);
Expand All @@ -549,6 +565,10 @@ fn generate_server_streaming<T: Method>(
.apply_compression_config(accept_compression_encodings, send_compression_encodings)
.apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);

if auto_encoding {
grpc = grpc.auto_encoding();
}

let res = grpc.server_streaming(method, req).await;
Ok(res)
};
Expand Down Expand Up @@ -598,6 +618,7 @@ fn generate_client_streaming<T: Method>(
let send_compression_encodings = self.send_compression_encodings;
let max_decoding_message_size = self.max_decoding_message_size;
let max_encoding_message_size = self.max_encoding_message_size;
let auto_encoding = self.auto_encoding;
let inner = self.inner.clone();
let fut = async move {
let method = #service_ident(inner);
Expand All @@ -607,6 +628,10 @@ fn generate_client_streaming<T: Method>(
.apply_compression_config(accept_compression_encodings, send_compression_encodings)
.apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);

if auto_encoding {
grpc = grpc.auto_encoding();
}

let res = grpc.client_streaming(method, req).await;
Ok(res)
};
Expand Down Expand Up @@ -666,7 +691,9 @@ fn generate_streaming<T: Method>(
let send_compression_encodings = self.send_compression_encodings;
let max_decoding_message_size = self.max_decoding_message_size;
let max_encoding_message_size = self.max_encoding_message_size;
let auto_encoding = self.auto_encoding;
let inner = self.inner.clone();

let fut = async move {
let method = #service_ident(inner);
let codec = #codec_name::default();
Expand All @@ -675,6 +702,10 @@ fn generate_streaming<T: Method>(
.apply_compression_config(accept_compression_encodings, send_compression_encodings)
.apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);

if auto_encoding {
grpc = grpc.auto_encoding();
}

let res = grpc.streaming(method, req).await;
Ok(res)
};
Expand Down
Loading