Skip to content

Commit

Permalink
feat(codec): Introduce Decoder/Encoder traits (#208)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Add new `Decoder/Encoder` traits and use `EncodeBuf/DecodeBuf` over `BytesMut` directly.
  • Loading branch information
alce authored and LucioFranco committed Jan 13, 2020
1 parent a41f55a commit 0fa2bf1
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 39 deletions.
23 changes: 10 additions & 13 deletions tonic/benches/decode.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
extern crate bencher;

use std::fmt::{Error, Formatter};
use bencher::{benchmark_group, benchmark_main, Bencher};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use http_body::Body;
use std::{
fmt::{Error, Formatter},
pin::Pin,
task::{Context, Poll},
};

use bencher::{benchmark_group, benchmark_main, Bencher};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use http_body::Body;
use tokio_util::codec::Decoder;
use tonic::{Status, Streaming};
use tonic::{codec::DecodeBuf, codec::Decoder, Status, Streaming};

macro_rules! bench {
($name:ident, $message_size:expr, $chunk_size:expr, $message_count:expr) => {
Expand Down Expand Up @@ -102,12 +98,13 @@ impl MockDecoder {
}

impl Decoder for MockDecoder {
type Item = Bytes;
type Item = Vec<u8>;
type Error = Status;

fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let item = buf.split_to(self.message_size).freeze();
Ok(Some(item))
fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
let out = Vec::from(buf.bytes());
buf.advance(self.message_size);
Ok(Some(out))
}
}

Expand Down
121 changes: 121 additions & 0 deletions tonic/src/codec/buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use bytes::{Buf, BufMut, BytesMut};
use std::mem::MaybeUninit;

/// A specialized buffer to decode gRPC messages from.
#[derive(Debug)]
pub struct DecodeBuf<'a> {
buf: &'a mut BytesMut,
len: usize,
}

/// A specialized buffer to encode gRPC messages into.
#[derive(Debug)]
pub struct EncodeBuf<'a> {
buf: &'a mut BytesMut,
}

impl<'a> DecodeBuf<'a> {
pub(crate) fn new(buf: &'a mut BytesMut, len: usize) -> Self {
DecodeBuf { buf, len }
}
}

impl Buf for DecodeBuf<'_> {
#[inline]
fn remaining(&self) -> usize {
self.len
}

#[inline]
fn bytes(&self) -> &[u8] {
let ret = self.buf.bytes();

if ret.len() > self.len {
&ret[..self.len]
} else {
ret
}
}

#[inline]
fn advance(&mut self, cnt: usize) {
assert!(cnt <= self.len);
self.buf.advance(cnt);
self.len -= cnt;
}
}

impl<'a> EncodeBuf<'a> {
pub(crate) fn new(buf: &'a mut BytesMut) -> Self {
EncodeBuf { buf }
}
}

impl EncodeBuf<'_> {
/// Reserves capacity for at least `additional` more bytes to be inserted
/// into the buffer.
///
/// More than `additional` bytes may be reserved in order to avoid frequent
/// reallocations. A call to `reserve` may result in an allocation.
#[inline]
pub fn reserve(&mut self, additional: usize) {
self.buf.reserve(additional);
}
}

impl BufMut for EncodeBuf<'_> {
#[inline]
fn remaining_mut(&self) -> usize {
self.buf.remaining_mut()
}

#[inline]
unsafe fn advance_mut(&mut self, cnt: usize) {
self.buf.advance_mut(cnt)
}

