Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EncodedSize trait to calculate encoded sizes #609

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions derive/src/attribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub struct ContainerAttributes {
pub decode_bounds: Option<(String, Literal)>,
pub borrow_decode_bounds: Option<(String, Literal)>,
pub encode_bounds: Option<(String, Literal)>,
pub encoded_size_bounds: Option<(String, Literal)>,
}

impl Default for ContainerAttributes {
Expand All @@ -17,6 +18,7 @@ impl Default for ContainerAttributes {
decode_bounds: None,
encode_bounds: None,
borrow_decode_bounds: None,
encoded_size_bounds: None,
}
}
}
Expand Down Expand Up @@ -76,6 +78,15 @@ impl FromAttribute for ContainerAttributes {
return Err(Error::custom_at("Should be a literal str", val.span()));
}
}
ParsedAttribute::Property(key, val) if key.to_string() == "encoded_size_bounds" => {
let val_string = val.to_string();
if val_string.starts_with('"') && val_string.ends_with('"') {
result.encoded_size_bounds =
Some((val_string[1..val_string.len() - 1].to_string(), val));
} else {
return Err(Error::custom_at("Should be a literal str", val.span()));
}
}
ParsedAttribute::Tag(i) => {
return Err(Error::custom_at("Unknown field attribute", i.span()))
}
Expand Down
116 changes: 116 additions & 0 deletions derive/src/derive_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,122 @@ impl DeriveEnum {
Ok(())
}

pub fn generate_encoded_size(self, generator: &mut Generator) -> Result<()> {
let crate_name = self.attributes.crate_name.as_str();
generator
.impl_for(format!("{}::EncodedSize", crate_name))
.modify_generic_constraints(|generics, where_constraints| {
if let Some((bounds, lit)) =
(self.attributes.encoded_size_bounds.as_ref()).or(self.attributes.bounds.as_ref())
{
where_constraints.clear();
where_constraints
.push_parsed_constraint(bounds)
.map_err(|e| e.with_span(lit.span()))?;
} else {
for g in generics.iter_generics() {
where_constraints
.push_constraint(g, format!("{}::EncodedSize", crate_name))
.unwrap();
}
}
Ok(())
})?
.generate_fn("encoded_size")
.with_generic_deps("__C", [format!("{}::config::Config", crate_name)])
.with_self_arg(FnSelfArg::RefSelf)
.with_return_type(format!(
"core::result::Result<usize, {}::error::EncodeError>",
crate_name
))
.body(|fn_body| {
fn_body.ident_str("match");
fn_body.ident_str("self");
fn_body.group(Delimiter::Brace, |match_body| {
if self.variants.is_empty() {
self.encode_empty_enum_case(match_body)?;
}
for (variant_index, variant) in self.iter_fields() {
// Self::Variant
match_body.ident_str("Self");
match_body.puncts("::");
match_body.ident(variant.name.clone());

// if we have any fields, declare them here
// Self::Variant { a, b, c }
if let Some(delimiter) = variant.fields.delimiter() {
match_body.group(delimiter, |field_body| {
for (idx, field_name) in
variant.fields.names().into_iter().enumerate()
{
if idx != 0 {
field_body.punct(',');
}
field_body.push(
field_name.to_token_tree_with_prefix(TUPLE_FIELD_PREFIX),
);
}
Ok(())
})?;
}

// Arrow
// Self::Variant { a, b, c } =>
match_body.puncts("=>");

// Body of this variant
// Note that the fields are available as locals because of the match destructuring above
// {
// encoder.encode_u32(n)?;
// bincode::Encode::encode(a, encoder)?;
// bincode::Encode::encode(b, encoder)?;
// bincode::Encode::encode(c, encoder)?;
// }
match_body.group(Delimiter::Brace, |body| {
// variant index
body.push_parsed(format!("let mut __encoded_size = <u32 as {}::EncodedSize>::encoded_size::<__C>", crate_name))?;
body.group(Delimiter::Parenthesis, |args| {
args.punct('&');
args.group(Delimiter::Parenthesis, |num| {
num.extend(variant_index);
Ok(())
})?;
Ok(())
})?;
body.punct('?');
body.punct(';');
// If we have any fields, add up all their sizes them all one by one
for field_name in variant.fields.names() {
let attributes = field_name
.attributes()
.get_attribute::<FieldAttributes>()?
.unwrap_or_default();
if attributes.with_serde {
body.push_parsed(format!(
"__encoded_size += {0}::EncodedSize::encoded_size::<__C>(&{0}::serde::Compat({1}))?;",
crate_name,
field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX),
))?;
} else {
body.push_parsed(format!(
"__encoded_size += {0}::EncodedSize::encoded_size::<__C>({1})?;",
crate_name,
field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX),
))?;
}
}
body.push_parsed("Ok(__encoded_size)")?;
Ok(())
})?;
match_body.punct(',');
}
Ok(())
})?;
Ok(())
})?;
Ok(())
}

