From 91e8054269be74f47d8e683b7d604368b29b1f09 Mon Sep 17 00:00:00 2001 From: Anton Parfonov Date: Sat, 14 Dec 2024 04:03:00 +0200 Subject: [PATCH] Add support for #[row(crate = ...)] --- derive/src/lib.rs | 45 ++++++++++++++++++++++++++++++++++++++++++--- src/row.rs | 12 +++++++++--- src/sql/mod.rs | 4 ++-- 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 73f31a1..677b82d 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -1,11 +1,49 @@ use proc_macro2::TokenStream; -use quote::quote; +use quote::{quote, ToTokens}; use serde_derive_internals::{ attr::{Container, Default as SerdeDefault, Field}, Ctxt, }; use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields}; +struct Attributes { + crate_name: syn::Path, +} + +impl From<&[syn::Attribute]> for Attributes { + fn from(attrs: &[syn::Attribute]) -> Self { + const ATTRIBUTE_NAME: &str = "row"; + const ATTRIBUTE_SYNTAX: &str = "#[row(crate = ...)]"; + + const CRATE_NAME: &str = "crate"; + const DEFAULT_CRATE_NAME: &str = "clickhouse"; + + let mut crate_name = None; + for attr in attrs { + if attr.path().is_ident(ATTRIBUTE_NAME) { + let row = attr.parse_args::().unwrap(); + let syn::Expr::Assign(syn::ExprAssign { left, right, .. }) = row else { + panic!("expected `{}`", ATTRIBUTE_SYNTAX); + }; + if left.to_token_stream().to_string() != CRATE_NAME { + panic!("expected `{}`", ATTRIBUTE_SYNTAX); + } + let syn::Expr::Path(syn::ExprPath { path, .. }) = *right else { + panic!("expected `{}`", ATTRIBUTE_SYNTAX); + }; + crate_name = Some(path); + } + } + let crate_name = crate_name.unwrap_or_else(|| { + syn::Path::from(syn::Ident::new( + DEFAULT_CRATE_NAME, + proc_macro2::Span::call_site(), + )) + }); + Self { crate_name } + } +} + fn column_names(data: &DataStruct, cx: &Ctxt, container: &Container) -> TokenStream { match &data.fields { Fields::Named(fields) => { @@ -36,11 +74,12 @@ fn column_names(data: &DataStruct, cx: &Ctxt, container: &Container) -> TokenStr // TODO: support wrappers `Wrapper(Inner)` and `Wrapper(T)`. // TODO: support the `nested` attribute. // TODO: support the `crate` attribute. -#[proc_macro_derive(Row)] +#[proc_macro_derive(Row, attributes(row))] pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let cx = Ctxt::new(); + let Attributes { crate_name } = Attributes::from(input.attrs.as_slice()); let container = Container::from_ast(&cx, &input); let name = input.ident; @@ -56,7 +95,7 @@ pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let expanded = quote! { #[automatically_derived] - impl #impl_generics ::clickhouse::Row for #name #ty_generics #where_clause { + impl #impl_generics #crate_name::Row for #name #ty_generics #where_clause { const COLUMN_NAMES: &'static [&'static str] = #column_names; } }; diff --git a/src/row.rs b/src/row.rs index c5ca680..8efa4db 100644 --- a/src/row.rs +++ b/src/row.rs @@ -75,21 +75,21 @@ pub(crate) fn join_column_names() -> Option { #[cfg(test)] mod tests { - // XXX: need for `derive(Row)`. Provide `row(crate = ..)` instead. - use crate as clickhouse; - use clickhouse::Row; + use crate::Row; use super::*; #[test] fn it_grabs_simple_struct() { #[derive(Row)] + #[row(crate = crate)] #[allow(dead_code)] struct Simple1 { one: u32, } #[derive(Row)] + #[row(crate = crate)] #[allow(dead_code)] struct Simple2 { one: u32, @@ -103,6 +103,7 @@ mod tests { #[test] fn it_grabs_mix() { #[derive(Row)] + #[row(crate = crate)] struct SomeRow { _a: u32, } @@ -115,6 +116,7 @@ mod tests { use serde::Serialize; #[derive(Row, Serialize)] + #[row(crate = crate)] #[allow(dead_code)] struct TopLevel { #[serde(rename = "two")] @@ -129,6 +131,7 @@ mod tests { use serde::Serialize; #[derive(Row, Serialize)] + #[row(crate = crate)] #[allow(dead_code)] struct TopLevel { one: u32, @@ -144,6 +147,7 @@ mod tests { use serde::Deserialize; #[derive(Row, Deserialize)] + #[row(crate = crate)] #[allow(dead_code)] struct TopLevel { one: u32, @@ -158,6 +162,7 @@ mod tests { fn it_rejects_other() { #[allow(dead_code)] #[derive(Row)] + #[row(crate = crate)] struct NamedTuple(u32, u32); assert_eq!(join_column_names::(), None); @@ -170,6 +175,7 @@ mod tests { use serde::Serialize; #[derive(Row, Serialize)] + #[row(crate = crate)] #[allow(dead_code)] struct MyRow { r#type: u32, diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 66330f6..465fc77 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -149,12 +149,11 @@ impl SqlBuilder { mod tests { use super::*; - // XXX: need for `derive(Row)`. Provide `row(crate = ..)` instead. - use crate as clickhouse; use clickhouse_derive::Row; #[allow(unused)] #[derive(Row)] + #[row(crate = crate)] struct Row { a: u32, b: u32, @@ -162,6 +161,7 @@ mod tests { #[allow(unused)] #[derive(Row)] + #[row(crate = crate)] struct Unnamed(u32, u32); #[test]