Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(volo-http): add client ip #535

Merged
merged 30 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
46f6c51
feat(http): add client ip
StellarisW Nov 15, 2024
2719019
feat(http): add client ip
StellarisW Nov 15, 2024
38767c8
feat(http): add client ip
StellarisW Nov 15, 2024
2d34efd
feat(http): add client ip
StellarisW Nov 15, 2024
3496959
feat(http): add client ip
StellarisW Nov 15, 2024
a678ea9
feat(http): add client ip
StellarisW Nov 15, 2024
48bc96a
feat(http): add client ip
StellarisW Nov 15, 2024
94691d7
feat(http): add client ip
StellarisW Nov 18, 2024
441ceb7
feat(http): add client ip
StellarisW Nov 18, 2024
4383afb
feat(http): add client ip
StellarisW Nov 18, 2024
963a860
Merge branch 'main' into feat(http)/client_ip
StellarisW Nov 18, 2024
83108dc
feat(http): add client ip
StellarisW Nov 18, 2024
015063b
feat(http): add client ip
StellarisW Nov 18, 2024
2893a9e
feat(http): add client ip
StellarisW Nov 18, 2024
d7b9634
feat(http): add client ip
StellarisW Nov 18, 2024
94e5692
feat(http): add client ip
StellarisW Nov 18, 2024
5b9e113
feat(http): add client ip
StellarisW Nov 18, 2024
1dd2624
feat(http): add client ip
StellarisW Nov 18, 2024
17a1c8f
feat(http): add client ip
StellarisW Nov 18, 2024
c6d9e92
feat(http): add client ip
StellarisW Nov 18, 2024
1b87816
feat(http): add client ip
StellarisW Nov 18, 2024
8a4fd59
feat(http): add client ip
StellarisW Nov 18, 2024
528a781
feat(http): add client ip
StellarisW Nov 18, 2024
e7f134a
feat(http): add client ip
StellarisW Nov 18, 2024
269df72
Merge branch 'main' into feat(http)/client_ip
StellarisW Nov 18, 2024
3bb42fc
feat(http): add client ip
StellarisW Nov 18, 2024
9d29b0b
feat(http): add client ip
StellarisW Nov 18, 2024
b24d32d
feat(http): add client ip
StellarisW Nov 18, 2024
590db06
feat(http): add client ip
StellarisW Nov 18, 2024
eba87b8
feat(http): add client ip
StellarisW Nov 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ http-body-util = "0.1"
hyper = "1"
hyper-timeout = "0.5"
hyper-util = "0.1"
ipnet = "2.10"
itertools = "0.13"
itoa = "1"
libc = "0.2"
Expand Down
1 change: 1 addition & 0 deletions volo-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ http-body.workspace = true
http-body-util.workspace = true
hyper.workspace = true
hyper-util = { workspace = true, features = ["tokio"] }
ipnet.workspace = true
itoa.workspace = true
memchr.workspace = true
metainfo.workspace = true
Expand Down
18 changes: 18 additions & 0 deletions volo-http/src/server/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::{
context::ServerContext,
error::server::{body_collection_error, ExtractBodyError},
request::{Request, RequestPartsExt},
server::utils::client_ip::ClientIP,
utils::macros::impl_deref_and_deref_mut,
};

Expand Down Expand Up @@ -290,6 +291,23 @@ impl FromContext for Method {
}
}

impl FromContext for ClientIP {
type Rejection = Infallible;

async fn from_context(
cx: &mut ServerContext,
_: &mut Parts,
) -> Result<ClientIP, Self::Rejection> {
Ok(ClientIP(
cx.rpc_info
.caller()
.tags
.get::<ClientIP>()
.and_then(|v| v.0),
))
}
}

#[cfg(feature = "query")]
impl<T> FromContext for Query<T>
where
Expand Down
2 changes: 1 addition & 1 deletion volo-http/src/server/layer/timeout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ where
tokio::select! {
resp = fut_service => resp.map(IntoResponse::into_response),
_ = fut_timeout => {
Ok((self.handler.clone()).call(cx))
Ok(self.handler.clone().call(cx))
},
}
}
Expand Down
333 changes: 333 additions & 0 deletions volo-http/src/server/utils/client_ip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
//! Utilities for extracting original client ip
//!
//! See [`ClientIP`] for more details.
use std::{net::IpAddr, str::FromStr};