#[inline]
fn bytes_mut(&mut self) -> &mut [MaybeUninit<u8>] {
self.buf.bytes_mut()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn decode_buf() {
let mut payload = BytesMut::with_capacity(100);
payload.put(&vec![0u8; 50][..]);
let mut buf = DecodeBuf::new(&mut payload, 20);

assert_eq!(buf.len, 20);
assert_eq!(buf.remaining(), 20);
assert_eq!(buf.bytes().len(), 20);

buf.advance(10);
assert_eq!(buf.remaining(), 10);

let mut out = [0; 5];
buf.copy_to_slice(&mut out);
assert_eq!(buf.remaining(), 5);
assert_eq!(buf.bytes().len(), 5);

assert_eq!(buf.to_bytes().len(), 5);
assert!(!buf.has_remaining());
}

#[test]
fn encode_buf() {
let mut bytes = BytesMut::with_capacity(100);
let mut buf = EncodeBuf::new(&mut bytes);

let initial = buf.remaining_mut();
unsafe { buf.advance_mut(20) };
assert_eq!(buf.remaining_mut(), initial - 20);

buf.put_u8(b'a');
assert_eq!(buf.remaining_mut(), initial - 20 - 1);
}
}
24 changes: 13 additions & 11 deletions tonic/src/codec/decode.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Decoder;
use super::{DecodeBuf, Decoder};
use crate::{body::BoxBody, metadata::MetadataMap, Code, Status};
use bytes::{Buf, BufMut, BytesMut};
use futures_core::Stream;
Expand Down Expand Up @@ -91,10 +91,11 @@ impl<T> Streaming<T> {
impl<T> Streaming<T> {
/// Fetch the next message from this stream.
/// ```rust
/// # use tonic::{Streaming, Status};
/// # use tonic::{Streaming, Status, codec::Decoder};
/// # use std::fmt::Debug;
/// # async fn next_message_ex<T>(mut request: Streaming<T>) -> Result<(), Status>
/// # where T: Debug
/// # async fn next_message_ex<T, D>(mut request: Streaming<T>) -> Result<(), Status>
/// # where T: Debug,
/// # D: Decoder<Item = T, Error = Status> + Send + Sync + 'static,
/// # {
/// if let Some(next_message) = request.message().await? {
/// println!("{:?}", next_message);
Expand Down Expand Up @@ -188,16 +189,17 @@ impl<T> Streaming<T> {
return Ok(None);
}

match self.decoder.decode(&mut self.buf) {
return match self
.decoder
.decode(&mut DecodeBuf::new(&mut self.buf, *len))
{
Ok(Some(msg)) => {
self.state = State::ReadHeader;
return Ok(Some(msg));
}
Ok(None) => return Ok(None),
Err(e) => {
return Err(e);
Ok(Some(msg))
}
}
Ok(None) => Ok(None),
Err(e) => Err(e),
};
}

Ok(None)
Expand Down
10 changes: 6 additions & 4 deletions tonic/src/codec/encode.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use super::{EncodeBuf, Encoder};
use crate::{Code, Status};
use bytes::{BufMut, Bytes, BytesMut};
use futures_core::{Stream, TryStream};
use futures_util::{ready, StreamExt, TryStreamExt};
use http::HeaderMap;
use http_body::Body;
use pin_project::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_util::codec::Encoder;
use std::{
pin::Pin,
task::{Context, Poll},
};

const BUFFER_SIZE: usize = 8 * 1024;

Expand Down Expand Up @@ -53,7 +55,7 @@ where
unsafe {
buf.advance_mut(5);
}
encoder.encode(item, &mut buf).map_err(drop).unwrap();
encoder.encode(item, &mut EncodeBuf::new(&mut buf)).map_err(drop).unwrap();

// now that we know length, we can write the header
let len = buf.len() - 5;
Expand Down
40 changes: 36 additions & 4 deletions tonic/src/codec/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Generic encoding and decoding.
//!
//! This module contains the generic `Codec` trait and a protobuf codec
//! based on prost.
//! This module contains the generic `Codec`, `Encoder` and `Decoder` traits
//! and a protobuf codec based on prost.
mod buffer;
mod decode;
mod encode;
#[cfg(feature = "prost")]
Expand All @@ -11,14 +12,15 @@ mod prost;
#[cfg(test)]
mod tests;

use std::io;

pub use self::decode::Streaming;
pub(crate) use self::encode::{encode_client, encode_server};
#[cfg(feature = "prost")]
#[cfg_attr(docsrs, doc(cfg(feature = "prost")))]
pub use self::prost::ProstCodec;
pub use tokio_util::codec::{Decoder, Encoder};

use crate::Status;
pub use buffer::{DecodeBuf, EncodeBuf};

/// Trait that knows how to encode and decode gRPC messages.
pub trait Codec: Default {
Expand All @@ -37,3 +39,33 @@ pub trait Codec: Default {
/// Fetch the decoder.
fn decoder(&mut self) -> Self::Decoder;
}

/// Encodes gRPC message types
pub trait Encoder {
/// The type that is encoded.
type Item;

/// The type of encoding errors.
///
/// The type of unrecoverable frame encoding errors.
type Error: From<io::Error>;

/// Encodes a message into the provided buffer.
fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error>;
}

/// Decodes gRPC message types
pub trait Decoder {
/// The type that is decoded.
type Item;

/// The type of unrecoverable frame decoding errors.
type Error: From<io::Error>;

/// Decode a message from the buffer.
///
/// The buffer will contain exactly the bytes of a full message. There
/// is no need to get the length from the bytes, gRPC framing is handled
/// for you.
fn decode(&mut self, src: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>;
}
8 changes: 4 additions & 4 deletions tonic/src/codec/prost.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{Codec, Decoder, Encoder};
use super::{Codec, DecodeBuf, Decoder, Encoder};
use crate::codec::EncodeBuf;
use crate::{Code, Status};
use bytes::BytesMut;
use prost::Message;
use std::marker::PhantomData;

Expand Down Expand Up @@ -44,7 +44,7 @@ impl<T: Message> Encoder for ProstEncoder<T> {
type Item = T;
type Error = Status;

fn encode(&mut self, item: Self::Item, buf: &mut BytesMut) -> Result<(), Self::Error> {
fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
item.encode(buf)
.expect("Message only errors if not enough space");

Expand All @@ -60,7 +60,7 @@ impl<U: Message + Default> Decoder for ProstDecoder<U> {
type Item = U;
type Error = Status;

fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
let item = Message::decode(buf)
.map(Option::Some)
.map_err(from_decode_error)?;
Expand Down
8 changes: 5 additions & 3 deletions tonic/src/codec/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::{encode_server, Decoder, Encoder, Streaming};
use crate::codec::buffer::DecodeBuf;
use crate::codec::EncodeBuf;
use crate::Status;
use bytes::{Buf, BufMut, BytesMut};
use http_body::Body;
Expand Down Expand Up @@ -56,7 +58,7 @@ impl Encoder for MockEncoder {
type Item = Vec<u8>;
type Error = Status;

fn encode(&mut self, item: Self::Item, buf: &mut BytesMut) -> Result<(), Self::Error> {
fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
buf.put(&item[..]);
Ok(())
}
Expand All @@ -69,8 +71,8 @@ impl Decoder for MockDecoder {
type Item = Vec<u8>;
type Error = Status;

fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let out = Vec::from(&buf[..LEN]);
fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
let out = Vec::from(buf.bytes());
buf.advance(LEN);
Ok(Some(out))
}
Expand Down

0 comments on commit 0fa2bf1

Please sign in to comment.