Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: FromStr derive could support setting the error type #380

Merged
merged 3 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions strum_macros/src/helpers/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub mod kw {
custom_keyword!(serialize_all);
custom_keyword!(use_phf);
custom_keyword!(prefix);
custom_keyword!(parse_err_ty);
custom_keyword!(parse_err_fn);

// enum discriminant metadata
custom_keyword!(derive);
Expand Down Expand Up @@ -51,6 +53,14 @@ pub enum EnumMeta {
kw: kw::prefix,
prefix: LitStr,
},
ParseErrTy {
kw: kw::parse_err_ty,
path: Path,
},
ParseErrFn {
kw: kw::parse_err_fn,
path: Path,
},
}

impl Parse for EnumMeta {
Expand Down Expand Up @@ -80,6 +90,20 @@ impl Parse for EnumMeta {
input.parse::<Token![=]>()?;
let prefix = input.parse()?;
Ok(EnumMeta::Prefix { kw, prefix })
} else if lookahead.peek(kw::parse_err_ty) {
let kw = input.parse::<kw::parse_err_ty>()?;
input.parse::<Token![=]>()?;
let path_str: LitStr = input.parse()?;
let path_tokens = parse_str(&path_str.value())?;
let path = parse2(path_tokens)?;
Ok(EnumMeta::ParseErrTy { kw, path })
} else if lookahead.peek(kw::parse_err_fn) {
let kw = input.parse::<kw::parse_err_fn>()?;
input.parse::<Token![=]>()?;
let path_str: LitStr = input.parse()?;
let path_tokens = parse_str(&path_str.value())?;
let path = parse2(path_tokens)?;
Ok(EnumMeta::ParseErrFn { kw, path })
} else {
Err(lookahead.error())
}
Expand Down
7 changes: 7 additions & 0 deletions strum_macros/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ use proc_macro2::Span;
use quote::ToTokens;
use syn::spanned::Spanned;

pub fn missing_parse_err_attr_error() -> syn::Error {
syn::Error::new(
Span::call_site(),
"`parse_err_ty` and `parse_err_fn` attribute is both required.",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"`parse_err_ty` and `parse_err_fn` attribute is both required.",
"`parse_err_ty` and `parse_err_fn` attributes are both required.",

)
}

pub fn non_enum_error() -> syn::Error {
syn::Error::new(Span::call_site(), "This macro only supports enums.")
}
Expand Down
20 changes: 20 additions & 0 deletions strum_macros/src/helpers/type_props.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub trait HasTypeProperties {

#[derive(Clone, Default)]
pub struct StrumTypeProperties {
pub parse_err_ty: Option<Path>,
pub parse_err_fn: Option<Path>,
pub case_style: Option<CaseStyle>,
pub ascii_case_insensitive: bool,
pub crate_module_path: Option<Path>,
Expand All @@ -32,6 +34,8 @@ impl HasTypeProperties for DeriveInput {
let strum_meta = self.get_metadata()?;
let discriminants_meta = self.get_discriminants_metadata()?;

let mut parse_err_ty_kw = None;
let mut parse_err_fn_kw = None;
let mut serialize_all_kw = None;
let mut ascii_case_insensitive_kw = None;
let mut use_phf_kw = None;
Expand Down Expand Up @@ -82,6 +86,22 @@ impl HasTypeProperties for DeriveInput {
prefix_kw = Some(kw);
output.prefix = Some(prefix);
}
EnumMeta::ParseErrTy { path, kw } => {
if let Some(fst_kw) = parse_err_ty_kw {
return Err(occurrence_error(fst_kw, kw, "parse_err_ty"));
}

parse_err_ty_kw = Some(kw);
output.parse_err_ty = Some(path);
}
EnumMeta::ParseErrFn { path, kw } => {
if let Some(fst_kw) = parse_err_fn_kw {
return Err(occurrence_error(fst_kw, kw, "parse_err_fn"));
}

parse_err_fn_kw = Some(kw);
output.parse_err_fn = Some(path);
}
}
}

Expand Down
37 changes: 27 additions & 10 deletions strum_macros/src/macros/strings/from_string.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields};
use syn::{parse_quote, Data, DeriveInput, Fields, Path};

use crate::helpers::{
non_enum_error, occurrence_error, HasInnerVariantProperties, HasStrumVariantProperties,
HasTypeProperties,
missing_parse_err_attr_error, non_enum_error, occurrence_error, HasInnerVariantProperties,
HasStrumVariantProperties, HasTypeProperties,
};

pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
Expand All @@ -19,9 +19,25 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let strum_module_path = type_properties.crate_module_path();

let mut default_kw = None;
let mut default =
quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) };

let (mut default_err_ty, mut default) = match (
type_properties.parse_err_ty,
type_properties.parse_err_fn,
) {
(None, None) => (
quote! { #strum_module_path::ParseError },
quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) },
),
(Some(ty), Some(f)) => {
let ty_path: Path = parse_quote!(#ty);
let fn_path: Path = parse_quote!(#f);

(
quote! { #ty_path },
quote! { ::core::result::Result::Err(#fn_path(s)) },
)
}
_ => return Err(missing_parse_err_attr_error()),
};
let mut phf_exact_match_arms = Vec::new();
let mut standard_match_arms = Vec::new();
for variant in variants {
Expand All @@ -47,6 +63,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}
}
default_kw = Some(kw);
default_err_ty = quote! { #strum_module_path::ParseError };
default = quote! {
::core::result::Result::Ok(#name::#ident(s.into()))
};
Expand Down Expand Up @@ -146,7 +163,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let from_str = quote! {
#[allow(clippy::use_self)]
impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
type Err = #strum_module_path::ParseError;
type Err = #default_err_ty;

#[inline]
fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
Expand All @@ -160,7 +177,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
&impl_generics,
&ty_generics,
where_clause,
&strum_module_path,
&default_err_ty,
);

Ok(quote! {
Expand All @@ -186,12 +203,12 @@ fn try_from_str(
impl_generics: &syn::ImplGenerics,
ty_generics: &syn::TypeGenerics,
where_clause: Option<&syn::WhereClause>,
strum_module_path: &syn::Path,
default_err_ty: &TokenStream,
) -> TokenStream {
quote! {
#[allow(clippy::use_self)]
impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
type Error = #strum_module_path::ParseError;
type Error = #default_err_ty;

#[inline]
fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
Expand Down
33 changes: 33 additions & 0 deletions strum_tests/tests/from_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,36 @@ fn color_default_with_white() {
}
}
}

#[derive(Debug, EnumString)]
#[strum(
parse_err_fn = "some_enum_not_found_err",
parse_err_ty = "CaseCustomParseErrorNotFoundError"
)]
enum CaseCustomParseErrorEnum {
#[strum(serialize = "red")]
Red,
#[strum(serialize = "blue")]
Blue,
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
struct CaseCustomParseErrorNotFoundError(String);
impl std::fmt::Display for CaseCustomParseErrorNotFoundError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "not found `{}`", self.0)
}
}
impl std::error::Error for CaseCustomParseErrorNotFoundError {}
fn some_enum_not_found_err(s: &str) -> CaseCustomParseErrorNotFoundError {
CaseCustomParseErrorNotFoundError(s.to_string())
}

#[test]
fn case_custom_parse_error() {
let r = "yellow".parse::<CaseCustomParseErrorEnum>();
assert!(r.is_err());
assert_eq!(
CaseCustomParseErrorNotFoundError("yellow".to_string()),
r.unwrap_err()
);
}