diff --git a/tests/varflags.rs b/tests/varflags.rs index 850b7f0..4772b00 100644 --- a/tests/varflags.rs +++ b/tests/varflags.rs @@ -4,8 +4,8 @@ use varflags::varflags; // Required attributes (added manually by user). #[derive(Clone, Copy, PartialEq, Eq, Debug)] -#[varflags] -enum TestInput { +#[varflags(Clone, Copy, Hash)] +enum TestInput2 { // Representation of the unspecified bits will be calculated A, B, @@ -31,48 +31,52 @@ enum TestInput { fn example() -> Result<(), Box> { use bitworks::prelude::*; - let a = TestInput::A; - let b = TestInput::B; + let a = TestInput2::A; + let b = TestInput2::B; - assert_eq!(TestInput::D as u8, 0b00010000); - assert_eq!(TestInput::E as u8, 0b10000000); - assert_eq!(TestInput::F as u8, 0b01000000); + assert_eq!(TestInput2::D as u8, 0b00010000); + assert_eq!(TestInput2::E as u8, 0b10000000); + assert_eq!(TestInput2::F as u8, 0b01000000); - let c = a | b | TestInput::D; - // EFHDGCBA - assert_eq!(c, TestInputVarflags(Bitset8::new(0b00010011))); + let c = a | b | TestInput2::D; + // EFHDGCBA + assert_eq!(c, TestInput2Varflags(Bitset8::new(0b00010011))); - assert!(c.contains(&TestInput::A)); - assert!(!c.contains(&TestInput::H)); + assert!(c.contains(&TestInput2::A)); + assert!(!c.contains(&TestInput2::H)); - let d = TestInput::A | TestInput::B; - let e = TestInput::A | TestInput::C; + let d = TestInput2::A | TestInput2::B; + let e = TestInput2::A | TestInput2::C; assert!(c.super_set(&d)); assert!(!c.super_set(&e)); - let f = TestInput::F | TestInput::H; + let f = TestInput2::F | TestInput2::H; assert!(c.intersects(&e)); assert!(!c.intersects(&f)); - let x = TestInputVarflags::all(); + let x = TestInput2Varflags::all(); let mut iter = x.variants(); - assert_eq!(iter.next(), Some(TestInput::A)); - assert_eq!(iter.next(), Some(TestInput::B)); - assert_eq!(iter.next(), Some(TestInput::C)); - assert_eq!(iter.next(), Some(TestInput::G)); - assert_eq!(iter.next(), Some(TestInput::D)); - assert_eq!(iter.next(), Some(TestInput::H)); - assert_eq!(iter.next(), Some(TestInput::F)); - assert_eq!(iter.next(), Some(TestInput::E)); + assert_eq!(iter.next(), Some(TestInput2::A)); + assert_eq!(iter.next(), Some(TestInput2::B)); + assert_eq!(iter.next(), Some(TestInput2::C)); + assert_eq!(iter.next(), Some(TestInput2::G)); + assert_eq!(iter.next(), Some(TestInput2::D)); + assert_eq!(iter.next(), Some(TestInput2::H)); + assert_eq!(iter.next(), Some(TestInput2::F)); + assert_eq!(iter.next(), Some(TestInput2::E)); assert_eq!(iter.next(), None); let iter = c.variants(); - let c: TestInputVarflags = iter.collect(); - // EFHDGCBA - assert_eq!(c, TestInputVarflags(Bitset8::new(0b00010011))); + let c: TestInput2Varflags = iter.collect(); + // EFHDGCBA + assert_eq!(c, TestInput2Varflags(Bitset8::new(0b00010011))); + + println!("{c}"); + + println!("{c:?}"); Ok(()) } diff --git a/tests/varflags_spec.rs b/tests/varflags_spec.rs index 01e2dd5..dfa3317 100644 --- a/tests/varflags_spec.rs +++ b/tests/varflags_spec.rs @@ -1,6 +1,6 @@ use std::error::Error; -// Specs for Bitflags attribute +// Specs for Varflags attribute #[rustfmt::skip] @@ -42,10 +42,10 @@ mod test_input_varflags { // Pick appropriate Bitfield and generate Repr depending on the choice. use bitworks::prelude::Bitset8 as Inner; type Repr = u8; - + // Use the enum. use super::TestInput as E; - + // Generated based on number of variants const VAR_COUNT: usize = 8; @@ -98,14 +98,37 @@ mod test_input_varflags { 0b01000000 => Ok(E::F), 0b00001000 => Ok(E::G), 0b00100000 => Ok(E::H), - _ => Err(ConvError::new(ConvTarget::Index(Inner::BYTE_SIZE), ConvTarget::Enum(VAR_COUNT))), + _ => Err(ConvError::new( + ConvTarget::Index(Inner::BYTE_SIZE), + ConvTarget::Enum(VAR_COUNT), + )), } } } - // This struct will be generated with Bitflags appended to enum's name. + // Generate Display for input + impl core::fmt::Display for E { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match *self { + E::A => "A", + E::B => "B", + E::C => "C", + E::D => "D", + E::E => "E", + E::F => "F", + E::G => "G", + E::H => "H", + } + ) + } + } + + // This struct will be generated with Varflags appended to enum's name. // Should derive Debug, PartialEq and Eq. - #[derive(Debug, PartialEq, Eq)] + #[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct TestInputVarflags( // Attribute should calculate, that Inner is to be used here. // Can't use ByteField here. Should be private. @@ -247,12 +270,52 @@ mod test_input_varflags { } } - // Generate implementation of FromIterator for Bitflags + // Generate implementation of FromIterator for Varflags impl FromIterator for TestInputVarflags { fn from_iter>(iter: T) -> Self { iter.into_iter().fold(Self::none(), |acc, v| acc | v) } } + + // Generate Debug for Varflags + impl core::fmt::Debug for TestInputVarflags { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let count = self.variants().count(); + write!( + f, + "TestInputVarflags{{{}}}", + self.variants() + .enumerate() + .fold("".to_owned(), |mut acc, (i, v)| { + acc.push_str(&v.to_string()); + if i != count - 1 { + acc.push_str(", ") + } + acc + }) + ) + } + } + + // Generate Display for Varflags + impl core::fmt::Display for TestInputVarflags { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let count = self.variants().count(); + write!( + f, + "{{{}}}", + self.variants() + .enumerate() + .fold("".to_owned(), |mut acc, (i, v)| { + acc.push_str(&v.to_string()); + if i != count - 1 { + acc.push_str(", ") + } + acc + }) + ) + } + } } // Reexport the struct locally. use test_input_varflags::TestInputVarflags; @@ -264,8 +327,12 @@ fn example() -> Result<(), Box> { let a = TestInput::A; let b = TestInput::B; + assert_eq!(TestInput::D as u8, 0b00010000); + assert_eq!(TestInput::E as u8, 0b10000000); + assert_eq!(TestInput::F as u8, 0b01000000); + let c = a | b | TestInput::D; - // EFHDGCBA + // EFHDGCBA assert_eq!(c, TestInputVarflags(Bitset8::new(0b00010011))); assert!(c.contains(&TestInput::A)); @@ -297,8 +364,12 @@ fn example() -> Result<(), Box> { let iter = c.variants(); let c: TestInputVarflags = iter.collect(); - // EFHDGCBA + // EFHDGCBA assert_eq!(c, TestInputVarflags(Bitset8::new(0b00010011))); + println!("{c}"); + + println!("{c:?}"); + Ok(()) } diff --git a/varflags_attribute/src/lib.rs b/varflags_attribute/src/lib.rs index 53761c6..6b942d5 100644 --- a/varflags_attribute/src/lib.rs +++ b/varflags_attribute/src/lib.rs @@ -1,5 +1,7 @@ extern crate proc_macro; +use std::collections::HashSet; + use bitworks::{bitset::Bitset, bitset128::Bitset128}; use proc_macro2::Ident; use quote::quote; @@ -27,17 +29,21 @@ static INVALID_DECL_ERROR: &'static str = "INTERNAL ERROR: variant declaration s static INVALID_MATC_ERROR: &'static str = "INTERNAL ERROR: variant match should be valid"; static INVALID_MOD_ERROR: &'static str = "INTERNAL ERROR: mod name should be valid"; static INVALID_STRUCT_NAME: &'static str = "INTERNAL ERROR: struct name should be valid"; +static MAX_3_ARGS: &'static str = "too many arguments, max 3"; +static ONLY_NAME_ARGS: &'static str = "argument should be a single word"; +static BAD_ARG: &'static str = "bad argument, expected one of: \"Clone\", \"Copy\" or \"Hash\""; +static BAD_ARG_TYPE: &'static str = "wrong argument type, expected name of trait to be derived"; /// Attribute #[proc_macro_attribute] pub fn varflags( - _: proc_macro::TokenStream, + args: proc_macro::TokenStream, item: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - varflags_impl(item) + varflags_impl(args, item) } -fn varflags_impl(item: proc_macro::TokenStream) -> proc_macro::TokenStream { +fn varflags_impl(args: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream { let ItemEnum { vis, ident, @@ -68,17 +74,21 @@ fn varflags_impl(item: proc_macro::TokenStream) -> proc_macro::TokenStream { let variant_data = parse_variants(variants, count); let variant_declaration = variant_declaration(variant_data.clone(), count); let mod_name = make_mod_name(ident.clone()); - let max_discriminant = *variant_data.discriminants.iter().flatten().max().expect("should have one max discriminant"); + let max_discriminant = *variant_data.discriminants.iter().flatten().max().expect("enum variants should have one biggest discriminant"); let (bitset, repr) = bitset_repr(max_discriminant); - let count_tok: proc_macro2::TokenStream = count.to_string().parse().unwrap(); + let count_tok: proc_macro2::TokenStream = count.to_string().parse().expect("count should be a valid numeric literal token"); let struct_name = make_struct_name(ident.clone()); - let variant_match = variant_match(variant_data, count); + let try_from_match = try_from_match(variant_data.clone(), count); + let struct_name_string = struct_name.to_string(); + let display_match = display_match(variant_data, count); + let args = parse_macro_input!(args with Punctuated::::parse_terminated); + let additional_derives = additional_derives(args); #[allow(unused_mut)] let mut serde_impl = proc_macro2::TokenStream::new(); #[cfg(feature = "serde")] { - serde_impl = "#[derive(serde::Serialize, serde::Deserialize)]".parse().unwrap(); + serde_impl = "#[derive(serde::Serialize, serde::Deserialize)]".parse().expect("derive for serde should be a valid derive attribute token"); } quote! { @@ -138,13 +148,21 @@ fn varflags_impl(item: proc_macro::TokenStream) -> proc_macro::TokenStream { fn try_from(value: Index) -> ConvResult { let n: Repr = 1 << value.into_inner(); match n { - #variant_match + #try_from_match _ => Err(ConvError::new(ConvTarget::Index(Inner::BYTE_SIZE), ConvTarget::Enum(VAR_COUNT))), } } } - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + impl core::fmt::Display for E { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match *self { + #display_match + }) + } + } + + #[derive(PartialEq, Eq, #additional_derives)] #serde_impl pub struct #struct_name(pub Inner); @@ -278,6 +296,32 @@ fn varflags_impl(item: proc_macro::TokenStream) -> proc_macro::TokenStream { iter.into_iter().fold(Self::none(), |acc, v| acc | v) } } + + impl core::fmt::Debug for #struct_name { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let count = self.variants().count(); + write!(f, "{}{{{}}}", #struct_name_string, self.variants().enumerate().fold("".to_owned(), |mut acc, (i, v)| { + acc.push_str(&v.to_string()); + if i != count - 1 { + acc.push_str(", ") + } + acc + })) + } + } + + impl core::fmt::Display for #struct_name { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let count = self.variants().count(); + write!(f, "{{{}}}", self.variants().enumerate().fold("".to_owned(), |mut acc, (i, v)| { + acc.push_str(&v.to_string()); + if i != count - 1 { + acc.push_str(", ") + } + acc + })) + } + } } #vis use #mod_name::#struct_name; @@ -456,12 +500,12 @@ fn bitset_repr(max_discriminant: u128) -> (proc_macro2::TokenStream, proc_macro2 U16_MAX_PLUS_1..=U32_MAX => 32, U32_MAX_PLUS_1..=U64_MAX => 64, U64_MAX_PLUS_1..=u128::MAX => 128, - 0 => panic!("bitset_repr: {NO_ZERO_VAR}"), + 0 => panic!("{NO_ZERO_VAR}"), }; let bitset = format!("Bitset{n}"); let repr = format!("u{n}"); - (bitset.parse().unwrap(), repr.parse().unwrap()) + (bitset.parse().expect("bitset should be a valid type token"), repr.parse().expect("repr should be a valid type token")) } fn make_struct_name(enum_name: Ident) -> proc_macro2::TokenStream { @@ -469,7 +513,7 @@ fn make_struct_name(enum_name: Ident) -> proc_macro2::TokenStream { s.parse().expect(INVALID_STRUCT_NAME) } -fn variant_match(data: VariantData, count: usize) -> proc_macro2::TokenStream { +fn try_from_match(data: VariantData, count: usize) -> proc_macro2::TokenStream { let mut s = "".to_owned(); for i in 0..count { s.push_str(&format!( @@ -480,3 +524,40 @@ fn variant_match(data: VariantData, count: usize) -> proc_macro2::TokenStream { } s.parse().expect(INVALID_MATC_ERROR) } + +fn display_match(data: VariantData, count: usize) -> proc_macro2::TokenStream { + let mut s = "".to_owned(); + for i in 0..count { + s.push_str(&format!( + "E::{} => \"{}\",", + data.idents[i], + data.idents[i] + )); + } + s.parse().expect(INVALID_MATC_ERROR) +} + +fn additional_derives(args: Punctuated) -> proc_macro2::TokenStream { + if args.len() > 3 { + panic!("{MAX_3_ARGS}"); + } + + let mut possible_args = HashSet::<&str>::from(["Clone", "Copy", "Hash"]); + let mut derives = String::new(); + for arg in args { + match arg { + Meta::Path(p) => { + let ident = p.get_ident().expect(ONLY_NAME_ARGS).to_string(); + if possible_args.contains(ident.as_str()) { + possible_args.remove(ident.as_str()); + derives.push_str(&ident); + derives.push_str(", "); + } else { + panic!("{BAD_ARG}"); + } + }, + _ => panic!("{BAD_ARG_TYPE}"), + } + } + derives.parse().expect("derives should be a valid comma separated meta token list") +}