diff --git a/derive/src/max_encoded_len.rs b/derive/src/max_encoded_len.rs index bf8b20f7..f2db291b 100644 --- a/derive/src/max_encoded_len.rs +++ b/derive/src/max_encoded_len.rs @@ -17,10 +17,10 @@ use crate::{ trait_bounds, - utils::{codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip}, + utils::{self, codec_crate_path, custom_mel_trait_bound, has_dumb_trait_bound, should_skip}, }; use quote::{quote, quote_spanned}; -use syn::{parse_quote, spanned::Spanned, Data, DeriveInput, Fields, Type}; +use syn::{parse_quote, spanned::Spanned, Data, DeriveInput, Field, Fields}; /// impl for `#[derive(MaxEncodedLen)]` pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::TokenStream { @@ -43,13 +43,13 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok parse_quote!(#crate_path::MaxEncodedLen), None, has_dumb_trait_bound(&input.attrs), - &crate_path + &crate_path, ) { return e.to_compile_error().into() } let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let data_expr = data_length_expr(&input.data); + let data_expr = data_length_expr(&input.data, &crate_path); quote::quote!( const _: () = { @@ -64,22 +64,22 @@ pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::Tok } /// generate an expression to sum up the max encoded length from several fields -fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream { - let type_iter: Box> = match fields { - Fields::Named(ref fields) => Box::new( - fields.named.iter().filter_map(|field| if should_skip(&field.attrs) { +fn fields_length_expr(fields: &Fields, crate_path: &syn::Path) -> proc_macro2::TokenStream { + let fields_iter: Box> = match fields { + Fields::Named(ref fields) => Box::new(fields.named.iter().filter_map(|field| { + if should_skip(&field.attrs) { None } else { - Some(&field.ty) - }) - ), - Fields::Unnamed(ref fields) => Box::new( - fields.unnamed.iter().filter_map(|field| if should_skip(&field.attrs) { + Some(field) + } + })), + Fields::Unnamed(ref fields) => Box::new(fields.unnamed.iter().filter_map(|field| { + if should_skip(&field.attrs) { None } else { - Some(&field.ty) - }) - ), + Some(field) + } + })), Fields::Unit => Box::new(std::iter::empty()), }; // expands to an expression like @@ -92,9 +92,16 @@ fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream { // `max_encoded_len` call. This way, if one field's type doesn't implement // `MaxEncodedLen`, the compiler's error message will underline which field // caused the issue. - let expansion = type_iter.map(|ty| { - quote_spanned! { - ty.span() => .saturating_add(<#ty>::max_encoded_len()) + let expansion = fields_iter.map(|field| { + let ty = &field.ty; + if utils::is_compact(&field) { + quote_spanned! { + ty.span() => .saturating_add(<#crate_path::Compact::<#ty> as #crate_path::MaxEncodedLen>::max_encoded_len()) + } + } else { + quote_spanned! { + ty.span() => .saturating_add(<#ty as #crate_path::MaxEncodedLen>::max_encoded_len()) + } } }); quote! { @@ -103,9 +110,9 @@ fn fields_length_expr(fields: &Fields) -> proc_macro2::TokenStream { } // generate an expression to sum up the max encoded length of each field -fn data_length_expr(data: &Data) -> proc_macro2::TokenStream { +fn data_length_expr(data: &Data, crate_path: &syn::Path) -> proc_macro2::TokenStream { match *data { - Data::Struct(ref data) => fields_length_expr(&data.fields), + Data::Struct(ref data) => fields_length_expr(&data.fields, crate_path), Data::Enum(ref data) => { // We need an expression expanded for each variant like // @@ -121,7 +128,7 @@ fn data_length_expr(data: &Data) -> proc_macro2::TokenStream { // Each variant expression's sum is computed the way an equivalent struct's would be. let expansion = data.variants.iter().map(|variant| { - let variant_expression = fields_length_expr(&variant.fields); + let variant_expression = fields_length_expr(&variant.fields, crate_path); quote! { .max(#variant_expression) } diff --git a/tests/max_encoded_len.rs b/tests/max_encoded_len.rs index b34ec12e..09d71a6f 100644 --- a/tests/max_encoded_len.rs +++ b/tests/max_encoded_len.rs @@ -16,7 +16,7 @@ //! Tests for MaxEncodedLen derive macro #![cfg(all(feature = "derive", feature = "max-encoded-len"))] -use parity_scale_codec::{MaxEncodedLen, Compact, Decode, Encode}; +use parity_scale_codec::{Compact, Decode, Encode, MaxEncodedLen}; #[derive(Encode, MaxEncodedLen)] struct Primitives { @@ -64,6 +64,29 @@ fn generic_max_length() { assert_eq!(Generic::::max_encoded_len(), u32::max_encoded_len() * 2); } +#[derive(Encode, MaxEncodedLen)] +struct CompactField { + #[codec(compact)] + t: u64, + v: u64, +} + +#[test] +fn compact_field_max_length() { + assert_eq!( + CompactField::max_encoded_len(), + Compact::::max_encoded_len() + u64::max_encoded_len() + ); +} + +#[derive(Encode, MaxEncodedLen)] +struct CompactStruct(#[codec(compact)] u64); + +#[test] +fn compact_struct_max_length() { + assert_eq!(CompactStruct::max_encoded_len(), Compact::::max_encoded_len()); +} + #[derive(Encode, MaxEncodedLen)] struct TwoGenerics { t: T, diff --git a/tests/max_encoded_len_ui/unsupported_variant.stderr b/tests/max_encoded_len_ui/unsupported_variant.stderr index bc4acacc..c4a48c60 100644 --- a/tests/max_encoded_len_ui/unsupported_variant.stderr +++ b/tests/max_encoded_len_ui/unsupported_variant.stderr @@ -1,12 +1,16 @@ -error[E0599]: no function or associated item named `max_encoded_len` found for struct `NotMel` in the current scope +error[E0277]: the trait bound `NotMel: MaxEncodedLen` is not satisfied --> tests/max_encoded_len_ui/unsupported_variant.rs:8:9 | -4 | struct NotMel; - | ------------- function or associated item `max_encoded_len` not found for this struct -... 8 | NotMel(NotMel), - | ^^^^^^ function or associated item not found in `NotMel` + | ^^^^^^ the trait `MaxEncodedLen` is not implemented for `NotMel` | - = help: items from traits can only be used if the trait is implemented and in scope - = note: the following trait defines an item `max_encoded_len`, perhaps you need to implement it: - candidate #1: `MaxEncodedLen` + = help: the following other types implement trait `MaxEncodedLen`: + () + (TupleElement0, TupleElement1) + (TupleElement0, TupleElement1, TupleElement2) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5, TupleElement6) + (TupleElement0, TupleElement1, TupleElement2, TupleElement3, TupleElement4, TupleElement5, TupleElement6, TupleElement7) + and $N others