Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(proxy): make poll_read_exact a general trait for vmess and other streams #675

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions clash_lib/src/common/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
use std::future::Future;
use std::{
io,
mem::MaybeUninit,
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use bytes::BytesMut;
use futures::ready;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

Expand Down Expand Up @@ -352,3 +354,56 @@ where
}
.await
}

pub trait ReadExactBase {
/// inner stream to be polled
type I: AsyncRead + Unpin;
/// prepare the inner stream, read buffer and read position
fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize);
}

pub trait ReadExt: ReadExactBase {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>>;
}

impl<T: ReadExactBase> ReadExt for T {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>> {
let (raw, read_buf, read_pos) = self.decompose();
read_buf.reserve(size);
// # safety: read_buf has reserved `size`
unsafe { read_buf.set_len(size) }
loop {
if *read_pos < size {
// # safety: read_pos<size==read_buf.len(), and
// read_buf[0..read_pos] is initialized
let dst = unsafe {
&mut *((&mut read_buf[*read_pos..size]) as *mut _
as *mut [MaybeUninit<u8>])
};
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
ready!(Pin::new(&mut *raw).poll_read(cx, &mut buf))?;
assert_eq!(ptr, buf.filled().as_ptr());
if buf.filled().is_empty() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected eof",
)));
}
*read_pos += buf.filled().len();
} else {
assert!(*read_pos == size);
*read_pos = 0;
return Poll::Ready(Ok(()));
}
}
}
}
105 changes: 0 additions & 105 deletions clash_lib/src/proxy/group/loadbalance/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,108 +167,3 @@ pub fn strategy_sticky_session(proxy_manager: ProxyManager) -> StrategyFn {
})
})
}

#[cfg(test)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these intential?

mod tests {
use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
};

use super::*;
use crate::{
app::remote_content_manager::ProxyManager,
proxy::utils::test_utils::noop::{NoopOutboundHandler, NoopResolver},
session::SocksAddr,
};

macro_rules! assert_cache_state {
($state:expr) => {
assert_eq!(
TEST_LRU_STATE.load(std::sync::atomic::Ordering::Relaxed),
$state
);
};
}

#[tokio::test]
async fn test_sticky_session() {
let resolver = Arc::new(NoopResolver);
let proxies = vec![
Arc::new(NoopOutboundHandler {
name: "a".to_string(),
}) as _,
Arc::new(NoopOutboundHandler {
name: "b".to_string(),
}) as _,
Arc::new(NoopOutboundHandler {
name: "c".to_string(),
}) as _,
];
let manager = ProxyManager::new(resolver);
// if the proxy alive state isn't set, will return true by default
// so we need to clear the alive states first
manager.report_alive("a", false).await;
manager.report_alive("b", false).await;
manager.report_alive("c", false).await;

let mut strategy_fn = strategy_sticky_session(manager.clone());

// all proxies is not alive since we have not setup the proxy manager
let res = strategy_fn(proxies.clone(), &Session::default()).await;
assert!(res.is_err());
assert_cache_state!(CACHE_MISS);

manager.report_alive("a", true).await;
manager.report_alive("b", true).await;
manager.report_alive("c", true).await;

let mut session1 = Session::default();
let src1 = Ipv4Addr::new(127, 0, 0, 1);
let dst1 = Ipv4Addr::new(1, 1, 1, 1);
session1.source = SocketAddr::V4(SocketAddrV4::new(src1, 8964));
session1.destination =
SocksAddr::Ip(SocketAddr::V4(SocketAddrV4::new(dst1, 80)));

// 1.1 first time, cache miss & update
let res = strategy_fn(proxies.clone(), &session1).await;
assert_cache_state!(CACHE_UPDATE);
let session1_outbound_name_1 = res.unwrap().name().to_owned();

// 1.2 second time, cache hit
let res = strategy_fn(proxies.clone(), &session1).await;
assert_eq!(res.unwrap().name(), session1_outbound_name_1);
assert_cache_state!(CACHE_HIT);
// 1.3 third time, cache hit
let res = strategy_fn(proxies.clone(), &session1).await;
assert_eq!(res.unwrap().name(), session1_outbound_name_1);
assert_cache_state!(CACHE_HIT);

// 1.4 change the source address, cache miss & update
let dst1_2 = Ipv4Addr::new(8, 8, 8, 8);
session1.destination =
SocksAddr::Ip(SocketAddr::V4(SocketAddrV4::new(dst1_2, 80)));
let res = strategy_fn(proxies.clone(), &session1).await;
assert_cache_state!(CACHE_UPDATE);
let session1_outbound_name_2 = res.unwrap().name().to_owned();

// 1.5 cache hit
let res = strategy_fn(proxies.clone(), &session1).await;
assert_eq!(res.unwrap().name(), session1_outbound_name_2);
assert_cache_state!(CACHE_HIT);

for i in 1..100 {
// 1.6 change the src address, cache miss & update
let src1_new = Ipv4Addr::new(192, 168, 2, i);
session1.source = SocketAddr::V4(SocketAddrV4::new(src1_new, 8964));
let res = strategy_fn(proxies.clone(), &session1).await;
assert_cache_state!(CACHE_UPDATE);
let session1_outbound_name_new = res.unwrap().name().to_owned();

// 1.6 cache hit
let res = strategy_fn(proxies.clone(), &session1).await;
assert_eq!(res.unwrap().name(), session1_outbound_name_new);
assert_cache_state!(CACHE_HIT);
}
}
}
71 changes: 11 additions & 60 deletions clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::{
mem::MaybeUninit,
pin::Pin,
ptr::{copy, copy_nonoverlapping},
task::{ready, Poll},
};

