diff --git a/clash_lib/src/app/outbound/manager.rs b/clash_lib/src/app/outbound/manager.rs index 5033d55d7..07f857732 100644 --- a/clash_lib/src/app/outbound/manager.rs +++ b/clash_lib/src/app/outbound/manager.rs @@ -570,6 +570,7 @@ impl OutboundManager { ..Default::default() }, providers, + proxy_manager.clone(), ); handlers.insert(proto.name.clone(), Arc::new(load_balance)); diff --git a/clash_lib/src/config/internal/proxy.rs b/clash_lib/src/config/internal/proxy.rs index 46436ecf7..be7462f73 100644 --- a/clash_lib/src/config/internal/proxy.rs +++ b/clash_lib/src/config/internal/proxy.rs @@ -412,6 +412,8 @@ pub enum LoadBalanceStrategy { ConsistentHashing, #[serde(rename = "round-robin")] RoundRobin, + #[serde(rename = "sticky-session")] + StickySession, } #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)] diff --git a/clash_lib/src/proxy/group/loadbalance/helpers.rs b/clash_lib/src/proxy/group/loadbalance/helpers.rs index 54fcfc8df..230516d48 100644 --- a/clash_lib/src/proxy/group/loadbalance/helpers.rs +++ b/clash_lib/src/proxy/group/loadbalance/helpers.rs @@ -1,10 +1,18 @@ -use std::io::Cursor; +use std::{ + io::Cursor, + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; use futures::future::BoxFuture; use murmur3::murmur3_32; use public_suffix::{EffectiveTLDProvider, DEFAULT_PROVIDER}; +use tokio::sync::Mutex; -use crate::{proxy::AnyOutboundHandler, session::Session}; +use crate::{ + app::remote_content_manager::ProxyManager, proxy::AnyOutboundHandler, + session::Session, +}; pub type StrategyFn = Box< dyn FnMut( @@ -25,6 +33,15 @@ fn get_key(sess: &Session) -> String { } } +fn get_key_src_and_dst(sess: &Session) -> String { + let dst = get_key(sess); + let src = match &sess.source { + std::net::SocketAddr::V4(socket_addr_v4) => socket_addr_v4.ip().to_string(), + std::net::SocketAddr::V6(socket_addr_v6) => socket_addr_v6.ip().to_string(), + }; + format!("{}-{}", src, dst) +} + fn jump_hash(key: u64, buckets: i32) -> i32 { let mut key = key; let mut b = -1i64; @@ -63,3 +80,195 @@ pub fn strategy_consistent_hashring() -> StrategyFn { ))) }) } + +#[cfg(test)] +static TEST_LRU_STATE: std::sync::atomic::AtomicUsize = + std::sync::atomic::AtomicUsize::new(CACHE_MISS); +#[cfg(test)] +const CACHE_MISS: usize = 0; +#[cfg(test)] +const CACHE_HIT: usize = 1; +#[cfg(test)] +const CACHE_UPDATE: usize = 2; + +pub fn strategy_sticky_session(proxy_manager: ProxyManager) -> StrategyFn { + let max_retry = 5; + // 10 minutes, 1024 entries + let lru_cache: lru_time_cache::LruCache = + lru_time_cache::LruCache::with_expiry_duration_and_capacity( + std::time::Duration::from_secs(60 * 10), + 1024, + ); + let lru_cache = Arc::new(Mutex::new(lru_cache)); + Box::new(move |proxies, sess| { + let key_str = get_key_src_and_dst(sess); + let key = murmur3_32(&mut Cursor::new(&key_str), 0).unwrap() as u64; + let proxy_manager_clone = proxy_manager.clone(); + let lru_cache_clone = lru_cache.clone(); + let timestamp = || { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() as u64 + }; + + Box::pin(async move { + let buckets = proxies.len() as i32; + let (start_index, hit) = match lru_cache_clone.lock().await.get(&key) { + Some(&index) => { + #[cfg(test)] + { + TEST_LRU_STATE + .store(CACHE_HIT, std::sync::atomic::Ordering::Relaxed); + } + (index, true) + } + None => (jump_hash(key + timestamp(), buckets) as usize, false), + }; + + // use `do - while` since we have the cached result + let mut index = start_index; + for _ in 0..max_retry { + if let Some(proxy) = proxies.get(index) { + if proxy_manager_clone.alive(proxy.name()).await { + // now it's a valid proxy + // check if it's the same as the last one(likely) + // update the cache if: + // 1. the index is not the same as the start_index + // 2. the start_index is not fetched from the cache + if index != start_index || !hit { + lru_cache_clone.lock().await.insert(key, index); + #[cfg(test)] + { + TEST_LRU_STATE.store( + CACHE_UPDATE, + std::sync::atomic::Ordering::Relaxed, + ); + } + } + return Ok(proxy.clone()); + } + } + // the cached proxy is dead, change the key by a new timestamp and + // try again + index = jump_hash(key + timestamp(), buckets) as usize; + } + // TODO: if we should just remove the key from the cache? + lru_cache_clone.lock().await.insert(key, 0); + #[cfg(test)] + { + TEST_LRU_STATE + .store(CACHE_MISS, std::sync::atomic::Ordering::Relaxed); + } + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "no proxy found", + )) + }) + }) +} + +#[cfg(test)] +mod tests { + use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + sync::Arc, + }; + + use super::*; + use crate::{ + app::remote_content_manager::ProxyManager, + proxy::utils::test_utils::noop::{NoopOutboundHandler, NoopResolver}, + session::SocksAddr, + }; + + macro_rules! assert_cache_state { + ($state:expr) => { + assert_eq!( + TEST_LRU_STATE.load(std::sync::atomic::Ordering::Relaxed), + $state + ); + }; + } + + #[tokio::test] + async fn test_sticky_session() { + let resolver = Arc::new(NoopResolver); + let proxies = vec![ + Arc::new(NoopOutboundHandler { + name: "a".to_string(), + }) as _, + Arc::new(NoopOutboundHandler { + name: "b".to_string(), + }) as _, + Arc::new(NoopOutboundHandler { + name: "c".to_string(), + }) as _, + ]; + let manager = ProxyManager::new(resolver); + // if the proxy alive state isn't set, will return true by default + // so we need to clear the alive states first + manager.report_alive("a", false).await; + manager.report_alive("b", false).await; + manager.report_alive("c", false).await; + + let mut strategy_fn = strategy_sticky_session(manager.clone()); + + // all proxies is not alive since we have not setup the proxy manager + let res = strategy_fn(proxies.clone(), &Session::default()).await; + assert!(res.is_err()); + assert_cache_state!(CACHE_MISS); + + manager.report_alive("a", true).await; + manager.report_alive("b", true).await; + manager.report_alive("c", true).await; + + let mut session1 = Session::default(); + let src1 = Ipv4Addr::new(127, 0, 0, 1); + let dst1 = Ipv4Addr::new(1, 1, 1, 1); + session1.source = SocketAddr::V4(SocketAddrV4::new(src1, 8964)); + session1.destination = + SocksAddr::Ip(SocketAddr::V4(SocketAddrV4::new(dst1, 80))); + + // 1.1 first time, cache miss & update + let res = strategy_fn(proxies.clone(), &session1).await; + assert_cache_state!(CACHE_UPDATE); + let session1_outbound_name_1 = res.unwrap().name().to_owned(); + + // 1.2 second time, cache hit + let res = strategy_fn(proxies.clone(), &session1).await; + assert_eq!(res.unwrap().name(), session1_outbound_name_1); + assert_cache_state!(CACHE_HIT); + // 1.3 third time, cache hit + let res = strategy_fn(proxies.clone(), &session1).await; + assert_eq!(res.unwrap().name(), session1_outbound_name_1); + assert_cache_state!(CACHE_HIT); + + // 1.4 change the source address, cache miss & update + let dst1_2 = Ipv4Addr::new(8, 8, 8, 8); + session1.destination = + SocksAddr::Ip(SocketAddr::V4(SocketAddrV4::new(dst1_2, 80))); + let res = strategy_fn(proxies.clone(), &session1).await; + assert_cache_state!(CACHE_UPDATE); + let session1_outbound_name_2 = res.unwrap().name().to_owned(); + + // 1.5 cache hit + let res = strategy_fn(proxies.clone(), &session1).await; + assert_eq!(res.unwrap().name(), session1_outbound_name_2); + assert_cache_state!(CACHE_HIT); + + for i in 1..100 { + // 1.6 change the src address, cache miss & update + let src1_new = Ipv4Addr::new(192, 168, 2, i); + session1.source = SocketAddr::V4(SocketAddrV4::new(src1_new, 8964)); + let res = strategy_fn(proxies.clone(), &session1).await; + assert_cache_state!(CACHE_UPDATE); + let session1_outbound_name_new = res.unwrap().name().to_owned(); + + // 1.6 cache hit + let res = strategy_fn(proxies.clone(), &session1).await; + assert_eq!(res.unwrap().name(), session1_outbound_name_new); + assert_cache_state!(CACHE_HIT); + } + } +} diff --git a/clash_lib/src/proxy/group/loadbalance/mod.rs b/clash_lib/src/proxy/group/loadbalance/mod.rs index 1766d04d9..490019aa9 100644 --- a/clash_lib/src/proxy/group/loadbalance/mod.rs +++ b/clash_lib/src/proxy/group/loadbalance/mod.rs @@ -3,6 +3,7 @@ mod helpers; use std::{collections::HashMap, io, sync::Arc}; use erased_serde::Serialize; +use helpers::strategy_sticky_session; use tokio::sync::Mutex; use tracing::debug; @@ -10,7 +11,9 @@ use crate::{ app::{ dispatcher::{BoxedChainedDatagram, BoxedChainedStream}, dns::ThreadSafeDNSResolver, - remote_content_manager::providers::proxy_provider::ThreadSafeProxyProvider, + remote_content_manager::{ + providers::proxy_provider::ThreadSafeProxyProvider, ProxyManager, + }, }, config::internal::proxy::LoadBalanceStrategy, proxy::{ @@ -55,10 +58,14 @@ impl Handler { pub fn new( opts: HandlerOptions, providers: Vec, + proxy_manager: ProxyManager, ) -> Self { let strategy_fn = match opts.strategy { LoadBalanceStrategy::ConsistentHashing => strategy_consistent_hashring(), LoadBalanceStrategy::RoundRobin => strategy_rr(), + LoadBalanceStrategy::StickySession => { + strategy_sticky_session(proxy_manager) + } }; Self { diff --git a/clash_lib/src/proxy/utils/test_utils/mod.rs b/clash_lib/src/proxy/utils/test_utils/mod.rs index cb0bd1e22..96000df07 100644 --- a/clash_lib/src/proxy/utils/test_utils/mod.rs +++ b/clash_lib/src/proxy/utils/test_utils/mod.rs @@ -23,6 +23,7 @@ use self::docker_runner::RunAndCleanup; pub mod config_helper; pub mod consts; pub mod docker_runner; +pub mod noop; // TODO: add the throughput metrics pub async fn ping_pong_test( diff --git a/clash_lib/src/proxy/utils/test_utils/noop.rs b/clash_lib/src/proxy/utils/test_utils/noop.rs new file mode 100644 index 000000000..634bf0b28 --- /dev/null +++ b/clash_lib/src/proxy/utils/test_utils/noop.rs @@ -0,0 +1,121 @@ +use std::io; + +use async_trait::async_trait; +use hickory_client::op; + +use crate::{ + app::{ + dispatcher::{BoxedChainedDatagram, BoxedChainedStream}, + dns::{ClashResolver, ResolverKind, ThreadSafeDNSResolver}, + }, + proxy::{ConnectorType, DialWithConnector, OutboundHandler, OutboundType}, + session::Session, +}; + +pub struct NoopResolver; + +#[async_trait] +impl ClashResolver for NoopResolver { + async fn resolve( + &self, + _host: &str, + _enhanced: bool, + ) -> anyhow::Result> { + Ok(None) + } + + async fn resolve_v4( + &self, + _host: &str, + _enhanced: bool, + ) -> anyhow::Result> { + Ok(None) + } + + async fn resolve_v6( + &self, + _host: &str, + _enhanced: bool, + ) -> anyhow::Result> { + Ok(None) + } + + async fn cached_for(&self, _ip: std::net::IpAddr) -> Option { + None + } + + /// Used for DNS Server + async fn exchange(&self, _message: &op::Message) -> anyhow::Result { + Err(anyhow::anyhow!("unsupported")) + } + + /// Only used for look up fake IP + async fn reverse_lookup(&self, _ip: std::net::IpAddr) -> Option { + None + } + + async fn is_fake_ip(&self, _ip: std::net::IpAddr) -> bool { + false + } + + fn fake_ip_enabled(&self) -> bool { + false + } + + fn ipv6(&self) -> bool { + false + } + + fn set_ipv6(&self, _enable: bool) {} + + fn kind(&self) -> ResolverKind { + ResolverKind::Clash + } +} + +#[derive(Debug)] +pub struct NoopOutboundHandler { + pub name: String, +} + +#[async_trait] +impl DialWithConnector for NoopOutboundHandler { + fn support_dialer(&self) -> Option<&str> { + None + } +} + +#[async_trait] +impl OutboundHandler for NoopOutboundHandler { + fn name(&self) -> &str { + &self.name + } + + fn proto(&self) -> OutboundType { + OutboundType::Direct + } + + async fn support_udp(&self) -> bool { + false + } + + async fn connect_stream( + &self, + _sess: &Session, + _resolver: ThreadSafeDNSResolver, + ) -> io::Result { + Err(io::Error::new(io::ErrorKind::Other, "noop")) + } + + async fn connect_datagram( + &self, + _sess: &Session, + _resolver: ThreadSafeDNSResolver, + ) -> io::Result { + Err(io::Error::new(io::ErrorKind::Other, "noop")) + } + + async fn support_connector(&self) -> ConnectorType { + ConnectorType::None + } +}