Skip to content

Commit

Permalink
feat: add GrpcMethod extension into request for client (#1275)
Browse files Browse the repository at this point in the history
* feat: add GrpcMethod extension into request for client

* refactor: change GrpcMethod fields into private and expose methods instead

* refactor: hide GrpcMethod::new in doc

---------

Co-authored-by: Lucio Franco <[email protected]>
  • Loading branch information
linw1995 and LucioFranco authored Feb 23, 2023
1 parent 1547f96 commit 7a6b20d
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 85 deletions.
53 changes: 53 additions & 0 deletions tests/integration_tests/tests/interceptor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use std::time::Duration;

use futures::{channel::oneshot, FutureExt};
use integration_tests::pb::{test_client::TestClient, test_server, Input, Output};
use tonic::{
transport::{Endpoint, Server},
GrpcMethod, Request, Response, Status,
};

#[tokio::test]
async fn interceptor_retrieves_grpc_method() {
use test_server::Test;

struct Svc;

#[tonic::async_trait]
impl Test for Svc {
async fn unary_call(&self, _: Request<Input>) -> Result<Response<Output>, Status> {
Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::new(Svc);

let (tx, rx) = oneshot::channel();
// Start the server now, second call should succeed
let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1340".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

let channel = Endpoint::from_static("http://127.0.0.1:1340").connect_lazy();

fn client_intercept(req: Request<()>) -> Result<Request<()>, Status> {
println!("Intercepting client request: {:?}", req);

let gm = req.extensions().get::<GrpcMethod>().unwrap();
assert_eq!(gm.service(), "test.Test");
assert_eq!(gm.method(), "UnaryCall");

Ok(req)
}
let mut client = TestClient::with_interceptor(channel, client_intercept);

tokio::time::sleep(Duration::from_millis(100)).await;
client.unary_call(Request::new(Input {})).await.unwrap();

tx.send(()).unwrap();
jh.await.unwrap();
}
122 changes: 76 additions & 46 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::collections::HashSet;

use super::{Attributes, Method, Service};
use crate::{format_method_name, generate_doc_comments, naive_snake_case};
use crate::{
format_method_name, format_method_path, format_service_name, generate_doc_comments,
naive_snake_case,
};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};

Expand Down Expand Up @@ -51,21 +54,16 @@ pub(crate) fn generate_internal<T: Service>(
let connect = generate_connect(&service_ident, build_transport);

let package = if emit_package { service.package() } else { "" };
let path = format!(
"{}{}{}",
package,
if package.is_empty() { "" } else { "." },
service.identifier()
);
let service_name = format_service_name(service, emit_package);

let service_doc = if disable_comments.contains(&path) {
let service_doc = if disable_comments.contains(&service_name) {
TokenStream::new()
} else {
generate_doc_comments(service.comment())
};

let mod_attributes = attributes.for_mod(package);
let struct_attributes = attributes.for_struct(&path);
let struct_attributes = attributes.for_struct(&service_name);

quote! {
/// Generated client implementations.
Expand Down Expand Up @@ -193,30 +191,41 @@ fn generate_methods<T: Service>(
disable_comments: &HashSet<String>,
) -> TokenStream {
let mut stream = TokenStream::new();
let package = if emit_package { service.package() } else { "" };

for method in service.methods() {
let path = format!(
"/{}{}{}/{}",
package,
if package.is_empty() { "" } else { "." },
service.identifier(),
method.identifier()
);

if !disable_comments.contains(&format_method_name(package, service, method)) {
if !disable_comments.contains(&format_method_name(service, method, emit_package)) {
stream.extend(generate_doc_comments(method.comment()));
}

let method = match (method.client_streaming(), method.server_streaming()) {
(false, false) => generate_unary(method, proto_path, compile_well_known_types, path),
(false, true) => {
generate_server_streaming(method, proto_path, compile_well_known_types, path)
}
(true, false) => {
generate_client_streaming(method, proto_path, compile_well_known_types, path)
}
(true, true) => generate_streaming(method, proto_path, compile_well_known_types, path),
(false, false) => generate_unary(
service,
method,
emit_package,
proto_path,
compile_well_known_types,
),
(false, true) => generate_server_streaming(
service,
method,
emit_package,
proto_path,
compile_well_known_types,
),
(true, false) => generate_client_streaming(
service,
method,
emit_package,
proto_path,
compile_well_known_types,
),
(true, true) => generate_streaming(
service,
method,
emit_package,
proto_path,
compile_well_known_types,
),
};

stream.extend(method);
Expand All @@ -225,15 +234,19 @@ fn generate_methods<T: Service>(
stream
}

fn generate_unary<T: Method>(
method: &T,
fn generate_unary<T: Service>(
service: &T,
method: &T::Method,
emit_package: bool,
proto_path: &str,
compile_well_known_types: bool,
path: String,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
let service_name = format_service_name(service, emit_package);
let path = format_method_path(service, method, emit_package);
let method_name = method.identifier();

quote! {
pub async fn #ident(
Expand All @@ -245,21 +258,26 @@ fn generate_unary<T: Method>(
})?;
let codec = #codec_name::default();
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.unary(request.into_request(), path, codec).await
let mut req = request.into_request();
req.extensions_mut().insert(GrpcMethod::new(#service_name, #method_name));
self.inner.unary(req, path, codec).await
}
}
}

fn generate_server_streaming<T: Method>(
method: &T,
fn generate_server_streaming<T: Service>(
service: &T,
method: &T::Method,
emit_package: bool,
proto_path: &str,
compile_well_known_types: bool,
path: String,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());

let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
let service_name = format_service_name(service, emit_package);
let path = format_method_path(service, method, emit_package);
let method_name = method.identifier();

quote! {
pub async fn #ident(
Expand All @@ -271,21 +289,26 @@ fn generate_server_streaming<T: Method>(
})?;
let codec = #codec_name::default();
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.server_streaming(request.into_request(), path, codec).await
let mut req = request.into_request();
req.extensions_mut().insert(GrpcMethod::new(#service_name, #method_name));
self.inner.server_streaming(req, path, codec).await
}
}
}

fn generate_client_streaming<T: Method>(
method: &T,
fn generate_client_streaming<T: Service>(
service: &T,
method: &T::Method,
emit_package: bool,
proto_path: &str,
compile_well_known_types: bool,
path: String,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());

let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
let service_name = format_service_name(service, emit_package);
let path = format_method_path(service, method, emit_package);
let method_name = method.identifier();

quote! {
pub async fn #ident(
Expand All @@ -297,21 +320,26 @@ fn generate_client_streaming<T: Method>(
})?;
let codec = #codec_name::default();
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.client_streaming(request.into_streaming_request(), path, codec).await
let mut req = request.into_streaming_request();
req.extensions_mut().insert(GrpcMethod::new(#service_name, #method_name));
self.inner.client_streaming(req, path, codec).await
}
}
}

fn generate_streaming<T: Method>(
method: &T,
fn generate_streaming<T: Service>(
service: &T,
method: &T::Method,
emit_package: bool,
proto_path: &str,
compile_well_known_types: bool,
path: String,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());

let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
let service_name = format_service_name(service, emit_package);
let path = format_method_path(service, method, emit_package);
let method_name = method.identifier();

quote! {
pub async fn #ident(
Expand All @@ -323,7 +351,9 @@ fn generate_streaming<T: Method>(
})?;
let codec = #codec_name::default();
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.streaming(request.into_streaming_request(), path, codec).await
let mut req = request.into_streaming_request();
req.extensions_mut().insert(GrpcMethod::new(#service_name,#method_name));
self.inner.streaming(req, path, codec).await
}
}
}
24 changes: 18 additions & 6 deletions tonic-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,28 @@ impl Attributes {
}
}

fn format_method_name<T: Service>(
package: &str,
service: &T,
method: &<T as Service>::Method,
) -> String {
fn format_service_name<T: Service>(service: &T, emit_package: bool) -> String {
let package = if emit_package { service.package() } else { "" };
format!(
"{}{}{}.{}",
"{}{}{}",
package,
if package.is_empty() { "" } else { "." },
service.identifier(),
)
}

fn format_method_path<T: Service>(service: &T, method: &T::Method, emit_package: bool) -> String {
format!(
"/{}/{}",
format_service_name(service, emit_package),
method.identifier()
)
}

fn format_method_name<T: Service>(service: &T, method: &T::Method, emit_package: bool) -> String {
format!(
"{}.{}",
format_service_name(service, emit_package),
method.identifier()
)
}
Expand Down
Loading

0 comments on commit 7a6b20d

Please sign in to comment.