use byteorder::{BigEndian, WriteBytesExt};
use bytes::{BufMut, BytesMut};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite};

use crate::common::io::{ReadExactBase, ReadExt};

use super::utils::{prelude::*, *};

Expand All @@ -26,60 +27,6 @@ pub enum WriteState {
FlushingData(usize, usize, usize),
}

pub trait AsyncReadUnpin: AsyncRead + Unpin {}

impl<T: AsyncRead + Unpin> AsyncReadUnpin for T {}

pub trait ReadExtBase {
fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize);
}

pub trait ReadExt {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>>;
}

impl<T: ReadExtBase> ReadExt for T {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>> {
let (raw, read_buf, read_pos) = self.prepare();
read_buf.reserve(size);
// # safety: read_buf has reserved `size`
unsafe { read_buf.set_len(size) }
loop {
if *read_pos < size {
// # safety: read_pos<size==read_buf.len(), and
// read_buf[0..read_pos] is initialized
let dst = unsafe {
&mut *((&mut read_buf[*read_pos..size]) as *mut _
as *mut [MaybeUninit<u8>])
};
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
ready!(Pin::new(&mut *raw).poll_read(cx, &mut buf))?;
assert_eq!(ptr, buf.filled().as_ptr());
if buf.filled().is_empty() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected eof",
)));
}
*read_pos += buf.filled().len();
} else {
assert!(*read_pos == size);
*read_pos = 0;
return Poll::Ready(Ok(()));
}
}
}
}

#[derive(Clone, Debug)]
pub struct Certs {
pub(crate) server_random: [u8; TLS_RANDOM_SIZE],
Expand Down Expand Up @@ -139,8 +86,10 @@ impl<S> ProxyTlsStream<S> {
}
}

impl<S: AsyncReadUnpin> ReadExtBase for ProxyTlsStream<S> {
fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize) {
impl<S: AsyncRead + Unpin> ReadExactBase for ProxyTlsStream<S> {
type I = S;

fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) {
(&mut self.raw, &mut self.read_buf, &mut self.read_pos)
}
}
Expand Down Expand Up @@ -334,8 +283,10 @@ impl<S> VerifiedStream<S> {
}
}

impl<S: AsyncReadUnpin> ReadExtBase for VerifiedStream<S> {
fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize) {
impl<S: AsyncRead + Unpin> ReadExactBase for VerifiedStream<S> {
type I = S;

fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) {
(&mut self.raw, &mut self.read_buf, &mut self.read_pos)
}
}
Expand Down
53 changes: 7 additions & 46 deletions clash_lib/src/proxy/vmess/vmess_impl/stream.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::{fmt::Debug, mem::MaybeUninit, pin::Pin, task::Poll, time::SystemTime};
use std::{fmt::Debug, pin::Pin, task::Poll, time::SystemTime};

use aes_gcm::Aes128Gcm;
use bytes::{BufMut, BytesMut};
use chacha20poly1305::ChaCha20Poly1305;
use futures::ready;

use md5::Md5;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};

use crate::{
common::{
Expand Down Expand Up @@ -78,52 +78,13 @@ enum WriteState {
FlushingData(usize, (usize, usize)),
}

pub trait ReadExt {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>>;
#[allow(unused)]
fn get_data(&self) -> &[u8];
}
use crate::common::io::{ReadExactBase, ReadExt};

impl<S: AsyncRead + Unpin> ReadExt for VmessStream<S> {
// Read exactly `size` bytes into `read_buf`, starting from position 0.
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>> {
self.read_buf.reserve(size);
unsafe { self.read_buf.set_len(size) }
loop {
if self.read_pos < size {
let dst = unsafe {
&mut *((&mut self.read_buf[self.read_pos..size]) as *mut _
as *mut [MaybeUninit<u8>])
};
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?;
assert_eq!(ptr, buf.filled().as_ptr());
if buf.filled().is_empty() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected eof",
)));
}
self.read_pos += buf.filled().len();
} else {
assert!(self.read_pos == size);
self.read_pos = 0;
return Poll::Ready(Ok(()));
}
}
}
impl<S: AsyncRead + Unpin> ReadExactBase for VmessStream<S> {
type I = S;

fn get_data(&self) -> &[u8] {
self.read_buf.as_ref()
fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) {
(&mut self.stream, &mut self.read_buf, &mut self.read_pos)
}
}

Expand Down
Loading