Skip to content

Commit

Permalink
refactor: refactor parameters with full features
Browse files Browse the repository at this point in the history
  • Loading branch information
huster-zhangpeng committed Dec 3, 2024
1 parent fc5d8ad commit 56f3859
Show file tree
Hide file tree
Showing 23 changed files with 1,254 additions and 1,038 deletions.
18 changes: 11 additions & 7 deletions gm-quic/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{

use qbase::{
cid::ConnectionId,
param::{ClientParameters, Parameters},
param::{ClientParameters, CommonParameters},
sid::{handy::ConsistentConcurrency, ControlConcurrency},
token::{ArcTokenRegistry, TokenSink},
};
Expand All @@ -29,7 +29,8 @@ pub struct QuicClient {
_reuse_connection: bool, // TODO
_enable_happy_eyepballs: bool,
_prefer_versions: Vec<u32>,
parameters: Parameters,
parameters: ClientParameters,
remembered: Option<CommonParameters>,
tls_config: Arc<TlsClientConfig>,
streams_controller: Box<dyn Fn(u64, u64) -> Box<dyn ControlConcurrency> + Send + Sync>,
token_sink: Option<Arc<dyn TokenSink>>,
Expand Down Expand Up @@ -63,7 +64,7 @@ impl QuicClient {
reuse_connection: true,
enable_happy_eyepballs: false,
prefer_versions: vec![1],
parameters: Parameters::default(),
parameters: ClientParameters::default(),
tls_config: TlsClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]),
streams_controller: Box::new(|bi, uni| Box::new(ConsistentConcurrency::new(bi, uni))),
token_sink: None,
Expand All @@ -80,7 +81,7 @@ impl QuicClient {
reuse_connection: true,
enable_happy_eyepballs: false,
prefer_versions: vec![1],
parameters: Parameters::default(),
parameters: ClientParameters::default(),
tls_config: TlsClientConfig::builder_with_provider(provider)
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap(),
Expand All @@ -99,7 +100,7 @@ impl QuicClient {
reuse_connection: true,
enable_happy_eyepballs: false,
prefer_versions: vec![1],
parameters: Parameters::default(),
parameters: ClientParameters::default(),
tls_config,
streams_controller: Box::new(|bi, uni| Box::new(ConsistentConcurrency::new(bi, uni))),
token_sink: None,
Expand Down Expand Up @@ -226,6 +227,7 @@ impl QuicClient {
initial_scid,
server_name,
self.parameters,
self.remembered,
streams_ctrl,
tls_config,
token_registry,
Expand All @@ -246,7 +248,7 @@ pub struct QuicClientBuilder<T> {
reuse_connection: bool,
enable_happy_eyepballs: bool,
prefer_versions: Vec<u32>,
parameters: Parameters,
parameters: ClientParameters,
tls_config: T,
streams_controller: Box<dyn Fn(u64, u64) -> Box<dyn ControlConcurrency> + Send + Sync>,
token_sink: Option<Arc<dyn TokenSink>>,
Expand Down Expand Up @@ -309,7 +311,7 @@ impl<T> QuicClientBuilder<T> {
///
/// [transport parameters](https://www.rfc-editor.org/rfc/rfc9000.html#name-transport-parameter-definit)
pub fn with_parameters(mut self, parameters: ClientParameters) -> Self {
self.parameters = parameters.into();
self.parameters = parameters;
self
}

Expand Down Expand Up @@ -495,6 +497,8 @@ impl QuicClientBuilder<TlsClientConfig> {
_enable_happy_eyepballs: self.enable_happy_eyepballs,
_prefer_versions: self.prefer_versions,
parameters: self.parameters,
// TODO: 要能加载上次连接的parameters
remembered: None,
tls_config: Arc::new(self.tls_config),
streams_controller: self.streams_controller,
token_sink: self.token_sink,
Expand Down
6 changes: 4 additions & 2 deletions gm-quic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ pub struct QuicConnection {

impl QuicConnection {
#[inline]
pub async fn accept_bi_stream(&self) -> io::Result<(StreamId, (StreamReader, StreamWriter))> {
pub async fn accept_bi_stream(
&self,
) -> io::Result<Option<(StreamId, (StreamReader, StreamWriter))>> {
self.inner.accept_bi_stream().await
}

Expand All @@ -68,7 +70,7 @@ impl QuicConnection {
}

#[inline]
pub async fn datagram_writer(&self) -> io::Result<qunreliable::UnreliableWriter> {
pub async fn datagram_writer(&self) -> io::Result<Option<qunreliable::UnreliableWriter>> {
self.inner.datagram_writer().await
}

Expand Down
16 changes: 8 additions & 8 deletions gm-quic/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use dashmap::DashMap;
use qbase::{
cid::ConnectionId,
packet::{header::GetScid, long, DataHeader, DataPacket, InitialHeader, RetryHeader},
param::{Parameters, ServerParameters},
param::ServerParameters,
sid::{handy::ConsistentConcurrency, ControlConcurrency},
token::{ArcTokenRegistry, TokenProvider},
};
Expand Down Expand Up @@ -59,7 +59,7 @@ pub struct QuicServer {
passive_listening: bool,
_supported_versions: Vec<u32>,
_load_balance: Arc<dyn Fn(InitialHeader) -> Option<RetryHeader> + Send + Sync + 'static>,
parameters: Parameters,
parameters: ServerParameters,
tls_config: Arc<TlsServerConfig>,
streams_controller:
Box<dyn Fn(u64, u64) -> Box<dyn ControlConcurrency> + Send + Sync + 'static>,
Expand All @@ -73,7 +73,7 @@ impl QuicServer {
passive_listening: false,
supported_versions: Vec::with_capacity(2),
load_balance: Arc::new(|_| None),
parameters: Parameters::default(),
parameters: ServerParameters::default(),
tls_config: TlsServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]),
streams_controller: Box::new(|bi, uni| Box::new(ConsistentConcurrency::new(bi, uni))),
token_provider: None,
Expand All @@ -88,7 +88,7 @@ impl QuicServer {
passive_listening: false,
supported_versions: Vec::with_capacity(2),
load_balance: Arc::new(|_| None),
parameters: Parameters::default(),
parameters: ServerParameters::default(),
tls_config,
streams_controller: Box::new(|bi, uni| Box::new(ConsistentConcurrency::new(bi, uni))),
token_provider: None,
Expand All @@ -103,7 +103,7 @@ impl QuicServer {
passive_listening: false,
supported_versions: Vec::with_capacity(2),
load_balance: Arc::new(|_| None),
parameters: Parameters::default(),
parameters: ServerParameters::default(),
tls_config: TlsServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap(),
Expand Down Expand Up @@ -249,7 +249,7 @@ pub struct QuicServerBuilder<T> {
supported_versions: Vec<u32>,
passive_listening: bool,
load_balance: Arc<dyn Fn(InitialHeader) -> Option<RetryHeader> + Send + Sync + 'static>,
parameters: Parameters,
parameters: ServerParameters,
tls_config: T,
streams_controller:
Box<dyn Fn(u64, u64) -> Box<dyn ControlConcurrency> + Send + Sync + 'static>,
Expand All @@ -262,7 +262,7 @@ pub struct QuicServerSniBuilder<T> {
passive_listening: bool,
load_balance: Arc<dyn Fn(InitialHeader) -> Option<RetryHeader> + Send + Sync + 'static>,
hosts: Arc<DashMap<String, Host>>,
parameters: Parameters,
parameters: ServerParameters,
tls_config: T,
streams_controller:
Box<dyn Fn(u64, u64) -> Box<dyn ControlConcurrency> + Send + Sync + 'static>,
Expand Down Expand Up @@ -315,7 +315,7 @@ impl<T> QuicServerBuilder<T> {
///
/// [transport parameters](https://www.rfc-editor.org/rfc/rfc9000.html#name-transport-parameter-definit)
pub fn with_parameters(mut self, parameters: ServerParameters) -> Self {
self.parameters = parameters.into();
self.parameters = parameters;
self
}

Expand Down
13 changes: 11 additions & 2 deletions h3-shim/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ impl QuicConnection {
accpet_uni: AcceptUniStreams::new(conn.clone()),
open_bi: OpenBiStreams::new(conn.clone()),
open_uni: OpenUniStreams::new(conn.clone()),
send_datagram: SendDatagram(conn.datagram_writer().await.map_err(Into::into)),
send_datagram: SendDatagram(
conn.datagram_writer()
.await
.map(Option::unwrap)
.map_err(Into::into),
),
recv_datagram: RecvDatagram(conn.datagram_reader().map_err(Into::into)),
connection: conn,
}
Expand Down Expand Up @@ -247,7 +252,11 @@ struct AcceptBiStreams(BoxStream<Result<(StreamId, (StreamReader, StreamWriter))
impl AcceptBiStreams {
fn new(conn: Arc<gm_quic::QuicConnection>) -> Self {
let stream = futures::stream::unfold(conn, |conn| async {
let bidi = conn.accept_bi_stream().await.map_err(Into::into);
let bidi = conn
.accept_bi_stream()
.await
.map(Option::unwrap)
.map_err(Into::into);
if bidi.is_err() && !conn.is_active() {
return None;
}
Expand Down
1 change: 1 addition & 0 deletions h3-shim/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod server_example {
}

#[tokio::test]
#[ignore]
async fn h3_test() {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
Expand Down
16 changes: 13 additions & 3 deletions qbase/src/cid/connection_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,23 @@ impl ConnectionId {
/// See [`ConnectionId`].
pub fn be_connection_id(input: &[u8]) -> IResult<&[u8], ConnectionId> {
let (remain, len) = be_u8(input)?;
if len as usize > MAX_CID_SIZE {
be_connection_id_with_len(remain, len as usize)
}

/// Parse a given `len` connection ID from the input buffer,
/// [nom](https://docs.rs/nom/latest/nom/) parser style.
///
/// ## Note:
///
/// The connection ID length is limited to 20 bytes, or it will return an error.
pub fn be_connection_id_with_len(input: &[u8], len: usize) -> IResult<&[u8], ConnectionId> {
if len > MAX_CID_SIZE {
return Err(nom::Err::Error(nom::error::make_error(
remain,
input,
nom::error::ErrorKind::TooLarge,
)));
}
let (remain, bytes) = take(len as usize)(remain)?;
let (remain, bytes) = take(len)(input)?;
Ok((remain, ConnectionId::from_slice(bytes)))
}

Expand Down
24 changes: 9 additions & 15 deletions qbase/src/frame/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,15 @@ impl From<Error> for TransportError {
fn from(e: Error) -> Self {
match e {
// An endpoint MUST treat receipt of a packet containing no frames as a connection error of type PROTOCOL_VIOLATION.
Error::NoFrames => Self::new(
TransportErrorKind::ProtocolViolation,
FrameType::Padding,
e.to_string(),
),
Error::IncompleteType(_) => Self::new(
TransportErrorKind::FrameEncoding,
FrameType::Padding,
e.to_string(),
),
Error::InvalidType(_) => Self::new(
TransportErrorKind::FrameEncoding,
FrameType::Padding,
e.to_string(),
),
Error::NoFrames => {
Self::with_default_fty(TransportErrorKind::ProtocolViolation, e.to_string())
}
Error::IncompleteType(_) => {
Self::with_default_fty(TransportErrorKind::FrameEncoding, e.to_string())
}
Error::InvalidType(_) => {
Self::with_default_fty(TransportErrorKind::FrameEncoding, e.to_string())
}
Error::WrongType(fty, _) => {
Self::new(TransportErrorKind::FrameEncoding, fty, e.to_string())
}
Expand Down
7 changes: 7 additions & 0 deletions qbase/src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ where
}
}

pub fn done(&self) {
match self {
Handshake::Client(..) => (), /* for client, do nothing */
Handshake::Server(h) => h.done(),
}
}

/// Return the role of this handshake signal.
pub fn role(&self) -> Role {
match self {
Expand Down
46 changes: 10 additions & 36 deletions qbase/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,33 +397,7 @@ mod tests {
use super::*;
use crate::frame::CryptoFrame;

struct TransparentKeys(rustls::quic::DirectionalKeys);

impl TransparentKeys {
pub fn new() -> Self {
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install ring crypto provider");
let suite = rustls::crypto::ring::default_provider()
.cipher_suites
.iter()
.find_map(|cs| match (cs.suite(), cs.tls13()) {
(rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
Some(suite.quic_suite())
}
_ => None,
})
.flatten()
.unwrap();

let pk = suite.keys(
"meanless".as_bytes(),
rustls::Side::Client,
rustls::quic::Version::V1,
);
TransparentKeys(pk.local)
}
}
struct TransparentKeys;

impl rustls::quic::PacketKey for TransparentKeys {
fn decrypt_in_place<'a>(
Expand All @@ -445,15 +419,15 @@ mod tests {
}

fn confidentiality_limit(&self) -> u64 {
self.0.packet.confidentiality_limit()
0
}

fn integrity_limit(&self) -> u64 {
self.0.packet.integrity_limit()
0
}

fn tag_len(&self) -> usize {
self.0.packet.tag_len()
16
}
}

Expand All @@ -477,7 +451,7 @@ mod tests {
}

fn sample_len(&self) -> usize {
self.0.header.sample_len()
20
}
}

Expand All @@ -488,7 +462,7 @@ mod tests {
ConnectionId::from_slice("testdcid".as_bytes()),
ConnectionId::from_slice("testscid".as_bytes()),
)
.initial(Vec::with_capacity(0));
.initial(b"test_token".to_vec());

let pn = (0, PacketNumber::encode(0, 0));
let tag_len = 16;
Expand All @@ -503,11 +477,10 @@ mod tests {
assert!(writer.is_ack_eliciting());
assert!(writer.in_flight());

let transparent_keys = TransparentKeys::new();
let packet = writer.encrypt_long_packet(&transparent_keys, &transparent_keys);
let packet = writer.encrypt_long_packet(&TransparentKeys, &TransparentKeys);
assert!(packet.is_ack_eliciting());
assert!(packet.in_flight());
assert_eq!(packet.len(), 58);
assert_eq!(packet.len(), 68);
assert_eq!(
packet.deref(),
[
Expand All @@ -525,7 +498,8 @@ mod tests {
// source connection id, "testscid"
8, // scid length
b't', b'e', b's', b't', b's', b'c', b'i', b'd', // scid bytes
0, // token length, no token
10, // token length, no token
b't', b'e', b's', b't', b'_', b't', b'o', b'k', b'e', b'n', // token bytes
64, 16, // payload length, 2 bytes encoded varint
0, // encoded packet number
// crypto frame header
Expand Down
Loading

0 comments on commit 56f3859

Please sign in to comment.