From 97fa58ca4412408110e39db9f98ecc59d70120ed Mon Sep 17 00:00:00 2001 From: VendettaReborn Date: Wed, 15 Jan 2025 20:45:30 +0800 Subject: [PATCH 1/3] feat: add StickySession load balance strategy --- clash_lib/src/app/dns/mod.rs | 62 +++++++++++++++++++ .../src/proxy/group/loadbalance/helpers.rs | 4 +- clash_lib/src/proxy/mod.rs | 47 ++++++++++++++ 3 files changed, 111 insertions(+), 2 deletions(-) diff --git a/clash_lib/src/app/dns/mod.rs b/clash_lib/src/app/dns/mod.rs index 31a116a10..24d77f452 100644 --- a/clash_lib/src/app/dns/mod.rs +++ b/clash_lib/src/app/dns/mod.rs @@ -33,6 +33,7 @@ pub trait Client: Sync + Send + Debug { type ThreadSafeDNSClient = Arc; pub enum ResolverKind { + Noop, Clash, System, } @@ -77,3 +78,64 @@ pub trait ClashResolver: Sync + Send { fn kind(&self) -> ResolverKind; } + +pub struct NoopResolver; + +#[async_trait] +impl ClashResolver for NoopResolver { + async fn resolve( + &self, + _host: &str, + _enhanced: bool, + ) -> anyhow::Result> { + Ok(None) + } + + async fn resolve_v4( + &self, + _host: &str, + _enhanced: bool, + ) -> anyhow::Result> { + Ok(None) + } + + async fn resolve_v6( + &self, + _host: &str, + _enhanced: bool, + ) -> anyhow::Result> { + Ok(None) + } + + async fn cached_for(&self, _ip: std::net::IpAddr) -> Option { + None + } + + /// Used for DNS Server + async fn exchange(&self, _message: &op::Message) -> anyhow::Result { + Err(anyhow::anyhow!("unsupported")) + } + + /// Only used for look up fake IP + async fn reverse_lookup(&self, _ip: std::net::IpAddr) -> Option { + None + } + + async fn is_fake_ip(&self, _ip: std::net::IpAddr) -> bool { + false + } + + fn fake_ip_enabled(&self) -> bool { + false + } + + fn ipv6(&self) -> bool { + false + } + + fn set_ipv6(&self, _enable: bool) {} + + fn kind(&self) -> ResolverKind { + ResolverKind::Noop + } +} diff --git a/clash_lib/src/proxy/group/loadbalance/helpers.rs b/clash_lib/src/proxy/group/loadbalance/helpers.rs index 230516d48..646fe0899 100644 --- a/clash_lib/src/proxy/group/loadbalance/helpers.rs +++ b/clash_lib/src/proxy/group/loadbalance/helpers.rs @@ -177,8 +177,8 @@ mod tests { use super::*; use crate::{ - app::remote_content_manager::ProxyManager, - proxy::utils::test_utils::noop::{NoopOutboundHandler, NoopResolver}, + app::{dns::NoopResolver, remote_content_manager::ProxyManager}, + proxy::NoopOutboundHandler, session::SocksAddr, }; diff --git a/clash_lib/src/proxy/mod.rs b/clash_lib/src/proxy/mod.rs index 4a02e8772..323fbc121 100644 --- a/clash_lib/src/proxy/mod.rs +++ b/clash_lib/src/proxy/mod.rs @@ -248,3 +248,50 @@ pub trait DialWithConnector { /// this must be called before the outbound handler is used async fn register_connector(&self, _: Arc) {} } + +#[derive(Debug)] +pub struct NoopOutboundHandler { + pub name: String, +} + +#[async_trait] +impl DialWithConnector for NoopOutboundHandler { + fn support_dialer(&self) -> Option<&str> { + None + } +} + +#[async_trait] +impl OutboundHandler for NoopOutboundHandler { + fn name(&self) -> &str { + &self.name + } + + fn proto(&self) -> OutboundType { + OutboundType::Direct + } + + async fn support_udp(&self) -> bool { + false + } + + async fn connect_stream( + &self, + _sess: &Session, + _resolver: ThreadSafeDNSResolver, + ) -> io::Result { + Err(io::Error::new(io::ErrorKind::Other, "noop")) + } + + async fn connect_datagram( + &self, + _sess: &Session, + _resolver: ThreadSafeDNSResolver, + ) -> io::Result { + Err(io::Error::new(io::ErrorKind::Other, "noop")) + } + + async fn support_connector(&self) -> ConnectorType { + ConnectorType::None + } +} From f94affae303caa1f3ea15f516d29bef9ba4b42b5 Mon Sep 17 00:00:00 2001 From: VendettaReborn Date: Thu, 16 Jan 2025 10:34:54 +0800 Subject: [PATCH 2/3] move noopxxx to test_utils --- clash_lib/src/app/dns/mod.rs | 62 ------------------- .../src/proxy/group/loadbalance/helpers.rs | 4 +- clash_lib/src/proxy/mod.rs | 47 -------------- 3 files changed, 2 insertions(+), 111 deletions(-) diff --git a/clash_lib/src/app/dns/mod.rs b/clash_lib/src/app/dns/mod.rs index 24d77f452..31a116a10 100644 --- a/clash_lib/src/app/dns/mod.rs +++ b/clash_lib/src/app/dns/mod.rs @@ -33,7 +33,6 @@ pub trait Client: Sync + Send + Debug { type ThreadSafeDNSClient = Arc; pub enum ResolverKind { - Noop, Clash, System, } @@ -78,64 +77,3 @@ pub trait ClashResolver: Sync + Send { fn kind(&self) -> ResolverKind; } - -pub struct NoopResolver; - -#[async_trait] -impl ClashResolver for NoopResolver { - async fn resolve( - &self, - _host: &str, - _enhanced: bool, - ) -> anyhow::Result> { - Ok(None) - } - - async fn resolve_v4( - &self, - _host: &str, - _enhanced: bool, - ) -> anyhow::Result> { - Ok(None) - } - - async fn resolve_v6( - &self, - _host: &str, - _enhanced: bool, - ) -> anyhow::Result> { - Ok(None) - } - - async fn cached_for(&self, _ip: std::net::IpAddr) -> Option { - None - } - - /// Used for DNS Server - async fn exchange(&self, _message: &op::Message) -> anyhow::Result { - Err(anyhow::anyhow!("unsupported")) - } - - /// Only used for look up fake IP - async fn reverse_lookup(&self, _ip: std::net::IpAddr) -> Option { - None - } - - async fn is_fake_ip(&self, _ip: std::net::IpAddr) -> bool { - false - } - - fn fake_ip_enabled(&self) -> bool { - false - } - - fn ipv6(&self) -> bool { - false - } - - fn set_ipv6(&self, _enable: bool) {} - - fn kind(&self) -> ResolverKind { - ResolverKind::Noop - } -} diff --git a/clash_lib/src/proxy/group/loadbalance/helpers.rs b/clash_lib/src/proxy/group/loadbalance/helpers.rs index 646fe0899..230516d48 100644 --- a/clash_lib/src/proxy/group/loadbalance/helpers.rs +++ b/clash_lib/src/proxy/group/loadbalance/helpers.rs @@ -177,8 +177,8 @@ mod tests { use super::*; use crate::{ - app::{dns::NoopResolver, remote_content_manager::ProxyManager}, - proxy::NoopOutboundHandler, + app::remote_content_manager::ProxyManager, + proxy::utils::test_utils::noop::{NoopOutboundHandler, NoopResolver}, session::SocksAddr, }; diff --git a/clash_lib/src/proxy/mod.rs b/clash_lib/src/proxy/mod.rs index 323fbc121..4a02e8772 100644 --- a/clash_lib/src/proxy/mod.rs +++ b/clash_lib/src/proxy/mod.rs @@ -248,50 +248,3 @@ pub trait DialWithConnector { /// this must be called before the outbound handler is used async fn register_connector(&self, _: Arc) {} } - -#[derive(Debug)] -pub struct NoopOutboundHandler { - pub name: String, -} - -#[async_trait] -impl DialWithConnector for NoopOutboundHandler { - fn support_dialer(&self) -> Option<&str> { - None - } -} - -#[async_trait] -impl OutboundHandler for NoopOutboundHandler { - fn name(&self) -> &str { - &self.name - } - - fn proto(&self) -> OutboundType { - OutboundType::Direct - } - - async fn support_udp(&self) -> bool { - false - } - - async fn connect_stream( - &self, - _sess: &Session, - _resolver: ThreadSafeDNSResolver, - ) -> io::Result { - Err(io::Error::new(io::ErrorKind::Other, "noop")) - } - - async fn connect_datagram( - &self, - _sess: &Session, - _resolver: ThreadSafeDNSResolver, - ) -> io::Result { - Err(io::Error::new(io::ErrorKind::Other, "noop")) - } - - async fn support_connector(&self) -> ConnectorType { - ConnectorType::None - } -} From b07befa06cfd1de157c619ac3be6bbc8b0755680 Mon Sep 17 00:00:00 2001 From: VendettaReborn Date: Sat, 18 Jan 2025 18:53:22 +0800 Subject: [PATCH 3/3] style: make poll_read_exact a general trait --- clash_lib/src/common/io.rs | 55 +++++++++ .../src/proxy/group/loadbalance/helpers.rs | 105 ------------------ .../proxy/shadowsocks/shadow_tls/stream.rs | 71 ++---------- .../src/proxy/vmess/vmess_impl/stream.rs | 53 ++------- 4 files changed, 73 insertions(+), 211 deletions(-) diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs index cdc45d9ce..123de6794 100644 --- a/clash_lib/src/common/io.rs +++ b/clash_lib/src/common/io.rs @@ -2,11 +2,13 @@ use std::future::Future; use std::{ io, + mem::MaybeUninit, pin::Pin, task::{Context, Poll}, time::Duration, }; +use bytes::BytesMut; use futures::ready; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -352,3 +354,56 @@ where } .await } + +pub trait ReadExactBase { + /// inner stream to be polled + type I: AsyncRead + Unpin; + /// prepare the inner stream, read buffer and read position + fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize); +} + +pub trait ReadExt: ReadExactBase { + fn poll_read_exact( + &mut self, + cx: &mut std::task::Context, + size: usize, + ) -> Poll>; +} + +impl ReadExt for T { + fn poll_read_exact( + &mut self, + cx: &mut std::task::Context, + size: usize, + ) -> Poll> { + let (raw, read_buf, read_pos) = self.decompose(); + read_buf.reserve(size); + // # safety: read_buf has reserved `size` + unsafe { read_buf.set_len(size) } + loop { + if *read_pos < size { + // # safety: read_pos]) + }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(Pin::new(&mut *raw).poll_read(cx, &mut buf))?; + assert_eq!(ptr, buf.filled().as_ptr()); + if buf.filled().is_empty() { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unexpected eof", + ))); + } + *read_pos += buf.filled().len(); + } else { + assert!(*read_pos == size); + *read_pos = 0; + return Poll::Ready(Ok(())); + } + } + } +} diff --git a/clash_lib/src/proxy/group/loadbalance/helpers.rs b/clash_lib/src/proxy/group/loadbalance/helpers.rs index 230516d48..753f64a85 100644 --- a/clash_lib/src/proxy/group/loadbalance/helpers.rs +++ b/clash_lib/src/proxy/group/loadbalance/helpers.rs @@ -167,108 +167,3 @@ pub fn strategy_sticky_session(proxy_manager: ProxyManager) -> StrategyFn { }) }) } - -#[cfg(test)] -mod tests { - use std::{ - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::Arc, - }; - - use super::*; - use crate::{ - app::remote_content_manager::ProxyManager, - proxy::utils::test_utils::noop::{NoopOutboundHandler, NoopResolver}, - session::SocksAddr, - }; - - macro_rules! assert_cache_state { - ($state:expr) => { - assert_eq!( - TEST_LRU_STATE.load(std::sync::atomic::Ordering::Relaxed), - $state - ); - }; - } - - #[tokio::test] - async fn test_sticky_session() { - let resolver = Arc::new(NoopResolver); - let proxies = vec![ - Arc::new(NoopOutboundHandler { - name: "a".to_string(), - }) as _, - Arc::new(NoopOutboundHandler { - name: "b".to_string(), - }) as _, - Arc::new(NoopOutboundHandler { - name: "c".to_string(), - }) as _, - ]; - let manager = ProxyManager::new(resolver); - // if the proxy alive state isn't set, will return true by default - // so we need to clear the alive states first - manager.report_alive("a", false).await; - manager.report_alive("b", false).await; - manager.report_alive("c", false).await; - - let mut strategy_fn = strategy_sticky_session(manager.clone()); - - // all proxies is not alive since we have not setup the proxy manager - let res = strategy_fn(proxies.clone(), &Session::default()).await; - assert!(res.is_err()); - assert_cache_state!(CACHE_MISS); - - manager.report_alive("a", true).await; - manager.report_alive("b", true).await; - manager.report_alive("c", true).await; - - let mut session1 = Session::default(); - let src1 = Ipv4Addr::new(127, 0, 0, 1); - let dst1 = Ipv4Addr::new(1, 1, 1, 1); - session1.source = SocketAddr::V4(SocketAddrV4::new(src1, 8964)); - session1.destination = - SocksAddr::Ip(SocketAddr::V4(SocketAddrV4::new(dst1, 80))); - - // 1.1 first time, cache miss & update - let res = strategy_fn(proxies.clone(), &session1).await; - assert_cache_state!(CACHE_UPDATE); - let session1_outbound_name_1 = res.unwrap().name().to_owned(); - - // 1.2 second time, cache hit - let res = strategy_fn(proxies.clone(), &session1).await; - assert_eq!(res.unwrap().name(), session1_outbound_name_1); - assert_cache_state!(CACHE_HIT); - // 1.3 third time, cache hit - let res = strategy_fn(proxies.clone(), &session1).await; - assert_eq!(res.unwrap().name(), session1_outbound_name_1); - assert_cache_state!(CACHE_HIT); - - // 1.4 change the source address, cache miss & update - let dst1_2 = Ipv4Addr::new(8, 8, 8, 8); - session1.destination = - SocksAddr::Ip(SocketAddr::V4(SocketAddrV4::new(dst1_2, 80))); - let res = strategy_fn(proxies.clone(), &session1).await; - assert_cache_state!(CACHE_UPDATE); - let session1_outbound_name_2 = res.unwrap().name().to_owned(); - - // 1.5 cache hit - let res = strategy_fn(proxies.clone(), &session1).await; - assert_eq!(res.unwrap().name(), session1_outbound_name_2); - assert_cache_state!(CACHE_HIT); - - for i in 1..100 { - // 1.6 change the src address, cache miss & update - let src1_new = Ipv4Addr::new(192, 168, 2, i); - session1.source = SocketAddr::V4(SocketAddrV4::new(src1_new, 8964)); - let res = strategy_fn(proxies.clone(), &session1).await; - assert_cache_state!(CACHE_UPDATE); - let session1_outbound_name_new = res.unwrap().name().to_owned(); - - // 1.6 cache hit - let res = strategy_fn(proxies.clone(), &session1).await; - assert_eq!(res.unwrap().name(), session1_outbound_name_new); - assert_cache_state!(CACHE_HIT); - } - } -} diff --git a/clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs b/clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs index 9f441339c..4a0373eda 100644 --- a/clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs +++ b/clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs @@ -1,5 +1,4 @@ use std::{ - mem::MaybeUninit, pin::Pin, ptr::{copy, copy_nonoverlapping}, task::{ready, Poll}, @@ -7,7 +6,9 @@ use std::{ use byteorder::{BigEndian, WriteBytesExt}; use bytes::{BufMut, BytesMut}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::common::io::{ReadExactBase, ReadExt}; use super::utils::{prelude::*, *}; @@ -26,60 +27,6 @@ pub enum WriteState { FlushingData(usize, usize, usize), } -pub trait AsyncReadUnpin: AsyncRead + Unpin {} - -impl AsyncReadUnpin for T {} - -pub trait ReadExtBase { - fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize); -} - -pub trait ReadExt { - fn poll_read_exact( - &mut self, - cx: &mut std::task::Context, - size: usize, - ) -> Poll>; -} - -impl ReadExt for T { - fn poll_read_exact( - &mut self, - cx: &mut std::task::Context, - size: usize, - ) -> Poll> { - let (raw, read_buf, read_pos) = self.prepare(); - read_buf.reserve(size); - // # safety: read_buf has reserved `size` - unsafe { read_buf.set_len(size) } - loop { - if *read_pos < size { - // # safety: read_pos]) - }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - ready!(Pin::new(&mut *raw).poll_read(cx, &mut buf))?; - assert_eq!(ptr, buf.filled().as_ptr()); - if buf.filled().is_empty() { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "unexpected eof", - ))); - } - *read_pos += buf.filled().len(); - } else { - assert!(*read_pos == size); - *read_pos = 0; - return Poll::Ready(Ok(())); - } - } - } -} - #[derive(Clone, Debug)] pub struct Certs { pub(crate) server_random: [u8; TLS_RANDOM_SIZE], @@ -139,8 +86,10 @@ impl ProxyTlsStream { } } -impl ReadExtBase for ProxyTlsStream { - fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize) { +impl ReadExactBase for ProxyTlsStream { + type I = S; + + fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) { (&mut self.raw, &mut self.read_buf, &mut self.read_pos) } } @@ -334,8 +283,10 @@ impl VerifiedStream { } } -impl ReadExtBase for VerifiedStream { - fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize) { +impl ReadExactBase for VerifiedStream { + type I = S; + + fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) { (&mut self.raw, &mut self.read_buf, &mut self.read_pos) } } diff --git a/clash_lib/src/proxy/vmess/vmess_impl/stream.rs b/clash_lib/src/proxy/vmess/vmess_impl/stream.rs index 87b81bc61..16b592dab 100644 --- a/clash_lib/src/proxy/vmess/vmess_impl/stream.rs +++ b/clash_lib/src/proxy/vmess/vmess_impl/stream.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, mem::MaybeUninit, pin::Pin, task::Poll, time::SystemTime}; +use std::{fmt::Debug, pin::Pin, task::Poll, time::SystemTime}; use aes_gcm::Aes128Gcm; use bytes::{BufMut, BytesMut}; @@ -6,7 +6,7 @@ use chacha20poly1305::ChaCha20Poly1305; use futures::ready; use md5::Md5; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use crate::{ common::{ @@ -78,52 +78,13 @@ enum WriteState { FlushingData(usize, (usize, usize)), } -pub trait ReadExt { - fn poll_read_exact( - &mut self, - cx: &mut std::task::Context, - size: usize, - ) -> Poll>; - #[allow(unused)] - fn get_data(&self) -> &[u8]; -} +use crate::common::io::{ReadExactBase, ReadExt}; -impl ReadExt for VmessStream { - // Read exactly `size` bytes into `read_buf`, starting from position 0. - fn poll_read_exact( - &mut self, - cx: &mut std::task::Context, - size: usize, - ) -> Poll> { - self.read_buf.reserve(size); - unsafe { self.read_buf.set_len(size) } - loop { - if self.read_pos < size { - let dst = unsafe { - &mut *((&mut self.read_buf[self.read_pos..size]) as *mut _ - as *mut [MaybeUninit]) - }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?; - assert_eq!(ptr, buf.filled().as_ptr()); - if buf.filled().is_empty() { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "unexpected eof", - ))); - } - self.read_pos += buf.filled().len(); - } else { - assert!(self.read_pos == size); - self.read_pos = 0; - return Poll::Ready(Ok(())); - } - } - } +impl ReadExactBase for VmessStream { + type I = S; - fn get_data(&self) -> &[u8] { - self.read_buf.as_ref() + fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) { + (&mut self.stream, &mut self.read_buf, &mut self.read_pos) } }