diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index e913a35724c..a99c04d2249 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -621,6 +621,7 @@ struct PyClassEnumVariantNamedField<'a> { } /// `#[pyo3()]` options for pyclass enum variants +#[derive(Default)] struct EnumVariantPyO3Options { name: Option, constructor: Option, @@ -646,31 +647,33 @@ impl Parse for EnumVariantPyO3Option { impl EnumVariantPyO3Options { fn take_pyo3_options(attrs: &mut Vec) -> Result { - let mut options = EnumVariantPyO3Options { - name: None, - constructor: None, - }; + let mut options = EnumVariantPyO3Options::default(); - for option in take_pyo3_options(attrs)? { - match option { - EnumVariantPyO3Option::Name(name) => { - ensure_spanned!( - options.name.is_none(), - name.span() => "`name` may only be specified once" - ); - options.name = Some(name); - } - EnumVariantPyO3Option::Constructor(constructor) => { + take_pyo3_options(attrs)? + .into_iter() + .try_for_each(|option| options.set_option(option))?; + + Ok(options) + } + + fn set_option(&mut self, option: EnumVariantPyO3Option) -> syn::Result<()> { + macro_rules! set_option { + ($key:ident) => { + { ensure_spanned!( - options.constructor.is_none(), - constructor.span() => "`constructor` may only be specified once" + self.$key.is_none(), + $key.span() => concat!("`", stringify!($key), "` may only be specified once") ); - options.constructor = Some(constructor); + self.$key = Some($key); } - } + }; } - Ok(options) + match option { + EnumVariantPyO3Option::Constructor(constructor) => set_option!(constructor), + EnumVariantPyO3Option::Name(name) => set_option!(name), + } + Ok(()) } } @@ -704,18 +707,24 @@ fn impl_simple_enum( let variants = simple_enum.variants; let pytypeinfo = impl_pytypeinfo(cls, args, None, ctx); + for variant in &variants { + ensure_spanned!(variant.options.constructor.is_none(), variant.options.constructor.span() => "`constructor` can't be used on a simple enum variant"); + } + let (default_repr, default_repr_slot) = { - let variants_repr = variants.iter().map(|variant| { - ensure_spanned!(variant.options.constructor.is_none(), variant.options.constructor.span() => "`constructor` can't be used on a simple enum variant"); - let variant_name = variant.ident; - // Assuming all variants are unit variants because they are the only type we support. - let repr = format!( - "{}.{}", - get_class_python_name(cls, args), - variant.get_python_name(args), - ); - Ok(quote! { #cls::#variant_name => #repr, }) - }).collect::>()?; + let variants_repr = variants + .iter() + .map(|variant| { + let variant_name = variant.ident; + // Assuming all variants are unit variants because they are the only type we support. + let repr = format!( + "{}.{}", + get_class_python_name(cls, args), + variant.get_python_name(args), + ); + Ok(quote! { #cls::#variant_name => #repr, }) + }) + .collect::>()?; let mut repr_impl: syn::ImplItemFn = syn::parse_quote! { fn __pyo3__repr__(&self) -> &'static str { match self {