use http::{HeaderMap, HeaderName};
use ipnet::IpNet;
use motore::{layer::Layer, Service};
use volo::{context::Context, net::Address};

use crate::{context::ServerContext, request::Request, utils::macros::impl_deref_and_deref_mut};

/// [`Layer`] for extracting client ip
///
/// See [`ClientIP`] for more details.
#[derive(Clone, Default)]
pub struct ClientIPLayer {
config: ClientIPConfig,
}

impl ClientIPLayer {
/// Create a new [`ClientIPLayer`] with default config
pub fn new() -> Self {
Default::default()
}

/// Create a new [`ClientIPLayer`] with the given [`ClientIPConfig`]
pub fn with_config(self, config: ClientIPConfig) -> Self {
Self { config }
}
}

impl<S> Layer<S> for ClientIPLayer
where
S: Send + Sync + 'static,
{
type Service = ClientIPService<S>;

fn layer(self, inner: S) -> Self::Service {
ClientIPService {
service: inner,
config: self.config,
}
}
}

/// Config for extract client ip
#[derive(Clone, Debug)]
pub struct ClientIPConfig {
remote_ip_headers: Vec<HeaderName>,
trusted_cidrs: Vec<IpNet>,
}

impl Default for ClientIPConfig {
fn default() -> Self {
Self {
remote_ip_headers: vec![
HeaderName::from_static("x-real-ip"),
HeaderName::from_static("x-forwarded-for"),
],
trusted_cidrs: vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()],
}
}
}

impl ClientIPConfig {
/// Create a new [`ClientIPConfig`] with default values
///
/// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]`
///
/// default trusted cidrs: `["0.0.0.0/0", "::/0"]`
pub fn new() -> Self {
Default::default()
}

/// Get Real Client IP by parsing the given headers.
///
/// See [`ClientIP`] for more details.
///
/// # Example
///
/// ```rust
/// use volo_http::server::utils::client_ip::ClientIPConfig;
///
/// let client_ip_config =
/// ClientIPConfig::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]);
/// ```
pub fn with_remote_ip_headers<I>(
self,
headers: I,
) -> Result<Self, http::header::InvalidHeaderName>
where
I: IntoIterator,
I::Item: AsRef<str>,
{
let headers = headers.into_iter().collect::<Vec<_>>();
let mut remote_ip_headers = Vec::with_capacity(headers.len());
for header_str in headers {
let header_value = HeaderName::from_str(header_str.as_ref())?;
remote_ip_headers.push(header_value);
}

Ok(Self {
remote_ip_headers,
trusted_cidrs: self.trusted_cidrs,
})
}

/// Get Real Client IP if it is trusted, otherwise it will just return caller ip.
///
/// See [`ClientIP`] for more details.
///
/// # Example
///
/// ```rust
/// use volo_http::server::utils::client_ip::ClientIPConfig;
///
/// let client_ip_config = ClientIPConfig::new()
/// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]);
/// ```
pub fn with_trusted_cidrs<H>(self, cidrs: H) -> Self
where
H: IntoIterator<Item = IpNet>,
{
Self {
remote_ip_headers: self.remote_ip_headers,
trusted_cidrs: cidrs.into_iter().collect(),
}
}
}

/// Return original client IP Address
///
/// If you want to get client IP by retrieving specific headers, you can use
/// [`with_remote_ip_headers`](ClientIPConfig::with_remote_ip_headers) to set the
/// headers.
///
/// If you want to get client IP that is trusted with specific cidrs, you can use
/// [`with_trusted_cidrs`](ClientIPConfig::with_trusted_cidrs) to set the cidrs.
///
/// # Example
///
/// ## Default config
///
/// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]`
///
/// default trusted cidrs: `["0.0.0.0/0", "::/0"]`
///
/// ```rust
/// ///
/// use volo_http::server::utils::client_ip::ClientIP;
/// use volo_http::server::{
/// route::{get, Router},
/// utils::client_ip::{ClientIPConfig, ClientIPLayer},
/// Server,
/// };
///
/// async fn handler(client_ip: ClientIP) -> String {
/// client_ip.unwrap().to_string()
/// }
///
/// let router: Router = Router::new()
/// .route("/", get(handler))
/// .layer(ClientIPLayer::new());
/// ```
///
/// ## With custom config
///
/// ```rust
/// use http::HeaderMap;
/// use volo_http::{
/// context::ServerContext,
/// server::{
/// route::{get, Router},
/// utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer},
/// Server,
/// },
/// };
///
/// async fn handler(client_ip: ClientIP) -> String {
/// client_ip.unwrap().to_string()
/// }
///
/// let router: Router = Router::new().route("/", get(handler)).layer(
/// ClientIPLayer::new().with_config(
/// ClientIPConfig::new()
/// .with_remote_ip_headers(vec!["x-real-ip", "x-forwarded-for"])
/// .unwrap()
/// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]),
/// ),
/// );
/// ```
pub struct ClientIP(pub Option<IpAddr>);

