Skip to content

Commit

Permalink
feat(transport): Add remote_addr to Request on the server si… (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
LucioFranco authored Dec 14, 2019
1 parent 0505dff commit 3eb76ab
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/src/helloworld/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl Greeter for MyGreeter {
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
println!("Got a request: {:?}", request);
println!("Got a request from {:?}", request.remote_addr());

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name).into(),
Expand Down
52 changes: 48 additions & 4 deletions examples/src/uds/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use std::path::Path;
use tokio::net::UnixListener;
use tonic::{transport::Server, Request, Response, Status};
use futures::stream::TryStreamExt;
use std::{
path::Path,
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::UnixListener,
};
use tonic::{
transport::{server::Connected, Server},
Request, Response, Status,
};

pub mod hello_world {
tonic::include_proto!("helloworld");
Expand Down Expand Up @@ -41,8 +52,41 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

Server::builder()
.add_service(GreeterServer::new(greeter))
.serve_with_incoming(uds.incoming())
.serve_with_incoming(uds.incoming().map_ok(UnixStream))
.await?;

Ok(())
}

#[derive(Debug)]
struct UnixStream(tokio::net::UnixStream);

impl Connected for UnixStream {}

impl AsyncRead for UnixStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl AsyncWrite for UnixStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
25 changes: 25 additions & 0 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
use crate::metadata::MetadataMap;
use futures_core::Stream;
use http::Extensions;
use std::net::SocketAddr;

/// A gRPC request and metadata from an RPC call.
#[derive(Debug)]
pub struct Request<T> {
metadata: MetadataMap,
message: T,
extensions: Extensions,
}

#[derive(Clone)]
pub(crate) struct ConnectionInfo {
pub(crate) remote_addr: Option<SocketAddr>,
}

/// Trait implemented by RPC request types.
Expand Down Expand Up @@ -102,6 +110,7 @@ impl<T> Request<T> {
Request {
metadata: MetadataMap::new(),
message,
extensions: Extensions::default(),
}
}

Expand Down Expand Up @@ -134,6 +143,7 @@ impl<T> Request<T> {
Request {
metadata: MetadataMap::from_headers(parts.headers),
message,
extensions: parts.extensions,
}
}

Expand All @@ -150,6 +160,7 @@ impl<T> Request<T> {
*request.method_mut() = http::Method::POST;
*request.uri_mut() = uri;
*request.headers_mut() = self.metadata.into_sanitized_headers();
*request.extensions_mut() = self.extensions;

request
}
Expand All @@ -164,8 +175,22 @@ impl<T> Request<T> {
Request {
metadata: self.metadata,
message,
extensions: Extensions::default(),
}
}

/// Get the remote address of this connection.
///
/// This will return `None` if the `IO` type used
/// does not implement `Connected`. This currently,
/// only works on the server side.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.get::<ConnectionInfo>()?.remote_addr
}

pub(crate) fn get<I: Send + Sync + 'static>(&self) -> Option<&I> {
self.extensions.get::<I>()
}
}

impl<T> IntoRequest<T> for T {
Expand Down
30 changes: 30 additions & 0 deletions tonic/src/transport/server/conn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use hyper::server::conn::AddrStream;
use std::net::SocketAddr;
#[cfg(feature = "tls")]
use tokio_rustls::TlsStream;

/// Trait that connected IO resources implement.
///
/// The goal for this trait is to allow users to implement
/// custom IO types that can still provide the same connection
/// metadata.
pub trait Connected {
/// Return the remote address this IO resource is connected too.
fn remote_addr(&self) -> Option<SocketAddr> {
None
}
}

impl Connected for AddrStream {
fn remote_addr(&self) -> Option<SocketAddr> {
Some(self.remote_addr())
}
}

#[cfg(feature = "tls")]
impl<T: Connected> Connected for TlsStream<T> {
fn remote_addr(&self) -> Option<SocketAddr> {
let (inner, _) = self.get_ref();
inner.remote_addr()
}
}
12 changes: 6 additions & 6 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Server;
use crate::transport::service::BoxedIo;
use super::{Connected, Server};
use crate::transport::service::ServerIo;
use futures_core::Stream;
use futures_util::stream::TryStreamExt;
use hyper::server::{
Expand All @@ -20,9 +20,9 @@ use tracing::error;
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server,
) -> impl Stream<Item = Result<BoxedIo, crate::Error>>
) -> impl Stream<Item = Result<ServerIo, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
Expand All @@ -39,12 +39,12 @@ where
continue
},
};
yield BoxedIo::new(io);
yield ServerIo::new(io);
continue;
}
}

yield BoxedIo::new(stream);
yield ServerIo::new(stream);
}
}
}
Expand Down
35 changes: 26 additions & 9 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! Server implementation and builder.
mod conn;
mod incoming;
#[cfg(feature = "tls")]
mod tls;

