Skip to content

Commit

Permalink
Add support for #[row(crate = ...)]
Browse files Browse the repository at this point in the history
  • Loading branch information
YBoy-git committed Dec 14, 2024
1 parent 3e81b45 commit 91e8054
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 8 deletions.
45 changes: 42 additions & 3 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,49 @@
use proc_macro2::TokenStream;
use quote::quote;
use quote::{quote, ToTokens};
use serde_derive_internals::{
attr::{Container, Default as SerdeDefault, Field},
Ctxt,
};
use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields};

struct Attributes {
crate_name: syn::Path,
}

impl From<&[syn::Attribute]> for Attributes {
fn from(attrs: &[syn::Attribute]) -> Self {
const ATTRIBUTE_NAME: &str = "row";
const ATTRIBUTE_SYNTAX: &str = "#[row(crate = ...)]";

const CRATE_NAME: &str = "crate";
const DEFAULT_CRATE_NAME: &str = "clickhouse";

let mut crate_name = None;
for attr in attrs {
if attr.path().is_ident(ATTRIBUTE_NAME) {
let row = attr.parse_args::<syn::Expr>().unwrap();
let syn::Expr::Assign(syn::ExprAssign { left, right, .. }) = row else {
panic!("expected `{}`", ATTRIBUTE_SYNTAX);
};
if left.to_token_stream().to_string() != CRATE_NAME {
panic!("expected `{}`", ATTRIBUTE_SYNTAX);
}
let syn::Expr::Path(syn::ExprPath { path, .. }) = *right else {
panic!("expected `{}`", ATTRIBUTE_SYNTAX);
};
crate_name = Some(path);
}
}
let crate_name = crate_name.unwrap_or_else(|| {
syn::Path::from(syn::Ident::new(
DEFAULT_CRATE_NAME,
proc_macro2::Span::call_site(),
))
});
Self { crate_name }
}
}

fn column_names(data: &DataStruct, cx: &Ctxt, container: &Container) -> TokenStream {
match &data.fields {
Fields::Named(fields) => {
Expand Down Expand Up @@ -36,11 +74,12 @@ fn column_names(data: &DataStruct, cx: &Ctxt, container: &Container) -> TokenStr
// TODO: support wrappers `Wrapper(Inner)` and `Wrapper<T>(T)`.
// TODO: support the `nested` attribute.
// TODO: support the `crate` attribute.
#[proc_macro_derive(Row)]
#[proc_macro_derive(Row, attributes(row))]
pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);

let cx = Ctxt::new();
let Attributes { crate_name } = Attributes::from(input.attrs.as_slice());
let container = Container::from_ast(&cx, &input);
let name = input.ident;

Expand All @@ -56,7 +95,7 @@ pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {

let expanded = quote! {
#[automatically_derived]
impl #impl_generics ::clickhouse::Row for #name #ty_generics #where_clause {
impl #impl_generics #crate_name::Row for #name #ty_generics #where_clause {
const COLUMN_NAMES: &'static [&'static str] = #column_names;
}
};
Expand Down
12 changes: 9 additions & 3 deletions src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,21 @@ pub(crate) fn join_column_names<R: Row>() -> Option<String> {

#[cfg(test)]
mod tests {
// XXX: need for `derive(Row)`. Provide `row(crate = ..)` instead.
use crate as clickhouse;
use clickhouse::Row;
use crate::Row;

use super::*;

#[test]
fn it_grabs_simple_struct() {
#[derive(Row)]
#[row(crate = crate)]
#[allow(dead_code)]
struct Simple1 {
one: u32,
}

#[derive(Row)]
#[row(crate = crate)]
#[allow(dead_code)]
struct Simple2 {
one: u32,
Expand All @@ -103,6 +103,7 @@ mod tests {
#[test]
fn it_grabs_mix() {
#[derive(Row)]
#[row(crate = crate)]
struct SomeRow {
_a: u32,
}
Expand All @@ -115,6 +116,7 @@ mod tests {
use serde::Serialize;

#[derive(Row, Serialize)]
#[row(crate = crate)]
#[allow(dead_code)]
struct TopLevel {
#[serde(rename = "two")]
Expand All @@ -129,6 +131,7 @@ mod tests {
use serde::Serialize;

#[derive(Row, Serialize)]
#[row(crate = crate)]
#[allow(dead_code)]
struct TopLevel {
one: u32,
Expand All @@ -144,6 +147,7 @@ mod tests {
use serde::Deserialize;

#[derive(Row, Deserialize)]
#[row(crate = crate)]
#[allow(dead_code)]
struct TopLevel {
one: u32,
Expand All @@ -158,6 +162,7 @@ mod tests {
fn it_rejects_other() {
#[allow(dead_code)]
#[derive(Row)]
#[row(crate = crate)]
struct NamedTuple(u32, u32);

assert_eq!(join_column_names::<u32>(), None);
Expand All @@ -170,6 +175,7 @@ mod tests {
use serde::Serialize;

#[derive(Row, Serialize)]
#[row(crate = crate)]
#[allow(dead_code)]
struct MyRow {
r#type: u32,
Expand Down
4 changes: 2 additions & 2 deletions src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,19 @@ impl SqlBuilder {
mod tests {
use super::*;

// XXX: need for `derive(Row)`. Provide `row(crate = ..)` instead.
use crate as clickhouse;
use clickhouse_derive::Row;

#[allow(unused)]
#[derive(Row)]
#[row(crate = crate)]
struct Row {
a: u32,
b: u32,
}

#[allow(unused)]
#[derive(Row)]
#[row(crate = crate)]
struct Unnamed(u32, u32);

#[test]
Expand Down

0 comments on commit 91e8054

Please sign in to comment.