Skip to content

Commit

Permalink
Fix max_encoded_len for Compact fields (#508)
Browse files Browse the repository at this point in the history
* Fix max_encoded_len for Compact fields

* fix

* Add missing test

* nit

* Apply suggestions from code review

Co-authored-by: Bastian Köcher <[email protected]>

* fix ui-test output

---------

Co-authored-by: Bastian Köcher <[email protected]>
  • Loading branch information
pgherveou and bkchr authored Sep 4, 2023
1 parent 1516bb9 commit ddf9439
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 31 deletions.
51 changes: 29 additions & 22 deletions derive/src/max_encoded_len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

use crate::{
trait_bounds,
utils::{codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip},
utils::{self, codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip},
};
use quote::{quote, quote_spanned};
use syn::{parse_quote, spanned::Spanned, Data, DeriveInput, Fields, Type};
use syn::{parse_quote, spanned::Spanned, Data, DeriveInput, Field, Fields};

/// impl for `#[derive(MaxEncodedLen)]`
pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand All @@ -43,13 +43,13 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok
parse_quote!(#crate_path::MaxEncodedLen),
None,
has_dumb_trait_bound(&input.attrs),
&crate_path
&crate_path,
) {
return e.to_compile_error().into()
}
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let data_expr = data_length_expr(&input.data);
let data_expr = data_length_expr(&input.data, &crate_path);

quote::quote!(
const _: () = {
Expand All @@ -64,22 +64,22 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok
}

/// generate an expression to sum up the max encoded length from several fields
fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream {
let type_iter: Box<dyn Iterator<Item = &Type>> = match fields {
Fields::Named(ref fields) => Box::new(
fields.named.iter().filter_map(|field| if should_skip(&field.attrs) {
fn fields_length_expr(fields: &Fields, crate_path: &syn::Path) -> proc_macro2::TokenStream {
let fields_iter: Box<dyn Iterator<Item = &Field>> = match fields {
Fields::Named(ref fields) => Box::new(fields.named.iter().filter_map(|field| {
if should_skip(&field.attrs) {
None
} else {
Some(&field.ty)
})
),
Fields::Unnamed(ref fields) => Box::new(
fields.unnamed.iter().filter_map(|field| if should_skip(&field.attrs) {
Some(field)
}
})),
Fields::Unnamed(ref fields) => Box::new(fields.unnamed.iter().filter_map(|field| {
if should_skip(&field.attrs) {
None
} else {
Some(&field.ty)
})
),
Some(field)
}
})),
Fields::Unit => Box::new(std::iter::empty()),
};
// expands to an expression like
Expand All @@ -92,9 +92,16 @@ fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream {
// `max_encoded_len` call. This way, if one field's type doesn't implement
// `MaxEncodedLen`, the compiler's error message will underline which field
// caused the issue.
let expansion = type_iter.map(|ty| {
quote_spanned! {
ty.span() => .saturating_add(<#ty>::max_encoded_len())
let expansion = fields_iter.map(|field| {
let ty = &field.ty;
if utils::is_compact(&field) {
quote_spanned! {
ty.span() => .saturating_add(<#crate_path::Compact::<#ty> as #crate_path::MaxEncodedLen>::max_encoded_len())
}
} else {
quote_spanned! {
ty.span() => .saturating_add(<#ty as #crate_path::MaxEncodedLen>::max_encoded_len())
}
}
});
quote! {
Expand All @@ -103,9 +110,9 @@ fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream {
}

// generate an expression to sum up the max encoded length of each field
fn data_length_expr(data: &Data) -> proc_macro2::TokenStream {
fn data_length_expr(data: &Data, crate_path: &syn::Path) -> proc_macro2::TokenStream {
match *data {
Data::Struct(ref data) => fields_length_expr(&data.fields),
Data::Struct(ref data) => fields_length_expr(&data.fields, crate_path),
Data::Enum(ref data) => {
// We need an expression expanded for each variant like
//
Expand All @@ -121,7 +128,7 @@ fn data_length_expr(data: &Data) -> proc_macro2::TokenStream {
// Each variant expression's sum is computed the way an equivalent struct's would be.

let expansion = data.variants.iter().map(|variant| {
let variant_expression = fields_length_expr(&variant.fields);
let variant_expression = fields_length_expr(&variant.fields, crate_path);
quote! {
.max(#variant_expression)
}
Expand Down
25 changes: 24 additions & 1 deletion tests/max_encoded_len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
//! Tests for MaxEncodedLen derive macro
#![cfg(all(feature = "derive", feature = "max-encoded-len"))]

use parity_scale_codec::{MaxEncodedLen, Compact, Decode, Encode};
use parity_scale_codec::{Compact, Decode, Encode, MaxEncodedLen};

#[derive(Encode, MaxEncodedLen)]
struct Primitives {
Expand Down Expand Up @@ -64,6 +64,29 @@ fn generic_max_length() {
assert_eq!(Generic::<u32>::max_encoded_len(), u32::max_encoded_len() * 2);
}

#[derive(Encode, MaxEncodedLen)]
struct CompactField {
#[codec(compact)]
t: u64,
v: u64,
}

#[test]
fn compact_field_max_length() {
assert_eq!(
CompactField::max_encoded_len(),
Compact::<u64>::max_encoded_len() + u64::max_encoded_len()
);
}

#[derive(Encode, MaxEncodedLen)]
struct CompactStruct(#[codec(compact)] u64);

#[test]
fn compact_struct_max_length() {
assert_eq!(CompactStruct::max_encoded_len(), Compact::<u64>::max_encoded_len());
}

#[derive(Encode, MaxEncodedLen)]
struct TwoGenerics<T, U> {
t: T,
Expand Down
20 changes: 12 additions & 8 deletions tests/max_encoded_len_ui/unsupported_variant.stderr
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
error[E0599]: no function or associated item named `max_encoded_len` found for struct `NotMel` in the current scope
error[E0277]: the trait bound `NotMel: MaxEncodedLen` is not satisfied
--> tests/max_encoded_len_ui/unsupported_variant.rs:8:9
|
4 | struct NotMel;
| ------------- function or associated item `max_encoded_len` not found for this struct
...
8 | NotMel(NotMel),
| ^^^^^^ function or associated item not found in `NotMel`
| ^^^^^^ the trait `MaxEncodedLen` is not implemented for `NotMel`
|
= help: items from traits can only be used if the trait is implemented and in scope
= note: the following trait defines an item `max_encoded_len`, perhaps you need to implement it:
candidate #1: `MaxEncodedLen`
= help: the following other types implement trait `MaxEncodedLen`:
()
(TupleElement0, TupleElement1)
(TupleElement0, TupleElement1, TupleElement2)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5, TupleElement6)
(TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5, TupleElement6, TupleElement7)
and $N others

0 comments on commit ddf9439

Please sign in to comment.