diff --git a/conjure-http/src/client.rs b/conjure-http/src/client.rs index 4dae12ba..3c425c87 100644 --- a/conjure-http/src/client.rs +++ b/conjure-http/src/client.rs @@ -328,7 +328,39 @@ pub trait AsyncDeserializeResponse { fn accept() -> Option; /// Deserializes the response. - async fn deserialize(response: Response) -> Result; + async fn deserialize(response: Response) -> Result + where + R: 'async_trait; +} + +/// A response deserializer which ignores the response and returns `()`. +pub enum UnitResponseDeserializer {} + +impl DeserializeResponse<(), R> for UnitResponseDeserializer { + fn accept() -> Option { + None + } + + fn deserialize(_: Response) -> Result<(), Error> { + Ok(()) + } +} + +#[async_trait] +impl AsyncDeserializeResponse<(), R> for UnitResponseDeserializer +where + R: Send, +{ + fn accept() -> Option { + None + } + + async fn deserialize(_: Response) -> Result<(), Error> + where + R: 'async_trait, + { + Ok(()) + } } /// A response deserializer which acts like a Conjure-generated client would. @@ -356,13 +388,16 @@ where impl AsyncDeserializeResponse for ConjureResponseDeserializer where T: DeserializeOwned, - R: Stream> + 'static + Send, + R: Stream> + Send, { fn accept() -> Option { Some(APPLICATION_JSON) } - async fn deserialize(response: Response) -> Result { + async fn deserialize(response: Response) -> Result + where + R: 'async_trait, + { if response.headers().get(CONTENT_TYPE) != Some(&APPLICATION_JSON) { return Err(Error::internal_safe("invalid response Content-Type")); } @@ -379,10 +414,21 @@ pub trait EncodeHeader { fn encode(value: T) -> Result, Error>; } -/// A header encoder which converts values via their `Display` implementation. -pub enum DisplayHeaderEncoder {} +/// A trait implemented by URL parameter encoders used by custom Conjure client trait +/// implementations. +pub trait EncodeParam { + /// Encodes the value into a sequence of parameters. + /// + /// When used with a path parameter, each returned string will be a separate path component. + /// When used with a query parameter, each returned string will be the value of a separate query + /// entry. + fn encode(value: T) -> Result, Error>; +} + +/// An encoder which converts values via their `Display` implementation. +pub enum DisplayEncoder {} -impl EncodeHeader for DisplayHeaderEncoder +impl EncodeHeader for DisplayEncoder where T: Display, { @@ -393,11 +439,20 @@ where } } -/// A header encoder which converts a sequence of values via their individual `Display` +impl EncodeParam for DisplayEncoder +where + T: Display, +{ + fn encode(value: T) -> Result, Error> { + Ok(vec![value.to_string()]) + } +} + +/// An encoder which converts a sequence of values via their individual `Display` /// implementations. -pub enum DisplaySeqHeaderEncoder {} +pub enum DisplaySeqEncoder {} -impl EncodeHeader for DisplaySeqHeaderEncoder +impl EncodeHeader for DisplaySeqEncoder where T: IntoIterator, U: Display, @@ -410,34 +465,7 @@ where } } -/// A trait implemented by URL parameter encoders used by custom Conjure client trait -/// implementations. -pub trait EncodeParam { - /// Encodes the value into a sequence of parameters. - /// - /// When used with a path parameter, each returned string will be a separate path component. - /// When used with a query parameter, each returned string will be the value of a separate query - /// entry. - fn encode(value: T) -> Result, Error>; -} - -/// A param encoder which converts values via their `Display` implementations. -pub enum DisplayParamEncoder {} - -impl EncodeParam for DisplayParamEncoder -where - T: Display, -{ - fn encode(value: T) -> Result, Error> { - Ok(vec![value.to_string()]) - } -} - -/// A param encoder which converts a sequence of values via their individual `Display` -/// implementations. -pub enum DisplaySeqParamEncoder {} - -impl EncodeParam for DisplaySeqParamEncoder +impl EncodeParam for DisplaySeqEncoder where T: IntoIterator, U: Display, diff --git a/conjure-http/src/private/mod.rs b/conjure-http/src/private/mod.rs index 470898c5..af7c325f 100644 --- a/conjure-http/src/private/mod.rs +++ b/conjure-http/src/private/mod.rs @@ -23,8 +23,10 @@ pub use http::{self, header, Extensions, Method, Request, Response}; pub use pin_utils::pin_mut; pub use std::borrow::Cow; pub use std::boxed::Box; +pub use std::env; pub use std::future::Future; pub use std::iter::Iterator; +pub use std::marker::{Send, Sync}; pub use std::option::Option; pub use std::pin::Pin; pub use std::result::Result; diff --git a/conjure-macros/src/client.rs b/conjure-macros/src/client.rs index ce2c176b..1d88de1a 100644 --- a/conjure-macros/src/client.rs +++ b/conjure-macros/src/client.rs @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. use crate::path::{self, PathComponent}; -use crate::Asyncness; +use crate::{Asyncness, Errors}; use http::HeaderName; use percent_encoding::AsciiSet; -use proc_macro2::{Ident, TokenStream}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use std::collections::HashMap; use structmeta::StructMeta; +use syn::spanned::Spanned; use syn::{ - parse_macro_input, Error, FnArg, ItemTrait, LitStr, Meta, Pat, ReturnType, TraitItem, - TraitItemFn, Type, + parse_macro_input, Error, FnArg, GenericParam, Generics, ItemTrait, LitStr, Meta, Pat, PatType, + ReturnType, TraitItem, TraitItemFn, Type, Visibility, }; // https://url.spec.whatwg.org/#query-percent-encode-set @@ -56,8 +57,11 @@ pub fn generate( item: proc_macro::TokenStream, ) -> proc_macro::TokenStream { let mut item = parse_macro_input!(item as ItemTrait); - - let client = generate_client(&mut item); + let service = match Service::new(&mut item) { + Ok(service) => service, + Err(e) => return e.into_compile_error().into(), + }; + let client = generate_client(&service); quote! { #item @@ -67,42 +71,63 @@ pub fn generate( .into() } -fn generate_client(trait_: &mut ItemTrait) -> TokenStream { - let vis = &trait_.vis; - let trait_name = &trait_.ident; +fn generate_client(service: &Service) -> TokenStream { + let vis = &service.vis; + let trait_name = &service.name; let type_name = Ident::new(&format!("{}Client", trait_name), trait_name.span()); - let asyncness = match Asyncness::resolve(trait_) { - Ok(asyncness) => asyncness, - Err(e) => return e.into_compile_error(), - }; - - let service_trait = match asyncness { + let service_trait = match service.asyncness { Asyncness::Sync => quote!(Service), Asyncness::Async => quote!(AsyncService), }; - let impl_attrs = match asyncness { + let impl_attrs = match service.asyncness { Asyncness::Sync => quote!(), Asyncness::Async => quote!(#[conjure_http::private::async_trait]), }; - let where_ = match asyncness { - Asyncness::Sync => quote!(C: conjure_http::client::Client), - Asyncness::Async => quote! { - C: conjure_http::client::AsyncClient + Sync + Send, - C::ResponseBody: 'static + Send, - }, - }; + let (_, type_generics, _) = service.generics.split_for_impl(); - let methods = trait_ - .items - .iter_mut() - .filter_map(|item| match item { - TraitItem::Fn(meth) => Some(meth), - _ => None, + let mut impl_generics = service.generics.clone(); + + let client_param = quote!(__C); + impl_generics.params.push(syn::parse2(quote!(__C)).unwrap()); + + let where_clause = impl_generics.make_where_clause(); + let client_trait = match service.asyncness { + Asyncness::Sync => quote!(Client), + Asyncness::Async => quote!(AsyncClient), + }; + let mut client_bindings = vec![]; + if let Some(param) = &service.request_writer_param { + client_bindings.push(quote!(BodyWriter = #param)); + } + if let Some(param) = &service.response_body_param { + client_bindings.push(quote!(ResponseBody = #param)); + } + let extra_client_predicates = match service.asyncness { + Asyncness::Sync => quote!(), + Asyncness::Async => quote!(+ conjure_http::private::Sync + conjure_http::private::Send), + }; + where_clause.predicates.push( + syn::parse2(quote! { + #client_param: conjure_http::client::#client_trait<#(#client_bindings),*> #extra_client_predicates }) - .map(|m| generate_client_method(trait_name, asyncness, m)); + .unwrap(), + ); + if let Asyncness::Async = service.asyncness { + where_clause.predicates.push( + syn::parse2(quote!(#client_param::ResponseBody: 'static + conjure_http::private::Send)) + .unwrap(), + ); + } + + let (impl_generics, _, where_clause) = impl_generics.split_for_impl(); + + let methods = service + .endpoints + .iter() + .map(|endpoint| generate_client_method(&client_param, service, endpoint)); quote! { #vis struct #type_name { @@ -116,8 +141,8 @@ fn generate_client(trait_: &mut ItemTrait) -> TokenStream { } #impl_attrs - impl #trait_name for #type_name - where #where_ + impl #impl_generics #trait_name #type_generics for #type_name<#client_param> + #where_clause { #(#methods)* } @@ -125,72 +150,47 @@ fn generate_client(trait_: &mut ItemTrait) -> TokenStream { } fn generate_client_method( - trait_name: &Ident, - asyncness: Asyncness, - method: &mut TraitItemFn, + client_param: &TokenStream, + service: &Service, + endpoint: &Endpoint, ) -> TokenStream { - let mut endpoint_attrs = method - .attrs - .iter() - .filter(|attr| attr.path().is_ident("endpoint")); - let Some(endpoint_attr) = endpoint_attrs.next() else { - return Error::new_spanned(method, "missing #[endpoint] attribute").into_compile_error(); - }; - let endpoint = match endpoint_attr.parse_args::() { - Ok(endpoint) => endpoint, - Err(e) => return e.into_compile_error(), - }; - let duplicates = endpoint_attrs - .map(|a| Error::new_spanned(a, "duplicate #[endpoint] attribute").into_compile_error()) - .collect::>(); - if !duplicates.is_empty() { - return quote!(#(#duplicates)*); - } - - let async_ = match asyncness { + let async_ = match service.asyncness { Asyncness::Sync => quote!(), Asyncness::Async => quote!(async), }; - let client_trait = match asyncness { + let client_trait = match service.asyncness { Asyncness::Sync => quote!(Client), Asyncness::Async => quote!(AsyncClient), }; - let await_ = match asyncness { + let await_ = match service.asyncness { Asyncness::Sync => quote!(), Asyncness::Async => quote!(.await), }; - let request_args = match method - .sig - .inputs - .iter_mut() - .flat_map(|a| ArgType::new(a).transpose()) - .collect::, _>>() - { - Ok(request_args) => request_args, - Err(e) => return e.into_compile_error(), - }; - - let name = &method.sig.ident; - let args = &method.sig.inputs; - let ret = &method.sig.output; + let name = &endpoint.ident; + let args = endpoint.args.iter().map(|a| { + let ident = a.ident(); + let ty = a.ty(); + quote!(#ident: #ty) + }); + let ret_ty = &endpoint.ret_ty; let request = quote!(__request); let response = quote!(__response); - let http_method = &endpoint.method; + let http_method = &endpoint.params.method; - let create_request = create_request(asyncness, &request, &request_args); - let add_path = add_path(&request, &request_args, &endpoint); - let add_accept = add_accept(asyncness, &request, &endpoint, &method.sig.output); - let add_auth = add_auth(&request, &request_args); - let add_headers = add_headers(&request, &request_args); - let add_endpoint = add_endpoint(trait_name, method, &endpoint, &request); - let handle_response = handle_response(asyncness, &endpoint, &response); + let create_request = create_request(client_param, &request, service, endpoint); + let add_path = add_path(&request, endpoint); + let add_accept = add_accept(client_param, &request, service, endpoint); + let add_auth = add_auth(&request, endpoint); + let add_headers = add_headers(&request, endpoint); + let add_endpoint = add_endpoint(&request, service, endpoint); + let handle_response = handle_response(&response, service, endpoint); quote! { - #async_ fn #name(#args) #ret { + #async_ fn #name(&self #(, #args)*) -> #ret_ty { #create_request *#request.method_mut() = conjure_http::private::Method::#http_method; #add_path @@ -204,13 +204,18 @@ fn generate_client_method( } } -fn create_request(asyncness: Asyncness, request: &TokenStream, args: &[ArgType]) -> TokenStream { - let mut it = args.iter().filter_map(|a| match a { +fn create_request( + client_param: &TokenStream, + request: &TokenStream, + service: &Service, + endpoint: &Endpoint, +) -> TokenStream { + let arg = endpoint.args.iter().find_map(|a| match a { ArgType::Body(arg) => Some(arg), _ => None, }); - let Some(arg) = it.next() else { - let body = match asyncness { + let Some(arg) = arg else { + let body = match service.asyncness { Asyncness::Sync => quote!(RequestBody), Asyncness::Async => quote!(AsyncRequestBody), }; @@ -221,12 +226,7 @@ fn create_request(asyncness: Asyncness, request: &TokenStream, args: &[ArgType]) }; }; - if let Some(arg) = it.next() { - return Error::new_spanned(&arg.ident, "only one #[body] argument allowed") - .into_compile_error(); - } - - let trait_ = match asyncness { + let trait_ = match service.asyncness { Asyncness::Sync => quote!(SerializeRequest), Asyncness::Async => quote!(AsyncSerializeRequest), }; @@ -239,13 +239,13 @@ fn create_request(asyncness: Asyncness, request: &TokenStream, args: &[ArgType]) quote! { let __content_type = < - #serializer as conjure_http::client::#trait_<_, C::BodyWriter> + #serializer as conjure_http::client::#trait_<_, #client_param::BodyWriter> >::content_type(&#ident); let __content_length = < - #serializer as conjure_http::client::#trait_<_, C::BodyWriter> + #serializer as conjure_http::client::#trait_<_, #client_param::BodyWriter> >::content_length(&#ident); let __body = < - #serializer as conjure_http::client::#trait_<_, C::BodyWriter> + #serializer as conjure_http::client::#trait_<_, #client_param::BodyWriter> >::serialize(#ident)?; let mut #request = conjure_http::private::Request::new(__body); @@ -262,16 +262,13 @@ fn create_request(asyncness: Asyncness, request: &TokenStream, args: &[ArgType]) } } -fn add_path( - request: &TokenStream, - request_args: &[ArgType], - endpoint: &EndpointConfig, -) -> TokenStream { +fn add_path(request: &TokenStream, endpoint: &Endpoint) -> TokenStream { let builder = quote!(__path); - let path_writes = add_path_components(&endpoint.path, &builder, request_args); + let path_writes = add_path_components(&builder, endpoint); - let query_params = request_args + let query_params = endpoint + .args .iter() .filter_map(|arg| match arg { ArgType::Query(arg) => Some(arg), @@ -287,17 +284,9 @@ fn add_path( } } -fn add_path_components( - path_lit: &LitStr, - builder: &TokenStream, - request_args: &[ArgType], -) -> TokenStream { - let path = match path::parse(path_lit) { - Ok(path) => path, - Err(e) => return e.into_compile_error(), - }; - - let path_params = request_args +fn add_path_components(builder: &TokenStream, endpoint: &Endpoint) -> TokenStream { + let path_params = endpoint + .args .iter() .filter_map(|a| match a { ArgType::Path(param) => Some((param.ident.to_string(), param)), @@ -307,7 +296,7 @@ fn add_path_components( let mut path_writes = vec![]; let mut literal_buf = String::new(); - for component in path { + for component in &endpoint.path { match component { PathComponent::Literal(lit) => { literal_buf.push('/'); @@ -323,20 +312,11 @@ fn add_path_components( literal_buf = String::new(); } - let Some(param) = path_params.get(¶m) else { - path_writes.push( - Error::new_spanned( - path_lit, - format_args!("invalid path parameter `{param}`"), - ) - .into_compile_error(), - ); - continue; - }; + let param = path_params[param]; let ident = ¶m.ident; let encoder = param.attr.encoder.as_ref().map_or_else( - || quote!(conjure_http::client::DisplayParamEncoder), + || quote!(conjure_http::client::DisplayEncoder), |e| quote!(#e), ); @@ -366,7 +346,7 @@ fn add_query_arg(builder: &TokenStream, arg: &Arg) -> TokenStream { let name = percent_encoding::percent_encode(arg.attr.name.value().as_bytes(), COMPONENT).to_string(); let encoder = arg.attr.encoder.as_ref().map_or_else( - || quote!(conjure_http::client::DisplayParamEncoder), + || quote!(conjure_http::client::DisplayEncoder), |e| quote!(#e), ); @@ -379,29 +359,26 @@ fn add_query_arg(builder: &TokenStream, arg: &Arg) -> TokenStream { } fn add_accept( - asyncness: Asyncness, + client_param: &TokenStream, request: &TokenStream, - endpoint: &EndpointConfig, - ret_ty: &ReturnType, + service: &Service, + endpoint: &Endpoint, ) -> TokenStream { - let Some(accept) = &endpoint.accept else { + let Some(accept) = &endpoint.params.accept else { return quote!(); }; - let trait_ = match asyncness { + let trait_ = match service.asyncness { Asyncness::Sync => quote!(DeserializeResponse), Asyncness::Async => quote!(AsyncDeserializeResponse), }; - let ret = match ret_ty { - ReturnType::Default => quote!(()), - ReturnType::Type(_, ty) => quote!(#ty), - }; + let ret_ty = &endpoint.ret_ty; quote! { let __accept = <#accept as conjure_http::client::#trait_< - <#ret as conjure_http::private::ExtractOk>::Ok, - C::ResponseBody, + <#ret_ty as conjure_http::private::ExtractOk>::Ok, + #client_param::ResponseBody, >>::accept(); if let Some(__accept) = __accept { #request.headers_mut().insert(conjure_http::private::header::ACCEPT, __accept); @@ -409,8 +386,8 @@ fn add_accept( } } -fn add_auth(request: &TokenStream, args: &[ArgType]) -> TokenStream { - let mut it = args.iter().filter_map(|a| match a { +fn add_auth(request: &TokenStream, endpoint: &Endpoint) -> TokenStream { + let mut it = endpoint.args.iter().filter_map(|a| match a { ArgType::Auth(auth) => Some(auth), _ => None, }); @@ -438,8 +415,9 @@ fn add_auth(request: &TokenStream, args: &[ArgType]) -> TokenStream { } } -fn add_headers(request: &TokenStream, args: &[ArgType]) -> TokenStream { - let add_headers = args +fn add_headers(request: &TokenStream, endpoint: &Endpoint) -> TokenStream { + let add_headers = endpoint + .args .iter() .filter_map(|arg| match arg { ArgType::Header(arg) => Some(arg), @@ -453,15 +431,10 @@ fn add_headers(request: &TokenStream, args: &[ArgType]) -> TokenStream { } fn add_header(request: &TokenStream, arg: &Arg) -> TokenStream { - let header_name = match arg.attr.name.value().parse::() { - Ok(header_name) => header_name, - Err(e) => return Error::new_spanned(&arg.attr.name, e).into_compile_error(), - }; - let ident = &arg.ident; - let name = header_name.as_str(); + let name = arg.attr.name.value().to_ascii_lowercase(); let encoder = arg.attr.encoder.as_ref().map_or_else( - || quote!(conjure_http::client::DisplayHeaderEncoder), + || quote!(conjure_http::client::DisplayEncoder), |v| quote!(#v), ); @@ -476,52 +449,283 @@ fn add_header(request: &TokenStream, arg: &Arg) -> TokenStream { } } -fn add_endpoint( - trait_name: &Ident, - method: &TraitItemFn, - endpoint: &EndpointConfig, - request: &TokenStream, -) -> TokenStream { - let service = trait_name.to_string(); - let name = method.sig.ident.to_string(); - let path = &endpoint.path; +fn add_endpoint(request: &TokenStream, service: &Service, endpoint: &Endpoint) -> TokenStream { + let service = service.name.to_string(); + let name = endpoint.ident.to_string(); + let path = &endpoint.params.path; quote! { #request.extensions_mut().insert(conjure_http::client::Endpoint::new( #service, - std::option::Option::Some(std::env!("CARGO_PKG_VERSION")), + conjure_http::private::Option::Some(conjure_http::private::env!("CARGO_PKG_VERSION")), #name, #path, )); } } -fn handle_response( +fn handle_response(response: &TokenStream, service: &Service, endpoint: &Endpoint) -> TokenStream { + let accept = endpoint.params.accept.as_ref().map_or_else( + || quote!(conjure_http::client::UnitResponseDeserializer), + |t| quote!(#t), + ); + let trait_ = match service.asyncness { + Asyncness::Sync => quote!(DeserializeResponse), + Asyncness::Async => quote!(AsyncDeserializeResponse), + }; + let await_ = match service.asyncness { + Asyncness::Sync => quote!(), + Asyncness::Async => quote!(.await), + }; + + quote! { + <#accept as conjure_http::client::#trait_<_, _>>::deserialize(#response) #await_ + } +} + +struct Service { + vis: Visibility, + name: Ident, + generics: Generics, + request_writer_param: Option, + response_body_param: Option, asyncness: Asyncness, - endpoint: &EndpointConfig, - response: &TokenStream, -) -> TokenStream { - match &endpoint.accept { - Some(accept) => { - let trait_ = match asyncness { - Asyncness::Sync => quote!(DeserializeResponse), - Asyncness::Async => quote!(AsyncDeserializeResponse), + endpoints: Vec, +} + +impl Service { + fn new(trait_: &mut ItemTrait) -> Result { + let mut errors = Errors::new(); + let mut endpoints = vec![]; + for item in &trait_.items { + match Endpoint::new(item) { + Ok(endpoint) => endpoints.push(endpoint), + Err(e) => errors.push(e), + } + } + + let asyncness = match Asyncness::resolve(trait_) { + Ok(asyncness) => Some(asyncness), + Err(e) => { + errors.push(e); + None + } + }; + + let mut request_writer_param = None; + let mut response_body_param = None; + for param in &trait_.generics.params { + let GenericParam::Type(param) = param else { + errors.push(Error::new_spanned(param, "unexpected parameter")); + continue; }; - let await_ = match asyncness { - Asyncness::Sync => quote!(), - Asyncness::Async => quote!(.await), + + for attr in ¶m.attrs { + if attr.path().is_ident("request_writer") { + request_writer_param = Some(param.ident.clone()); + } else if attr.path().is_ident("response_body") { + response_body_param = Some(param.ident.clone()); + } + } + } + + strip_trait(trait_); + errors.build()?; + + Ok(Service { + vis: trait_.vis.clone(), + name: trait_.ident.clone(), + generics: trait_.generics.clone(), + request_writer_param, + response_body_param, + asyncness: asyncness.unwrap(), + endpoints, + }) + } +} + +// Rust doesn't support helper attributes in attribute macros so we need to manually strip them out +fn strip_trait(trait_: &mut ItemTrait) { + for param in &mut trait_.generics.params { + strip_param(param); + } + + for item in &mut trait_.items { + if let TraitItem::Fn(fn_) = item { + strip_fn(fn_); + } + } +} + +fn strip_param(param: &mut GenericParam) { + let GenericParam::Type(param) = param else { + return; + }; + + param.attrs.retain(|attr| { + !["request_writer", "response_body"] + .iter() + .any(|v| attr.path().is_ident(v)) + }); +} + +fn strip_fn(fn_: &mut TraitItemFn) { + for arg in &mut fn_.sig.inputs { + strip_arg(arg); + } +} + +fn strip_arg(arg: &mut FnArg) { + let FnArg::Typed(arg) = arg else { return }; + + arg.attrs.retain(|attr| { + !["path", "query", "header", "body", "auth"] + .iter() + .any(|v| attr.path().is_ident(v)) + }); +} + +struct Endpoint { + ident: Ident, + args: Vec, + ret_ty: Type, + params: EndpointParams, + path: Vec, +} + +impl Endpoint { + fn new(item: &TraitItem) -> Result { + let TraitItem::Fn(item) = item else { + return Err(Error::new_spanned( + item, + "Conjure traits may only contain methods", + )); + }; + + let mut errors = Errors::new(); + + let mut endpoint_attrs = item + .attrs + .iter() + .filter(|attr| attr.path().is_ident("endpoint")); + let params = endpoint_attrs + .next() + .ok_or_else(|| Error::new_spanned(item, "missing #[endpoint] attribute")) + .and_then(|a| a.parse_args::()); + let params = match params { + Ok(params) => Some(params), + Err(e) => { + errors.push(e); + None + } + }; + + let mut args = vec![]; + for arg in &item.sig.inputs { + // Ignore the self arg. + let FnArg::Typed(arg) = arg else { continue }; + + match ArgType::new(arg) { + Ok(arg) => args.push(arg), + Err(e) => errors.push(e), + } + } + + let ret_ty = match &item.sig.output { + ReturnType::Default => { + errors.push(Error::new_spanned( + &item.sig.output, + "expected a return type", + )); + None + } + ReturnType::Type(_, ty) => Some((**ty).clone()), + }; + + let path = match params.as_ref().map(|p| path::parse(&p.path)).transpose() { + Ok(path) => path, + Err(e) => { + errors.push(e); + None + } + }; + + if let Err(e) = validate_args(&args, params.as_ref().map(|p| &p.path), path.as_deref()) { + errors.push(e); + } + + errors.build()?; + + Ok(Endpoint { + ident: item.sig.ident.clone(), + args, + ret_ty: ret_ty.unwrap(), + params: params.unwrap(), + path: path.unwrap(), + }) + } +} + +fn validate_args( + args: &[ArgType], + path: Option<&LitStr>, + path_components: Option<&[PathComponent]>, +) -> Result<(), Error> { + let mut errors = Errors::new(); + + let mut body_args = args.iter().filter(|a| matches!(a, ArgType::Body(_))); + if body_args.next().is_some() { + for arg in body_args { + errors.push(Error::new(arg.span(), "duplicate `#[body]` arg")); + } + } + + let mut auth_args = args.iter().filter(|a| matches!(a, ArgType::Auth(_))); + if auth_args.next().is_some() { + for arg in auth_args { + errors.push(Error::new(arg.span(), "duplicate `#[auth]` arg")); + } + } + + for arg in args { + let ArgType::Header(arg) = arg else { continue }; + if let Err(e) = arg.attr.name.value().parse::() { + errors.push(Error::new(arg.span, e)); + } + } + + if let (Some(path), Some(path_components)) = (path, path_components) { + let mut path_params = args + .iter() + .filter_map(|a| match a { + ArgType::Path(arg) => Some((arg.ident.to_string(), arg.span)), + _ => None, + }) + .collect::>(); + + for component in path_components { + let PathComponent::Parameter(param) = component else { + continue; }; - quote! { - <#accept as conjure_http::client::#trait_<_, _>>::deserialize(#response) #await_ + if path_params.remove(param).is_none() { + errors.push(Error::new_spanned( + path, + format!("invalid path parameter `{param}`"), + )); } } - None => quote!(conjure_http::private::Result::Ok(())), + + for span in path_params.values() { + errors.push(Error::new(*span, "unused path parameter")); + } } + + errors.build() } #[derive(StructMeta)] -struct EndpointConfig { +struct EndpointParams { method: Ident, path: LitStr, accept: Option, @@ -535,8 +739,42 @@ enum ArgType { Body(Arg), } +impl ArgType { + fn ident(&self) -> &Ident { + match self { + ArgType::Path(arg) => &arg.ident, + ArgType::Query(arg) => &arg.ident, + ArgType::Header(arg) => &arg.ident, + ArgType::Auth(arg) => &arg.ident, + ArgType::Body(arg) => &arg.ident, + } + } + + fn ty(&self) -> &Type { + match self { + ArgType::Path(arg) => &arg.ty, + ArgType::Query(arg) => &arg.ty, + ArgType::Header(arg) => &arg.ty, + ArgType::Auth(arg) => &arg.ty, + ArgType::Body(arg) => &arg.ty, + } + } + + fn span(&self) -> Span { + match self { + ArgType::Path(arg) => arg.span, + ArgType::Query(arg) => arg.span, + ArgType::Header(arg) => arg.span, + ArgType::Auth(arg) => arg.span, + ArgType::Body(arg) => arg.span, + } + } +} + struct Arg { ident: Ident, + ty: Type, + span: Span, attr: T, } @@ -562,27 +800,17 @@ struct BodyAttr { } impl ArgType { - fn new(arg: &mut FnArg) -> syn::Result> { - // Ignore the self arg. - let FnArg::Typed(pat_type) = arg else { - return Ok(None); - }; - + fn new(arg: &PatType) -> syn::Result { // FIXME we should probably just rename the arguments in our impl? - let ident = match &*pat_type.pat { + let ident = match &*arg.pat { Pat::Ident(pat_ident) => &pat_ident.ident, - _ => { - return Err(Error::new_spanned( - &pat_type.pat, - "expected an ident pattern", - )) - } + _ => return Err(Error::new_spanned(&arg.pat, "expected an ident pattern")), }; let mut type_ = None; // FIXME detect multiple attrs - for attr in &pat_type.attrs { + for attr in &arg.attrs { if attr.path().is_ident("path") { let attr = match attr.meta { Meta::Path(_) => PathAttr { encoder: None }, @@ -590,16 +818,22 @@ impl ArgType { }; type_ = Some(ArgType::Path(Arg { ident: ident.clone(), + ty: (*arg.ty).clone(), + span: arg.span(), attr, })); } else if attr.path().is_ident("query") { type_ = Some(ArgType::Query(Arg { ident: ident.clone(), + ty: (*arg.ty).clone(), + span: arg.span(), attr: attr.parse_args()?, })); } else if attr.path().is_ident("header") { type_ = Some(ArgType::Header(Arg { ident: ident.clone(), + ty: (*arg.ty).clone(), + span: arg.span(), attr: attr.parse_args()?, })); } else if attr.path().is_ident("auth") { @@ -609,6 +843,8 @@ impl ArgType { }; type_ = Some(ArgType::Auth(Arg { ident: ident.clone(), + ty: (*arg.ty).clone(), + span: arg.span(), attr, })); } else if attr.path().is_ident("body") { @@ -618,27 +854,13 @@ impl ArgType { }; type_ = Some(ArgType::Body(Arg { ident: ident.clone(), + ty: (*arg.ty).clone(), + span: arg.span(), attr, })); } } - // Rust doesn't support "helper" attributes in attribute macros, so we need to strip out our - // helper attributes on arguments. - strip_arg_attrs(arg); - - type_ - .ok_or_else(|| Error::new_spanned(arg, "missing argument type annotation")) - .map(Some) + type_.ok_or_else(|| Error::new_spanned(arg, "missing parameter type attribute")) } } - -fn strip_arg_attrs(arg: &mut FnArg) { - let FnArg::Typed(arg) = arg else { return }; - - arg.attrs.retain(|attr| { - !["path", "query", "header", "body", "auth"] - .iter() - .any(|v| attr.path().is_ident(v)) - }); -} diff --git a/conjure-macros/src/endpoints.rs b/conjure-macros/src/endpoints.rs index 90120257..0aa53d6c 100644 --- a/conjure-macros/src/endpoints.rs +++ b/conjure-macros/src/endpoints.rs @@ -123,7 +123,9 @@ fn generate_endpoints(service: &Service) -> TokenStream { &self, runtime: &conjure_http::private::Arc, ) -> conjure_http::private::Vec + Sync + Send, + dyn conjure_http::server::#endpoint_trait<#request_body, #response_writer> + + conjure_http::private::Sync + + conjure_http::private::Send, >> { #(#endpoints)* @@ -171,8 +173,13 @@ fn impl_params(service: &Service) -> ImplParams { let where_clause = impl_generics.make_where_clause(); where_clause.predicates.push( - syn::parse2(quote!(#trait_impl: #trait_name #type_generics + 'static + Sync + Send)) - .unwrap(), + syn::parse2(quote! { + #trait_impl: #trait_name #type_generics + + 'static + + conjure_http::private::Sync + + conjure_http::private::Send + }) + .unwrap(), ); let input_bounds = input_bounds(service); where_clause @@ -197,7 +204,11 @@ fn input_bounds(service: &Service) -> TokenStream { match service.asyncness { Asyncness::Sync => quote!(conjure_http::private::Iterator), - Asyncness::Async => quote!(conjure_http::private::Stream + Sync + Send), + Asyncness::Async => quote! { + conjure_http::private::Stream + + conjure_http::private::Sync + + conjure_http::private::Send + }, } } diff --git a/conjure-macros/src/lib.rs b/conjure-macros/src/lib.rs index 0e2eabc9..655ed3b0 100644 --- a/conjure-macros/src/lib.rs +++ b/conjure-macros/src/lib.rs @@ -44,6 +44,11 @@ mod path; /// For a trait named `MyService`, the macro will create a type named `MyServiceClient` which /// implements the Conjure `Client` and `MyService` traits. /// +/// # Parameters +/// +/// The trait can optionally be declared generic over the request body and response writer types by +/// using the `#[request_writer]` and `#[response_body]` annotations on the type parameters. +/// /// # Endpoints /// /// Each method corresponds to a separate HTTP endpoint, and is expected to take `&self` and return @@ -90,9 +95,8 @@ mod path; /// /// # Async /// -/// Both blocking and async clients are supported. For technical reasons, async trait -/// implementations must put the `#[conjure_client]` annotation *above* the `#[async_trait]` -/// annotation. +/// Both blocking and async clients are supported. For technical reasons, async trait definitions +/// must put the `#[conjure_client]` annotation *above* the `#[async_trait]` annotation. /// /// # Examples /// @@ -101,10 +105,13 @@ mod path; /// use conjure_error::Error; /// use conjure_http::{conjure_client, endpoint}; /// use conjure_http::client::{ -/// AsyncClient, AsyncService, Client, ConjureResponseDeserializer, DisplaySeqParamEncoder, -/// Service, +/// AsyncClient, AsyncService, Client, ConjureResponseDeserializer, DeserializeResponse, +/// DisplaySeqEncoder, RequestBody, SerializeRequest, Service, WriteBody, /// }; /// use conjure_object::BearerToken; +/// use http::Response; +/// use http::header::HeaderValue; +/// use std::io::Write; /// /// #[conjure_client] /// trait MyService { @@ -115,7 +122,7 @@ mod path; /// fn create_yak( /// &self, /// #[auth] auth_token: &BearerToken, -/// #[query(name = "parentName", encoder = DisplaySeqParamEncoder)] parent_id: Option<&str>, +/// #[query(name = "parentName", encoder = DisplaySeqEncoder)] parent_id: Option<&str>, /// #[body] yak: &str, /// ) -> Result<(), Error>; /// } @@ -141,7 +148,7 @@ mod path; /// async fn create_yak( /// &self, /// #[auth] auth_token: &BearerToken, -/// #[query(name = "parentName", encoder = DisplaySeqParamEncoder)] parent_id: Option<&str>, +/// #[query(name = "parentName", encoder = DisplaySeqEncoder)] parent_id: Option<&str>, /// #[body] yak: &str, /// ) -> Result<(), Error>; /// } @@ -156,6 +163,64 @@ mod path; /// /// Ok(()) /// } +/// +/// #[conjure_client] +/// trait MyStreamingService<#[response_body] I, #[request_writer] O> +/// where +/// O: Write, +/// { +/// #[endpoint(method = POST, path = "/streamData")] +/// fn upload_stream( +/// &self, +/// #[body(serializer = StreamingRequestSerializer)] body: &mut StreamingRequest, +/// ) -> Result<(), Error>; +/// +/// #[endpoint(method = GET, path = "/streamData", accept = StreamingResponseDeserializer)] +/// fn download_stream(&self) -> Result; +/// } +/// +/// struct StreamingRequest; +/// +/// impl WriteBody for StreamingRequest +/// where +/// W: Write, +/// { +/// fn write_body(&mut self, w: &mut W) -> Result<(), Error> { +/// // ... +/// Ok(()) +/// } +/// +/// fn reset(&mut self) -> bool { +/// true +/// } +/// } +/// +/// enum StreamingRequestSerializer {} +/// +/// impl<'a, W> SerializeRequest<'a, &'a mut StreamingRequest, W> for StreamingRequestSerializer +/// where +/// W: Write, +/// { +/// fn content_type(_: &&mut StreamingRequest) -> HeaderValue { +/// HeaderValue::from_static("text/plain") +/// } +/// +/// fn serialize(value: &'a mut StreamingRequest) -> Result, Error> { +/// Ok(RequestBody::Streaming(value)) +/// } +/// } +/// +/// enum StreamingResponseDeserializer {} +/// +/// impl DeserializeResponse for StreamingResponseDeserializer { +/// fn accept() -> Option { +/// None +/// } +/// +/// fn deserialize(response: Response) -> Result { +/// Ok(response.into_body()) +/// } +/// } /// ``` #[proc_macro_attribute] pub fn conjure_client(attr: TokenStream, item: TokenStream) -> TokenStream { @@ -221,9 +286,8 @@ pub fn conjure_client(attr: TokenStream, item: TokenStream) -> TokenStream { /// /// # Async /// -/// Both blocking and async services are supported. For technical reasons, async trait -/// implementations must put the `#[conjure_endpoints]` annotation *above* the `#[async_trait]` -/// annotation. +/// Both blocking and async services are supported. For technical reasons, async trait definitions +/// must put the `#[conjure_endpoints]` annotation *above* the `#[async_trait]` annotation. /// /// # Examples /// diff --git a/conjure-test/src/test/clients.rs b/conjure-test/src/test/clients.rs index 2033d8e7..48315ee9 100644 --- a/conjure-test/src/test/clients.rs +++ b/conjure-test/src/test/clients.rs @@ -18,15 +18,16 @@ use async_trait::async_trait; use conjure_error::Error; use conjure_http::client::{ AsyncClient, AsyncRequestBody, AsyncService, AsyncWriteBody, Client, - ConjureResponseDeserializer, DisplaySeqHeaderEncoder, DisplaySeqParamEncoder, RequestBody, - Service, WriteBody, + ConjureResponseDeserializer, DeserializeResponse, DisplaySeqEncoder, RequestBody, + SerializeRequest, Service, WriteBody, }; use conjure_macros::{conjure_client, endpoint}; use conjure_object::{BearerToken, ResourceIdentifier}; use futures::executor; use http::header::CONTENT_TYPE; -use http::{HeaderMap, Method, Request, Response, StatusCode}; +use http::{HeaderMap, HeaderValue, Method, Request, Response, StatusCode}; use std::collections::{BTreeMap, BTreeSet}; +use std::io::Write; use std::pin::Pin; struct StreamingBody<'a>(&'a [u8]); @@ -221,21 +222,22 @@ trait CustomService { fn query_param( &self, #[query(name = "normal")] normal: &str, - #[query(name = "list", encoder = DisplaySeqParamEncoder)] list: &[i32], + #[query(name = "list", encoder = DisplaySeqEncoder)] list: &[i32], ) -> Result<(), Error>; #[endpoint(method = GET, path = "/test/pathParams/{foo}/raw/{multi}")] fn path_param( &self, #[path] foo: &str, - #[path(encoder = DisplaySeqParamEncoder)] multi: &[&str], + #[path(encoder = DisplaySeqEncoder)] multi: &[&str], ) -> Result<(), Error>; #[endpoint(method = GET, path = "/test/headers")] fn headers( &self, #[header(name = "Some-Custom-Header")] custom_header: &str, - #[header(name = "Some-Optional-Header", encoder = DisplaySeqHeaderEncoder)] optional_header: Option, + #[header(name = "Some-Optional-Header", encoder = DisplaySeqEncoder)] + optional_header: Option, ) -> Result<(), Error>; #[endpoint(method = POST, path = "/test/jsonRequest")] @@ -261,21 +263,22 @@ trait CustomServiceAsync { async fn query_param( &self, #[query(name = "normal")] normal: &str, - #[query(name = "list", encoder = DisplaySeqParamEncoder)] list: &[i32], + #[query(name = "list", encoder = DisplaySeqEncoder)] list: &[i32], ) -> Result<(), Error>; #[endpoint(method = GET, path = "/test/pathParams/{foo}/raw/{multi}")] async fn path_param( &self, #[path] foo: &str, - #[path(encoder = DisplaySeqParamEncoder)] multi: &[&str], + #[path(encoder = DisplaySeqEncoder)] multi: &[&str], ) -> Result<(), Error>; #[endpoint(method = GET, path = "/test/headers")] async fn headers( &self, #[header(name = "Some-Custom-Header")] custom_header: &str, - #[header(name = "Some-Optional-Header", encoder = DisplaySeqHeaderEncoder)] optional_header: Option, + #[header(name = "Some-Optional-Header", encoder = DisplaySeqEncoder)] + optional_header: Option, ) -> Result<(), Error>; #[endpoint(method = POST, path = "/test/jsonRequest")] @@ -629,3 +632,80 @@ fn cookie_auth() { client.cookie_auth(&BearerToken::new("fizzbuzz").unwrap()) ); } + +#[conjure_client] +trait CustomStreamingService<#[response_body] I, #[request_writer] O> +where + O: Write, +{ + #[endpoint(method = POST, path = "/test/streamingRequest")] + fn streaming_request( + &self, + #[body(serializer = RawRequestSerializer)] body: &mut RawRequest, + ) -> Result<(), Error>; + + #[endpoint(method = GET, path = "/test/streamingResponse", accept = RawResponseDeserializer)] + fn streaming_response(&self) -> Result; +} + +struct RawRequest; + +impl WriteBody for RawRequest +where + W: Write, +{ + fn write_body(&mut self, w: &mut W) -> Result<(), Error> { + w.write_all(b"hello world").map_err(Error::internal_safe) + } + + fn reset(&mut self) -> bool { + true + } +} + +enum RawRequestSerializer {} + +impl<'a, W> SerializeRequest<'a, &'a mut RawRequest, W> for RawRequestSerializer +where + W: Write, +{ + fn content_type(_: &&mut RawRequest) -> HeaderValue { + HeaderValue::from_static("text/plain") + } + + fn serialize(value: &'a mut RawRequest) -> Result, Error> { + Ok(RequestBody::Streaming(value)) + } +} + +enum RawResponseDeserializer {} + +impl DeserializeResponse for RawResponseDeserializer { + fn accept() -> Option { + None + } + + fn deserialize(response: Response) -> Result { + Ok(response.into_body()) + } +} + +#[test] +fn custom_streaming_request() { + let client = TestClient::new(Method::POST, "/test/streamingRequest") + .header("Content-Type", "text/plain") + .body(TestBody::Streaming(b"hello world".to_vec())); + + CustomStreamingServiceClient::new(&client) + .streaming_request(&mut RawRequest) + .unwrap(); +} + +#[test] +fn custom_streaming_response() { + let client = TestClient::new(Method::GET, "/test/streamingResponse"); + + CustomStreamingServiceClient::new(&client) + .streaming_response() + .unwrap(); +}