From c471ce88d3e4c931d2d231d7e76ec8ea12d63bfb Mon Sep 17 00:00:00 2001 From: Mark Juggurnauth-Thomas Date: Tue, 17 Jan 2023 08:20:49 +0000 Subject: [PATCH] Add EncodedSize trait to calculate encoded sizes For some storage back-ends it's necessary to know the size of the buffer to pre-allocate before encoding. The bincode 1.0 version had `bincode::serialized_size`, which would tell you the serialized size of data. We add a new `EncodedSize` trait which accomplishes the same objective, and implement that for all types. Implementations for custom structs and enums can be derived using `#[derive(bincode::EncodedSize)]`. --- derive/src/attribute.rs | 11 + derive/src/derive_enum.rs | 116 ++++++++ derive/src/derive_struct.rs | 53 ++++ derive/src/lib.rs | 33 +++ docs/migration_guide.md | 4 +- fuzz/Cargo.toml | 6 + fuzz/fuzz_targets/encoded_size.rs | 52 ++++ src/atomic.rs | 79 +++++- src/features/derive.rs | 2 +- src/features/impl_alloc.rs | 115 +++++++- src/features/impl_std.rs | 134 ++++++++++ src/features/serde/mod.rs | 14 + src/features/serde/size.rs | 425 ++++++++++++++++++++++++++++++ src/lib.rs | 10 + src/size/impl_tuples.rs | 389 +++++++++++++++++++++++++++ src/size/impls.rs | 384 +++++++++++++++++++++++++++ src/size/mod.rs | 58 ++++ src/varint/mod.rs | 10 + src/varint/size_signed.rs | 228 ++++++++++++++++ src/varint/size_unsigned.rs | 174 ++++++++++++ tests/derive.rs | 55 +++- tests/serde.rs | 42 +-- tests/utils.rs | 25 +- 23 files changed, 2385 insertions(+), 34 deletions(-) create mode 100644 fuzz/fuzz_targets/encoded_size.rs create mode 100644 src/features/serde/size.rs create mode 100644 src/size/impl_tuples.rs create mode 100644 src/size/impls.rs create mode 100644 src/size/mod.rs create mode 100644 src/varint/size_signed.rs create mode 100644 src/varint/size_unsigned.rs diff --git a/derive/src/attribute.rs b/derive/src/attribute.rs index 30f817b9..00df4e09 100644 --- a/derive/src/attribute.rs +++ b/derive/src/attribute.rs @@ -7,6 +7,7 @@ pub struct ContainerAttributes { pub decode_bounds: Option<(String, Literal)>, pub borrow_decode_bounds: Option<(String, Literal)>, pub encode_bounds: Option<(String, Literal)>, + pub encoded_size_bounds: Option<(String, Literal)>, } impl Default for ContainerAttributes { @@ -17,6 +18,7 @@ impl Default for ContainerAttributes { decode_bounds: None, encode_bounds: None, borrow_decode_bounds: None, + encoded_size_bounds: None, } } } @@ -76,6 +78,15 @@ impl FromAttribute for ContainerAttributes { return Err(Error::custom_at("Should be a literal str", val.span())); } } + ParsedAttribute::Property(key, val) if key.to_string() == "encoded_size_bounds" => { + let val_string = val.to_string(); + if val_string.starts_with('"') && val_string.ends_with('"') { + result.encoded_size_bounds = + Some((val_string[1..val_string.len() - 1].to_string(), val)); + } else { + return Err(Error::custom_at("Should be a literal str", val.span())); + } + } ParsedAttribute::Tag(i) => { return Err(Error::custom_at("Unknown field attribute", i.span())) } diff --git a/derive/src/derive_enum.rs b/derive/src/derive_enum.rs index 614453a9..c6a15d79 100644 --- a/derive/src/derive_enum.rs +++ b/derive/src/derive_enum.rs @@ -135,6 +135,122 @@ impl DeriveEnum { Ok(()) } + pub fn generate_encoded_size(self, generator: &mut Generator) -> Result<()> { + let crate_name = self.attributes.crate_name.as_str(); + generator + .impl_for(format!("{}::EncodedSize", crate_name)) + .modify_generic_constraints(|generics, where_constraints| { + if let Some((bounds, lit)) = + (self.attributes.encoded_size_bounds.as_ref()).or(self.attributes.bounds.as_ref()) + { + where_constraints.clear(); + where_constraints + .push_parsed_constraint(bounds) + .map_err(|e| e.with_span(lit.span()))?; + } else { + for g in generics.iter_generics() { + where_constraints + .push_constraint(g, format!("{}::EncodedSize", crate_name)) + .unwrap(); + } + } + Ok(()) + })? + .generate_fn("encoded_size") + .with_generic_deps("__C", [format!("{}::config::Config", crate_name)]) + .with_self_arg(FnSelfArg::RefSelf) + .with_return_type(format!( + "core::result::Result", + crate_name + )) + .body(|fn_body| { + fn_body.ident_str("match"); + fn_body.ident_str("self"); + fn_body.group(Delimiter::Brace, |match_body| { + if self.variants.is_empty() { + self.encode_empty_enum_case(match_body)?; + } + for (variant_index, variant) in self.iter_fields() { + // Self::Variant + match_body.ident_str("Self"); + match_body.puncts("::"); + match_body.ident(variant.name.clone()); + + // if we have any fields, declare them here + // Self::Variant { a, b, c } + if let Some(delimiter) = variant.fields.delimiter() { + match_body.group(delimiter, |field_body| { + for (idx, field_name) in + variant.fields.names().into_iter().enumerate() + { + if idx != 0 { + field_body.punct(','); + } + field_body.push( + field_name.to_token_tree_with_prefix(TUPLE_FIELD_PREFIX), + ); + } + Ok(()) + })?; + } + + // Arrow + // Self::Variant { a, b, c } => + match_body.puncts("=>"); + + // Body of this variant + // Note that the fields are available as locals because of the match destructuring above + // { + // encoder.encode_u32(n)?; + // bincode::Encode::encode(a, encoder)?; + // bincode::Encode::encode(b, encoder)?; + // bincode::Encode::encode(c, encoder)?; + // } + match_body.group(Delimiter::Brace, |body| { + // variant index + body.push_parsed(format!("let mut __encoded_size = ::encoded_size::<__C>", crate_name))?; + body.group(Delimiter::Parenthesis, |args| { + args.punct('&'); + args.group(Delimiter::Parenthesis, |num| { + num.extend(variant_index); + Ok(()) + })?; + Ok(()) + })?; + body.punct('?'); + body.punct(';'); + // If we have any fields, add up all their sizes them all one by one + for field_name in variant.fields.names() { + let attributes = field_name + .attributes() + .get_attribute::()? + .unwrap_or_default(); + if attributes.with_serde { + body.push_parsed(format!( + "__encoded_size += {0}::EncodedSize::encoded_size::<__C>(&{0}::serde::Compat({1}))?;", + crate_name, + field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX), + ))?; + } else { + body.push_parsed(format!( + "__encoded_size += {0}::EncodedSize::encoded_size::<__C>({1})?;", + crate_name, + field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX), + ))?; + } + } + body.push_parsed("Ok(__encoded_size)")?; + Ok(()) + })?; + match_body.punct(','); + } + Ok(()) + })?; + Ok(()) + })?; + Ok(()) + } + /// If we're encoding an empty enum, we need to add an empty case in the form of: /// `_ => core::unreachable!(),` fn encode_empty_enum_case(&self, builder: &mut StreamBuilder) -> Result { diff --git a/derive/src/derive_struct.rs b/derive/src/derive_struct.rs index 1b50cf8e..61c3fdb3 100644 --- a/derive/src/derive_struct.rs +++ b/derive/src/derive_struct.rs @@ -62,6 +62,59 @@ impl DeriveStruct { Ok(()) } + pub fn generate_encoded_size(self, generator: &mut Generator) -> Result<()> { + let crate_name = &self.attributes.crate_name; + generator + .impl_for(&format!("{}::EncodedSize", crate_name)) + .modify_generic_constraints(|generics, where_constraints| { + if let Some((bounds, lit)) = + (self.attributes.encoded_size_bounds.as_ref()).or(self.attributes.bounds.as_ref()) + { + where_constraints.clear(); + where_constraints + .push_parsed_constraint(bounds) + .map_err(|e| e.with_span(lit.span()))?; + } else { + for g in generics.iter_generics() { + where_constraints + .push_constraint(g, format!("{}::EncodedSize", crate_name)) + .unwrap(); + } + } + Ok(()) + })? + .generate_fn("encoded_size") + .with_generic_deps("__C", [format!("{}::config::Config", crate_name)]) + .with_self_arg(virtue::generate::FnSelfArg::RefSelf) + .with_return_type(format!( + "core::result::Result", + crate_name + )) + .body(|fn_body| { + fn_body.push_parsed("let mut __encoded_size = 0;")?; + for field in self.fields.names() { + let attributes = field + .attributes() + .get_attribute::()? + .unwrap_or_default(); + if attributes.with_serde { + fn_body.push_parsed(format!( + "__encoded_size += {0}::EncodedSize::encoded_size::<__C>(&{0}::serde::Compat(&self.{1}))?;", + crate_name, field + ))?; + } else { + fn_body.push_parsed(format!( + "__encoded_size += {}::EncodedSize::encoded_size::<__C>(&self.{})?;", + crate_name, field + ))?; + } + } + fn_body.push_parsed("Ok(__encoded_size)")?; + Ok(()) + })?; + Ok(()) + } + pub fn generate_decode(self, generator: &mut Generator) -> Result<()> { // Remember to keep this mostly in sync with generate_borrow_decode let crate_name = &self.attributes.crate_name; diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 166bb2fd..c4976f7f 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -38,6 +38,39 @@ fn derive_encode_inner(input: TokenStream) -> Result { generator.finish() } +#[proc_macro_derive(EncodedSize, attributes(bincode))] +pub fn derive_encoded_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + derive_encoded_size_inner(input).unwrap_or_else(|e| e.into_token_stream()) +} + +fn derive_encoded_size_inner(input: TokenStream) -> Result { + let parse = Parse::new(input)?; + let (mut generator, attributes, body) = parse.into_generator(); + let attributes = attributes + .get_attribute::()? + .unwrap_or_default(); + + match body { + Body::Struct(body) => { + derive_struct::DeriveStruct { + fields: body.fields, + attributes, + } + .generate_encoded_size(&mut generator)?; + } + Body::Enum(body) => { + derive_enum::DeriveEnum { + variants: body.variants, + attributes, + } + .generate_encoded_size(&mut generator)?; + } + } + + generator.export_to_file("bincode", "EncodedSize"); + generator.finish() +} + #[proc_macro_derive(Decode, attributes(bincode))] pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream { derive_decode_inner(input).unwrap_or_else(|e| e.into_token_stream()) diff --git a/docs/migration_guide.md b/docs/migration_guide.md index 04c49c72..4b62dec5 100644 --- a/docs/migration_guide.md +++ b/docs/migration_guide.md @@ -60,7 +60,7 @@ Then replace the following functions: (`Configuration` is `bincode::config::lega ||| |`bincode::serialize(T)`|`bincode::serde::encode_to_vec(T, Configuration)`
`bincode::serde::encode_into_slice(T, &mut [u8], Configuration)`| |`bincode::serialize_into(std::io::Write, T)`|`bincode::serde::encode_into_std_write(T, std::io::Write, Configuration)`| -|`bincode::serialized_size(T)`|Currently not implemented| +|`bincode::serialized_size(T)`|`bincode::serde::encoded_size(T, Configuration)`| ## Migrating to `bincode-derive` @@ -98,7 +98,7 @@ Then replace the following functions: (`Configuration` is `bincode::config::lega ||| |`bincode::serialize(T)`|`bincode::encode_to_vec(T, Configuration)`
`bincode::encode_into_slice(t: T, &mut [u8], Configuration)`| |`bincode::serialize_into(std::io::Write, T)`|`bincode::encode_into_std_write(T, std::io::Write, Configuration)`| -|`bincode::serialized_size(T)`|Currently not implemented| +|`bincode::serialized_size(T)`|`bincode::encoded_size(T, Configuration)`| ### Bincode derive and libraries diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 44ecfac7..2c34502b 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -31,3 +31,9 @@ name = "compat" path = "fuzz_targets/compat.rs" test = false doc = false + +[[bin]] +name = "encoded_size" +path = "fuzz_targets/encoded_size.rs" +test = false +doc = false diff --git a/fuzz/fuzz_targets/encoded_size.rs b/fuzz/fuzz_targets/encoded_size.rs new file mode 100644 index 00000000..5c7707ef --- /dev/null +++ b/fuzz/fuzz_targets/encoded_size.rs @@ -0,0 +1,52 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; + +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; +use std::ffi::CString; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::num::{NonZeroI128, NonZeroI32, NonZeroU128, NonZeroU32}; +use std::path::PathBuf; +use std::rc::Rc; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; + +#[derive(bincode::Decode, bincode::Encode, bincode::EncodedSize, PartialEq, Debug)] +enum AllTypes { + BTreeMap(BTreeMap), + HashMap(HashMap), + HashSet(HashSet), + BTreeSet(BTreeSet), + VecDeque(VecDeque), + Vec(Vec), + String(String), + Box(Box), + BoxSlice(Box<[AllTypes]>), + Rc(Rc), + Arc(Arc), + CString(CString), + SystemTime(SystemTime), + Duration(Duration), + PathBuf(PathBuf), + IpAddr(IpAddr), + Ipv4Addr(Ipv4Addr), + Ipv6Addr(Ipv6Addr), + SocketAddr(SocketAddr), + SocketAddrV4(SocketAddrV4), + SocketAddrV6(SocketAddrV6), + NonZeroU32(NonZeroU32), + NonZeroI32(NonZeroI32), + NonZeroU128(NonZeroU128), + NonZeroI128(NonZeroI128), + // Cow(Cow<'static, [u8]>), Blocked, see comment on decode +} + +fuzz_target!(|data: &[u8]| { + let config = bincode::config::standard().with_limit::<1024>(); + let result: Result<(AllTypes, _), _> = bincode::decode_from_slice(data, config); + + if let Ok((value, _)) = result { + let encoded_size = bincode::encoded_size(&value, config).expect("encoded size"); + let encoded = bincode::encode_to_vec(&value, config).expect("round trip"); + assert_eq!(encoded_size, encoded.len()); + } +}); diff --git a/src/atomic.rs b/src/atomic.rs index 1b702520..7c75933c 100644 --- a/src/atomic.rs +++ b/src/atomic.rs @@ -1,4 +1,4 @@ -use crate::{de::Decode, enc::Encode, impl_borrow_decode}; +use crate::{de::Decode, enc::Encode, impl_borrow_decode, size::EncodedSize}; use core::sync::atomic::Ordering; #[cfg(target_has_atomic = "ptr")] @@ -26,6 +26,13 @@ impl Encode for AtomicBool { } } +#[cfg(target_has_atomic = "8")] +impl EncodedSize for AtomicBool { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "8")] impl Decode for AtomicBool { fn decode(decoder: &mut D) -> Result { @@ -45,6 +52,13 @@ impl Encode for AtomicU8 { } } +#[cfg(target_has_atomic = "8")] +impl EncodedSize for AtomicU8 { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "8")] impl Decode for AtomicU8 { fn decode(decoder: &mut D) -> Result { @@ -64,6 +78,13 @@ impl Encode for AtomicU16 { } } +#[cfg(target_has_atomic = "16")] +impl EncodedSize for AtomicU16 { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "16")] impl Decode for AtomicU16 { fn decode(decoder: &mut D) -> Result { @@ -83,6 +104,13 @@ impl Encode for AtomicU32 { } } +#[cfg(target_has_atomic = "32")] +impl EncodedSize for AtomicU32 { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "32")] impl Decode for AtomicU32 { fn decode(decoder: &mut D) -> Result { @@ -102,6 +130,13 @@ impl Encode for AtomicU64 { } } +#[cfg(target_has_atomic = "64")] +impl EncodedSize for AtomicU64 { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "64")] impl Decode for AtomicU64 { fn decode(decoder: &mut D) -> Result { @@ -121,6 +156,13 @@ impl Encode for AtomicUsize { } } +#[cfg(target_has_atomic = "ptr")] +impl EncodedSize for AtomicUsize { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "ptr")] impl Decode for AtomicUsize { fn decode(decoder: &mut D) -> Result { @@ -140,6 +182,13 @@ impl Encode for AtomicI8 { } } +#[cfg(target_has_atomic = "8")] +impl EncodedSize for AtomicI8 { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "8")] impl Decode for AtomicI8 { fn decode(decoder: &mut D) -> Result { @@ -159,6 +208,13 @@ impl Encode for AtomicI16 { } } +#[cfg(target_has_atomic = "16")] +impl EncodedSize for AtomicI16 { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "16")] impl Decode for AtomicI16 { fn decode(decoder: &mut D) -> Result { @@ -178,6 +234,13 @@ impl Encode for AtomicI32 { } } +#[cfg(target_has_atomic = "32")] +impl EncodedSize for AtomicI32 { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "32")] impl Decode for AtomicI32 { fn decode(decoder: &mut D) -> Result { @@ -197,6 +260,13 @@ impl Encode for AtomicI64 { } } +#[cfg(target_has_atomic = "64")] +impl EncodedSize for AtomicI64 { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "64")] impl Decode for AtomicI64 { fn decode(decoder: &mut D) -> Result { @@ -216,6 +286,13 @@ impl Encode for AtomicIsize { } } +#[cfg(target_has_atomic = "ptr")] +impl EncodedSize for AtomicIsize { + fn encoded_size(&self) -> Result { + self.load(Ordering::SeqCst).encoded_size::() + } +} + #[cfg(target_has_atomic = "ptr")] impl Decode for AtomicIsize { fn decode(decoder: &mut D) -> Result { diff --git a/src/features/derive.rs b/src/features/derive.rs index 1d07ba11..cab18b20 100644 --- a/src/features/derive.rs +++ b/src/features/derive.rs @@ -1,2 +1,2 @@ #[cfg_attr(docsrs, doc(cfg(feature = "derive")))] -pub use bincode_derive::{BorrowDecode, Decode, Encode}; +pub use bincode_derive::{BorrowDecode, Decode, Encode, EncodedSize}; diff --git a/src/features/impl_alloc.rs b/src/features/impl_alloc.rs index b705de96..aedea223 100644 --- a/src/features/impl_alloc.rs +++ b/src/features/impl_alloc.rs @@ -2,7 +2,9 @@ use crate::{ de::{BorrowDecoder, Decode, Decoder}, enc::{self, Encode, Encoder}, error::{DecodeError, EncodeError}, - impl_borrow_decode, BorrowDecode, Config, + impl_borrow_decode, + size::EncodedSize, + BorrowDecode, Config, }; #[cfg(target_has_atomic = "ptr")] use alloc::sync::Arc; @@ -98,6 +100,19 @@ where } } +impl EncodedSize for BinaryHeap +where + T: EncodedSize + Ord, +{ + fn encoded_size(&self) -> Result { + let mut size = crate::size::size_slice_len::(self.len())?; + for val in self.iter() { + size += val.encoded_size::()?; + } + Ok(size) + } +} + impl Decode for BTreeMap where K: Decode + Ord, @@ -156,6 +171,21 @@ where } } +impl EncodedSize for BTreeMap +where + K: EncodedSize + Ord, + V: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let mut size = crate::size::size_slice_len::(self.len())?; + for (key, val) in self.iter() { + size += key.encoded_size::()?; + size += val.encoded_size::()?; + } + Ok(size) + } +} + impl Decode for BTreeSet where T: Decode + Ord, @@ -208,6 +238,19 @@ where } } +impl EncodedSize for BTreeSet +where + T: EncodedSize + Ord, +{ + fn encoded_size(&self) -> Result { + let mut size = crate::size::size_slice_len::(self.len())?; + for item in self.iter() { + size += item.encoded_size::()?; + } + Ok(size) + } +} + impl Decode for VecDeque where T: Decode, @@ -260,6 +303,19 @@ where } } +impl EncodedSize for VecDeque +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let mut size = crate::size::size_slice_len::(self.len())?; + for item in self.iter() { + size += item.encoded_size::()?; + } + Ok(size) + } +} + impl Decode for Vec where T: Decode, @@ -311,6 +367,19 @@ where } } +impl EncodedSize for Vec +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let mut size = crate::size::size_slice_len::(self.len())?; + for item in self.iter() { + size += item.encoded_size::()?; + } + Ok(size) + } +} + impl Decode for String { fn decode(decoder: &mut D) -> Result { let bytes = Vec::::decode(decoder)?; @@ -334,6 +403,12 @@ impl Encode for String { } } +impl EncodedSize for String { + fn encoded_size(&self) -> Result { + self.as_bytes().encoded_size::() + } +} + impl Decode for Box where T: Decode, @@ -362,6 +437,15 @@ where } } +impl EncodedSize for Box +where + T: EncodedSize + ?Sized, +{ + fn encoded_size(&self) -> Result { + T::encoded_size::(self) + } +} + impl Decode for Box<[T]> where T: Decode, @@ -413,6 +497,16 @@ where } } +impl<'cow, T> EncodedSize for Cow<'cow, T> +where + T: ToOwned + ?Sized, + for<'a> &'a T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + self.as_ref().encoded_size::() + } +} + impl Decode for Rc where T: Decode, @@ -442,6 +536,15 @@ where } } +impl EncodedSize for Rc +where + T: EncodedSize + ?Sized, +{ + fn encoded_size(&self) -> Result { + T::encoded_size::(self) + } +} + impl Decode for Rc<[T]> where T: Decode, @@ -510,6 +613,16 @@ where } } +#[cfg(target_has_atomic = "ptr")] +impl EncodedSize for Arc +where + T: EncodedSize + ?Sized, +{ + fn encoded_size(&self) -> Result { + T::encoded_size::(self) + } +} + #[cfg(target_has_atomic = "ptr")] impl Decode for Arc<[T]> where diff --git a/src/features/impl_std.rs b/src/features/impl_std.rs index 96721053..b8e404a0 100644 --- a/src/features/impl_std.rs +++ b/src/features/impl_std.rs @@ -4,6 +4,7 @@ use crate::{ enc::{write::Writer, Encode, Encoder, EncoderImpl}, error::{DecodeError, EncodeError}, impl_borrow_decode, + size::EncodedSize, }; use core::time::Duration; use std::{ @@ -134,12 +135,24 @@ impl<'a> Encode for &'a CStr { } } +impl<'a> EncodedSize for &'a CStr { + fn encoded_size(&self) -> Result { + self.to_bytes().encoded_size::() + } +} + impl Encode for CString { fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { self.as_bytes().encode(encoder) } } +impl EncodedSize for CString { + fn encoded_size(&self) -> Result { + self.as_bytes().encoded_size::() + } +} + impl Decode for CString { fn decode(decoder: &mut D) -> Result { let vec = std::vec::Vec::decode(decoder)?; @@ -162,6 +175,18 @@ where } } +impl EncodedSize for Mutex +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let t = self.lock().map_err(|_| EncodeError::LockFailed { + type_name: core::any::type_name::>(), + })?; + t.encoded_size::() + } +} + impl Decode for Mutex where T: Decode, @@ -193,6 +218,18 @@ where } } +impl EncodedSize for RwLock +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let t = self.read().map_err(|_| EncodeError::LockFailed { + type_name: core::any::type_name::>(), + })?; + t.encoded_size::() + } +} + impl Decode for RwLock where T: Decode, @@ -224,6 +261,18 @@ impl Encode for SystemTime { } } +impl EncodedSize for SystemTime { + fn encoded_size(&self) -> Result { + let duration = self.duration_since(SystemTime::UNIX_EPOCH).map_err(|e| { + EncodeError::InvalidSystemTime { + inner: e, + time: std::boxed::Box::new(*self), + } + })?; + duration.encoded_size::() + } +} + impl Decode for SystemTime { fn decode(decoder: &mut D) -> Result { let duration = Duration::decode(decoder)?; @@ -244,6 +293,15 @@ impl Encode for &'_ Path { } } +impl EncodedSize for &'_ Path { + fn encoded_size(&self) -> Result { + match self.to_str() { + Some(s) => s.encoded_size::(), + None => Err(EncodeError::InvalidPathCharacters), + } + } +} + impl<'de> BorrowDecode<'de> for &'de Path { fn borrow_decode>(decoder: &mut D) -> Result { let str = <&'de str>::borrow_decode(decoder)?; @@ -257,6 +315,12 @@ impl Encode for PathBuf { } } +impl EncodedSize for PathBuf { + fn encoded_size(&self) -> Result { + self.as_path().encoded_size::() + } +} + impl Decode for PathBuf { fn decode(decoder: &mut D) -> Result { let string = std::string::String::decode(decoder)?; @@ -280,6 +344,15 @@ impl Encode for IpAddr { } } +impl EncodedSize for IpAddr { + fn encoded_size(&self) -> Result { + match self { + IpAddr::V4(v4) => Ok(0u32.encoded_size::()? + v4.encoded_size::()?), + IpAddr::V6(v6) => Ok(1u32.encoded_size::()? + v6.encoded_size::()?), + } + } +} + impl Decode for IpAddr { fn decode(decoder: &mut D) -> Result { match u32::decode(decoder)? { @@ -301,6 +374,12 @@ impl Encode for Ipv4Addr { } } +impl EncodedSize for Ipv4Addr { + fn encoded_size(&self) -> Result { + Ok(self.octets().len()) + } +} + impl Decode for Ipv4Addr { fn decode(decoder: &mut D) -> Result { let mut buff = [0u8; 4]; @@ -316,6 +395,12 @@ impl Encode for Ipv6Addr { } } +impl EncodedSize for Ipv6Addr { + fn encoded_size(&self) -> Result { + Ok(self.octets().len()) + } +} + impl Decode for Ipv6Addr { fn decode(decoder: &mut D) -> Result { let mut buff = [0u8; 16]; @@ -340,6 +425,15 @@ impl Encode for SocketAddr { } } +impl EncodedSize for SocketAddr { + fn encoded_size(&self) -> Result { + match self { + SocketAddr::V4(v4) => Ok(0u32.encoded_size::()? + v4.encoded_size::()?), + SocketAddr::V6(v6) => Ok(1u32.encoded_size::()? + v6.encoded_size::()?), + } + } +} + impl Decode for SocketAddr { fn decode(decoder: &mut D) -> Result { match u32::decode(decoder)? { @@ -362,6 +456,12 @@ impl Encode for SocketAddrV4 { } } +impl EncodedSize for SocketAddrV4 { + fn encoded_size(&self) -> Result { + Ok(self.ip().encoded_size::()? + self.port().encoded_size::()?) + } +} + impl Decode for SocketAddrV4 { fn decode(decoder: &mut D) -> Result { let ip = Ipv4Addr::decode(decoder)?; @@ -378,6 +478,12 @@ impl Encode for SocketAddrV6 { } } +impl EncodedSize for SocketAddrV6 { + fn encoded_size(&self) -> Result { + Ok(self.ip().encoded_size::()? + self.port().encoded_size::()?) + } +} + impl Decode for SocketAddrV6 { fn decode(decoder: &mut D) -> Result { let ip = Ipv6Addr::decode(decoder)?; @@ -421,6 +527,21 @@ where } } +impl EncodedSize for HashMap +where + K: EncodedSize, + V: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let mut size = crate::size::size_slice_len::(self.len())?; + for (k, v) in self.iter() { + size += k.encoded_size::()?; + size += v.encoded_size::()?; + } + Ok(size) + } +} + impl Decode for HashMap where K: Decode + Eq + std::hash::Hash, @@ -523,3 +644,16 @@ where Ok(()) } } + +impl EncodedSize for HashSet +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let mut size = crate::size::size_slice_len::(self.len())?; + for item in self.iter() { + size += item.encoded_size::()?; + } + Ok(size) + } +} diff --git a/src/features/serde/mod.rs b/src/features/serde/mod.rs index 04a50699..3ef3e531 100644 --- a/src/features/serde/mod.rs +++ b/src/features/serde/mod.rs @@ -58,10 +58,12 @@ mod de_borrowed; mod de_owned; mod ser; +mod size; pub use self::de_borrowed::*; pub use self::de_owned::*; pub use self::ser::*; +pub use self::size::*; /// A serde-specific error that occurred while decoding. #[derive(Debug)] @@ -222,6 +224,18 @@ where } } +impl crate::EncodedSize for Compat +where + T: serde::Serialize, +{ + fn encoded_size(&self) -> Result { + let mut encoded_size = 0; + let serializer = size::SerdeEncodedSize::<'_, C>::new(&mut encoded_size); + self.0.serialize(serializer)?; + Ok(encoded_size) + } +} + /// Wrapper struct that implements [BorrowDecode] and [Encode] on any type that implements serde's [Deserialize] and [Serialize] respectively. This is mostly used on `&[u8]` and `&str`, for other types consider using [Compat] instead. /// /// [BorrowDecode]: ../de/trait.BorrowDecode.html diff --git a/src/features/serde/size.rs b/src/features/serde/size.rs new file mode 100644 index 00000000..d84ab597 --- /dev/null +++ b/src/features/serde/size.rs @@ -0,0 +1,425 @@ +use super::EncodeError as SerdeEncodeError; +use crate::{config::Config, error::EncodeError, size::EncodedSize}; +use core::marker::PhantomData; +use serde::ser::*; + +/// Calculate the encoded size for the given value. +pub fn encoded_size(t: T, _config: C) -> Result +where + T: Serialize, + C: Config, +{ + if C::SKIP_FIXED_ARRAY_LENGTH { + return Err(SerdeEncodeError::SkipFixedArrayLengthNotSupported.into()); + } + let mut encoded_size: usize = 0; + let serializer = SerdeEncodedSize::<'_, C>::new(&mut encoded_size); + t.serialize(serializer)?; + Ok(encoded_size) +} + +pub(super) struct SerdeEncodedSize<'a, C: Config> { + encoded_size: &'a mut usize, + config: PhantomData, +} + +impl<'a, C> SerdeEncodedSize<'a, C> +where + C: Config, +{ + pub(super) fn new(encoded_size: &'a mut usize) -> Self { + SerdeEncodedSize { + encoded_size, + config: PhantomData, + } + } +} + +impl<'a, C> Serializer for SerdeEncodedSize<'a, C> +where + C: Config, +{ + type Ok = (); + + type Error = EncodeError; + + type SerializeSeq = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + type SerializeMap = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + + fn serialize_bool(self, v: bool) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_i8(self, v: i8) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_i16(self, v: i16) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_i32(self, v: i32) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_i64(self, v: i64) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + serde::serde_if_integer128! { + fn serialize_i128(self, v: i128) -> Result { + *self.encoded_size += v.encoded_size::()?; Ok(()) + } + } + + fn serialize_u8(self, v: u8) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_u16(self, v: u16) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_u32(self, v: u32) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_u64(self, v: u64) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + serde::serde_if_integer128! { + fn serialize_u128(self, v: u128) -> Result { + *self.encoded_size += v.encoded_size::()?; Ok(()) + } + } + + fn serialize_f32(self, v: f32) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_f64(self, v: f64) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_char(self, v: char) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_str(self, v: &str) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + *self.encoded_size += v.encoded_size::()?; + Ok(()) + } + + fn serialize_none(self) -> Result { + *self.encoded_size += 0u8.encoded_size::()?; + Ok(()) + } + + fn serialize_some(self, value: &T) -> Result + where + T: Serialize, + { + *self.encoded_size += 1u8.encoded_size::()?; + value.serialize(self) + } + + fn serialize_unit(self) -> Result { + Ok(()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Ok(()) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + ) -> Result { + *self.encoded_size += variant_index.encoded_size::()?; + Ok(()) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result + where + T: Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + value: &T, + ) -> Result + where + T: Serialize, + { + *self.encoded_size += variant_index.encoded_size::()?; + value.serialize(self) + } + + fn serialize_seq(self, len: Option) -> Result { + let len = len.ok_or_else(|| SerdeEncodeError::SequenceMustHaveLength.into())?; + *self.encoded_size += len.encoded_size::()?; + Ok(Compound { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn serialize_tuple(self, _: usize) -> Result { + Ok(self) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Ok(self) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + *self.encoded_size += variant_index.encoded_size::()?; + Ok(Compound { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn serialize_map(self, len: Option) -> Result { + let len = len.ok_or_else(|| SerdeEncodeError::SequenceMustHaveLength.into())?; + *self.encoded_size += len.encoded_size::()?; + Ok(Compound { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Ok(Compound { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + *self.encoded_size += variant_index.encoded_size::()?; + Ok(Compound { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + #[cfg(not(feature = "alloc"))] + fn collect_str(self, _: &T) -> Result + where + T: core::fmt::Display, + { + Err(SerdeEncodeError::CannotCollectStr.into()) + } + + fn is_human_readable(&self) -> bool { + false + } +} + +type Compound<'a, C> = SerdeEncodedSize<'a, C>; + +impl<'a, C: Config> SerializeSeq for Compound<'a, C> { + type Ok = (); + type Error = EncodeError; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + value.serialize(SerdeEncodedSize { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, C: Config> SerializeTuple for Compound<'a, C> { + type Ok = (); + type Error = EncodeError; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + value.serialize(SerdeEncodedSize { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, C: Config> SerializeTupleStruct for Compound<'a, C> { + type Ok = (); + type Error = EncodeError; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + value.serialize(SerdeEncodedSize { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, C: Config> SerializeTupleVariant for Compound<'a, C> { + type Ok = (); + type Error = EncodeError; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + value.serialize(SerdeEncodedSize { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, C: Config> SerializeMap for Compound<'a, C> { + type Ok = (); + type Error = EncodeError; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + key.serialize(SerdeEncodedSize { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + value.serialize(SerdeEncodedSize { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, C: Config> SerializeStruct for Compound<'a, C> { + type Ok = (); + type Error = EncodeError; + + fn serialize_field( + &mut self, + _key: &'static str, + value: &T, + ) -> Result<(), Self::Error> + where + T: Serialize, + { + value.serialize(SerdeEncodedSize { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, C: Config> SerializeStructVariant for Compound<'a, C> { + type Ok = (); + type Error = EncodeError; + + fn serialize_field( + &mut self, + _key: &'static str, + value: &T, + ) -> Result<(), Self::Error> + where + T: Serialize, + { + value.serialize(SerdeEncodedSize { + encoded_size: self.encoded_size, + config: self.config, + }) + } + + fn end(self) -> Result { + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 5bc75860..10ae3326 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,10 +94,12 @@ pub mod config; pub mod de; pub mod enc; pub mod error; +pub mod size; pub use atomic::*; pub use de::{BorrowDecode, Decode}; pub use enc::Encode; +pub use size::EncodedSize; use config::Config; @@ -177,6 +179,14 @@ pub fn decode_from_reader( D::decode(&mut decoder) } +/// Determine the encoded size of a value. +pub fn encoded_size( + val: E, + _config: C, +) -> Result { + val.encoded_size::() +} + // TODO: Currently our doctests fail when trying to include the specs because the specs depend on `derive` and `alloc`. // But we want to have the specs in the docs always #[cfg(all(feature = "alloc", feature = "derive", doc))] diff --git a/src/size/impl_tuples.rs b/src/size/impl_tuples.rs new file mode 100644 index 00000000..2e58bb49 --- /dev/null +++ b/src/size/impl_tuples.rs @@ -0,0 +1,389 @@ +use super::EncodedSize; +use crate::config::Config; +use crate::error::EncodeError; + +impl EncodedSize for (A,) +where + A: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + self.0.encoded_size::<_C>() + } +} + +impl EncodedSize for (A, B) +where + A: EncodedSize, + B: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + self.1.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D, E) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D, E, F) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D, E, F, G) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D, E, F, G, H) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, + H: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()? + + self.7.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D, E, F, G, H, I) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, + H: EncodedSize, + I: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()? + + self.7.encoded_size::<_C>()? + + self.8.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D, E, F, G, H, I, J) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, + H: EncodedSize, + I: EncodedSize, + J: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()? + + self.7.encoded_size::<_C>()? + + self.8.encoded_size::<_C>()? + + self.9.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D, E, F, G, H, I, J, K) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, + H: EncodedSize, + I: EncodedSize, + J: EncodedSize, + K: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()? + + self.7.encoded_size::<_C>()? + + self.8.encoded_size::<_C>()? + + self.9.encoded_size::<_C>()? + + self.10.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D, E, F, G, H, I, J, K, L) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, + H: EncodedSize, + I: EncodedSize, + J: EncodedSize, + K: EncodedSize, + L: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()? + + self.7.encoded_size::<_C>()? + + self.8.encoded_size::<_C>()? + + self.9.encoded_size::<_C>()? + + self.10.encoded_size::<_C>()? + + self.11.encoded_size::<_C>()?) + } +} + +impl EncodedSize for (A, B, C, D, E, F, G, H, I, J, K, L, M) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, + H: EncodedSize, + I: EncodedSize, + J: EncodedSize, + K: EncodedSize, + L: EncodedSize, + M: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()? + + self.7.encoded_size::<_C>()? + + self.8.encoded_size::<_C>()? + + self.9.encoded_size::<_C>()? + + self.10.encoded_size::<_C>()? + + self.11.encoded_size::<_C>()? + + self.12.encoded_size::<_C>()?) + } +} + +impl EncodedSize + for (A, B, C, D, E, F, G, H, I, J, K, L, M, N) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, + H: EncodedSize, + I: EncodedSize, + J: EncodedSize, + K: EncodedSize, + L: EncodedSize, + M: EncodedSize, + N: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()? + + self.7.encoded_size::<_C>()? + + self.8.encoded_size::<_C>()? + + self.9.encoded_size::<_C>()? + + self.10.encoded_size::<_C>()? + + self.11.encoded_size::<_C>()? + + self.12.encoded_size::<_C>()? + + self.13.encoded_size::<_C>()?) + } +} + +impl EncodedSize + for (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, + H: EncodedSize, + I: EncodedSize, + J: EncodedSize, + K: EncodedSize, + L: EncodedSize, + M: EncodedSize, + N: EncodedSize, + O: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()? + + self.7.encoded_size::<_C>()? + + self.8.encoded_size::<_C>()? + + self.9.encoded_size::<_C>()? + + self.10.encoded_size::<_C>()? + + self.11.encoded_size::<_C>()? + + self.12.encoded_size::<_C>()? + + self.13.encoded_size::<_C>()? + + self.14.encoded_size::<_C>()?) + } +} + +impl EncodedSize + for (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P) +where + A: EncodedSize, + B: EncodedSize, + C: EncodedSize, + D: EncodedSize, + E: EncodedSize, + F: EncodedSize, + G: EncodedSize, + H: EncodedSize, + I: EncodedSize, + J: EncodedSize, + K: EncodedSize, + L: EncodedSize, + M: EncodedSize, + N: EncodedSize, + O: EncodedSize, + P: EncodedSize, +{ + fn encoded_size<_C: Config>(&self) -> Result { + Ok(self.0.encoded_size::<_C>()? + + self.1.encoded_size::<_C>()? + + self.2.encoded_size::<_C>()? + + self.3.encoded_size::<_C>()? + + self.4.encoded_size::<_C>()? + + self.5.encoded_size::<_C>()? + + self.6.encoded_size::<_C>()? + + self.7.encoded_size::<_C>()? + + self.8.encoded_size::<_C>()? + + self.9.encoded_size::<_C>()? + + self.10.encoded_size::<_C>()? + + self.11.encoded_size::<_C>()? + + self.12.encoded_size::<_C>()? + + self.13.encoded_size::<_C>()? + + self.14.encoded_size::<_C>()? + + self.15.encoded_size::<_C>()?) + } +} diff --git a/src/size/impls.rs b/src/size/impls.rs new file mode 100644 index 00000000..d769b6a2 --- /dev/null +++ b/src/size/impls.rs @@ -0,0 +1,384 @@ +use super::EncodedSize; +use crate::{ + config::{Config, IntEncoding}, + error::EncodeError, +}; +use core::{ + cell::{Cell, RefCell}, + marker::PhantomData, + num::{ + NonZeroI128, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroIsize, NonZeroU128, + NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize, + }, + ops::{Bound, Range, RangeInclusive}, + time::Duration, +}; + +impl EncodedSize for () { + fn encoded_size(&self) -> Result { + Ok(0) + } +} + +impl EncodedSize for PhantomData { + fn encoded_size(&self) -> Result { + Ok(0) + } +} + +impl EncodedSize for bool { + fn encoded_size(&self) -> Result { + u8::from(*self).encoded_size::() + } +} + +impl EncodedSize for u8 { + fn encoded_size(&self) -> Result { + Ok(1) + } +} + +impl EncodedSize for NonZeroU8 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for u16 { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_u16(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroU16 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for u32 { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_u32(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroU32 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for u64 { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_u64(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroU64 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for u128 { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_u128(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroU128 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for usize { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_usize(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroUsize { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for i8 { + fn encoded_size(&self) -> Result { + Ok(1) + } +} + +impl EncodedSize for NonZeroI8 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for i16 { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_i16(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroI16 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for i32 { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_i32(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroI32 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for i64 { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_i64(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroI64 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for i128 { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_i128(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroI128 { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for isize { + fn encoded_size(&self) -> Result { + match C::INT_ENCODING { + IntEncoding::Variable => Ok(crate::varint::varint_size_isize(*self)), + IntEncoding::Fixed => Ok(std::mem::size_of::()), + } + } +} + +impl EncodedSize for NonZeroIsize { + fn encoded_size(&self) -> Result { + self.get().encoded_size::() + } +} + +impl EncodedSize for f32 { + fn encoded_size(&self) -> Result { + Ok(std::mem::size_of::()) + } +} + +impl EncodedSize for f64 { + fn encoded_size(&self) -> Result { + Ok(std::mem::size_of::()) + } +} + +impl EncodedSize for char { + fn encoded_size(&self) -> Result { + Ok(encoded_size_utf8(*self)) + } +} + +// BlockedTODO: https://github.com/rust-lang/rust/issues/37653 +// +// We'll want to implement encoding for both &[u8] and &[T: EncodedSizedSize], +// but those implementations overlap because u8 also implements EncodedSize +// impl EncodedSize for &'_ [u8] { +// fn encoded_size(&self) -> Result { +// encoder.writer().write(*self) +// } +// } + +impl EncodedSize for [T] +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let mut size = super::size_slice_len::(self.len())?; + for item in self { + size += item.encoded_size::()?; + } + Ok(size) + } +} + +const MAX_ONE_B: u32 = 0x80; +const MAX_TWO_B: u32 = 0x800; +const MAX_THREE_B: u32 = 0x10000; + +fn encoded_size_utf8(c: char) -> usize { + let code = c as u32; + + if code < MAX_ONE_B { + 1 + } else if code < MAX_TWO_B { + 2 + } else if code < MAX_THREE_B { + 3 + } else { + 4 + } +} + +impl EncodedSize for str { + fn encoded_size(&self) -> Result { + self.as_bytes().encoded_size::() + } +} + +impl EncodedSize for [T; N] +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let mut size = 0; + if !C::SKIP_FIXED_ARRAY_LENGTH { + size += super::size_slice_len::(N)?; + } + for item in self.iter() { + size += item.encoded_size::()?; + } + Ok(size) + } +} + +impl EncodedSize for Option +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + let mut size = 1; + if let Some(val) = self { + size += val.encoded_size::()?; + } + Ok(size) + } +} + +impl EncodedSize for Result +where + T: EncodedSize, + U: EncodedSize, +{ + fn encoded_size(&self) -> Result { + match self { + Ok(val) => Ok(0u32.encoded_size::()? + val.encoded_size::()?), + Err(err) => Ok(1u32.encoded_size::()? + err.encoded_size::()?), + } + } +} + +impl EncodedSize for Cell +where + T: EncodedSize + Copy, +{ + fn encoded_size(&self) -> Result { + T::encoded_size::(&self.get()) + } +} + +impl EncodedSize for RefCell +where + T: EncodedSize + ?Sized, +{ + fn encoded_size(&self) -> Result { + let borrow_guard = self + .try_borrow() + .map_err(|e| EncodeError::RefCellAlreadyBorrowed { + inner: e, + type_name: core::any::type_name::>(), + })?; + T::encoded_size::(&borrow_guard) + } +} + +impl EncodedSize for Duration { + fn encoded_size(&self) -> Result { + Ok(self.as_secs().encoded_size::()? + self.subsec_nanos().encoded_size::()?) + } +} + +impl EncodedSize for Range +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + Ok(self.start.encoded_size::()? + self.end.encoded_size::()?) + } +} + +impl EncodedSize for RangeInclusive +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + Ok(self.start().encoded_size::()? + self.end().encoded_size::()?) + } +} + +impl EncodedSize for Bound +where + T: EncodedSize, +{ + fn encoded_size(&self) -> Result { + match self { + Self::Unbounded => 0u32.encoded_size::(), + Self::Included(val) => Ok(1u32.encoded_size::()? + val.encoded_size::()?), + Self::Excluded(val) => Ok(2u32.encoded_size::()? + val.encoded_size::()?), + } + } +} + +impl<'a, T> EncodedSize for &'a T +where + T: EncodedSize + ?Sized, +{ + fn encoded_size(&self) -> Result { + T::encoded_size::(self) + } +} diff --git a/src/size/mod.rs b/src/size/mod.rs new file mode 100644 index 00000000..f8f28c03 --- /dev/null +++ b/src/size/mod.rs @@ -0,0 +1,58 @@ +//! Size determination structs and traits. + +mod impl_tuples; +mod impls; + +use crate::config::Config; +use crate::error::EncodeError; + +/// Any source whose size when encoded can be determined. +/// +/// This trait should be implemented for all types that you want to be able to determine the encoded size ahead of actual encoding. +/// +/// # Implementing this trait manually +/// +/// If you want to implement this trait for your type, the easiest way is to add a `#[derive(bincode::EncodedSize)]`, build and check your `target/generated/bincode/` folder. This should generate a `_EncodedSize.rs` file. +/// +/// For this struct: +/// +/// ``` +/// struct Entity { +/// pub x: f32, +/// pub y: f32, +/// } +/// ``` +/// It will look something like: +/// +/// ``` +/// # struct Entity { +/// # pub x: f32, +/// # pub y: f32, +/// # } +/// impl bincode::EncodedSize for Entity { +/// fn encoded_size( +/// &self, +/// ) -> core::result::Result { +/// let mut __encoded_size = 0; +/// __encoded_size += bincode::EncodedSize::encoded_size::(&self.x)?; +/// __encoded_size += bincode::EncodedSize::encoded_size::(&self.y)?; +/// Ok(__encoded_size) +/// } +/// } +/// ``` +/// +/// From here you can add/remove fields, or add custom logic. +/// +/// # Interior Mutability +/// +/// Types with interior mutability may be mutated between calls to `encoded_size` and one of the `encode` methods. If this happens, the encoded size may change. You must ensure that your encoded values are either not mutated between calls to `encoded_size` and `encode`, or handle the case where the actual encoded size is large than the value that `encoded_size` returns. +pub trait EncodedSize { + /// Determine the encoded size of a given type. + fn encoded_size(&self) -> Result; +} + +/// Returns the size of the encoded length of any slice, container, etc. +#[inline] +pub(crate) fn size_slice_len(len: usize) -> Result { + (len as u64).encoded_size::() +} diff --git a/src/varint/mod.rs b/src/varint/mod.rs index afadfb50..af7a73cd 100644 --- a/src/varint/mod.rs +++ b/src/varint/mod.rs @@ -2,6 +2,8 @@ mod decode_signed; mod decode_unsigned; mod encode_signed; mod encode_unsigned; +mod size_signed; +mod size_unsigned; pub use self::{ decode_signed::{ @@ -20,6 +22,12 @@ pub use self::{ varint_encode_u128, varint_encode_u16, varint_encode_u32, varint_encode_u64, varint_encode_usize, }, + size_signed::{ + varint_size_i128, varint_size_i16, varint_size_i32, varint_size_i64, varint_size_isize, + }, + size_unsigned::{ + varint_size_u128, varint_size_u16, varint_size_u32, varint_size_u64, varint_size_usize, + }, }; pub(self) const SINGLE_BYTE_MAX: u8 = 250; @@ -27,3 +35,5 @@ pub(self) const U16_BYTE: u8 = 251; pub(self) const U32_BYTE: u8 = 252; pub(self) const U64_BYTE: u8 = 253; pub(self) const U128_BYTE: u8 = 254; +pub(self) const SIGNED_SINGLE_BYTE_MIN: i8 = -125; +pub(self) const SIGNED_SINGLE_BYTE_MAX: i8 = 125; diff --git a/src/varint/size_signed.rs b/src/varint/size_signed.rs new file mode 100644 index 00000000..efdfe503 --- /dev/null +++ b/src/varint/size_signed.rs @@ -0,0 +1,228 @@ +use super::{SIGNED_SINGLE_BYTE_MAX, SIGNED_SINGLE_BYTE_MIN}; + +// Convenicence macro to specify a range with a specific type. +macro_rules! range { + ($min:path, $max:path as $t:ty) => { + (($min as $t) ..= ($max as $t)) + }; +} + +pub fn varint_size_i16(val: i16) -> usize { + if range!(SIGNED_SINGLE_BYTE_MIN, SIGNED_SINGLE_BYTE_MAX as i16).contains(&val) { + 1 + } else { + 1 + std::mem::size_of::() + } +} + +pub fn varint_size_i32(val: i32) -> usize { + if range!(SIGNED_SINGLE_BYTE_MIN, SIGNED_SINGLE_BYTE_MAX as i32).contains(&val) { + 1 + } else if range!(i16::MIN, i16::MAX as i32).contains(&val) { + 1 + std::mem::size_of::() + } else { + 1 + std::mem::size_of::() + } +} + +pub fn varint_size_i64(val: i64) -> usize { + if range!(SIGNED_SINGLE_BYTE_MIN, SIGNED_SINGLE_BYTE_MAX as i64).contains(&val) { + 1 + } else if range!(i16::MIN, i16::MAX as i64).contains(&val) { + 1 + std::mem::size_of::() + } else if range!(i32::MIN, i32::MAX as i64).contains(&val) { + 1 + std::mem::size_of::() + } else { + 1 + std::mem::size_of::() + } +} + +pub fn varint_size_i128(val: i128) -> usize { + if range!(SIGNED_SINGLE_BYTE_MIN, SIGNED_SINGLE_BYTE_MAX as i128).contains(&val) { + 1 + } else if range!(i16::MIN, i16::MAX as i128).contains(&val) { + 1 + std::mem::size_of::() + } else if range!(i32::MIN, i32::MAX as i128).contains(&val) { + 1 + std::mem::size_of::() + } else if range!(i64::MIN, i64::MAX as i128).contains(&val) { + 1 + std::mem::size_of::() + } else { + 1 + std::mem::size_of::() + } +} + +pub fn varint_size_isize(val: isize) -> usize { + // isize is being encoded as a i64 + varint_size_i64(val as i64) +} + +#[test] +fn test_size_i16() { + // these should all encode to a single byte + for i in range!(SIGNED_SINGLE_BYTE_MIN, SIGNED_SINGLE_BYTE_MAX as i16) { + assert_eq!(varint_size_i16(i), 1, "value: {}", i); + } + + // these values should encode in 3 bytes (leading byte + 2 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i16::MIN, + -1000, + -200, + SIGNED_SINGLE_BYTE_MIN as i16 - 1, + SIGNED_SINGLE_BYTE_MAX as i16 + 1, + 222, + 1234, + i16::MAX, + ] { + assert_eq!(varint_size_i16(i), 3, "value: {}", i); + } +} + +#[test] +fn test_size_i32() { + // these should all encode to a single byte + for i in range!(SIGNED_SINGLE_BYTE_MIN, SIGNED_SINGLE_BYTE_MAX as i32) { + assert_eq!(varint_size_i32(i), 1, "value: {}", i); + } + + // these values should encode in 3 bytes (leading byte + 2 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i16::MIN as i32, + -1000, + -200, + SIGNED_SINGLE_BYTE_MIN as i32 - 1, + SIGNED_SINGLE_BYTE_MAX as i32 + 1, + 222, + 1234, + i16::MAX as i32, + ] { + assert_eq!(varint_size_i32(i), 3, "value: {}", i); + } + + // these values should encode in 5 bytes (leading byte + 4 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i32::MIN, + -1_000_000, + i16::MIN as i32 - 1, + i16::MAX as i32 + 1, + 100_000, + 1_000_000, + i32::MAX, + ] { + assert_eq!(varint_size_i32(i), 5, "value: {}", i); + } +} + +#[test] +fn test_size_i64() { + // these should all encode to a single byte + for i in range!(SIGNED_SINGLE_BYTE_MIN, SIGNED_SINGLE_BYTE_MAX as i64) { + assert_eq!(varint_size_i64(i), 1, "value: {}", i); + } + + // these values should encode in 3 bytes (leading byte + 2 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i16::MIN as i64, + -1000, + -200, + SIGNED_SINGLE_BYTE_MIN as i64 - 1, + SIGNED_SINGLE_BYTE_MAX as i64 + 1, + 222, + 1234, + i16::MAX as i64, + ] { + assert_eq!(varint_size_i64(i), 3, "value: {}", i); + } + + // these values should encode in 5 bytes (leading byte + 4 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i32::MIN as i64, + -1_000_000, + i16::MIN as i64 - 1, + i16::MAX as i64 + 1, + 100_000, + 1_000_000, + i32::MAX as i64, + ] { + assert_eq!(varint_size_i64(i), 5, "value: {}", i); + } + + // these values should encode in 9 bytes (leading byte + 8 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i64::MIN, + -6_000_000_000, + i32::MIN as i64 - 1, + i32::MAX as i64 + 1, + 5_000_000_000, + i64::MAX, + ] { + assert_eq!(varint_size_i64(i), 9, "value: {}", i); + } +} + +#[test] +fn test_size_i128() { + // these should all encode to a single byte + for i in range!(SIGNED_SINGLE_BYTE_MIN, SIGNED_SINGLE_BYTE_MAX as i128) { + assert_eq!(varint_size_i128(i), 1, "value: {}", i); + } + + // these values should encode in 3 bytes (leading byte + 2 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i16::MIN as i128, + -1000, + -200, + SIGNED_SINGLE_BYTE_MIN as i128 - 1, + SIGNED_SINGLE_BYTE_MAX as i128 + 1, + 222, + 1234, + i16::MAX as i128, + ] { + assert_eq!(varint_size_i128(i), 3, "value: {}", i); + } + + // these values should encode in 5 bytes (leading byte + 4 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i32::MIN as i128, + -1_000_000, + i16::MIN as i128 - 1, + i16::MAX as i128 + 1, + 100_000, + 1_000_000, + i32::MAX as i128, + ] { + assert_eq!(varint_size_i128(i), 5, "value: {}", i); + } + + // these values should encode in 9 bytes (leading byte + 8 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i64::MIN as i128, + -6_000_000_000, + i32::MIN as i128 - 1, + i32::MAX as i128 + 1, + 5_000_000_000, + i64::MAX as i128, + ] { + assert_eq!(varint_size_i128(i), 9, "value: {}", i); + } + + // these values should encode in 17 bytes (leading byte + 16 bytes) + // Values chosen at random, add new cases as needed + for i in [ + i128::MIN, + i64::MIN as i128 - 1, + i64::MAX as i128 + 1, + i128::MAX, + ] { + assert_eq!(varint_size_i128(i), 17, "value: {}", i); + } +} diff --git a/src/varint/size_unsigned.rs b/src/varint/size_unsigned.rs new file mode 100644 index 00000000..3aa35885 --- /dev/null +++ b/src/varint/size_unsigned.rs @@ -0,0 +1,174 @@ +use super::SINGLE_BYTE_MAX; + +pub fn varint_size_u16(val: u16) -> usize { + if val <= SINGLE_BYTE_MAX as _ { + 1 + } else { + 1 + std::mem::size_of::() + } +} + +pub fn varint_size_u32(val: u32) -> usize { + if val <= SINGLE_BYTE_MAX as _ { + 1 + } else if val <= u16::MAX as _ { + 1 + std::mem::size_of::() + } else { + 1 + std::mem::size_of::() + } +} + +pub fn varint_size_u64(val: u64) -> usize { + if val <= SINGLE_BYTE_MAX as _ { + 1 + } else if val <= u16::MAX as _ { + 1 + std::mem::size_of::() + } else if val <= u32::MAX as _ { + 1 + std::mem::size_of::() + } else { + 1 + std::mem::size_of::() + } +} + +pub fn varint_size_u128(val: u128) -> usize { + if val <= SINGLE_BYTE_MAX as _ { + 1 + } else if val <= u16::MAX as _ { + 1 + std::mem::size_of::() + } else if val <= u32::MAX as _ { + 1 + std::mem::size_of::() + } else if val <= u64::MAX as _ { + 1 + std::mem::size_of::() + } else { + 1 + std::mem::size_of::() + } +} + +pub fn varint_size_usize(val: usize) -> usize { + // usize is being encoded as a u64 + varint_size_u64(val as u64) +} + +#[test] +fn test_size_u16() { + // these should all encode to a single byte + for i in 0u16..=SINGLE_BYTE_MAX as u16 { + assert_eq!(varint_size_u16(i), 1, "value: {}", i); + } + + // these values should encode in 3 bytes (leading byte + 2 bytes) + // Values chosen at random, add new cases as needed + for i in [ + SINGLE_BYTE_MAX as u16 + 1, + 300, + 500, + 700, + 888, + 1234, + u16::MAX, + ] { + assert_eq!(varint_size_u16(i), 3, "value: {}", i); + } +} + +#[test] +fn test_size_u32() { + // these should all encode to a single byte + for i in 0u32..=SINGLE_BYTE_MAX as u32 { + assert_eq!(varint_size_u32(i), 1, "value: {}", i); + } + + // these values should encode in 3 bytes (leading byte + 2 bytes) + // Values chosen at random, add new cases as needed + for i in [ + SINGLE_BYTE_MAX as u32 + 1, + 300, + 500, + 700, + 888, + 1234, + u16::MAX as u32, + ] { + assert_eq!(varint_size_u32(i), 3, "value: {}", i); + } + + // these values should encode in 5 bytes (leading byte + 4 bytes) + // Values chosen at random, add new cases as needed + for i in [u16::MAX as u32 + 1, 100_000, 1_000_000, u32::MAX] { + assert_eq!(varint_size_u32(i), 5, "value: {}", i); + } +} + +#[test] +fn test_size_u64() { + // these should all encode to a single byte + for i in 0u64..=SINGLE_BYTE_MAX as u64 { + assert_eq!(varint_size_u64(i), 1, "value: {}", i); + } + + // these values should encode in 3 bytes (leading byte + 2 bytes) + // Values chosen at random, add new cases as needed + for i in [ + SINGLE_BYTE_MAX as u64 + 1, + 300, + 500, + 700, + 888, + 1234, + u16::MAX as u64, + ] { + assert_eq!(varint_size_u64(i), 3, "value: {}", i); + } + + // these values should encode in 5 bytes (leading byte + 4 bytes) + // Values chosen at random, add new cases as needed + for i in [u16::MAX as u64 + 1, 100_000, 1_000_000, u32::MAX as u64] { + assert_eq!(varint_size_u64(i), 5, "value: {}", i); + } + + // these values should encode in 9 bytes (leading byte + 8 bytes) + // Values chosen at random, add new cases as needed + for i in [u32::MAX as u64 + 1, 5_000_000_000, u64::MAX] { + assert_eq!(varint_size_u64(i), 9, "value: {}", i); + } +} + +#[test] +fn test_size_u128() { + // these should all encode to a single byte + for i in 0u128..=SINGLE_BYTE_MAX as u128 { + assert_eq!(varint_size_u128(i), 1, "value: {}", i); + } + + // these values should encode in 3 bytes (leading byte + 2 bytes) + // Values chosen at random, add new cases as needed + for i in [ + SINGLE_BYTE_MAX as u128 + 1, + 300, + 500, + 700, + 888, + 1234, + u16::MAX as u128, + ] { + assert_eq!(varint_size_u128(i), 3, "value: {}", i); + } + + // these values should encode in 5 bytes (leading byte + 4 bytes) + // Values chosen at random, add new cases as needed + for i in [u16::MAX as u128 + 1, 100_000, 1_000_000, u32::MAX as u128] { + assert_eq!(varint_size_u128(i), 5, "value: {}", i); + } + + // these values should encode in 9 bytes (leading byte + 8 bytes) + // Values chosen at random, add new cases as needed + for i in [u32::MAX as u128 + 1, 5_000_000_000, u64::MAX as u128] { + assert_eq!(varint_size_u128(i), 9, "value: {}", i); + } + + // these values should encode in 17 bytes (leading byte + 16 bytes) + // Values chosen at random, add new cases as needed + for i in [u64::MAX as u128 + 1, u128::MAX] { + assert_eq!(varint_size_u128(i), 17, "value: {}", i); + } +} diff --git a/tests/derive.rs b/tests/derive.rs index 009312b4..875f9603 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -2,7 +2,7 @@ use bincode::error::DecodeError; -#[derive(bincode::Encode, PartialEq, Debug)] +#[derive(bincode::Encode, bincode::EncodedSize, PartialEq, Debug)] pub(crate) struct Test { a: T, b: u32, @@ -17,9 +17,11 @@ fn test_encode() { c: 20u8, }; let mut slice = [0u8; 1024]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(bytes_written, 3); + assert_eq!(bytes_written, encoded_size); assert_eq!(&slice[..bytes_written], &[10, 10, 20]); } #[derive(PartialEq, Debug, Eq)] @@ -71,7 +73,7 @@ fn test_decode() { assert_eq!(len, 5); } -#[derive(bincode::BorrowDecode, bincode::Encode, PartialEq, Debug, Eq)] +#[derive(bincode::BorrowDecode, bincode::Encode, bincode::EncodedSize, PartialEq, Debug, Eq)] pub struct Test3<'a> { a: &'a str, b: u32, @@ -89,24 +91,28 @@ fn test_encode_decode_str() { }; let mut slice = [0u8; 100]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let len = bincode::encode_into_slice(&start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(len, 21); + assert_eq!(len, encoded_size); let (end, len): (Test3, usize) = bincode::borrow_decode_from_slice(&slice[..len], bincode::config::standard()).unwrap(); assert_eq!(end, start); assert_eq!(len, 21); } -#[derive(bincode::Encode, bincode::Decode, PartialEq, Debug, Eq)] +#[derive(bincode::Encode, bincode::Decode, bincode::EncodedSize, PartialEq, Debug, Eq)] pub struct TestTupleStruct(u32, u32, u32); #[test] fn test_encode_tuple() { let start = TestTupleStruct(5, 10, 1024); let mut slice = [0u8; 1024]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(bytes_written, 5); + assert_eq!(bytes_written, encoded_size); assert_eq!(&slice[..bytes_written], &[5, 10, 251, 0, 4]); } @@ -120,7 +126,7 @@ fn test_decode_tuple() { assert_eq!(len, 5); } -#[derive(bincode::Encode, bincode::Decode, PartialEq, Debug, Eq)] +#[derive(bincode::Encode, bincode::Decode, bincode::EncodedSize, PartialEq, Debug, Eq)] pub enum TestEnum { Foo, Bar { name: u32 }, @@ -130,9 +136,11 @@ pub enum TestEnum { fn test_encode_enum_struct_variant() { let start = TestEnum::Bar { name: 5u32 }; let mut slice = [0u8; 1024]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(bytes_written, 2); + assert_eq!(bytes_written, encoded_size); assert_eq!(&slice[..bytes_written], &[1, 5]); } @@ -160,9 +168,11 @@ fn test_decode_enum_unit_variant() { fn test_encode_enum_unit_variant() { let start = TestEnum::Foo; let mut slice = [0u8; 1024]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(bytes_written, 1); + assert_eq!(bytes_written, encoded_size); assert_eq!(&slice[..bytes_written], &[0]); } @@ -170,9 +180,11 @@ fn test_encode_enum_unit_variant() { fn test_encode_enum_tuple_variant() { let start = TestEnum::Baz(5, 10, 1024); let mut slice = [0u8; 1024]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(bytes_written, 6); + assert_eq!(bytes_written, encoded_size); assert_eq!(&slice[..bytes_written], &[2, 5, 10, 251, 0, 4]); } @@ -186,7 +198,7 @@ fn test_decode_enum_tuple_variant() { assert_eq!(len, 6); } -#[derive(bincode::Encode, bincode::BorrowDecode, PartialEq, Debug, Eq)] +#[derive(bincode::Encode, bincode::EncodedSize, bincode::BorrowDecode, PartialEq, Debug, Eq)] pub enum TestEnum2<'a> { Foo, Bar { name: &'a str }, @@ -197,9 +209,11 @@ pub enum TestEnum2<'a> { fn test_encode_borrowed_enum_struct_variant() { let start = TestEnum2::Bar { name: "foo" }; let mut slice = [0u8; 1024]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(bytes_written, 5); + assert_eq!(bytes_written, encoded_size); assert_eq!(&slice[..bytes_written], &[1, 3, 102, 111, 111]); } @@ -227,9 +241,11 @@ fn test_decode_borrowed_enum_unit_variant() { fn test_encode_borrowed_enum_unit_variant() { let start = TestEnum2::Foo; let mut slice = [0u8; 1024]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(bytes_written, 1); + assert_eq!(bytes_written, encoded_size); assert_eq!(&slice[..bytes_written], &[0]); } @@ -237,9 +253,11 @@ fn test_encode_borrowed_enum_unit_variant() { fn test_encode_borrowed_enum_tuple_variant() { let start = TestEnum2::Baz(5, 10, 1024); let mut slice = [0u8; 1024]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(bytes_written, 6); + assert_eq!(bytes_written, encoded_size); assert_eq!(&slice[..bytes_written], &[2, 5, 10, 251, 0, 4]); } @@ -253,7 +271,7 @@ fn test_decode_borrowed_enum_tuple_variant() { assert_eq!(len, 6); } -#[derive(bincode::Decode, bincode::Encode, PartialEq, Eq, Debug)] +#[derive(bincode::Decode, bincode::Encode, bincode::EncodedSize, PartialEq, Eq, Debug)] enum CStyleEnum { A = -1, B = 2, @@ -266,9 +284,11 @@ enum CStyleEnum { fn test_c_style_enum() { fn ser(e: CStyleEnum) -> u8 { let mut slice = [0u8; 10]; + let encoded_size = bincode::encoded_size(&e, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(e, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(bytes_written, 1); + assert_eq!(bytes_written, encoded_size); slice[0] } @@ -312,7 +332,7 @@ fn test_c_style_enum() { macro_rules! macro_newtype { ($name:ident) => { - #[derive(bincode::Encode, bincode::Decode, PartialEq, Eq, Debug)] + #[derive(bincode::Encode, bincode::Decode, bincode::EncodedSize, PartialEq, Eq, Debug)] pub struct $name(pub usize); }; } @@ -322,10 +342,13 @@ macro_newtype!(MacroNewType); fn test_macro_newtype() { for val in [0, 100, usize::MAX] { let mut usize_slice = [0u8; 10]; + let usize_encoded_size = bincode::encoded_size(&val, bincode::config::standard()).unwrap(); let usize_len = bincode::encode_into_slice(val, &mut usize_slice, bincode::config::standard()).unwrap(); let mut newtype_slice = [0u8; 10]; + let newtype_encoded_size = + bincode::encoded_size(&val, bincode::config::standard()).unwrap(); let newtype_len = bincode::encode_into_slice( MacroNewType(val), &mut newtype_slice, @@ -335,6 +358,8 @@ fn test_macro_newtype() { assert_eq!(usize_len, newtype_len); assert_eq!(usize_slice, newtype_slice); + assert_eq!(usize_len, usize_encoded_size); + assert_eq!(newtype_len, newtype_encoded_size); let (newtype, len) = bincode::decode_from_slice::( &newtype_slice, @@ -346,7 +371,7 @@ fn test_macro_newtype() { } } -#[derive(bincode::Encode, bincode::Decode, Debug)] +#[derive(bincode::Encode, bincode::Decode, bincode::EncodedSize, Debug)] pub enum EmptyEnum {} #[derive(bincode::Encode, bincode::BorrowDecode, Debug)] @@ -363,7 +388,7 @@ fn test_empty_enum_decode() { } } -#[derive(bincode::Encode, bincode::Decode, PartialEq, Debug, Eq)] +#[derive(bincode::Encode, bincode::Decode, bincode::EncodedSize, PartialEq, Debug, Eq)] pub enum TestWithGeneric { Foo, Bar(T), @@ -373,6 +398,7 @@ pub enum TestWithGeneric { fn test_enum_with_generics_roundtrip() { let start = TestWithGeneric::Bar(1234); let mut slice = [0u8; 10]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(&start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!( @@ -383,6 +409,7 @@ fn test_enum_with_generics_roundtrip() { 210, 4 // 1234 ] ); + assert_eq!(bytes_written, encoded_size); let decoded: TestWithGeneric = bincode::decode_from_slice(&slice[..bytes_written], bincode::config::standard()) @@ -392,9 +419,11 @@ fn test_enum_with_generics_roundtrip() { let start = TestWithGeneric::<()>::Foo; let mut slice = [0u8; 10]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let bytes_written = bincode::encode_into_slice(&start, &mut slice, bincode::config::standard()).unwrap(); assert_eq!(&slice[..bytes_written], &[0]); + assert_eq!(bytes_written, encoded_size); let decoded: TestWithGeneric<()> = bincode::decode_from_slice(&slice[..bytes_written], bincode::config::standard()) @@ -408,12 +437,12 @@ mod zoxide { extern crate alloc; use alloc::borrow::Cow; - use bincode::{Decode, Encode}; + use bincode::{Decode, Encode, EncodedSize}; pub type Rank = f64; pub type Epoch = u64; - #[derive(Encode, Decode)] + #[derive(Encode, Decode, EncodedSize)] pub struct Dir<'a> { pub path: Cow<'a, str>, pub rank: Rank, @@ -436,7 +465,11 @@ mod zoxide { ]; let config = bincode::config::standard(); + let encoded_size = bincode::encoded_size(dirs, config).unwrap(); let slice = bincode::encode_to_vec(dirs, config).unwrap(); + + assert_eq!(slice.len(), encoded_size); + let decoded: Vec = bincode::borrow_decode_from_slice(&slice, config).unwrap().0; assert_eq!(decoded.len(), 2); diff --git a/tests/serde.rs b/tests/serde.rs index 2adfcc4d..819b4306 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -5,7 +5,7 @@ extern crate alloc; use alloc::string::String; use serde_derive::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, bincode::Encode, bincode::Decode)] +#[derive(Serialize, Deserialize, bincode::Encode, bincode::Decode, bincode::EncodedSize)] pub struct SerdeRoundtrip { pub a: u32, #[serde(skip)] @@ -13,7 +13,9 @@ pub struct SerdeRoundtrip { pub c: TupleS, } -#[derive(Serialize, Deserialize, bincode::Encode, bincode::Decode, PartialEq, Debug)] +#[derive( + Serialize, Deserialize, bincode::Encode, bincode::Decode, bincode::EncodedSize, PartialEq, Debug, +)] pub struct TupleS(f32, f32, f32); #[test] @@ -32,16 +34,16 @@ fn test_serde_round_trip() { assert_eq!(result.b, 0); // validate bincode working - let bytes = bincode::serde::encode_to_vec( - SerdeRoundtrip { - a: 15, - b: 15, - c: TupleS(2.0, 3.0, 4.0), - }, - bincode::config::standard(), - ) - .unwrap(); + let start = SerdeRoundtrip { + a: 15, + b: 15, + c: TupleS(2.0, 3.0, 4.0), + }; + let encoded_size = bincode::serde::encoded_size(&start, bincode::config::standard()).unwrap(); + let bytes = bincode::serde::encode_to_vec(start, bincode::config::standard()).unwrap(); + assert_eq!(bytes.len(), encoded_size); assert_eq!(bytes, &[15, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64]); + let (result, len): (SerdeRoundtrip, usize) = bincode::serde::decode_from_slice(&bytes, bincode::config::standard()).unwrap(); assert_eq!(result.a, 15); @@ -75,8 +77,10 @@ fn test_serialize_deserialize_borrowed_data() { ]; let mut result = [0u8; 20]; + let encoded_size = bincode::serde::encoded_size(&input, bincode::config::standard()).unwrap(); let len = bincode::serde::encode_into_slice(&input, &mut result, bincode::config::standard()) .unwrap(); + assert_eq!(len, encoded_size); let result = &result[..len]; assert_eq!(result, expected); @@ -120,8 +124,10 @@ fn test_serialize_deserialize_owned_data() { ]; let mut result = [0u8; 20]; + let encoded_size = bincode::serde::encoded_size(&input, bincode::config::standard()).unwrap(); let len = bincode::serde::encode_into_slice(&input, &mut result, bincode::config::standard()) .unwrap(); + assert_eq!(len, encoded_size); let result = &result[..len]; assert_eq!(result, expected); @@ -143,7 +149,7 @@ fn test_serialize_deserialize_owned_data() { #[cfg(feature = "derive")] mod derive { - use bincode::{Decode, Encode}; + use bincode::{Decode, Encode, EncodedSize}; use serde_derive::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] @@ -151,13 +157,13 @@ mod derive { pub a: u32, } - #[derive(Decode, Encode, PartialEq, Eq, Debug)] + #[derive(Decode, Encode, EncodedSize, PartialEq, Eq, Debug)] pub struct StructWithSerde { #[bincode(with_serde)] pub serde: SerdeType, } - #[derive(Decode, Encode, PartialEq, Eq, Debug)] + #[derive(Decode, Encode, EncodedSize, PartialEq, Eq, Debug)] pub enum EnumWithSerde { Unit(#[bincode(with_serde)] SerdeType), Struct { @@ -170,12 +176,18 @@ mod derive { fn test_serde_derive() { fn test_encode_decode(start: T, expected_len: usize) where - T: bincode::Encode + bincode::Decode + PartialEq + core::fmt::Debug, + T: bincode::Encode + + bincode::Decode + + bincode::EncodedSize + + PartialEq + + core::fmt::Debug, { let mut slice = [0u8; 100]; + let encoded_size = bincode::encoded_size(&start, bincode::config::standard()).unwrap(); let len = bincode::encode_into_slice(&start, &mut slice, bincode::config::standard()) .unwrap(); assert_eq!(len, expected_len); + assert_eq!(len, encoded_size); let slice = &slice[..len]; let (result, len): (T, usize) = bincode::decode_from_slice(slice, bincode::config::standard()).unwrap(); diff --git a/tests/utils.rs b/tests/utils.rs index 3672ebc0..c4e63ada 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -7,6 +7,7 @@ where CMP: Fn(&V, &V) -> bool, { let mut buffer = [0u8; 2048]; + let calculated_size = bincode::encoded_size(&element, config).unwrap(); let len = bincode::encode_into_slice(&element, &mut buffer, config).unwrap(); println!( "{:?} ({}): {:?} ({:?})", @@ -15,6 +16,11 @@ where &buffer[..len], core::any::type_name::() ); + assert_eq!( + calculated_size, len, + "Calculated encoded size does not match actual encoded size\nCalculated: {:?}\nActual: {:?}", + calculated_size, len, + ); let (decoded, decoded_len): (V, usize) = bincode::decode_from_slice(&buffer, config).unwrap(); assert!( @@ -141,13 +147,20 @@ where #[cfg(feature = "serde")] pub trait TheSameTrait: - bincode::Encode + bincode::Decode + serde::de::DeserializeOwned + serde::Serialize + Debug + 'static + bincode::Encode + + bincode::Decode + + bincode::EncodedSize + + serde::de::DeserializeOwned + + serde::Serialize + + Debug + + 'static { } #[cfg(feature = "serde")] impl TheSameTrait for T where T: bincode::Encode + bincode::Decode + + bincode::EncodedSize + serde::de::DeserializeOwned + serde::Serialize + Debug @@ -156,9 +169,15 @@ impl TheSameTrait for T where } #[cfg(not(feature = "serde"))] -pub trait TheSameTrait: bincode::Encode + bincode::Decode + Debug + 'static {} +pub trait TheSameTrait: + bincode::Encode + bincode::Decode + bincode::EncodedSize + Debug + 'static +{ +} #[cfg(not(feature = "serde"))] -impl TheSameTrait for T where T: bincode::Encode + bincode::Decode + Debug + 'static {} +impl TheSameTrait for T where + T: bincode::Encode + bincode::Decode + bincode::EncodedSize + Debug + 'static +{ +} #[allow(dead_code)] // This is not used in every test pub fn the_same(element: V) {