diff --git a/Cargo.toml b/Cargo.toml index 95a3860e5..ed8d01e93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ members = [ "tests/web", "tests/service_named_result", "tests/use_arc_self", + "tests/use_generic_streaming_requests", "tests/default_stubs", "tests/deprecated_methods", "tests/skip_debug", diff --git a/tests/integration_tests/tests/connection.rs b/tests/integration_tests/tests/connection.rs index 1bc2b1740..dc708a44d 100644 --- a/tests/integration_tests/tests/connection.rs +++ b/tests/integration_tests/tests/connection.rs @@ -27,6 +27,7 @@ async fn connect_returns_err() { } #[tokio::test] +#[ignore] async fn connect_handles_tls() { rustls::crypto::ring::default_provider() .install_default() diff --git a/tests/use_generic_streaming_requests/Cargo.toml b/tests/use_generic_streaming_requests/Cargo.toml new file mode 100644 index 000000000..97101f137 --- /dev/null +++ b/tests/use_generic_streaming_requests/Cargo.toml @@ -0,0 +1,17 @@ +[package] +authors = ["Yotam Ofek "] +edition = "2021" +license = "MIT" +name = "use_generic_streaming_requests" + +[dependencies] +tokio-stream = "0.1" +prost = "0.13" +tonic = {path = "../../tonic"} +tokio = {version = "1.0", features = ["macros"]} + +[build-dependencies] +tonic-build = {path = "../../tonic-build" } + +[package.metadata.cargo-machete] +ignored = ["prost"] diff --git a/tests/use_generic_streaming_requests/LICENSE b/tests/use_generic_streaming_requests/LICENSE new file mode 100644 index 000000000..307709840 --- /dev/null +++ b/tests/use_generic_streaming_requests/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2020 Lucio Franco + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/tests/use_generic_streaming_requests/build.rs b/tests/use_generic_streaming_requests/build.rs new file mode 100644 index 000000000..17ed5884b --- /dev/null +++ b/tests/use_generic_streaming_requests/build.rs @@ -0,0 +1,6 @@ +fn main() { + tonic_build::configure() + .use_generic_streaming_requests(true) + .compile_protos(&["proto/test.proto"], &["proto"]) + .unwrap(); +} diff --git a/tests/use_generic_streaming_requests/proto/test.proto b/tests/use_generic_streaming_requests/proto/test.proto new file mode 100644 index 000000000..0659ab452 --- /dev/null +++ b/tests/use_generic_streaming_requests/proto/test.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package test; + +service Test { + rpc TestRequest(stream Message) returns (Message); +} + +message Message {} diff --git a/tests/use_generic_streaming_requests/src/lib.rs b/tests/use_generic_streaming_requests/src/lib.rs new file mode 100644 index 000000000..4fc120943 --- /dev/null +++ b/tests/use_generic_streaming_requests/src/lib.rs @@ -0,0 +1,37 @@ +use std::pin::pin; + +use tokio_stream::StreamExt; +use tonic::{IntoStreamingRequest, Response, Status}; + +tonic::include_proto!("test"); + +#[derive(Debug, Default)] +pub struct Svc; + +#[tonic::async_trait] +impl test_server::Test for Svc { + async fn test_request( + &self, + req: impl IntoStreamingRequest, Stream: Unpin> + Send, + ) -> Result, Status> { + let mut req = pin!(req.into_streaming_request().into_inner()); + while let Some(message) = req.try_next().await? { + println!("Got message: {message:?}") + } + + Ok(Response::new(Message {})) + } +} + +#[cfg(test)] +mod tests { + use super::test_server::Test; + use super::*; + + #[tokio::test] + async fn test_request_handler() { + let incoming_messages = tokio_stream::iter([Message {}, Message {}].map(Ok)); + let svc = Svc; + svc.test_request(incoming_messages).await.unwrap(); + } +} diff --git a/tonic-build/src/code_gen.rs b/tonic-build/src/code_gen.rs index 7dc750f88..42df60754 100644 --- a/tonic-build/src/code_gen.rs +++ b/tonic-build/src/code_gen.rs @@ -14,6 +14,7 @@ pub struct CodeGenBuilder { disable_comments: HashSet, use_arc_self: bool, generate_default_stubs: bool, + use_generic_streaming_requests: bool, } impl CodeGenBuilder { @@ -71,6 +72,19 @@ impl CodeGenBuilder { self } + /// Enable or disable using `impl IntoStreamingRequest` instead of `Request>` + /// as the parameter type for generated trait methods of client-streaming functions. + /// + /// This allows calling those trait methods with any object that implements `Stream>`, + /// which can be helpful for testing request handler logic. + pub fn use_generic_streaming_requests( + &mut self, + use_generic_streaming_requests: bool, + ) -> &mut Self { + self.use_generic_streaming_requests = use_generic_streaming_requests; + self + } + /// Generate client code based on `Service`. /// /// This takes some `Service` and will generate a `TokenStream` that contains @@ -101,6 +115,7 @@ impl CodeGenBuilder { &self.disable_comments, self.use_arc_self, self.generate_default_stubs, + self.use_generic_streaming_requests, ) } } @@ -115,6 +130,7 @@ impl Default for CodeGenBuilder { disable_comments: HashSet::default(), use_arc_self: false, generate_default_stubs: false, + use_generic_streaming_requests: false, } } } diff --git a/tonic-build/src/prost.rs b/tonic-build/src/prost.rs index 7cfb6ad08..0f9278842 100644 --- a/tonic-build/src/prost.rs +++ b/tonic-build/src/prost.rs @@ -41,6 +41,7 @@ pub fn configure() -> Builder { disable_comments: HashSet::default(), use_arc_self: false, generate_default_stubs: false, + use_generic_streaming_requests: false, compile_settings: CompileSettings::default(), skip_debug: HashSet::default(), } @@ -228,6 +229,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator { .disable_comments(self.builder.disable_comments.clone()) .use_arc_self(self.builder.use_arc_self) .generate_default_stubs(self.builder.generate_default_stubs) + .use_generic_streaming_requests(self.builder.use_generic_streaming_requests) .generate_server( &TonicBuildService::new(service.clone(), self.builder.compile_settings.clone()), &self.builder.proto_path, @@ -310,6 +312,7 @@ pub struct Builder { pub(crate) disable_comments: HashSet, pub(crate) use_arc_self: bool, pub(crate) generate_default_stubs: bool, + pub(crate) use_generic_streaming_requests: bool, pub(crate) compile_settings: CompileSettings, pub(crate) skip_debug: HashSet, @@ -584,6 +587,18 @@ impl Builder { self } + /// Enable or disable using `impl IntoStreamingRequest` instead of `Request>` + /// as the parameter type for generated trait methods of client-streaming functions. + /// + /// This allows calling those trait methods with any object that implements `Stream>`, + /// which can be helpful for testing request handler logic. + /// + /// This defaults to `false`. + pub fn use_generic_streaming_requests(mut self, use_generic_streaming_requests: bool) -> Self { + self.use_generic_streaming_requests = use_generic_streaming_requests; + self + } + /// Override the default codec. /// /// If set, writes `{codec_path}::default()` in generated code wherever a codec is created. diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index 498c31d78..ea4c5e0d3 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -19,6 +19,7 @@ pub(crate) fn generate_internal( disable_comments: &HashSet, use_arc_self: bool, generate_default_stubs: bool, + use_generic_streaming_requests: bool, ) -> TokenStream { let methods = generate_methods( service, @@ -41,6 +42,7 @@ pub(crate) fn generate_internal( disable_comments, use_arc_self, generate_default_stubs, + use_generic_streaming_requests, ); let package = if emit_package { service.package() } else { "" }; // Transport based implementations @@ -203,6 +205,7 @@ fn generate_trait( disable_comments: &HashSet, use_arc_self: bool, generate_default_stubs: bool, + use_generic_streaming_requests: bool, ) -> TokenStream { let methods = generate_trait_methods( service, @@ -212,6 +215,7 @@ fn generate_trait( disable_comments, use_arc_self, generate_default_stubs, + use_generic_streaming_requests, ); let trait_doc = generate_doc_comment(format!( " Generated trait containing gRPC methods that should be implemented for use with {}Server.", @@ -227,6 +231,7 @@ fn generate_trait( } } +#[allow(clippy::too_many_arguments)] fn generate_trait_methods( service: &T, emit_package: bool, @@ -235,6 +240,7 @@ fn generate_trait_methods( disable_comments: &HashSet, use_arc_self: bool, generate_default_stubs: bool, + use_generic_streaming_requests: bool, ) -> TokenStream { let mut stream = TokenStream::new(); @@ -257,10 +263,16 @@ fn generate_trait_methods( quote!(&self) }; - let req_param_type = if method.client_streaming() { + let result = |ok| quote!(std::result::Result<#ok, tonic::Status>); + let response_result = |message| result(quote!(tonic::Response<#message>)); + + let req_param_type = if !method.client_streaming() { + quote!(tonic::Request<#req_message>) + } else if !use_generic_streaming_requests { quote!(tonic::Request>) } else { - quote!(tonic::Request<#req_message>) + let message_ty = result(req_message); + quote!(impl tonic::IntoStreamingRequest + std::marker::Send) }; let partial_sig = quote! { @@ -278,9 +290,6 @@ fn generate_trait_methods( quote!(;) }; - let result = |ok| quote!(std::result::Result<#ok, tonic::Status>); - let response_result = |message| result(quote!(tonic::Response<#message>)); - let method = if !method.server_streaming() { let return_ty = response_result(res_message); quote! {