impl_deref_and_deref_mut!(ClientIP, Option<IpAddr>, 0);

/// [`ClientIPLayer`] generated [`Service`]
///
/// See [`ClientIP`] for more details.
#[derive(Clone)]
pub struct ClientIPService<S> {
service: S,
config: ClientIPConfig,
}

impl<S> ClientIPService<S> {
fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIP {
let remote_ip = match &cx.rpc_info().caller().address {
Some(Address::Ip(socket_addr)) => Some(socket_addr.ip()),
#[cfg(target_family = "unix")]
Some(Address::Unix(_)) => None,
None => return ClientIP(None),
};

if let Some(remote_ip) = remote_ip {
if !self
.config
.trusted_cidrs
.iter()
.any(|cidr| cidr.contains(&IpNet::from(remote_ip)))
{
return ClientIP(None);
}
}

for remote_ip_header in self.config.remote_ip_headers.iter() {
let remote_ips = match headers
.get(remote_ip_header)
.and_then(|v| v.to_str().ok())
.map(|v| v.split(',').map(|s| s.trim()).collect::<Vec<_>>())
{
Some(remote_ips) => remote_ips,
None => continue,
};
for remote_ip in remote_ips.iter() {
if let Ok(remote_ip_addr) = IpAddr::from_str(remote_ip) {
if self
.config
.trusted_cidrs
.iter()
.any(|cidr| cidr.contains(&remote_ip_addr))
{
return ClientIP(Some(remote_ip_addr));
}
}
}
}

ClientIP(remote_ip)
}
}

impl<S, B> Service<ServerContext, Request<B>> for ClientIPService<S>
where
S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
B: Send,
{
type Response = S::Response;
type Error = S::Error;

async fn call(
&self,
cx: &mut ServerContext,
req: Request<B>,
) -> Result<Self::Response, Self::Error> {
let client_ip = self.get_client_ip(cx, req.headers());
cx.rpc_info_mut().caller_mut().tags.insert(client_ip);

self.service.call(cx, req).await
}
}

#[cfg(test)]
mod client_ip_tests {
use std::{net::SocketAddr, str::FromStr};

use http::{HeaderValue, Method};
use motore::{layer::Layer, Service};
use volo::net::Address;

use crate::{
body::BodyConversion,
context::ServerContext,
server::{
route::{get, Route},
utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer},
},
utils::test_helpers::simple_req,
};

#[tokio::test]
async fn test_client_ip() {
async fn handler(client_ip: ClientIP) -> String {
client_ip.unwrap().to_string()
}

let route: Route<&str> = Route::new(get(handler));
let service = ClientIPLayer::new()
.with_config(
ClientIPConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]),
)
.layer(route);

let mut cx = ServerContext::new(Address::from(
SocketAddr::from_str("10.0.0.1:8080").unwrap(),
));

// Test case 1: no remote ip header
let req = simple_req(Method::GET, "/", "");
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!("10.0.0.1", resp.into_string().await.unwrap());

// Test case 2: with remote ip header
let mut req = simple_req(Method::GET, "/", "");
req.headers_mut()
.insert("X-Real-IP", HeaderValue::from_static("10.0.0.2"));
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!("10.0.0.2", resp.into_string().await.unwrap());

let mut req = simple_req(Method::GET, "/", "");
req.headers_mut()
.insert("X-Forwarded-For", HeaderValue::from_static("10.0.1.0"));
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!("10.0.1.0", resp.into_string().await.unwrap());

// Test case 3: with untrusted remote ip
let mut req = simple_req(Method::GET, "/", "");
req.headers_mut()
.insert("X-Real-IP", HeaderValue::from_static("11.0.0.1"));
let resp = service.call(&mut cx, req).await.unwrap();
assert_eq!("10.0.0.1", resp.into_string().await.unwrap());
}
}
Loading
Loading