Skip to content

Commit

Permalink
feat(transport): support customizing Channel's async executor (#935)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored Mar 15, 2022
1 parent 01e5be5 commit 0859d82
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 10 deletions.
18 changes: 17 additions & 1 deletion tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@ use super::Channel;
use super::ClientTlsConfig;
#[cfg(feature = "tls")]
use crate::transport::service::TlsConnector;
use crate::transport::Error;
use crate::transport::{service::SharedExec, Error, Executor};
use bytes::Bytes;
use http::{uri::Uri, HeaderValue};
use std::{
convert::{TryFrom, TryInto},
fmt,
future::Future,
pin::Pin,
str::FromStr,
time::Duration,
};
use tower::make::MakeConnection;
// use crate::transport::E

/// Channel builder.
///
Expand All @@ -37,6 +40,7 @@ pub struct Endpoint {
pub(crate) http2_keep_alive_while_idle: Option<bool>,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) http2_adaptive_window: Option<bool>,
pub(crate) executor: SharedExec,
}

impl Endpoint {
Expand Down Expand Up @@ -263,6 +267,17 @@ impl Endpoint {
}
}

/// Sets the executor used to spawn async tasks.
///
/// Uses `tokio::spawn` by default.
pub fn executor<E>(mut self, executor: E) -> Self
where
E: Executor<Pin<Box<dyn Future<Output = ()> + Send>>> + Send + Sync + 'static,
{
self.executor = SharedExec::new(executor);
self
}

/// Create a channel from this config.
pub async fn connect(&self) -> Result<Channel, Error> {
let mut http = hyper::client::connect::HttpConnector::new();
Expand Down Expand Up @@ -396,6 +411,7 @@ impl From<Uri> for Endpoint {
http2_keep_alive_while_idle: None,
connect_timeout: None,
http2_adaptive_window: None,
executor: SharedExec::tokio(),
}
}
}
Expand Down
35 changes: 29 additions & 6 deletions tonic/src/transport/channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ pub use endpoint::Endpoint;
#[cfg(feature = "tls")]
pub use tls::ClientTlsConfig;