/// If we're encoding an empty enum, we need to add an empty case in the form of:
/// `_ => core::unreachable!(),`
fn encode_empty_enum_case(&self, builder: &mut StreamBuilder) -> Result {
Expand Down
53 changes: 53 additions & 0 deletions derive/src/derive_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,59 @@ impl DeriveStruct {
Ok(())
}

pub fn generate_encoded_size(self, generator: &mut Generator) -> Result<()> {
let crate_name = &self.attributes.crate_name;
generator
.impl_for(&format!("{}::EncodedSize", crate_name))
.modify_generic_constraints(|generics, where_constraints| {
if let Some((bounds, lit)) =
(self.attributes.encoded_size_bounds.as_ref()).or(self.attributes.bounds.as_ref())
{
where_constraints.clear();
where_constraints
.push_parsed_constraint(bounds)
.map_err(|e| e.with_span(lit.span()))?;
} else {
for g in generics.iter_generics() {
where_constraints
.push_constraint(g, format!("{}::EncodedSize", crate_name))
.unwrap();
}
}
Ok(())
})?
.generate_fn("encoded_size")
.with_generic_deps("__C", [format!("{}::config::Config", crate_name)])
.with_self_arg(virtue::generate::FnSelfArg::RefSelf)
.with_return_type(format!(
"core::result::Result<usize, {}::error::EncodeError>",
crate_name
))
.body(|fn_body| {
fn_body.push_parsed("let mut __encoded_size = 0;")?;
for field in self.fields.names() {
let attributes = field
.attributes()
.get_attribute::<FieldAttributes>()?
.unwrap_or_default();
if attributes.with_serde {
fn_body.push_parsed(format!(
"__encoded_size += {0}::EncodedSize::encoded_size::<__C>(&{0}::serde::Compat(&self.{1}))?;",
crate_name, field
))?;
} else {
fn_body.push_parsed(format!(
"__encoded_size += {}::EncodedSize::encoded_size::<__C>(&self.{})?;",
crate_name, field
))?;
}
}
fn_body.push_parsed("Ok(__encoded_size)")?;
Ok(())
})?;
Ok(())
}

pub fn generate_decode(self, generator: &mut Generator) -> Result<()> {
// Remember to keep this mostly in sync with generate_borrow_decode
let crate_name = &self.attributes.crate_name;
Expand Down
33 changes: 33 additions & 0 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,39 @@ fn derive_encode_inner(input: TokenStream) -> Result<TokenStream> {
generator.finish()
}

#[proc_macro_derive(EncodedSize, attributes(bincode))]
pub fn derive_encoded_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_encoded_size_inner(input).unwrap_or_else(|e| e.into_token_stream())
}

fn derive_encoded_size_inner(input: TokenStream) -> Result<TokenStream> {
let parse = Parse::new(input)?;
let (mut generator, attributes, body) = parse.into_generator();
let attributes = attributes
.get_attribute::<ContainerAttributes>()?
.unwrap_or_default();

match body {
Body::Struct(body) => {
derive_struct::DeriveStruct {
fields: body.fields,
attributes,
}
.generate_encoded_size(&mut generator)?;
}
Body::Enum(body) => {
derive_enum::DeriveEnum {
variants: body.variants,
attributes,
}
.generate_encoded_size(&mut generator)?;
}
}

generator.export_to_file("bincode", "EncodedSize");
generator.finish()
}

#[proc_macro_derive(Decode, attributes(bincode))]
pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_decode_inner(input).unwrap_or_else(|e| e.into_token_stream())
Expand Down
4 changes: 2 additions & 2 deletions docs/migration_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Then replace the following functions: (`Configuration` is `bincode::config::lega
|||
|`bincode::serialize(T)`|`bincode::serde::encode_to_vec(T, Configuration)`<br />`bincode::serde::encode_into_slice(T, &mut [u8], Configuration)`|
|`bincode::serialize_into(std::io::Write, T)`|`bincode::serde::encode_into_std_write(T, std::io::Write, Configuration)`|
|`bincode::serialized_size(T)`|Currently not implemented|
|`bincode::serialized_size(T)`|`bincode::serde::encoded_size(T, Configuration)`|

## Migrating to `bincode-derive`

Expand Down Expand Up @@ -98,7 +98,7 @@ Then replace the following functions: (`Configuration` is `bincode::config::lega
|||
|`bincode::serialize(T)`|`bincode::encode_to_vec(T, Configuration)`<br />`bincode::encode_into_slice(t: T, &mut [u8], Configuration)`|
|`bincode::serialize_into(std::io::Write, T)`|`bincode::encode_into_std_write(T, std::io::Write, Configuration)`|
|`bincode::serialized_size(T)`|Currently not implemented|
|`bincode::serialized_size(T)`|`bincode::encoded_size(T, Configuration)`|


### Bincode derive and libraries
Expand Down
6 changes: 6 additions & 0 deletions fuzz/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ name = "compat"
path = "fuzz_targets/compat.rs"
test = false
doc = false

[[bin]]
name = "encoded_size"
path = "fuzz_targets/encoded_size.rs"
test = false
doc = false
52 changes: 52 additions & 0 deletions fuzz/fuzz_targets/encoded_size.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#![no_main]
use libfuzzer_sys::fuzz_target;

use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
use std::ffi::CString;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::num::{NonZeroI128, NonZeroI32, NonZeroU128, NonZeroU32};
use std::path::PathBuf;
use std::rc::Rc;
use std::sync::Arc;
use std::time::{Duration, SystemTime};

#[derive(bincode::Decode, bincode::Encode, bincode::EncodedSize, PartialEq, Debug)]
enum AllTypes {
BTreeMap(BTreeMap<u8, u8>),
HashMap(HashMap<u8, u8>),
HashSet(HashSet<u8>),
BTreeSet(BTreeSet<u8>),
VecDeque(VecDeque<AllTypes>),
Vec(Vec<AllTypes>),
String(String),
Box(Box<AllTypes>),
BoxSlice(Box<[AllTypes]>),
Rc(Rc<AllTypes>),
Arc(Arc<AllTypes>),
CString(CString),
SystemTime(SystemTime),
Duration(Duration),
PathBuf(PathBuf),
IpAddr(IpAddr),
Ipv4Addr(Ipv4Addr),
Ipv6Addr(Ipv6Addr),
SocketAddr(SocketAddr),
SocketAddrV4(SocketAddrV4),
SocketAddrV6(SocketAddrV6),
NonZeroU32(NonZeroU32),
NonZeroI32(NonZeroI32),
NonZeroU128(NonZeroU128),
NonZeroI128(NonZeroI128),
// Cow(Cow<'static, [u8]>), Blocked, see comment on decode
}

fuzz_target!(|data: &[u8]| {
let config = bincode::config::standard().with_limit::<1024>();
let result: Result<(AllTypes, _), _> = bincode::decode_from_slice(data, config);

if let Ok((value, _)) = result {
let encoded_size = bincode::encoded_size(&value, config).expect("encoded size");
let encoded = bincode::encode_to_vec(&value, config).expect("round trip");
assert_eq!(encoded_size, encoded.len());
}
});
Loading