diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 6edecbed6..531ca9109 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -139,7 +139,7 @@ name = "hyper-warp-multiplex-server" path = "src/hyper_warp_multiplex/server.rs" [dependencies] -tonic = { path = "../tonic", features = ["tls"] } +tonic = { path = "../tonic", features = ["tls", "gzip"] } prost = "0.6" tokio = { version = "0.2", features = ["rt-threaded", "time", "stream", "fs", "macros", "uds"] } futures = { version = "0.3", default-features = false, features = ["alloc"] } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 39b0276cc..75166e55e 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -36,6 +36,7 @@ transport = [ tls = ["transport", "tokio-rustls"] tls-roots = ["tls", "rustls-native-certs"] prost = ["prost1", "prost-derive"] +gzip = ["flate2"] # [[bench]] # name = "bench_main" @@ -48,6 +49,8 @@ futures-util = { version = "0.3", default-features = false } tracing = "0.1" http = "0.2" base64 = "0.12" +flate2 = { version = "1.0", optional = true } +once_cell = "1.0" percent-encoding = "2.0" tower-service = "0.3" @@ -89,4 +92,3 @@ rustdoc-args = ["--cfg", "docsrs"] [[bench]] name = "decode" harness = false - diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 529e69377..a4e84ff0d 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -6,24 +6,29 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tonic::{codec::DecodeBuf, codec::Decoder, Status, Streaming}; +use tonic::{codec::DecodeBuf, codec::Decoder, codec::Decompression, Status, Streaming}; macro_rules! bench { ($name:ident, $message_size:expr, $chunk_size:expr, $message_count:expr) => { + bench!($name, $message_size, $chunk_size, $message_count, None); + }; + ($name:ident, $message_size:expr, $chunk_size:expr, $message_count:expr, $encoding:expr) => { fn $name(b: &mut Bencher) { let mut rt = tokio::runtime::Builder::new() .basic_scheduler() .build() .expect("runtime"); - let payload = make_payload($message_size, $message_count); + let payload = make_payload($message_size, $message_count, $encoding); let body = MockBody::new(payload, $chunk_size); b.bytes = body.len() as u64; b.iter(|| { rt.block_on(async { let decoder = MockDecoder::new($message_size); - let mut stream = Streaming::new_request(decoder, body.clone()); + + let decompression = Decompression::from_encoding($encoding); + let mut stream = Streaming::new_request(decoder, body.clone(), decompression); let mut count = 0; while let Some(msg) = stream.message().await.unwrap() { @@ -108,15 +113,38 @@ impl Decoder for MockDecoder { } } -fn make_payload(message_length: usize, message_count: usize) -> Bytes { +fn make_payload(message_length: usize, message_count: usize, encoding: Option<&str>) -> Bytes { let mut buf = BytesMut::new(); + let raw_msg = vec![97u8; message_length]; + + let msg_buf = match encoding { + #[cfg(feature = "gzip")] + Some(encoding) if encoding == "gzip" => { + use bytes::buf::BufMutExt; + let mut reader = + flate2::read::GzEncoder::new(&raw_msg[..], flate2::Compression::best()); + let mut writer = BytesMut::new().writer(); + + std::io::copy(&mut reader, &mut writer).expect("copy"); + writer.into_inner() + } + None => { + let mut msg_buf = BytesMut::new(); + msg_buf.put(&raw_msg[..]); + msg_buf + } + Some(encoding) => panic!("Encoding {} isn't supported", encoding), + }; + for _ in 0..message_count { - let msg = vec![97u8; message_length]; - buf.reserve(msg.len() + 5); - buf.put_u8(0); - buf.put_u32(msg.len() as u32); - buf.put(&msg[..]); + buf.reserve(msg_buf.len() + 5); + buf.put_u8(match encoding { + Some(_) => 1, + None => 0, + }); + buf.put_u32(msg_buf.len() as u32); + buf.put(&msg_buf[..]); } buf.freeze() @@ -137,6 +165,21 @@ bench!(message_count_1, 500, 505, 1); bench!(message_count_10, 500, 505, 10); bench!(message_count_20, 500, 505, 20); +// gzip change body chunk size only +bench!(chunk_size_100_gzip, 1_000, 100, 1, Some("gzip")); +bench!(chunk_size_500_gzip, 1_000, 500, 1, Some("gzip")); +bench!(chunk_size_1005_gzip, 1_000, 1_005, 1, Some("gzip")); + +// gzip change message size only +bench!(message_size_1k_gzip, 1_000, 1_005, 2, Some("gzip")); +bench!(message_size_5k_gzip, 5_000, 1_005, 2, Some("gzip")); +bench!(message_size_10k_gzip, 10_000, 1_005, 2, Some("gzip")); + +// gzip change message count only +bench!(message_count_1_gzip, 500, 505, 1, Some("gzip")); +bench!(message_count_10_gzip, 500, 505, 10, Some("gzip")); +bench!(message_count_20_gzip, 500, 505, 20, Some("gzip")); + benchmark_group!(chunk_size, chunk_size_100, chunk_size_500, chunk_size_1005); benchmark_group!( @@ -153,4 +196,36 @@ benchmark_group!( message_count_20 ); +benchmark_group!( + chunk_size_gzip, + chunk_size_100_gzip, + chunk_size_500_gzip, + chunk_size_1005_gzip +); + +benchmark_group!( + message_size_gzip, + message_size_1k_gzip, + message_size_5k_gzip, + message_size_10k_gzip +); + +benchmark_group!( + message_count_gzip, + message_count_1_gzip, + message_count_10_gzip, + message_count_20_gzip +); + +#[cfg(feature = "gzip")] +benchmark_main!( + chunk_size, + message_size, + message_count, + chunk_size_gzip, + message_size_gzip, + message_count_gzip +); + +#[cfg(not(feature = "gzip"))] benchmark_main!(chunk_size, message_size, message_count); diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 6a42b4022..d1554abae 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -1,7 +1,7 @@ use crate::{ body::{Body, BoxBody}, client::GrpcService, - codec::{encode_client, Codec, Streaming}, + codec::{encode_client, Codec, Compression, Decompression, Streaming}, interceptor::Interceptor, Code, Request, Response, Status, }; @@ -159,11 +159,13 @@ impl Grpc { let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri"); + let compression = Compression::disabled(); let request = request - .map(|s| encode_client(codec.encoder(), s)) + .map(|s| encode_client(codec.encoder(), s, compression.clone())) .map(BoxBody::new); let mut request = request.into_http(uri); + compression.set_headers(request.headers_mut(), true); // Add the gRPC related HTTP headers request @@ -196,9 +198,10 @@ impl Grpc { true }; + let decompression = Decompression::from_headers(response.headers()); let response = response.map(|body| { if expect_additional_trailers { - Streaming::new_response(codec.decoder(), body, status_code) + Streaming::new_response(codec.decoder(), body, status_code, decompression) } else { Streaming::new_empty(codec.decoder(), body) } diff --git a/tonic/src/codec/compression/bufwriter.rs b/tonic/src/codec/compression/bufwriter.rs new file mode 100644 index 000000000..4c4e0cd67 --- /dev/null +++ b/tonic/src/codec/compression/bufwriter.rs @@ -0,0 +1,27 @@ +use bytes::BufMut; + +use std::{cmp, io}; + +/// A `BufMut` adapter which implements `io::Write` for the inner value. +#[derive(Debug)] +pub(crate) struct Writer<'a, B> { + buf: &'a mut B, +} + +#[cfg(feature = "gzip")] +pub(crate) fn new<'a, B>(buf: &'a mut B) -> Writer<'a, B> { + Writer { buf } +} + +impl<'a, B: BufMut + Sized> io::Write for Writer<'a, B> { + fn write(&mut self, src: &[u8]) -> io::Result { + let n = cmp::min(self.buf.remaining_mut(), src.len()); + + self.buf.put(&src[0..n]); + Ok(n) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/tonic/src/codec/compression/compression.rs b/tonic/src/codec/compression/compression.rs new file mode 100644 index 000000000..cca8a5149 --- /dev/null +++ b/tonic/src/codec/compression/compression.rs @@ -0,0 +1,97 @@ +use super::{ + compressors::{self, IDENTITY}, + errors::CompressionError, + Compressor, ACCEPT_ENCODING_HEADER, ENCODING_HEADER, +}; +use crate::metadata::MetadataMap; +use bytes::{Buf, BytesMut}; +use http::HeaderValue; +use std::fmt::Debug; +use tracing::debug; + +pub(crate) const BUFFER_SIZE: usize = 8 * 1024; + +#[derive(Clone)] +pub(crate) struct Compression { + compressor: Option<&'static Box>, +} + +impl Debug for Compression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Compression") + .field( + "compressor", + &self.compressor.map(|c| c.name()).unwrap_or(IDENTITY), + ) + .finish() + } +} + +impl Compression { + /// Create an instance of compression that doesn't compress anything + pub(crate) fn disabled() -> Compression { + Compression { compressor: None } + } + + /// Create an instance of compression from GRPC metadata + pub(crate) fn response_from_metadata(request_metadata: &MetadataMap) -> Compression { + // The following implementation is very conservative, and similar to the Golang GRPC implementation. + // Instead of looking at 'grpc-accept-encoding' and potentially compressing the response with a different + // compressor than the one used by the request it uses the same compressor + let request_compressor = request_metadata + .get(ENCODING_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(compressors::get); + + Compression { + compressor: request_compressor, + } + } + + /// Get if compression is enabled + pub(crate) fn is_enabled(&self) -> bool { + self.compressor.is_some() + } + + /// Decompress `len` bytes from `in_buffer` into `out_buffer` + pub(crate) fn compress( + &self, + in_buffer: &mut BytesMut, + out_buffer: &mut BytesMut, + len: usize, + ) -> Result<(), CompressionError> { + let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; + out_buffer.reserve(capacity); + + let compressor = self.compressor.ok_or(CompressionError::NoCompression)?; + compressor.compress(in_buffer, out_buffer, len)?; + in_buffer.advance(len); + + debug!( + "Decompressed {} bytes into {} bytes using {:?}", + len, + out_buffer.len(), + compressor.name() + ); + + Ok(()) + } + + /// Set the `grpc-encoding` header with the compressor name + pub(crate) fn set_headers(&self, headers: &mut http::HeaderMap, set_accept_encoding: bool) { + if set_accept_encoding { + headers.insert( + ACCEPT_ENCODING_HEADER, + HeaderValue::from_str(&compressors::get_accept_encoding_header()) + .expect("All encoding names should be ASCII"), + ); + } + + match self.compressor { + None => {} + Some(compressor) => { + headers.insert(ENCODING_HEADER, HeaderValue::from_static(compressor.name())); + } + } + } +} diff --git a/tonic/src/codec/compression/compressors.rs b/tonic/src/codec/compression/compressors.rs new file mode 100644 index 000000000..8da32861b --- /dev/null +++ b/tonic/src/codec/compression/compressors.rs @@ -0,0 +1,69 @@ +use bytes::BytesMut; +use once_cell::sync::Lazy; +use std::{collections::HashMap, io}; + +pub(crate) const IDENTITY: &str = "identity"; + +/// List of known compressors +static COMPRESSORS: Lazy>> = Lazy::new(|| { + #[cfg(feature = "gzip")] + { + let mut m = HashMap::new(); + + let mut add = |compressor: Box| { + m.insert(compressor.name().to_string(), compressor); + }; + + add(Box::new(super::gzip::GZipCompressor::default())); + + m + } + + #[cfg(not(feature = "gzip"))] + HashMap::new() +}); + +/// Get a compressor from it's name +pub(crate) fn get(name: impl AsRef) -> Option<&'static Box> { + COMPRESSORS.get(name.as_ref()) +} + +/// Get all the known compressors +pub(crate) fn names() -> Vec { + COMPRESSORS.keys().map(|n| n.clone()).collect() +} + +/// A compressor implement compression and decompression of GRPC frames +pub(crate) trait Compressor: Sync + Send { + /// Get the name of this compressor as present in http headers + fn name(&self) -> &'static str; + + /// Decompress `len` bytes from `in_buffer` into `out_buffer` + fn decompress( + &self, + in_buffer: &BytesMut, + out_buffer: &mut BytesMut, + len: usize, + ) -> io::Result<()>; + + /// Compress `len` bytes from `in_buffer` into `out_buffer` + fn compress( + &self, + in_buffer: &BytesMut, + out_buffer: &mut BytesMut, + len: usize, + ) -> io::Result<()>; + + /// Estimate the space necessary to decompress `compressed_len` bytes of compressed data + fn estimate_decompressed_len(&self, compressed_len: usize) -> usize { + compressed_len * 2 + } +} + +pub(crate) fn get_accept_encoding_header() -> String { + COMPRESSORS + .keys() + .map(|s| &**s) + .collect::>() + .join(",") +} diff --git a/tonic/src/codec/compression/decompression.rs b/tonic/src/codec/compression/decompression.rs new file mode 100644 index 000000000..3cb2d64a9 --- /dev/null +++ b/tonic/src/codec/compression/decompression.rs @@ -0,0 +1,94 @@ +use bytes::{Buf, BytesMut}; +use std::fmt::Debug; +use tracing::debug; + +use super::{ + compressors::{self, IDENTITY}, + Compressor, DecompressionError, ENCODING_HEADER, +}; + +const BUFFER_SIZE: usize = 8 * 1024; + +/// Information related to the decompression of a request or response +pub struct Decompression { + encoding: Option, + compressor: Option<&'static Box>, +} + +impl Debug for Decompression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let encoding = self.encoding.as_ref().map(|e| &e[..]).unwrap_or(""); + f.debug_struct("Compression") + .field("encoding", &encoding) + .field( + "compressor", + &self.compressor.map(|c| c.name()).unwrap_or(""), + ) + .finish() + } +} + +impl Decompression { + /// Create a `Decompression` structure from an encoding name + pub fn from_encoding(encoding: Option<&str>) -> Decompression { + let compressor = encoding.and_then(compressors::get); + + Decompression { + encoding: encoding.map(|v| v.to_string()), + compressor, + } + } + + /// Create a `Decompression` structure from http headers + pub fn from_headers(metadata: &http::HeaderMap) -> Decompression { + let encoding = metadata + .get(ENCODING_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(|v| if v == IDENTITY { None } else { Some(v) }); + + Decompression::from_encoding(encoding) + } + + /// Decompress `len` bytes from `in_buffer` into `out_buffer` + pub fn decompress( + &self, + in_buffer: &mut BytesMut, + out_buffer: &mut BytesMut, + len: usize, + ) -> Result<(), DecompressionError> { + let compressor = self.compressor.ok_or_else(|| { + match &self.encoding { + // Asked to decompress but not compression was specified + None => DecompressionError::NoCompression, + // Asked to decompress but the decompressor wasn't found + Some(encoding) => DecompressionError::NotFound { + requested: encoding.clone(), + known: compressors::names(), + }, + } + })?; + + let capacity = + ((compressor.estimate_decompressed_len(len) / BUFFER_SIZE) + 1) * BUFFER_SIZE; + out_buffer.reserve(capacity); + compressor.decompress(in_buffer, out_buffer, len)?; + in_buffer.advance(len); + + debug!( + "Decompressed {} bytes into {} bytes using {:?}", + len, + out_buffer.len(), + compressor.name() + ); + Ok(()) + } +} + +impl Default for Decompression { + fn default() -> Self { + Decompression { + encoding: None, + compressor: None, + } + } +} diff --git a/tonic/src/codec/compression/errors.rs b/tonic/src/codec/compression/errors.rs new file mode 100644 index 000000000..93203fdd1 --- /dev/null +++ b/tonic/src/codec/compression/errors.rs @@ -0,0 +1,76 @@ +#[derive(Debug)] +pub enum DecompressionError { + NotFound { + requested: String, + known: Vec, + }, + NoCompression, + Failed(std::io::Error), +} + +impl std::fmt::Display for DecompressionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self { + DecompressionError::NotFound { requested, known } => { + let known_joined = known.join(", "); + write!( + f, + "Compressor for '{}' not found. Known compressors: {}", + requested, known_joined + ) + } + DecompressionError::Failed(error) => write!(f, "Decompression failed: {}", error), + DecompressionError::NoCompression => { + write!(f, "Compressed flag set with identity or empty encoding") + } + } + } +} + +impl std::error::Error for DecompressionError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self { + DecompressionError::NoCompression => None, + DecompressionError::NotFound { .. } => None, + DecompressionError::Failed(err) => Some(err), + } + } +} + +impl From for DecompressionError { + fn from(error: std::io::Error) -> Self { + DecompressionError::Failed(error) + } +} + +#[derive(Debug)] +pub(crate) enum CompressionError { + NoCompression, + Failed(std::io::Error), +} + +impl std::fmt::Display for CompressionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self { + CompressionError::Failed(error) => write!(f, "Compression failed: {}", error), + CompressionError::NoCompression => { + write!(f, "Compression attempted without being configured") + } + } + } +} + +impl std::error::Error for CompressionError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self { + CompressionError::NoCompression { .. } => None, + CompressionError::Failed(err) => Some(err), + } + } +} + +impl From for CompressionError { + fn from(error: std::io::Error) -> Self { + CompressionError::Failed(error) + } +} diff --git a/tonic/src/codec/compression/gzip.rs b/tonic/src/codec/compression/gzip.rs new file mode 100644 index 000000000..fa671c1de --- /dev/null +++ b/tonic/src/codec/compression/gzip.rs @@ -0,0 +1,57 @@ +use std::io; + +use super::{bufwriter, Compressor}; +use bytes::BytesMut; +use flate2::read::{GzDecoder, GzEncoder}; + +/// Compress using GZIP +#[derive(Debug)] +pub(crate) struct GZipCompressor { + compression_level: flate2::Compression, +} + +impl GZipCompressor { + fn new(compression_level: flate2::Compression) -> GZipCompressor { + GZipCompressor { compression_level } + } +} + +impl Default for GZipCompressor { + fn default() -> Self { + Self::new(flate2::Compression::new(6)) + } +} + +impl Compressor for GZipCompressor { + fn name(&self) -> &'static str { + "gzip" + } + + fn decompress( + &self, + in_buffer: &BytesMut, + out_buffer: &mut BytesMut, + len: usize, + ) -> io::Result<()> { + let mut gzip_decoder = GzDecoder::new(&in_buffer[0..len]); + let mut out_writer = bufwriter::new(out_buffer); + + std::io::copy(&mut gzip_decoder, &mut out_writer)?; + + Ok(()) + } + + fn compress( + &self, + in_buffer: &BytesMut, + out_buffer: &mut BytesMut, + len: usize, + ) -> io::Result<()> { + let mut gzip_decoder = GzEncoder::new(&in_buffer[0..len], self.compression_level); + let mut out_writer = bufwriter::new(out_buffer); + + std::io::copy(&mut gzip_decoder, &mut out_writer)?; + + Ok(()) + } +} diff --git a/tonic/src/codec/compression/mod.rs b/tonic/src/codec/compression/mod.rs new file mode 100644 index 000000000..dc41f7402 --- /dev/null +++ b/tonic/src/codec/compression/mod.rs @@ -0,0 +1,19 @@ +mod bufwriter; +mod compression; +mod compressors; +mod decompression; +mod errors; + +pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; +pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; + +#[cfg(feature = "gzip")] +mod gzip; + +pub(crate) use self::compressors::Compressor; + +#[doc(hidden)] +pub use self::decompression::Decompression; +pub(crate) use self::errors::DecompressionError; + +pub(crate) use self::compression::Compression; diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 157a27093..80d4eb645 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,4 +1,4 @@ -use super::{DecodeBuf, Decoder}; +use super::{DecodeBuf, Decoder, Decompression}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use futures_core::Stream; @@ -24,6 +24,8 @@ pub struct Streaming { state: State, direction: Direction, buf: BytesMut, + decompression: Decompression, + decompress_buf: BytesMut, trailers: Option, } @@ -43,13 +45,23 @@ enum Direction { } impl Streaming { - pub(crate) fn new_response(decoder: D, body: B, status_code: StatusCode) -> Self + pub(crate) fn new_response( + decoder: D, + body: B, + status_code: StatusCode, + decompression: Decompression, + ) -> Self where B: Body + Send + Sync + 'static, B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::Response(status_code)) + Self::new( + decoder, + body, + Direction::Response(status_code), + decompression, + ) } pub(crate) fn new_empty(decoder: D, body: B) -> Self @@ -58,20 +70,25 @@ impl Streaming { B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::EmptyResponse) + Self::new( + decoder, + body, + Direction::EmptyResponse, + Decompression::default(), + ) } #[doc(hidden)] - pub fn new_request(decoder: D, body: B) -> Self + pub fn new_request(decoder: D, body: B, decompression: Decompression) -> Self where B: Body + Send + Sync + 'static, B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::Request) + Self::new(decoder, body, Direction::Request, decompression) } - fn new(decoder: D, body: B, direction: Direction) -> Self + fn new(decoder: D, body: B, direction: Direction, decompression: Decompression) -> Self where B: Body + Send + Sync + 'static, B::Error: Into, @@ -83,6 +100,8 @@ impl Streaming { state: State::ReadHeader, direction, buf: BytesMut::with_capacity(BUFFER_SIZE), + decompress_buf: BytesMut::new(), + decompression, trailers: None, } } @@ -159,13 +178,7 @@ impl Streaming { let is_compressed = match self.buf.get_u8() { 0 => false, - 1 => { - trace!("message compressed, compression not supported yet"); - return Err(Status::new( - Code::Unimplemented, - "Message compressed, compression not supported yet.".to_string(), - )); - } + 1 => true, f => { trace!("unexpected compression flag"); let message = if let Direction::Response(status) = self.direction { @@ -187,17 +200,40 @@ impl Streaming { } } - if let State::ReadBody { len, .. } = &self.state { + if let State::ReadBody { len, compression } = &self.state { // if we haven't read enough of the message then return and keep // reading if self.buf.remaining() < *len || self.buf.len() < *len { return Ok(None); } - return match self - .decoder - .decode(&mut DecodeBuf::new(&mut self.buf, *len)) - { + let decode_result = if *compression { + if let Err(err) = + self.decompression + .decompress(&mut self.buf, &mut self.decompress_buf, *len) + { + trace!(error = ?err, "Error decompressing: {}", err); + let message = if let Direction::Response(status) = self.direction { + format!( + "Error decompressing: {}, while receiving response with status: {}", + err, status + ) + } else { + format!("Error decompressing: {}, while sending request", err) + }; + return Err(Status::new(Code::Internal, message)); + } + let uncompressed_len = self.decompress_buf.len(); + self.decoder.decode(&mut DecodeBuf::new( + &mut self.decompress_buf, + uncompressed_len, + )) + } else { + self.decoder + .decode(&mut DecodeBuf::new(&mut self.buf, *len)) + }; + + return match decode_result { Ok(Some(msg)) => { self.state = State::ReadHeader; Ok(Some(msg)) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 58f4f99da..b698adcb4 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,4 +1,4 @@ -use super::{EncodeBuf, Encoder}; +use super::{compression::Compression, EncodeBuf, Encoder}; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use futures_core::{Stream, TryStream}; @@ -16,36 +16,46 @@ const BUFFER_SIZE: usize = 8 * 1024; pub(crate) fn encode_server( encoder: T, source: U, + compression: Compression, ) -> EncodeBody>> where T: Encoder + Send + Sync + 'static, T::Item: Send + Sync, U: Stream> + Send + Sync + 'static, { - let stream = encode(encoder, source).into_stream(); + let stream = encode(encoder, source, compression).into_stream(); EncodeBody::new_server(stream) } pub(crate) fn encode_client( encoder: T, source: U, + compression: Compression, ) -> EncodeBody>> where T: Encoder + Send + Sync + 'static, T::Item: Send + Sync, U: Stream + Send + Sync + 'static, { - let stream = encode(encoder, source.map(Ok)).into_stream(); + let stream = encode(encoder, source.map(Ok), compression).into_stream(); EncodeBody::new_client(stream) } -fn encode(mut encoder: T, source: U) -> impl TryStream +fn encode( + mut encoder: T, + source: U, + compression: Compression, +) -> impl TryStream where T: Encoder, U: Stream>, { + let compression_enabled = compression.is_enabled(); + let compressed_u8 = if compression_enabled { 1 } else { 0 }; + async_stream::stream! { let mut buf = BytesMut::with_capacity(BUFFER_SIZE); + let mut compress_buf = if compression_enabled { BytesMut::with_capacity(BUFFER_SIZE) } else { BytesMut::new() }; futures_util::pin_mut!(source); loop { @@ -55,14 +65,21 @@ where unsafe { buf.advance_mut(5); } - encoder.encode(item, &mut EncodeBuf::new(&mut buf)).map_err(drop).unwrap(); + if compression_enabled { + compress_buf.clear(); + encoder.encode(item, &mut EncodeBuf::new(&mut compress_buf)).map_err(drop).unwrap(); + let compressed_len = compress_buf.len(); + compression.compress(&mut compress_buf, &mut buf, compressed_len).map_err(drop).unwrap(); + } else { + encoder.encode(item, &mut EncodeBuf::new(&mut buf)).map_err(drop).unwrap(); + } // now that we know length, we can write the header let len = buf.len() - 5; assert!(len <= std::u32::MAX as usize); { let mut buf = &mut buf[..5]; - buf.put_u8(0); // byte must be 0, reserve doesn't auto-zero + buf.put_u8(compressed_u8); buf.put_u32(len as u32); } diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index e100556c3..3daa59c3a 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -4,6 +4,7 @@ //! and a protobuf codec based on prost. mod buffer; +mod compression; mod decode; mod encode; #[cfg(feature = "prost")] @@ -11,6 +12,9 @@ mod prost; use std::io; +pub(crate) use self::compression::Compression; +#[doc(hidden)] +pub use self::compression::Decompression; pub use self::decode::Streaming; pub(crate) use self::encode::{encode_client, encode_server}; #[cfg(feature = "prost")] diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 2135b8211..0c7aad835 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -121,7 +121,7 @@ mod tests { let messages = std::iter::repeat(Ok::<_, Status>(msg)).take(10000); let source = futures_util::stream::iter(messages); - let body = encode_server(encoder, source); + let body = encode_server(encoder, source, Compression::disabled()); futures_util::pin_mut!(body); diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index 7852d7fe7..df2294ed6 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -1,6 +1,6 @@ use crate::{ body::BoxBody, - codec::{encode_server, Codec, Streaming}, + codec::{encode_server, Codec, Compression, Decompression, Streaming}, interceptor::Interceptor, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, Code, Request, Status, @@ -71,20 +71,22 @@ where Ok(r) => r, Err(status) => { return self - .map_response::>>>(Err( - status, - )); + .map_response::>>>( + Err(status), + Compression::disabled(), + ); } }; let request = t!(self.intercept_request(request)); + let compression = Compression::response_from_metadata(request.metadata()); let response = service .call(request) .await .map(|r| r.map(|m| stream::once(future::ok(m)))); - self.map_response(response) + self.map_response(response, compression) } /// Handle a server side streaming request. @@ -102,15 +104,16 @@ where let request = match self.map_request_unary(req).await { Ok(r) => r, Err(status) => { - return self.map_response::(Err(status)); + return self + .map_response::(Err(status), Compression::disabled()); } }; let request = t!(self.intercept_request(request)); - + let compression = Compression::response_from_metadata(request.metadata()); let response = service.call(request).await; - self.map_response(response) + self.map_response(response, compression) } /// Handle a client side streaming gRPC request. @@ -126,11 +129,12 @@ where { let request = self.map_request_streaming(req); let request = t!(self.intercept_request(request)); + let compression = Compression::response_from_metadata(request.metadata()); let response = service .call(request) .await .map(|r| r.map(|m| stream::once(future::ok(m)))); - self.map_response(response) + self.map_response(response, compression) } /// Handle a bi-directional streaming gRPC request. @@ -147,8 +151,9 @@ where { let request = self.map_request_streaming(req); let request = t!(self.intercept_request(request)); + let compression = Compression::response_from_metadata(request.metadata()); let response = service.call(request).await; - self.map_response(response) + self.map_response(response, compression) } async fn map_request_unary( @@ -160,7 +165,8 @@ where B::Error: Into + Send, { let (parts, body) = request.into_parts(); - let stream = Streaming::new_request(self.codec.decoder(), body); + let decompression = Decompression::from_headers(&parts.headers); + let stream = Streaming::new_request(self.codec.decoder(), body, decompression); futures_util::pin_mut!(stream); @@ -186,12 +192,16 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - Request::from_http(request.map(|body| Streaming::new_request(self.codec.decoder(), body))) + let decompression = Decompression::from_headers(request.headers()); + Request::from_http( + request.map(|body| Streaming::new_request(self.codec.decoder(), body, decompression)), + ) } fn map_response( &mut self, response: Result, Status>, + compression: Compression, ) -> http::Response where B: TryStream + Send + Sync + 'static, @@ -205,8 +215,9 @@ where http::header::CONTENT_TYPE, http::header::HeaderValue::from_static("application/grpc"), ); + compression.set_headers(&mut parts.headers, false); - let body = encode_server(self.codec.encoder(), body.into_stream()); + let body = encode_server(self.codec.encoder(), body.into_stream(), compression); http::Response::from_parts(parts, BoxBody::new(body)) }