diff --git a/pilota-build/src/codegen/thrift/ty.rs b/pilota-build/src/codegen/thrift/ty.rs index 5fb9ba36..672f1bdc 100644 --- a/pilota-build/src/codegen/thrift/ty.rs +++ b/pilota-build/src/codegen/thrift/ty.rs @@ -349,12 +349,13 @@ impl ThriftBackend { let read_el = self.codegen_decode_ty(helper, ty); format! { r#" - {{ + unsafe {{ let list_ident = {read_list_begin}; let mut val = Vec::with_capacity(list_ident.size); - for _ in 0..list_ident.size {{ - val.push({read_el}); + for i in 0..list_ident.size {{ + *val.get_unchecked_mut(i) = {read_el}; }}; + val.set_len(list_ident.size); {read_list_end}; val }} diff --git a/pilota-build/test_data/thrift/wrapper_arc.rs b/pilota-build/test_data/thrift/wrapper_arc.rs index aee6ec4f..c6b24a66 100644 --- a/pilota-build/test_data/thrift/wrapper_arc.rs +++ b/pilota-build/test_data/thrift/wrapper_arc.rs @@ -616,22 +616,24 @@ pub mod wrapper_arc { id = Some(protocol.read_faststr()?); } Some(2) if field_ident.field_type == ::pilota::thrift::TType::List => { - name2 = Some({ + name2 = Some(unsafe { let list_ident = protocol.read_list_begin()?; let mut val = Vec::with_capacity(list_ident.size); - for _ in 0..list_ident.size { - val.push({ + for i in 0..list_ident.size { + *val.get_unchecked_mut(i) = unsafe { let list_ident = protocol.read_list_begin()?; let mut val = Vec::with_capacity(list_ident.size); - for _ in 0..list_ident.size { - val.push(::std::sync::Arc::new( + for i in 0..list_ident.size { + *val.get_unchecked_mut(i) = ::std::sync::Arc::new( ::pilota::thrift::Message::decode(protocol)?, - )); + ); } + val.set_len(list_ident.size); protocol.read_list_end()?; val - }); + }; } + val.set_len(list_ident.size); protocol.read_list_end()?; val }); @@ -642,14 +644,15 @@ pub mod wrapper_arc { let mut val = ::std::collections::HashMap::with_capacity(map_ident.size); for _ in 0..map_ident.size { - val.insert(protocol.read_i32()?, { + val.insert(protocol.read_i32()?, unsafe { let list_ident = protocol.read_list_begin()?; let mut val = Vec::with_capacity(list_ident.size); - for _ in 0..list_ident.size { - val.push(::std::sync::Arc::new( + for i in 0..list_ident.size { + *val.get_unchecked_mut(i) = ::std::sync::Arc::new( ::pilota::thrift::Message::decode(protocol)?, - )); + ); } + val.set_len(list_ident.size); protocol.read_list_end()?; val }); @@ -736,25 +739,27 @@ pub mod wrapper_arc { id = Some(protocol.read_faststr().await?); } Some(2) if field_ident.field_type == ::pilota::thrift::TType::List => { - name2 = Some({ + name2 = Some(unsafe { let list_ident = protocol.read_list_begin().await?; let mut val = Vec::with_capacity(list_ident.size); - for _ in 0..list_ident.size { - val.push({ + for i in 0..list_ident.size { + *val.get_unchecked_mut(i) = unsafe { let list_ident = protocol.read_list_begin().await?; let mut val = Vec::with_capacity(list_ident.size); - for _ in 0..list_ident.size { - val.push(::std::sync::Arc::new( + for i in 0..list_ident.size { + *val.get_unchecked_mut(i) = ::std::sync::Arc::new( ::pilota::thrift::Message::decode_async( protocol, ) .await?, - )); + ); } + val.set_len(list_ident.size); protocol.read_list_end().await?; val - }); + }; } + val.set_len(list_ident.size); protocol.read_list_end().await?; val }); @@ -765,17 +770,18 @@ pub mod wrapper_arc { let mut val = ::std::collections::HashMap::with_capacity(map_ident.size); for _ in 0..map_ident.size { - val.insert(protocol.read_i32().await?, { + val.insert(protocol.read_i32().await?, unsafe { let list_ident = protocol.read_list_begin().await?; let mut val = Vec::with_capacity(list_ident.size); - for _ in 0..list_ident.size { - val.push(::std::sync::Arc::new( + for i in 0..list_ident.size { + *val.get_unchecked_mut(i) = ::std::sync::Arc::new( ::pilota::thrift::Message::decode_async( protocol, ) .await?, - )); + ); } + val.set_len(list_ident.size); protocol.read_list_end().await?; val }); diff --git a/pilota/Cargo.toml b/pilota/Cargo.toml index f31d51b0..5b4dce28 100644 --- a/pilota/Cargo.toml +++ b/pilota/Cargo.toml @@ -44,3 +44,7 @@ rand = "0.8" [[bench]] name = "faststr" harness = false + +[[bench]] +name = "thrift_binary" +harness = false diff --git a/pilota/benches/thrift_binary.rs b/pilota/benches/thrift_binary.rs new file mode 100644 index 00000000..dc95a0fa --- /dev/null +++ b/pilota/benches/thrift_binary.rs @@ -0,0 +1,262 @@ +#![allow(clippy::redundant_clone)] + +use bytes::BytesMut; +use criterion::{black_box, criterion_group, criterion_main}; +use pilota::thrift::{TInputProtocol, TOutputProtocol}; +use rand::{self, Rng}; + +fn binary_bench(c: &mut criterion::Criterion) { + let size = std::env::var("SIZE") + .unwrap_or("10000".to_string()) + .parse() + .unwrap(); + let mut group = c.benchmark_group("Bench Thrift Binary"); + let mut v: Vec = Vec::with_capacity(size); + for _ in 0..size { + v.push(rand::thread_rng().gen()); + } + let mut buf = BytesMut::new(); + + let mut p = pilota::thrift::binary::TBinaryProtocol::new(&mut buf, true); + for i in &v { + p.write_i64(*i).unwrap(); + } + drop(p); + assert_eq!(buf.len(), 8 * size); + + let mut buf_le = BytesMut::new(); + + let mut p = pilota::thrift::binary_le::TBinaryProtocol::new(&mut buf_le, true); + for i in &v { + p.write_i64(*i).unwrap(); + } + drop(p); + assert_eq!(buf_le.len(), 8 * size); + + let b = buf_le.clone(); + let mut v2: Vec = Vec::with_capacity(size); + let src = b.as_ptr(); + let dst = v2.as_mut_ptr(); + unsafe { + std::ptr::copy_nonoverlapping(src, dst as *mut u8, size * 8); + v2.set_len(size); + } + assert_eq!(v, v2); + + group.bench_function("big endian decode vec i64", |b| { + b.iter(|| { + black_box({ + let b = buf.clone(); + black_box(read_be(b, size)); + }); + }) + }); + + group.bench_function("big endian decode vec i64 unsafe", |b| { + b.iter(|| { + black_box({ + let b = buf.clone(); + black_box(read_be_unsafe(b, size)); + }); + }) + }); + + group.bench_function("big endian decode vec i64 unsafe vec", |b| { + b.iter(|| { + black_box({ + let b = buf.clone(); + black_box(read_be_unsafe_vec(b, size)); + }); + }) + }); + + group.bench_function("big endian decode vec i64 unsafe optimized", |b| { + b.iter(|| { + black_box({ + let b = buf.clone(); + black_box(read_be_unsafe_optimized(b, size)); + }); + }) + }); + + group.bench_function("big endian encode vec i64", |b| { + b.iter(|| { + black_box({ + let mut b = BytesMut::with_capacity(8 * size); + black_box(write_be(&mut b, &v, size)); + }); + }) + }); + + group.bench_function("big endian encode vec i64 unsafe", |b| { + b.iter(|| { + black_box({ + let mut b = BytesMut::with_capacity(8 * size); + black_box(write_be_unsafe(&mut b, &v, size)); + }); + }) + }); + + group.bench_function("little endian decode vec i64", |b| { + b.iter(|| { + black_box({ + let b = buf_le.clone(); + black_box(read_le(b, size)); + }); + }) + }); + group.bench_function("little endian decode vec i64 unsafe optimized", |b| { + b.iter(|| { + black_box({ + let b = buf_le.clone(); + black_box(read_le_unsafe_optimized(b, size)); + }); + }) + }); + group.bench_function("little endian decode vec i64 optimized", |b| { + b.iter(|| { + black_box({ + let b = buf_le.clone(); + black_box(read_le_optimized(b, size)); + }); + }) + }); + + group.bench_function("alloc vec", |b| { + b.iter(|| { + let mut b = buf_le.clone(); + let _p = pilota::thrift::binary_le::TBinaryProtocol::new(&mut b, true); + let _: Vec = black_box(Vec::with_capacity(size)); + }) + }); + + group.finish(); +} + +#[inline(never)] +fn read_be(mut b: BytesMut, size: usize) -> Vec { + let mut p = pilota::thrift::binary::TBinaryProtocol::new(&mut b, true); + let mut v = Vec::with_capacity(size); + for _ in 0..size { + v.push(p.read_i64().unwrap()); + } + v +} + +#[inline(never)] +fn read_be_unsafe(mut b: BytesMut, size: usize) -> Vec { + unsafe { + let s = std::slice::from_raw_parts_mut(b.as_mut_ptr(), b.len()); + let mut p = pilota::thrift::binary_unsafe::TBinaryProtocol::new(&mut b, s, true); + let mut v = Vec::with_capacity(size); + for _ in 0..size { + v.push(p.read_i64().unwrap()); + } + v + } +} + +#[inline(never)] +fn read_be_unsafe_vec(mut b: BytesMut, size: usize) -> Vec { + unsafe { + let s = std::slice::from_raw_parts_mut(b.as_mut_ptr(), b.len()); + let mut p = pilota::thrift::binary_unsafe::TBinaryProtocol::new(&mut b, s, true); + let mut v = Vec::with_capacity(size); + for i in 0..size { + *v.get_unchecked_mut(i) = p.read_i64().unwrap(); + } + v + } +} + +#[inline(never)] +fn read_be_unsafe_optimized(b: BytesMut, size: usize) -> Vec { + unsafe { + let buf: &[u8] = b.as_ref(); + assert!(buf.len() >= size * 8); + let mut index = 0; + + let mut v = Vec::with_capacity(size); + for i in 0..size { + *v.get_unchecked_mut(i) = i64::from_be_bytes( + buf.get_unchecked(index..index + 8) + .try_into() + .unwrap_unchecked(), + ); + index += 8; + } + v.set_len(size); + v + } +} + +#[inline(never)] +fn write_be(b: &mut BytesMut, v: &Vec, size: usize) { + let mut p = pilota::thrift::binary::TBinaryProtocol::new(b, true); + for el in v { + p.write_i64(*el).unwrap(); + } +} + +#[inline(never)] +fn write_be_unsafe(b: &mut BytesMut, v: &Vec, size: usize) { + unsafe { + let s = std::slice::from_raw_parts_mut(b.as_mut_ptr(), b.len()); + let mut p = pilota::thrift::binary_unsafe::TBinaryProtocol::new(b, s, true); + for el in v { + p.write_i64(*el).unwrap(); + } + } +} + +#[inline(never)] +fn read_le(mut b: BytesMut, size: usize) -> Vec { + let mut p = pilota::thrift::binary_le::TBinaryProtocol::new(&mut b, true); + + let mut v = Vec::with_capacity(size); + for _ in 0..size { + v.push(p.read_i64().unwrap()); + } + v +} + +// cargo asm -p pilota --bench thrift_binary --native --full-name --keep-labels +// --simplify --rust +#[inline(never)] +fn read_le_unsafe_optimized(b: BytesMut, size: usize) -> Vec { + unsafe { + let buf: &[u8] = b.as_ref(); + assert!(buf.len() >= size * 8); + let mut index = 0; + + let mut v = Vec::with_capacity(size); + for i in 0..size { + *v.get_unchecked_mut(i) = i64::from_le_bytes( + buf.get_unchecked(index..index + 8) + .try_into() + .unwrap_unchecked(), + ); + index += 8; + } + v.set_len(size); + v + } +} + +#[inline(never)] +fn read_le_optimized(mut b: BytesMut, size: usize) -> Vec { + let _p = pilota::thrift::binary_le::TBinaryProtocol::new(&mut b, true); + let mut v: Vec = Vec::with_capacity(size); + let _ = black_box({ + let src = b.as_ptr(); + let dst = v.as_mut_ptr(); + unsafe { + std::ptr::copy_nonoverlapping(src, dst as *mut u8, size * 8); + v.set_len(size); + } + }); + v +} + +criterion_group!(benches, binary_bench); +criterion_main!(benches); diff --git a/pilota/src/thrift/binary.rs b/pilota/src/thrift/binary.rs index 90668b17..7fafdf60 100644 --- a/pilota/src/thrift/binary.rs +++ b/pilota/src/thrift/binary.rs @@ -2,7 +2,6 @@ use std::{convert::TryInto, str}; use bytes::{Bytes, BytesMut}; use faststr::FastStr; -use lazy_static::__Deref; use linkedbytes::LinkedBytes; use tokio::io::{AsyncRead, AsyncReadExt}; @@ -355,16 +354,6 @@ impl TOutputProtocol for TBinaryProtocol<&mut BytesMut> { Ok(()) } - #[inline] - fn reserve(&mut self, size: usize) { - self.trans.reserve(size) - } - - #[inline] - fn buf_mut(&mut self) -> &mut BytesMut { - self.trans - } - #[inline] fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), EncodeError> { self.write_i32(b.len() as i32)?; @@ -541,17 +530,6 @@ impl TOutputProtocol for TBinaryProtocol<&mut LinkedBytes> { fn flush(&mut self) -> Result<(), EncodeError> { Ok(()) } - - #[inline] - fn reserve(&mut self, size: usize) { - self.trans.reserve(size) - } - - #[inline] - fn buf_mut(&mut self) -> &mut LinkedBytes { - self.trans - } - #[inline] fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), EncodeError> { self.write_i32(b.len() as i32)?; @@ -903,11 +881,15 @@ impl TInputProtocol for TBinaryProtocol<&mut BytesMut> { #[inline] fn read_faststr(&mut self) -> Result { let len = self.trans.read_i32()? as usize; - let bytes = self.trans.split_to(len).freeze(); if len >= ZERO_COPY_THRESHOLD { + let bytes = self.trans.split_to(len).freeze(); unsafe { return Ok(FastStr::from_bytes_unchecked(bytes)) }; } - unsafe { Ok(FastStr::new(str::from_utf8_unchecked(bytes.deref()))) } + unsafe { + Ok(FastStr::new(str::from_utf8_unchecked( + self.trans.get(..len).unwrap(), + ))) + } } #[inline] @@ -952,11 +934,6 @@ impl TInputProtocol for TBinaryProtocol<&mut BytesMut> { Ok(self.trans.read_u8()?) } - #[inline] - fn buf_mut(&mut self) -> &mut Self::Buf { - self.trans - } - #[inline] fn read_bytes_vec(&mut self) -> Result, DecodeError> { let len = self.trans.read_i32()? as usize; diff --git a/pilota/src/thrift/binary_le.rs b/pilota/src/thrift/binary_le.rs new file mode 100644 index 00000000..1e9826e8 --- /dev/null +++ b/pilota/src/thrift/binary_le.rs @@ -0,0 +1,943 @@ +use std::{convert::TryInto, str}; + +use bytes::{Bytes, BytesMut}; +use faststr::FastStr; +use linkedbytes::LinkedBytes; +use tokio::io::{AsyncRead, AsyncReadExt}; + +use super::{ + error::ProtocolErrorKind, + new_protocol_error, + rw_ext::{ReadExt, WriteExt}, + DecodeError, DecodeErrorKind, EncodeError, ProtocolError, TAsyncInputProtocol, + TFieldIdentifier, TInputProtocol, TLengthProtocol, TListIdentifier, TMapIdentifier, + TMessageIdentifier, TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType, + ZERO_COPY_THRESHOLD, +}; + +const VERSION_LE: u32 = 0x88880000; +const VERSION_MASK: u32 = 0xffff0000; + +pub struct TBinaryProtocol { + pub(crate) trans: T, + + zero_copy: bool, + zero_copy_len: usize, +} + +impl TBinaryProtocol { + /// `zero_copy` only takes effect when `T` is [`BytesMut`] for input and + /// [`LinkedBytes`] for output. + #[inline] + pub fn new(trans: T, zero_copy: bool) -> Self { + Self { + trans, + zero_copy, + zero_copy_len: 0, + } + } +} + +#[inline] +fn field_type_from_u8(ttype: u8) -> Result { + let ttype: TType = ttype.try_into().map_err(|_| { + new_protocol_error( + ProtocolErrorKind::InvalidData, + format!("invalid ttype {}", ttype), + ) + })?; + + Ok(ttype) +} + +impl TLengthProtocol for TBinaryProtocol { + #[inline] + fn write_message_begin_len(&mut self, identifier: &TMessageIdentifier) -> usize { + self.write_i32_len(0) + self.write_faststr_len(&identifier.name) + self.write_i32_len(0) + } + + #[inline] + fn write_message_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_struct_begin_len(&mut self, _identifier: &TStructIdentifier) -> usize { + 0 + } + + #[inline] + fn write_struct_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_field_begin_len(&mut self, _field_type: TType, _id: Option) -> usize { + self.write_byte_len(0) + self.write_i16_len(0) + } + + #[inline] + fn write_field_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_field_stop_len(&mut self) -> usize { + self.write_byte_len(0) + } + + #[inline] + fn write_bool_len(&mut self, _b: bool) -> usize { + self.write_i8_len(0) + } + + #[inline] + fn write_bytes_len(&mut self, b: &[u8]) -> usize { + if self.zero_copy && b.len() >= ZERO_COPY_THRESHOLD { + self.zero_copy_len += b.len(); + } + self.write_i32_len(0) + b.len() + } + + #[inline] + fn write_byte_len(&mut self, _b: u8) -> usize { + 1 + } + + #[inline] + fn write_uuid_len(&mut self, _u: [u8; 16]) -> usize { + 16 + } + + #[inline] + fn write_i8_len(&mut self, _i: i8) -> usize { + 1 + } + + #[inline] + fn write_i16_len(&mut self, _i: i16) -> usize { + 2 + } + + #[inline] + fn write_i32_len(&mut self, _i: i32) -> usize { + 4 + } + + #[inline] + fn write_i64_len(&mut self, _i: i64) -> usize { + 8 + } + + #[inline] + fn write_double_len(&mut self, _d: f64) -> usize { + 8 + } + + fn write_string_len(&mut self, s: &str) -> usize { + self.write_i32_len(0) + s.len() + } + + #[inline] + fn write_faststr_len(&mut self, s: &FastStr) -> usize { + if self.zero_copy && s.len() >= ZERO_COPY_THRESHOLD { + self.zero_copy_len += s.len(); + } + self.write_i32_len(0) + s.len() + } + + #[inline] + fn write_list_begin_len(&mut self, _identifier: TListIdentifier) -> usize { + self.write_byte_len(0) + self.write_i32_len(0) + } + + #[inline] + fn write_list_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_set_begin_len(&mut self, _identifier: TSetIdentifier) -> usize { + self.write_byte_len(0) + self.write_i32_len(0) + } + + #[inline] + fn write_set_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_map_begin_len(&mut self, _identifier: TMapIdentifier) -> usize { + self.write_byte_len(0) + self.write_byte_len(0) + self.write_i32_len(0) + } + + #[inline] + fn write_map_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_bytes_vec_len(&mut self, b: &[u8]) -> usize { + self.write_i32_len(0) + b.len() + } + + #[inline] + fn zero_copy_len(&mut self) -> usize { + self.zero_copy_len + } + + #[inline] + fn reset(&mut self) { + self.zero_copy_len = 0; + } +} + +impl TOutputProtocol for TBinaryProtocol<&mut BytesMut> { + type BufMut = BytesMut; + + #[inline] + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> Result<(), EncodeError> { + let msg_type_u8: u8 = identifier.message_type.into(); + let version = (VERSION_LE | msg_type_u8 as u32) as i32; + self.write_i32(version)?; + self.write_faststr(identifier.name.clone())?; + self.write_i32(identifier.sequence_number)?; + Ok(()) + } + + #[inline] + fn write_message_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_struct_begin(&mut self, _: &TStructIdentifier) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_struct_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_field_begin(&mut self, field_type: TType, id: i16) -> Result<(), EncodeError> { + let mut data: [u8; 3] = [0; 3]; + data[0] = field_type as u8; + let id = id.to_le_bytes(); + data[1] = id[0]; + data[2] = id[1]; + self.trans.write_slice(&data)?; + Ok(()) + } + + #[inline] + fn write_field_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_field_stop(&mut self) -> Result<(), EncodeError> { + self.write_byte(TType::Stop as u8) + } + + #[inline] + fn write_bool(&mut self, b: bool) -> Result<(), EncodeError> { + if b { + self.write_i8(1) + } else { + self.write_i8(0) + } + } + + #[inline] + fn write_bytes(&mut self, b: Bytes) -> Result<(), EncodeError> { + self.write_i32(b.len() as i32)?; + self.trans.write_slice(&b)?; + Ok(()) + } + + #[inline] + fn write_byte(&mut self, b: u8) -> Result<(), EncodeError> { + self.trans.write_u8(b)?; + Ok(()) + } + + #[inline] + fn write_uuid(&mut self, u: [u8; 16]) -> Result<(), EncodeError> { + self.trans.write_slice(&u)?; + Ok(()) + } + + #[inline] + fn write_i8(&mut self, i: i8) -> Result<(), EncodeError> { + self.trans.write_i8(i)?; + Ok(()) + } + + #[inline] + fn write_i16(&mut self, i: i16) -> Result<(), EncodeError> { + self.trans.write_i16_le(i)?; + Ok(()) + } + + #[inline] + fn write_i32(&mut self, i: i32) -> Result<(), EncodeError> { + self.trans.write_i32_le(i)?; + Ok(()) + } + + #[inline] + fn write_i64(&mut self, i: i64) -> Result<(), EncodeError> { + self.trans.write_i64_le(i)?; + Ok(()) + } + + #[inline] + fn write_double(&mut self, d: f64) -> Result<(), EncodeError> { + self.trans.write_f64_le(d)?; + Ok(()) + } + + #[inline] + fn write_string(&mut self, s: &str) -> Result<(), EncodeError> { + self.write_i32(s.len() as i32)?; + self.trans.write_slice(s.as_bytes())?; + Ok(()) + } + + #[inline] + fn write_faststr(&mut self, s: FastStr) -> Result<(), EncodeError> { + self.write_i32(s.len() as i32)?; + self.trans.write_slice(s.as_ref())?; + Ok(()) + } + + #[inline] + fn write_list_begin(&mut self, identifier: TListIdentifier) -> Result<(), EncodeError> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_list_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_set_begin(&mut self, identifier: TSetIdentifier) -> Result<(), EncodeError> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_set_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_map_begin(&mut self, identifier: TMapIdentifier) -> Result<(), EncodeError> { + let key_type = identifier.key_type; + self.write_byte(key_type.into())?; + let val_type = identifier.value_type; + self.write_byte(val_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_map_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn flush(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), EncodeError> { + self.write_i32(b.len() as i32)?; + self.trans.write_slice(b)?; + Ok(()) + } +} + +impl TOutputProtocol for TBinaryProtocol<&mut LinkedBytes> { + type BufMut = LinkedBytes; + + #[inline] + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> Result<(), EncodeError> { + let msg_type_u8: u8 = identifier.message_type.into(); + let version = (VERSION_LE | msg_type_u8 as u32) as i32; + self.write_i32(version)?; + self.write_faststr(identifier.name.clone())?; + self.write_i32(identifier.sequence_number)?; + Ok(()) + } + + #[inline] + fn write_message_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_struct_begin(&mut self, _: &TStructIdentifier) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_struct_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_field_begin(&mut self, field_type: TType, id: i16) -> Result<(), EncodeError> { + let mut data: [u8; 3] = [0; 3]; + data[0] = field_type as u8; + let id = id.to_le_bytes(); + data[1] = id[0]; + data[2] = id[1]; + self.trans.bytes_mut().write_slice(&data)?; + Ok(()) + } + + #[inline] + fn write_field_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_field_stop(&mut self) -> Result<(), EncodeError> { + self.write_byte(TType::Stop as u8) + } + + #[inline] + fn write_bool(&mut self, b: bool) -> Result<(), EncodeError> { + if b { + self.write_i8(1) + } else { + self.write_i8(0) + } + } + + #[inline] + fn write_bytes(&mut self, b: Bytes) -> Result<(), EncodeError> { + self.write_i32(b.len() as i32)?; + if self.zero_copy && b.len() >= ZERO_COPY_THRESHOLD { + self.trans.insert(b); + return Ok(()); + } + self.trans.bytes_mut().write_slice(&b)?; + Ok(()) + } + + #[inline] + fn write_byte(&mut self, b: u8) -> Result<(), EncodeError> { + self.trans.bytes_mut().write_u8(b)?; + Ok(()) + } + + #[inline] + fn write_uuid(&mut self, u: [u8; 16]) -> Result<(), EncodeError> { + self.trans.bytes_mut().write_slice(&u)?; + Ok(()) + } + + #[inline] + fn write_i8(&mut self, i: i8) -> Result<(), EncodeError> { + self.trans.bytes_mut().write_i8(i)?; + Ok(()) + } + + #[inline] + fn write_i16(&mut self, i: i16) -> Result<(), EncodeError> { + self.trans.bytes_mut().write_i16_le(i)?; + Ok(()) + } + + #[inline] + fn write_i32(&mut self, i: i32) -> Result<(), EncodeError> { + self.trans.bytes_mut().write_i32_le(i)?; + Ok(()) + } + + #[inline] + fn write_i64(&mut self, i: i64) -> Result<(), EncodeError> { + self.trans.bytes_mut().write_i64_le(i)?; + Ok(()) + } + + #[inline] + fn write_double(&mut self, d: f64) -> Result<(), EncodeError> { + self.trans.bytes_mut().write_f64_le(d)?; + Ok(()) + } + + #[inline] + fn write_string(&mut self, s: &str) -> Result<(), EncodeError> { + self.write_i32(s.len() as i32)?; + self.trans.bytes_mut().write_slice(s.as_bytes())?; + Ok(()) + } + #[inline] + fn write_faststr(&mut self, s: FastStr) -> Result<(), EncodeError> { + self.write_i32(s.len() as i32)?; + if self.zero_copy && s.len() >= ZERO_COPY_THRESHOLD { + self.trans.insert_faststr(s); + return Ok(()); + } + self.trans.bytes_mut().write_slice(s.as_ref())?; + Ok(()) + } + + #[inline] + fn write_list_begin(&mut self, identifier: TListIdentifier) -> Result<(), EncodeError> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_list_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_set_begin(&mut self, identifier: TSetIdentifier) -> Result<(), EncodeError> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_set_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_map_begin(&mut self, identifier: TMapIdentifier) -> Result<(), EncodeError> { + let key_type = identifier.key_type; + self.write_byte(key_type.into())?; + let val_type = identifier.value_type; + self.write_byte(val_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_map_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn flush(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), EncodeError> { + self.write_i32(b.len() as i32)?; + self.trans.bytes_mut().write_slice(b)?; + Ok(()) + } +} + +pub struct TAsyncBinaryProtocol { + reader: R, +} + +#[async_trait::async_trait] +impl TAsyncInputProtocol for TAsyncBinaryProtocol +where + R: AsyncRead + Unpin + Send, +{ + // https://github.com/apache/thrift/blob/master/doc/specs/thrift-binary-protocol.md + async fn read_message_begin(&mut self) -> Result { + let size = self.reader.read_i32_le().await?; + if size > 0 { + return Err(DecodeError::new( + DecodeErrorKind::BadVersion, + "Missing version in ReadMessageBegin".to_string(), + )); + } + + let type_u8 = (size & 0xf) as u8; + + let message_type = TMessageType::try_from(type_u8).map_err(|_| { + DecodeError::new( + DecodeErrorKind::InvalidData, + format!("invalid message type {}", type_u8), + ) + })?; + + let version = size & (VERSION_MASK as i32); + if version != (VERSION_LE as i32) { + return Err(DecodeError::new( + DecodeErrorKind::BadVersion, + "Bad version in ReadMessageBegin", + )); + } + + let name = self.read_faststr().await?; + + let sequence_number = self.read_i32().await?; + Ok(TMessageIdentifier::new(name, message_type, sequence_number)) + } + + #[inline] + async fn read_message_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_struct_begin(&mut self) -> Result, DecodeError> { + Ok(None) + } + + #[inline] + async fn read_struct_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_field_begin(&mut self) -> Result { + let field_type_byte = self.read_byte().await?; + let field_type = field_type_byte.try_into().map_err(|_| { + DecodeError::new( + DecodeErrorKind::InvalidData, + format!("invalid ttype {}", field_type_byte), + ) + })?; + let id = match field_type { + TType::Stop => Ok(0), + _ => self.read_i16().await, + }?; + Ok(TFieldIdentifier::new::, i16>( + None, field_type, id, + )) + } + + #[inline] + async fn read_field_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_bool(&mut self) -> Result { + let b = self.read_i8().await?; + match b { + 0 => Ok(false), + _ => Ok(true), + } + } + + #[inline] + async fn read_bytes(&mut self) -> Result { + self.read_bytes_vec().await.map(Bytes::from) + } + + #[inline] + async fn read_bytes_vec(&mut self) -> Result, DecodeError> { + let len = self.reader.read_i32_le().await? as usize; + // FIXME: use maybe_uninit? + let mut v = vec![0; len]; + self.reader.read_exact(&mut v).await?; + Ok(v) + } + + #[inline] + async fn read_uuid(&mut self) -> Result<[u8; 16], DecodeError> { + let mut uuid = [0; 16]; + self.reader.read_exact(&mut uuid).await?; + Ok(uuid) + } + + #[inline] + async fn read_string(&mut self) -> Result { + let len = self.reader.read_i32_le().await? as usize; + // FIXME: use maybe_uninit? + let mut v = vec![0; len]; + self.reader.read_exact(&mut v).await?; + Ok(unsafe { String::from_utf8_unchecked(v) }) + } + + #[inline] + async fn read_faststr(&mut self) -> Result { + self.read_string().await.map(FastStr::from_string) + } + + #[inline] + async fn read_byte(&mut self) -> Result { + Ok(self.reader.read_u8().await?) + } + + #[inline] + async fn read_i8(&mut self) -> Result { + Ok(self.reader.read_i8().await?) + } + + #[inline] + async fn read_i16(&mut self) -> Result { + Ok(self.reader.read_i16_le().await?) + } + + #[inline] + async fn read_i32(&mut self) -> Result { + Ok(self.reader.read_i32_le().await?) + } + + #[inline] + async fn read_i64(&mut self) -> Result { + Ok(self.reader.read_i64_le().await?) + } + + #[inline] + async fn read_double(&mut self) -> Result { + Ok(self.reader.read_f64_le().await?) + } + + #[inline] + async fn read_list_begin(&mut self) -> Result { + let element_type: TType = self + .read_byte() + .await + .and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32().await?; + Ok(TListIdentifier::new(element_type, size as usize)) + } + + #[inline] + async fn read_list_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_set_begin(&mut self) -> Result { + let element_type: TType = self + .read_byte() + .await + .and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32().await?; + Ok(TSetIdentifier::new(element_type, size as usize)) + } + + #[inline] + async fn read_set_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_map_begin(&mut self) -> Result { + let key_type: TType = self + .read_byte() + .await + .and_then(|n| Ok(field_type_from_u8(n)?))?; + let value_type: TType = self + .read_byte() + .await + .and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32().await?; + Ok(TMapIdentifier::new(key_type, value_type, size as usize)) + } + + #[inline] + async fn read_map_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } +} + +impl TAsyncBinaryProtocol +where + R: AsyncRead + Unpin + Send, +{ + pub fn new(reader: R) -> Self { + Self { reader } + } +} + +impl TInputProtocol for TBinaryProtocol<&mut BytesMut> { + type Buf = BytesMut; + + fn read_message_begin(&mut self) -> Result { + let size = self.trans.read_i32_le()?; + + if size > 0 { + return Err(DecodeError::new( + DecodeErrorKind::BadVersion, + "Missing version in ReadMessageBegin".to_string(), + )); + } + let type_u8 = (size & 0xf) as u8; + + let message_type = TMessageType::try_from(type_u8).map_err(|_| { + DecodeError::new( + DecodeErrorKind::InvalidData, + format!("invalid message type {}", type_u8), + ) + })?; + + let version = size & (VERSION_MASK as i32); + if version != (VERSION_LE as i32) { + return Err(DecodeError::new( + DecodeErrorKind::BadVersion, + "Bad version in ReadMessageBegin", + )); + } + + let name = self.read_faststr()?; + + let sequence_number = self.read_i32()?; + Ok(TMessageIdentifier::new(name, message_type, sequence_number)) + } + + #[inline] + fn read_message_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_struct_begin(&mut self) -> Result, DecodeError> { + Ok(None) + } + + #[inline] + fn read_struct_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_field_begin(&mut self) -> Result { + let field_type_byte = self.read_byte()?; + let field_type = field_type_byte.try_into().map_err(|_| { + DecodeError::new( + DecodeErrorKind::InvalidData, + format!("invalid ttype {}", field_type_byte), + ) + })?; + let id = match field_type { + TType::Stop => Ok(0), + _ => self.read_i16(), + }?; + Ok(TFieldIdentifier::new::, i16>( + None, field_type, id, + )) + } + + #[inline] + fn read_field_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_bool(&mut self) -> Result { + let b = self.read_i8()?; + match b { + 0 => Ok(false), + _ => Ok(true), + } + } + + #[inline] + fn read_bytes(&mut self) -> Result { + let len = self.trans.read_i32_le()?; + // split and freeze it + Ok(self.trans.split_to(len as usize).freeze()) + } + + #[inline] + fn read_uuid(&mut self) -> Result<[u8; 16], DecodeError> { + let mut u = [0; 16]; + self.trans.read_to_slice(&mut u)?; + Ok(u) + } + + #[inline] + fn read_i8(&mut self) -> Result { + Ok(self.trans.read_i8()?) + } + + #[inline] + fn read_i16(&mut self) -> Result { + Ok(self.trans.read_i16_le()?) + } + + #[inline] + fn read_i32(&mut self) -> Result { + Ok(self.trans.read_i32_le()?) + } + + #[inline] + fn read_i64(&mut self) -> Result { + Ok(self.trans.read_i64_le()?) + } + + #[inline] + fn read_double(&mut self) -> Result { + Ok(self.trans.read_f64_le()?) + } + + #[inline] + fn read_string(&mut self) -> Result { + let len = self.trans.read_i32_le()?; + Ok(self.trans.read_to_string(len as usize)?) + } + + #[inline] + fn read_faststr(&mut self) -> Result { + let len = self.trans.read_i32_le()? as usize; + if len >= ZERO_COPY_THRESHOLD { + let bytes = self.trans.split_to(len).freeze(); + unsafe { return Ok(FastStr::from_bytes_unchecked(bytes)) }; + } + unsafe { + Ok(FastStr::new(str::from_utf8_unchecked( + self.trans.get(..len).unwrap(), + ))) + } + } + + #[inline] + fn read_list_begin(&mut self) -> Result { + let element_type: TType = self.read_byte().and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32()?; + Ok(TListIdentifier::new(element_type, size as usize)) + } + + #[inline] + fn read_list_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_set_begin(&mut self) -> Result { + let element_type: TType = self.read_byte().and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32()?; + Ok(TSetIdentifier::new(element_type, size as usize)) + } + + #[inline] + fn read_set_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_map_begin(&mut self) -> Result { + let key_type: TType = self.read_byte().and_then(|n| Ok(field_type_from_u8(n)?))?; + let value_type: TType = self.read_byte().and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32()?; + Ok(TMapIdentifier::new(key_type, value_type, size as usize)) + } + + #[inline] + fn read_map_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_byte(&mut self) -> Result { + Ok(self.trans.read_u8()?) + } + + #[inline] + fn read_bytes_vec(&mut self) -> Result, DecodeError> { + let len = self.trans.read_i32_le()? as usize; + Ok(self.trans.split_to(len).into()) + } +} diff --git a/pilota/src/thrift/binary_unsafe.rs b/pilota/src/thrift/binary_unsafe.rs new file mode 100644 index 00000000..79d71e80 --- /dev/null +++ b/pilota/src/thrift/binary_unsafe.rs @@ -0,0 +1,1190 @@ +use std::{convert::TryInto, ptr, slice, str}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use faststr::FastStr; +use linkedbytes::LinkedBytes; +use tokio::io::{AsyncRead, AsyncReadExt}; + +use super::{ + error::ProtocolErrorKind, new_protocol_error, DecodeError, DecodeErrorKind, EncodeError, + ProtocolError, TAsyncInputProtocol, TFieldIdentifier, TInputProtocol, TLengthProtocol, + TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType, TOutputProtocol, + TSetIdentifier, TStructIdentifier, TType, ZERO_COPY_THRESHOLD, +}; + +static VERSION_1: u32 = 0x80010000; +static VERSION_MASK: u32 = 0xffff0000; + +pub struct TBinaryProtocol { + pub(crate) trans: T, + pub(crate) buf: &'static mut [u8], + pub(crate) index: usize, + + zero_copy: bool, + zero_copy_len: usize, +} + +impl TBinaryProtocol { + /// `zero_copy` only takes effect when `T` is [`BytesMut`] for input and + /// [`LinkedBytes`] for output. + /// + /// # Safety + /// + /// The 'buf' MUST point to the same area of trans, this is a + /// self-referencial struct. + /// + /// The 'trans' MUST have enough capacity to read from or write to. + #[inline] + pub unsafe fn new(trans: T, buf: &'static mut [u8], zero_copy: bool) -> Self { + Self { + trans, + buf, + index: 0, + zero_copy, + zero_copy_len: 0, + } + } +} + +#[inline] +fn field_type_from_u8(ttype: u8) -> Result { + let ttype: TType = ttype.try_into().map_err(|_| { + new_protocol_error( + ProtocolErrorKind::InvalidData, + format!("invalid ttype {}", ttype), + ) + })?; + + Ok(ttype) +} + +impl TLengthProtocol for TBinaryProtocol { + #[inline] + fn write_message_begin_len(&mut self, identifier: &TMessageIdentifier) -> usize { + self.write_i32_len(0) + self.write_faststr_len(&identifier.name) + self.write_i32_len(0) + } + + #[inline] + fn write_message_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_struct_begin_len(&mut self, _identifier: &TStructIdentifier) -> usize { + 0 + } + + #[inline] + fn write_struct_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_field_begin_len(&mut self, _field_type: TType, _id: Option) -> usize { + self.write_byte_len(0) + self.write_i16_len(0) + } + + #[inline] + fn write_field_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_field_stop_len(&mut self) -> usize { + self.write_byte_len(0) + } + + #[inline] + fn write_bool_len(&mut self, _b: bool) -> usize { + self.write_i8_len(0) + } + + #[inline] + fn write_bytes_len(&mut self, b: &[u8]) -> usize { + if self.zero_copy && b.len() >= ZERO_COPY_THRESHOLD { + self.zero_copy_len += b.len(); + } + self.write_i32_len(0) + b.len() + } + + #[inline] + fn write_byte_len(&mut self, _b: u8) -> usize { + 1 + } + + #[inline] + fn write_uuid_len(&mut self, _u: [u8; 16]) -> usize { + 16 + } + + #[inline] + fn write_i8_len(&mut self, _i: i8) -> usize { + 1 + } + + #[inline] + fn write_i16_len(&mut self, _i: i16) -> usize { + 2 + } + + #[inline] + fn write_i32_len(&mut self, _i: i32) -> usize { + 4 + } + + #[inline] + fn write_i64_len(&mut self, _i: i64) -> usize { + 8 + } + + #[inline] + fn write_double_len(&mut self, _d: f64) -> usize { + 8 + } + + fn write_string_len(&mut self, s: &str) -> usize { + self.write_i32_len(0) + s.len() + } + + #[inline] + fn write_faststr_len(&mut self, s: &FastStr) -> usize { + if self.zero_copy && s.len() >= ZERO_COPY_THRESHOLD { + self.zero_copy_len += s.len(); + } + self.write_i32_len(0) + s.len() + } + + #[inline] + fn write_list_begin_len(&mut self, _identifier: TListIdentifier) -> usize { + self.write_byte_len(0) + self.write_i32_len(0) + } + + #[inline] + fn write_list_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_set_begin_len(&mut self, _identifier: TSetIdentifier) -> usize { + self.write_byte_len(0) + self.write_i32_len(0) + } + + #[inline] + fn write_set_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_map_begin_len(&mut self, _identifier: TMapIdentifier) -> usize { + self.write_byte_len(0) + self.write_byte_len(0) + self.write_i32_len(0) + } + + #[inline] + fn write_map_end_len(&mut self) -> usize { + 0 + } + + #[inline] + fn write_bytes_vec_len(&mut self, b: &[u8]) -> usize { + self.write_i32_len(0) + b.len() + } + + #[inline] + fn zero_copy_len(&mut self) -> usize { + self.zero_copy_len + } + + #[inline] + fn reset(&mut self) { + self.zero_copy_len = 0; + } +} + +impl TOutputProtocol for TBinaryProtocol<&mut BytesMut> { + type BufMut = BytesMut; + + #[inline] + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> Result<(), EncodeError> { + let msg_type_u8: u8 = identifier.message_type.into(); + let version = (VERSION_1 | msg_type_u8 as u32) as i32; + self.write_i32(version)?; + self.write_faststr(identifier.name.clone())?; + self.write_i32(identifier.sequence_number)?; + Ok(()) + } + + #[inline] + fn write_message_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_struct_begin(&mut self, _: &TStructIdentifier) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_struct_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_field_begin(&mut self, field_type: TType, id: i16) -> Result<(), EncodeError> { + unsafe { + *self.buf.get_unchecked_mut(self.index) = field_type as u8; + let buf: &mut [u8; 2] = self + .buf + .get_unchecked_mut(self.index + 1..self.index + 3) + .try_into() + .unwrap_unchecked(); + *buf = id.to_be_bytes(); + self.index += 3; + } + Ok(()) + } + + #[inline] + fn write_field_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_field_stop(&mut self) -> Result<(), EncodeError> { + self.write_byte(TType::Stop as u8) + } + + #[inline] + fn write_bool(&mut self, b: bool) -> Result<(), EncodeError> { + if b { + self.write_i8(1) + } else { + self.write_i8(0) + } + } + + #[inline] + fn write_bytes(&mut self, b: Bytes) -> Result<(), EncodeError> { + self.write_i32(b.len() as i32)?; + unsafe { + ptr::copy_nonoverlapping( + b.as_ptr(), + self.buf.as_mut_ptr().offset(self.index as isize), + b.len(), + ); + self.index += b.len(); + } + Ok(()) + } + + #[inline] + fn write_byte(&mut self, b: u8) -> Result<(), EncodeError> { + unsafe { + *self.buf.get_unchecked_mut(self.index) = b; + self.index += 1; + } + Ok(()) + } + + #[inline] + fn write_uuid(&mut self, u: [u8; 16]) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 16] = self + .buf + .get_unchecked_mut(self.index..self.index + 16) + .try_into() + .unwrap_unchecked(); + *buf = u; + self.index += 16; + } + Ok(()) + } + + #[inline] + fn write_i8(&mut self, i: i8) -> Result<(), EncodeError> { + unsafe { + *self.buf.get_unchecked_mut(self.index) = *i.to_be_bytes().get_unchecked(0); + self.index += 1; + } + Ok(()) + } + + #[inline] + fn write_i16(&mut self, i: i16) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 2] = self + .trans + .get_unchecked_mut(self.index..self.index + 2) + .try_into() + .unwrap_unchecked(); + *buf = i.to_be_bytes(); + self.index += 2; + } + Ok(()) + } + + #[inline] + fn write_i32(&mut self, i: i32) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 4] = self + .trans + .get_unchecked_mut(self.index..self.index + 4) + .try_into() + .unwrap_unchecked(); + *buf = i.to_be_bytes(); + self.index += 4; + } + Ok(()) + } + + #[inline] + fn write_i64(&mut self, i: i64) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 8] = self + .trans + .get_unchecked_mut(self.index..self.index + 8) + .try_into() + .unwrap_unchecked(); + *buf = i.to_be_bytes(); + self.index += 8; + } + Ok(()) + } + + #[inline] + fn write_double(&mut self, d: f64) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 8] = self + .trans + .get_unchecked_mut(self.index..self.index + 8) + .try_into() + .unwrap_unchecked(); + *buf = d.to_bits().to_be_bytes(); + self.index += 8; + } + Ok(()) + } + + #[inline] + fn write_string(&mut self, s: &str) -> Result<(), EncodeError> { + self.write_i32(s.len() as i32)?; + unsafe { + ptr::copy_nonoverlapping( + s.as_ptr(), + self.buf.as_mut_ptr().offset(self.index as isize), + s.len(), + ); + self.index += s.len(); + } + Ok(()) + } + + #[inline] + fn write_faststr(&mut self, s: FastStr) -> Result<(), EncodeError> { + self.write_i32(s.len() as i32)?; + unsafe { + ptr::copy_nonoverlapping( + s.as_ptr(), + self.buf.as_mut_ptr().offset(self.index as isize), + s.len(), + ); + self.index += s.len(); + } + Ok(()) + } + + #[inline] + fn write_list_begin(&mut self, identifier: TListIdentifier) -> Result<(), EncodeError> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_list_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_set_begin(&mut self, identifier: TSetIdentifier) -> Result<(), EncodeError> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_set_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_map_begin(&mut self, identifier: TMapIdentifier) -> Result<(), EncodeError> { + let key_type = identifier.key_type; + self.write_byte(key_type.into())?; + let val_type = identifier.value_type; + self.write_byte(val_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_map_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn flush(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), EncodeError> { + self.write_i32(b.len() as i32)?; + unsafe { + ptr::copy_nonoverlapping( + b.as_ptr(), + self.buf.as_mut_ptr().offset(self.index as isize), + b.len(), + ); + self.index += b.len(); + } + Ok(()) + } +} + +impl TOutputProtocol for TBinaryProtocol<&mut LinkedBytes> { + type BufMut = LinkedBytes; + + #[inline] + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> Result<(), EncodeError> { + let msg_type_u8: u8 = identifier.message_type.into(); + let version = (VERSION_1 | msg_type_u8 as u32) as i32; + self.write_i32(version)?; + self.write_faststr(identifier.name.clone())?; + self.write_i32(identifier.sequence_number)?; + Ok(()) + } + + #[inline] + fn write_message_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_struct_begin(&mut self, _: &TStructIdentifier) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_struct_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_field_begin(&mut self, field_type: TType, id: i16) -> Result<(), EncodeError> { + unsafe { + *self.buf.get_unchecked_mut(self.index) = field_type as u8; + let buf: &mut [u8; 2] = self + .buf + .get_unchecked_mut(self.index + 1..self.index + 3) + .try_into() + .unwrap_unchecked(); + *buf = id.to_be_bytes(); + self.index += 3; + } + Ok(()) + } + + #[inline] + fn write_field_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_field_stop(&mut self) -> Result<(), EncodeError> { + self.write_byte(TType::Stop as u8) + } + + #[inline] + fn write_bool(&mut self, b: bool) -> Result<(), EncodeError> { + if b { + self.write_i8(1) + } else { + self.write_i8(0) + } + } + + #[inline] + fn write_bytes(&mut self, b: Bytes) -> Result<(), EncodeError> { + self.write_i32(b.len() as i32)?; + if self.zero_copy && b.len() >= ZERO_COPY_THRESHOLD { + unsafe { + self.trans.bytes_mut().advance_mut(self.index); + self.index = 0; + } + self.trans.insert(b); + self.buf = unsafe { + slice::from_raw_parts_mut( + self.trans.bytes_mut().as_mut_ptr(), + self.trans.bytes_mut().len(), + ) + }; + return Ok(()); + } + unsafe { + ptr::copy_nonoverlapping( + b.as_ptr(), + self.buf.as_mut_ptr().offset(self.index as isize), + b.len(), + ); + self.index += b.len(); + } + Ok(()) + } + + #[inline] + fn write_byte(&mut self, b: u8) -> Result<(), EncodeError> { + unsafe { + *self.buf.get_unchecked_mut(self.index) = b; + self.index += 1; + } + Ok(()) + } + + #[inline] + fn write_uuid(&mut self, u: [u8; 16]) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 16] = self + .trans + .bytes_mut() + .get_unchecked_mut(self.index..self.index + 16) + .try_into() + .unwrap_unchecked(); + *buf = u; + self.index += 16; + } + Ok(()) + } + + #[inline] + fn write_i8(&mut self, i: i8) -> Result<(), EncodeError> { + unsafe { + *self.buf.get_unchecked_mut(self.index) = *i.to_be_bytes().get_unchecked(0); + self.index += 1; + } + Ok(()) + } + + #[inline] + fn write_i16(&mut self, i: i16) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 2] = self + .trans + .bytes_mut() + .get_unchecked_mut(self.index..self.index + 2) + .try_into() + .unwrap_unchecked(); + *buf = i.to_be_bytes(); + self.index += 2; + } + Ok(()) + } + + #[inline] + fn write_i32(&mut self, i: i32) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 4] = self + .trans + .bytes_mut() + .get_unchecked_mut(self.index..self.index + 4) + .try_into() + .unwrap_unchecked(); + *buf = i.to_be_bytes(); + self.index += 4; + } + Ok(()) + } + + #[inline] + fn write_i64(&mut self, i: i64) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 8] = self + .trans + .bytes_mut() + .get_unchecked_mut(self.index..self.index + 8) + .try_into() + .unwrap_unchecked(); + *buf = i.to_be_bytes(); + self.index += 8; + } + Ok(()) + } + + #[inline] + fn write_double(&mut self, d: f64) -> Result<(), EncodeError> { + unsafe { + let buf: &mut [u8; 8] = self + .trans + .bytes_mut() + .get_unchecked_mut(self.index..self.index + 8) + .try_into() + .unwrap_unchecked(); + *buf = d.to_bits().to_be_bytes(); + self.index += 8; + } + Ok(()) + } + + #[inline] + fn write_string(&mut self, s: &str) -> Result<(), EncodeError> { + self.write_i32(s.len() as i32)?; + unsafe { + ptr::copy_nonoverlapping( + s.as_ptr(), + self.buf.as_mut_ptr().offset(self.index as isize), + s.len(), + ); + self.index += s.len(); + } + Ok(()) + } + + #[inline] + fn write_faststr(&mut self, s: FastStr) -> Result<(), EncodeError> { + self.write_i32(s.len() as i32)?; + if self.zero_copy && s.len() >= ZERO_COPY_THRESHOLD { + unsafe { + self.trans.bytes_mut().advance_mut(self.index); + self.index = 0; + } + self.trans.insert_faststr(s); + self.buf = unsafe { + slice::from_raw_parts_mut( + self.trans.bytes_mut().as_mut_ptr(), + self.trans.bytes_mut().len(), + ) + }; + return Ok(()); + } + unsafe { + ptr::copy_nonoverlapping( + s.as_ptr(), + self.buf.as_mut_ptr().offset(self.index as isize), + s.len(), + ); + self.index += s.len(); + } + Ok(()) + } + + #[inline] + fn write_list_begin(&mut self, identifier: TListIdentifier) -> Result<(), EncodeError> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_list_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_set_begin(&mut self, identifier: TSetIdentifier) -> Result<(), EncodeError> { + self.write_byte(identifier.element_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_set_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_map_begin(&mut self, identifier: TMapIdentifier) -> Result<(), EncodeError> { + let key_type = identifier.key_type; + self.write_byte(key_type.into())?; + let val_type = identifier.value_type; + self.write_byte(val_type.into())?; + self.write_i32(identifier.size as i32) + } + + #[inline] + fn write_map_end(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn flush(&mut self) -> Result<(), EncodeError> { + Ok(()) + } + + #[inline] + fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), EncodeError> { + self.write_i32(b.len() as i32)?; + unsafe { + ptr::copy_nonoverlapping( + b.as_ptr(), + self.buf.as_mut_ptr().offset(self.index as isize), + b.len(), + ); + self.index += b.len(); + } + Ok(()) + } +} + +pub struct TAsyncBinaryProtocol { + reader: R, +} + +#[async_trait::async_trait] +impl TAsyncInputProtocol for TAsyncBinaryProtocol +where + R: AsyncRead + Unpin + Send, +{ + // https://github.com/apache/thrift/blob/master/doc/specs/thrift-binary-protocol.md + async fn read_message_begin(&mut self) -> Result { + let size = self.reader.read_i32().await?; + if size > 0 { + return Err(DecodeError::new( + DecodeErrorKind::BadVersion, + "Missing version in ReadMessageBegin".to_string(), + )); + } + + let type_u8 = (size & 0xf) as u8; + + let message_type = TMessageType::try_from(type_u8).map_err(|_| { + DecodeError::new( + DecodeErrorKind::InvalidData, + format!("invalid message type {}", type_u8), + ) + })?; + + let version = size & (VERSION_MASK as i32); + if version != (VERSION_1 as i32) { + return Err(DecodeError::new( + DecodeErrorKind::BadVersion, + "Bad version in ReadMessageBegin", + )); + } + + let name = self.read_faststr().await?; + + let sequence_number = self.read_i32().await?; + Ok(TMessageIdentifier::new(name, message_type, sequence_number)) + } + + #[inline] + async fn read_message_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_struct_begin(&mut self) -> Result, DecodeError> { + Ok(None) + } + + #[inline] + async fn read_struct_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_field_begin(&mut self) -> Result { + let field_type_byte = self.read_byte().await?; + let field_type = field_type_byte.try_into().map_err(|_| { + DecodeError::new( + DecodeErrorKind::InvalidData, + format!("invalid ttype {}", field_type_byte), + ) + })?; + let id = match field_type { + TType::Stop => Ok(0), + _ => self.read_i16().await, + }?; + Ok(TFieldIdentifier::new::, i16>( + None, field_type, id, + )) + } + + #[inline] + async fn read_field_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_bool(&mut self) -> Result { + let b = self.read_i8().await?; + match b { + 0 => Ok(false), + _ => Ok(true), + } + } + + #[inline] + async fn read_bytes(&mut self) -> Result { + self.read_bytes_vec().await.map(Bytes::from) + } + + #[inline] + async fn read_bytes_vec(&mut self) -> Result, DecodeError> { + let len = self.reader.read_i32().await? as usize; + // FIXME: use maybe_uninit? + let mut v = vec![0; len]; + self.reader.read_exact(&mut v).await?; + Ok(v) + } + + #[inline] + async fn read_uuid(&mut self) -> Result<[u8; 16], DecodeError> { + let mut uuid = [0; 16]; + self.reader.read_exact(&mut uuid).await?; + Ok(uuid) + } + + #[inline] + async fn read_string(&mut self) -> Result { + let len = self.reader.read_i32().await? as usize; + // FIXME: use maybe_uninit? + let mut v = vec![0; len]; + self.reader.read_exact(&mut v).await?; + Ok(unsafe { String::from_utf8_unchecked(v) }) + } + + #[inline] + async fn read_faststr(&mut self) -> Result { + self.read_string().await.map(FastStr::from_string) + } + + #[inline] + async fn read_byte(&mut self) -> Result { + Ok(self.reader.read_u8().await?) + } + + #[inline] + async fn read_i8(&mut self) -> Result { + Ok(self.reader.read_i8().await?) + } + + #[inline] + async fn read_i16(&mut self) -> Result { + Ok(self.reader.read_i16().await?) + } + + #[inline] + async fn read_i32(&mut self) -> Result { + Ok(self.reader.read_i32().await?) + } + + #[inline] + async fn read_i64(&mut self) -> Result { + Ok(self.reader.read_i64().await?) + } + + #[inline] + async fn read_double(&mut self) -> Result { + Ok(self.reader.read_f64().await?) + } + + #[inline] + async fn read_list_begin(&mut self) -> Result { + let element_type: TType = self + .read_byte() + .await + .and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32().await?; + Ok(TListIdentifier::new(element_type, size as usize)) + } + + #[inline] + async fn read_list_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_set_begin(&mut self) -> Result { + let element_type: TType = self + .read_byte() + .await + .and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32().await?; + Ok(TSetIdentifier::new(element_type, size as usize)) + } + + #[inline] + async fn read_set_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + async fn read_map_begin(&mut self) -> Result { + let key_type: TType = self + .read_byte() + .await + .and_then(|n| Ok(field_type_from_u8(n)?))?; + let value_type: TType = self + .read_byte() + .await + .and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32().await?; + Ok(TMapIdentifier::new(key_type, value_type, size as usize)) + } + + #[inline] + async fn read_map_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } +} + +impl TAsyncBinaryProtocol +where + R: AsyncRead + Unpin + Send, +{ + pub fn new(reader: R) -> Self { + Self { reader } + } +} + +impl TInputProtocol for TBinaryProtocol<&mut BytesMut> { + type Buf = BytesMut; + + fn read_message_begin(&mut self) -> Result { + let size = self.read_i32()?; + + if size > 0 { + return Err(DecodeError::new( + DecodeErrorKind::BadVersion, + "Missing version in ReadMessageBegin".to_string(), + )); + } + let type_u8 = (size & 0xf) as u8; + + let message_type = TMessageType::try_from(type_u8).map_err(|_| { + DecodeError::new( + DecodeErrorKind::InvalidData, + format!("invalid message type {}", type_u8), + ) + })?; + + let version = size & (VERSION_MASK as i32); + if version != (VERSION_1 as i32) { + return Err(DecodeError::new( + DecodeErrorKind::BadVersion, + "Bad version in ReadMessageBegin", + )); + } + + let name = self.read_faststr()?; + + let sequence_number = self.read_i32()?; + Ok(TMessageIdentifier::new(name, message_type, sequence_number)) + } + + #[inline] + fn read_message_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_struct_begin(&mut self) -> Result, DecodeError> { + Ok(None) + } + + #[inline] + fn read_struct_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_field_begin(&mut self) -> Result { + let field_type_byte = self.read_byte()?; + let field_type = field_type_byte.try_into().map_err(|_| { + DecodeError::new( + DecodeErrorKind::InvalidData, + format!("invalid ttype {}", field_type_byte), + ) + })?; + let id = match field_type { + TType::Stop => Ok(0), + _ => self.read_i16(), + }?; + Ok(TFieldIdentifier::new::, i16>( + None, field_type, id, + )) + } + + #[inline] + fn read_field_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_bool(&mut self) -> Result { + let b = self.read_i8()?; + match b { + 0 => Ok(false), + _ => Ok(true), + } + } + + #[inline] + fn read_bytes(&mut self) -> Result { + let len = self.read_i32()?; + self.trans.advance(self.index); + self.index = 0; + // split and freeze it + let val = self.trans.split_to(len as usize).freeze(); + self.buf = unsafe { slice::from_raw_parts_mut(self.trans.as_mut_ptr(), self.trans.len()) }; + Ok(val) + } + + #[inline] + fn read_uuid(&mut self) -> Result<[u8; 16], DecodeError> { + let u; + unsafe { + u = self + .trans + .get_unchecked_mut(self.index..self.index + 16) + .try_into() + .unwrap_unchecked(); + self.index += 16; + } + Ok(u) + } + + #[inline] + fn read_i8(&mut self) -> Result { + unsafe { + let val = *self.buf.get_unchecked(self.index) as i8; + self.index += 1; + Ok(val) + } + } + + #[inline] + fn read_i16(&mut self) -> Result { + unsafe { + let val = self.buf.get_unchecked(self.index..self.index + 2); + self.index += 2; + Ok(i16::from_be_bytes(val.try_into().unwrap_unchecked())) + } + } + + #[inline] + fn read_i32(&mut self) -> Result { + unsafe { + let val = self.buf.get_unchecked(self.index..self.index + 4); + self.index += 4; + Ok(i32::from_be_bytes(val.try_into().unwrap_unchecked())) + } + } + + #[inline] + fn read_i64(&mut self) -> Result { + unsafe { + let val = self.buf.get_unchecked(self.index..self.index + 8); + self.index += 8; + Ok(i64::from_be_bytes(val.try_into().unwrap_unchecked())) + } + } + + #[inline] + fn read_double(&mut self) -> Result { + unsafe { + let val = self.buf.get_unchecked(self.index..self.index + 8); + self.index += 8; + Ok(f64::from_bits(u64::from_be_bytes( + val.try_into().unwrap_unchecked(), + ))) + } + } + + #[inline] + fn read_string(&mut self) -> Result { + unsafe { + let len = self.read_i32().unwrap_unchecked(); + let val = str::from_utf8_unchecked( + self.buf + .get_unchecked(self.index..self.index + len as usize), + ) + .to_string(); + self.index += len as usize; + Ok(val) + } + } + + #[inline] + fn read_faststr(&mut self) -> Result { + unsafe { + let len = self.read_i32().unwrap_unchecked() as usize; + if len >= ZERO_COPY_THRESHOLD { + self.trans.advance(self.index); + self.index = 0; + let bytes = self.trans.split_to(len).freeze(); + self.buf = slice::from_raw_parts_mut(self.trans.as_mut_ptr(), self.trans.len()); + return Ok(FastStr::from_bytes_unchecked(bytes)); + } + + let val = FastStr::new(str::from_utf8_unchecked( + self.buf.get_unchecked(self.index..self.index + len), + )); + + self.index += len; + + Ok(val) + } + } + + #[inline] + fn read_list_begin(&mut self) -> Result { + let element_type: TType = self.read_byte().and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32()?; + Ok(TListIdentifier::new(element_type, size as usize)) + } + + #[inline] + fn read_list_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_set_begin(&mut self) -> Result { + let element_type: TType = self.read_byte().and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32()?; + Ok(TSetIdentifier::new(element_type, size as usize)) + } + + #[inline] + fn read_set_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_map_begin(&mut self) -> Result { + let key_type: TType = self.read_byte().and_then(|n| Ok(field_type_from_u8(n)?))?; + let value_type: TType = self.read_byte().and_then(|n| Ok(field_type_from_u8(n)?))?; + let size = self.read_i32()?; + Ok(TMapIdentifier::new(key_type, value_type, size as usize)) + } + + #[inline] + fn read_map_end(&mut self) -> Result<(), DecodeError> { + Ok(()) + } + + #[inline] + fn read_byte(&mut self) -> Result { + unsafe { + let val = *self.buf.get_unchecked(self.index); + self.index += 1; + Ok(val) + } + } + + #[inline] + fn read_bytes_vec(&mut self) -> Result, DecodeError> { + let len = self.read_i32()? as usize; + self.trans.advance(self.index); + self.index = 0; + let val = self.trans.split_to(len).into(); + self.buf = unsafe { slice::from_raw_parts_mut(self.trans.as_mut_ptr(), self.trans.len()) }; + Ok(val) + } +} diff --git a/pilota/src/thrift/compact.rs b/pilota/src/thrift/compact.rs index 6e5f76a0..7db1d76a 100644 --- a/pilota/src/thrift/compact.rs +++ b/pilota/src/thrift/compact.rs @@ -654,16 +654,6 @@ impl TOutputProtocol for TCompactOutputProtocol<&mut BytesMut> { Ok(()) } - #[inline] - fn reserve(&mut self, size: usize) { - self.trans.reserve(size) - } - - #[inline] - fn buf_mut(&mut self) -> &mut Self::BufMut { - self.trans - } - #[inline] fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), EncodeError> { // length is strictly positive as per the spec, so @@ -928,16 +918,6 @@ impl TOutputProtocol for TCompactOutputProtocol<&mut LinkedBytes> { Ok(()) } - #[inline] - fn reserve(&mut self, size: usize) { - self.trans.reserve(size) - } - - #[inline] - fn buf_mut(&mut self) -> &mut Self::BufMut { - self.trans - } - #[inline] fn write_bytes_vec(&mut self, b: &[u8]) -> Result<(), EncodeError> { // length is strictly positive as per the spec, so @@ -1500,10 +1480,6 @@ impl TInputProtocol for TCompactInputProtocol<&mut BytesMut> { Ok(()) } - fn buf_mut(&mut self) -> &mut Self::Buf { - self.trans - } - #[inline] fn read_bytes_vec(&mut self) -> Result, DecodeError> { let size = self.read_varint::()? as usize; diff --git a/pilota/src/thrift/mod.rs b/pilota/src/thrift/mod.rs index 92aa9834..6f0a626b 100644 --- a/pilota/src/thrift/mod.rs +++ b/pilota/src/thrift/mod.rs @@ -1,4 +1,6 @@ pub mod binary; +pub mod binary_le; +pub mod binary_unsafe; pub mod compact; pub mod error; pub mod rw_ext; @@ -44,10 +46,12 @@ pub trait Message: Sized + Send { #[async_trait::async_trait] impl Message for Box { + #[inline] fn encode(&self, protocol: &mut T) -> Result<(), EncodeError> { self.deref().encode(protocol) } + #[inline] fn decode(protocol: &mut T) -> Result { Ok(Box::new(M::decode(protocol)?)) } @@ -56,6 +60,7 @@ impl Message for Box { Ok(Box::new(M::decode_async(protocol).await?)) } + #[inline] fn size(&self, protocol: &mut T) -> usize { self.deref().size(protocol) } @@ -63,10 +68,12 @@ impl Message for Box { #[async_trait::async_trait] impl Message for Arc { + #[inline] fn encode(&self, protocol: &mut T) -> Result<(), EncodeError> { self.deref().encode(protocol) } + #[inline] fn decode(protocol: &mut T) -> Result { Ok(Arc::new(M::decode(protocol)?)) } @@ -75,6 +82,7 @@ impl Message for Arc { Ok(Arc::new(M::decode_async(protocol).await?)) } + #[inline] fn size(&self, protocol: &mut T) -> usize { self.deref().size(protocol) } @@ -128,6 +136,7 @@ pub trait TInputProtocol { fn read_map_end(&mut self) -> Result<(), DecodeError>; /// Skip a field with type `field_type` recursively until the default /// maximum skip depth is reached. + #[inline] fn skip(&mut self, field_type: TType) -> Result<(), DecodeError> { self.skip_till_depth(field_type, MAXIMUM_SKIP_DEPTH) } @@ -201,8 +210,6 @@ pub trait TInputProtocol { /// Read a Vec. fn read_bytes_vec(&mut self) -> Result, DecodeError>; - - fn buf_mut(&mut self) -> &mut Self::Buf; } macro_rules! write_field_len { @@ -632,10 +639,6 @@ pub trait TOutputProtocol { fn write_map_end(&mut self) -> Result<(), EncodeError>; /// Flush buffered bytes to the underlying transport. fn flush(&mut self) -> Result<(), EncodeError>; - - fn reserve(&mut self, size: usize); - - fn buf_mut(&mut self) -> &mut Self::BufMut; } #[async_trait::async_trait] diff --git a/pilota/src/thrift/rw_ext.rs b/pilota/src/thrift/rw_ext.rs index 524ce7e3..d690e18c 100644 --- a/pilota/src/thrift/rw_ext.rs +++ b/pilota/src/thrift/rw_ext.rs @@ -129,6 +129,8 @@ pub trait WriteExt { fn write_f32_le(&mut self, n: f32) -> Result<(), IOError>; fn write_f64(&mut self, n: f64) -> Result<(), IOError>; + + fn write_f64_le(&mut self, n: f64) -> Result<(), IOError>; } impl WriteExt for BytesMut { @@ -264,6 +266,11 @@ impl WriteExt for BytesMut { fn write_f64(&mut self, n: f64) -> Result<(), IOError> { self.write_u64(n.to_bits()) } + + #[inline] + fn write_f64_le(&mut self, n: f64) -> Result<(), IOError> { + self.write_u64_le(n.to_bits()) + } } pub trait ReadExt {