use super::service::{Connection, DynamicServiceStream};
use super::service::{Connection, DynamicServiceStream, SharedExec};
use crate::body::BoxBody;
use crate::transport::Executor;
use bytes::Bytes;
use http::{
uri::{InvalidUri, Uri},
Expand Down Expand Up @@ -124,10 +125,26 @@ impl Channel {
pub fn balance_channel<K>(capacity: usize) -> (Self, Sender<Change<K, Endpoint>>)
where
K: Hash + Eq + Send + Clone + 'static,
{
Self::balance_channel_with_executor(capacity, SharedExec::tokio())
}

/// Balance a list of [`Endpoint`]'s.
///
/// This creates a [`Channel`] that will listen to a stream of change events and will add or remove provided endpoints.
///
/// The [`Channel`] will use the given executor to spawn async tasks.
pub fn balance_channel_with_executor<K, E>(
capacity: usize,
executor: E,
) -> (Self, Sender<Change<K, Endpoint>>)
where
K: Hash + Eq + Send + Clone + 'static,
E: Executor<Pin<Box<dyn Future<Output = ()> + Send>>> + Send + Sync + 'static,
{
let (tx, rx) = channel(capacity);
let list = DynamicServiceStream::new(rx);
(Self::balance(list, DEFAULT_BUFFER_SIZE), tx)
(Self::balance(list, DEFAULT_BUFFER_SIZE, executor), tx)
}

pub(crate) fn new<C>(connector: C, endpoint: Endpoint) -> Self
Expand All @@ -138,9 +155,11 @@ impl Channel {
C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static,
{
let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE);
let executor = endpoint.executor.clone();

let svc = Connection::lazy(connector, endpoint);
let svc = Buffer::new(Either::A(svc), buffer_size);
let (svc, worker) = Buffer::pair(Either::A(svc), buffer_size);
executor.execute(Box::pin(worker));

Channel { svc }
}
Expand All @@ -153,25 +172,29 @@ impl Channel {
C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static,
{
let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE);
let executor = endpoint.executor.clone();

let svc = Connection::connect(connector, endpoint)
.await
.map_err(super::Error::from_source)?;
let svc = Buffer::new(Either::A(svc), buffer_size);
let (svc, worker) = Buffer::pair(Either::A(svc), buffer_size);
executor.execute(Box::pin(worker));

Ok(Channel { svc })
}

pub(crate) fn balance<D>(discover: D, buffer_size: usize) -> Self
pub(crate) fn balance<D, E>(discover: D, buffer_size: usize, executor: E) -> Self
where
D: Discover<Service = Connection> + Unpin + Send + 'static,
D::Error: Into<crate::Error>,
D::Key: Hash + Send + Clone,
E: Executor<futures_core::future::BoxFuture<'static, ()>> + Send + Sync + 'static,
{
let svc = Balance::new(discover);

let svc = BoxService::new(svc);
let svc = Buffer::new(Either::B(svc), buffer_size);
let (svc, worker) = Buffer::pair(Either::B(svc), buffer_size);
executor.execute(Box::pin(worker));

Channel { svc }
}
Expand Down
4 changes: 3 additions & 1 deletion tonic/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ pub use self::error::Error;
#[doc(inline)]
pub use self::server::{NamedService, Server};
#[doc(inline)]
pub use self::service::TimeoutExpired;
pub use self::service::grpc_timeout::TimeoutExpired;
pub use self::tls::Certificate;
pub use hyper::{Body, Uri};

pub(crate) use self::service::executor::Executor;

#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub use self::channel::ClientTlsConfig;
Expand Down
1 change: 1 addition & 0 deletions tonic/src/transport/service/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ impl Connection {
.http2_initial_connection_window_size(endpoint.init_connection_window_size)
.http2_only(true)
.http2_keep_alive_interval(endpoint.http2_keep_alive_interval)
.executor(endpoint.executor.clone())
.clone();

if let Some(val) = endpoint.http2_keep_alive_timeout {
Expand Down
43 changes: 43 additions & 0 deletions tonic/src/transport/service/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use futures_core::future::BoxFuture;
use std::{future::Future, sync::Arc};

pub(crate) use hyper::rt::Executor;

#[derive(Copy, Clone)]
struct TokioExec;

impl<F> Executor<F> for TokioExec
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
fn execute(&self, fut: F) {
tokio::spawn(fut);
}
}

#[derive(Clone)]
pub(crate) struct SharedExec {
inner: Arc<dyn Executor<BoxFuture<'static, ()>> + Send + Sync + 'static>,
}

impl SharedExec {
pub(crate) fn new<E>(exec: E) -> Self
where
E: Executor<BoxFuture<'static, ()>> + Send + Sync + 'static,
{
Self {
inner: Arc::new(exec),
}
}

pub(crate) fn tokio() -> Self {
Self::new(TokioExec)
}
}

impl Executor<BoxFuture<'static, ()>> for SharedExec {
fn execute(&self, fut: BoxFuture<'static, ()>) {
self.inner.execute(fut)
}
}
5 changes: 3 additions & 2 deletions tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ mod add_origin;
mod connection;
mod connector;
mod discover;
mod grpc_timeout;
pub(crate) mod executor;
pub(crate) mod grpc_timeout;
mod io;
mod reconnect;
mod router;
Expand All @@ -14,11 +15,11 @@ pub(crate) use self::add_origin::AddOrigin;
pub(crate) use self::connection::Connection;
pub(crate) use self::connector::connector;
pub(crate) use self::discover::DynamicServiceStream;
pub(crate) use self::executor::SharedExec;
pub(crate) use self::grpc_timeout::GrpcTimeout;
pub(crate) use self::io::ServerIo;
#[cfg(feature = "tls")]
pub(crate) use self::tls::{TlsAcceptor, TlsConnector};
pub(crate) use self::user_agent::UserAgent;

pub use self::grpc_timeout::TimeoutExpired;
pub use self::router::Routes;

0 comments on commit 0859d82

Please sign in to comment.