pub use conn::Connected;
#[cfg(feature = "tls")]
pub use tls::ServerTlsConfig;

Expand All @@ -12,8 +14,8 @@ use super::service::TlsAcceptor;

use incoming::TcpIncoming;

use super::service::{layer_fn, Or, Routes, ServiceBuilderExt};
use crate::body::BoxBody;
use super::service::{layer_fn, Or, Routes, ServerIo, ServiceBuilderExt};
use crate::{body::BoxBody, request::ConnectionInfo};
use futures_core::Stream;
use futures_util::{
future::{self, MapErr},
Expand Down Expand Up @@ -252,7 +254,7 @@ impl Server {
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
F: Future<Output = ()>,
{
Expand Down Expand Up @@ -390,7 +392,7 @@ where
pub async fn serve_with_incoming<I, IO, IE>(self, incoming: I) -> Result<(), super::Error>
where
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
self.server
Expand All @@ -412,6 +414,7 @@ impl fmt::Debug for Server {
struct Svc<S> {
inner: S,
span: Option<TraceInterceptor>,
conn_info: ConnectionInfo,
}

impl<S> Service<Request<Body>> for Svc<S>
Expand All @@ -427,13 +430,15 @@ where
self.inner.poll_ready(cx).map_err(Into::into)
}

fn call(&mut self, req: Request<Body>) -> Self::Future {
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let span = if let Some(trace_interceptor) = &self.span {
trace_interceptor(req.headers())
} else {
tracing::Span::none()
};

req.extensions_mut().insert(self.conn_info.clone());

self.inner.call(req).instrument(span).map_err(|e| e.into())
}
}
Expand All @@ -452,7 +457,7 @@ struct MakeSvc<S> {
span: Option<TraceInterceptor>,
}

impl<S, T> Service<T> for MakeSvc<S>
impl<S> Service<&ServerIo> for MakeSvc<S>
where
S: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
Expand All @@ -467,7 +472,11 @@ where
Ok(()).into()
}

fn call(&mut self, _: T) -> Self::Future {
fn call(&mut self, io: &ServerIo) -> Self::Future {
let conn_info = crate::request::ConnectionInfo {
remote_addr: io.remote_addr(),
};

let interceptor = self.interceptor.clone();
let svc = self.inner.clone();
let concurrency_limit = self.concurrency_limit;
Expand All @@ -481,10 +490,18 @@ where
.service(svc);

let svc = if let Some(interceptor) = interceptor {
let layered = interceptor.layer(BoxService::new(Svc { inner: svc, span }));
let layered = interceptor.layer(BoxService::new(Svc {
inner: svc,
span,
conn_info,
}));
BoxService::new(layered)
} else {
BoxService::new(Svc { inner: svc, span })
BoxService::new(Svc {
inner: svc,
span,
conn_info,
})
};

Ok(svc)
Expand Down
57 changes: 54 additions & 3 deletions tonic/src/transport/service/io.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use hyper::client::connect::{Connected, Connection};
use crate::transport::server::Connected;
use hyper::client::connect::{Connected as HyperConnected, Connection};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
Expand All @@ -20,11 +22,13 @@ impl BoxedIo {
}

impl Connection for BoxedIo {
fn connected(&self) -> Connected {
Connected::new()
fn connected(&self) -> HyperConnected {
HyperConnected::new()
}
}

impl Connected for BoxedIo {}

impl AsyncRead for BoxedIo {
fn poll_read(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -52,3 +56,50 @@ impl AsyncWrite for BoxedIo {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}

pub(in crate::transport) trait ConnectedIo: Io + Connected {}

impl<T> ConnectedIo for T where T: Io + Connected {}

pub(crate) struct ServerIo(Pin<Box<dyn ConnectedIo>>);

impl ServerIo {
pub(in crate::transport) fn new<I: ConnectedIo>(io: I) -> Self {
ServerIo(Box::pin(io))
}
}

impl Connected for ServerIo {
fn remote_addr(&self) -> Option<SocketAddr> {
let io = &*self.0;
io.remote_addr()
}
}

impl AsyncRead for ServerIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl AsyncWrite for ServerIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
2 changes: 1 addition & 1 deletion tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub(crate) use self::add_origin::AddOrigin;
pub(crate) use self::connection::Connection;
pub(crate) use self::connector::connector;
pub(crate) use self::discover::ServiceList;
pub(crate) use self::io::BoxedIo;
pub(crate) use self::io::ServerIo;
pub(crate) use self::layer::{layer_fn, ServiceBuilderExt};
pub(crate) use self::router::{Or, Routes};
#[cfg(feature = "tls")]
Expand Down
Loading

0 comments on commit 3eb76ab

Please sign in to comment.