Skip to content

Commit

Permalink
feat: Add try_from attribute for FromRow (#1081)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhengzhuo authored Sep 7, 2022
1 parent 18a76fb commit ddffaa7
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 16 deletions.
26 changes: 26 additions & 0 deletions sqlx-core/src/from_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,32 @@ use crate::row::Row;
/// }
/// }
/// ```
///
/// #### `try_from`
///
/// When your struct contains a field whose type is not matched with the database type,
/// if the field type has an implementation [`TryFrom`] for the database type,
/// you can use the `try_from` attribute to convert the database type to the field type.
/// For example:
///
/// ```rust,ignore
/// #[derive(sqlx::FromRow)]
/// struct User {
/// id: i32,
/// name: String,
/// #[sqlx(try_from = "i64")]
/// bigIntInMySql: u64
/// }
/// ```
///
/// Given a query such as:
///
/// ```sql
/// SELECT id, name, bigIntInMySql FROM users;
/// ```
///
/// In MySql, `BigInt` type matches `i64`, but you can convert it to `u64` by `try_from`.
///
pub trait FromRow<'r, R: Row>: Sized {
fn from_row(row: &'r R) -> Result<Self, Error>;
}
Expand Down
8 changes: 8 additions & 0 deletions sqlx-macros/src/derives/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub struct SqlxChildAttributes {
pub rename: Option<String>,
pub default: bool,
pub flatten: bool,
pub try_from: Option<Ident>,
}

pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
Expand Down Expand Up @@ -178,6 +179,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContai
pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttributes> {
let mut rename = None;
let mut default = false;
let mut try_from = None;
let mut flatten = false;

for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) {
Expand All @@ -194,6 +196,11 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
lit: Lit::Str(val),
..
}) if path.is_ident("rename") => try_set!(rename, val.value(), value),
Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(val),
..
}) if path.is_ident("try_from") => try_set!(try_from, val.parse()?, value),
Meta::Path(path) if path.is_ident("default") => default = true,
Meta::Path(path) if path.is_ident("flatten") => flatten = true,
u => fail!(u, "unexpected attribute"),
Expand All @@ -208,6 +215,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
rename,
default,
flatten,
try_from,
})
}

Expand Down
55 changes: 39 additions & 16 deletions sqlx-macros/src/derives/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,45 @@ fn expand_derive_from_row_struct(
let attributes = parse_child_attributes(&field.attrs).unwrap();
let ty = &field.ty;

let expr: Expr = if attributes.flatten {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
} else {
predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
let expr: Expr = match (attributes.flatten, attributes.try_from) {
(true, None) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
(false, None) => {
predicates
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
}
(true,Some(try_from)) => {
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
(false,Some(try_from)) => {
predicates
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));

let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
};

if attributes.default {
Expand Down
81 changes: 81 additions & 0 deletions tests/mysql/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,85 @@ async fn test_column_override_exact_enum() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_try_from_attr_for_native_type() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "i64")]
id: u64,
}

let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;

let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.id, id.0 as u64);

Ok(())
}

#[sqlx_macros::test]
async fn test_try_from_attr_for_custom_type() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "i64")]
id: Id,
}

#[derive(Debug, PartialEq)]
struct Id(i64);
impl std::convert::TryFrom<i64> for Id {
type Error = std::io::Error;
fn try_from(value: i64) -> Result<Self, Self::Error> {
Ok(Id(value))
}
}

let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;

let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.id, Id(id.0));

Ok(())
}

#[sqlx_macros::test]
async fn test_try_from_attr_with_flatten() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "Id", flatten)]
id: u64,
}

#[derive(Debug, PartialEq, sqlx::FromRow)]
struct Id {
id: i64,
}

impl std::convert::TryFrom<Id> for u64 {
type Error = std::io::Error;
fn try_from(value: Id) -> Result<Self, Self::Error> {
Ok(value.id as u64)
}
}

let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;

let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;

assert_eq!(record.id, id.0 as u64);

Ok(())
}

// we don't emit bind parameter type-checks for MySQL so testing the overrides is redundant

0 comments on commit ddffaa7

Please sign in to comment.