Skip to content

Commit

Permalink
refactor(credssp): follow up to #260 (#287)
Browse files Browse the repository at this point in the history
  • Loading branch information
CBenoit authored Nov 17, 2023
1 parent 5530550 commit 8fc213e
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 206 deletions.
45 changes: 24 additions & 21 deletions crates/ironrdp-async/src/connector.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use ironrdp_connector::credssp::{CredsspProcessGenerator, CredsspSequence, KerberosConfig};
use ironrdp_connector::sspi::credssp::ClientState;
use ironrdp_connector::sspi::generator::GeneratorState;
use ironrdp_connector::{
credssp_sequence::{CredsspProcessGenerator, CredsspSequence},
custom_err,
sspi::{credssp::ClientState, generator::GeneratorState},
ClientConnector, ClientConnectorState, ConnectionResult, ConnectorResult, KerberosConfig, Sequence as _,
ServerName, State as _, Written,
custom_err, ClientConnector, ClientConnectorState, ConnectionResult, ConnectorError, ConnectorResult,
Sequence as _, ServerName, State as _,
};
use ironrdp_pdu::write_buf::WriteBuf;

use crate::{
framed::{Framed, FramedRead, FramedWrite},
AsyncNetworkClient,
};
use crate::framed::{Framed, FramedRead, FramedWrite};
use crate::AsyncNetworkClient;

#[non_exhaustive]
pub struct ShouldUpgrade;
Expand Down Expand Up @@ -50,10 +48,10 @@ pub fn mark_as_upgraded(_: ShouldUpgrade, connector: &mut ClientConnector) -> Up
pub async fn connect_finalize<S>(
_: Upgraded,
framed: &mut Framed<S>,
mut connector: ClientConnector,
server_name: ServerName,
server_public_key: Vec<u8>,
network_client: Option<&mut dyn AsyncNetworkClient>,
mut connector: ClientConnector,
kerberos_config: Option<KerberosConfig>,
) -> ConnectorResult<ConnectionResult>
where
Expand Down Expand Up @@ -92,20 +90,22 @@ async fn resolve_generator(
network_client: &mut dyn AsyncNetworkClient,
) -> ConnectorResult<ClientState> {
let mut state = generator.start();

loop {
match state {
GeneratorState::Suspended(request) => {
let response = network_client.send(&request).await?;
state = generator.resume(Ok(response));
}
GeneratorState::Completed(client_state) => {
break Ok(client_state.map_err(|e| custom_err!("cannot resolve generator state", e))?)
break client_state
.map_err(|e| ConnectorError::new("CredSSP", ironrdp_connector::ConnectorErrorKind::Credssp(e)))
}
}
}
}

#[instrument(level = "trace", skip(network_client, framed, buf, server_name, server_public_key))]
#[instrument(level = "trace", skip_all)]
async fn perform_credssp_step<S>(
framed: &mut Framed<S>,
connector: &mut ClientConnector,
Expand All @@ -119,10 +119,13 @@ where
S: FramedRead + FramedWrite,
{
assert!(connector.should_perform_credssp());

let mut credssp_sequence = CredsspSequence::new(connector, server_name, server_public_key, kerberos_config)?;

while !credssp_sequence.is_done() {
buf.clear();
let input = if let Some(next_pdu_hint) = credssp_sequence.next_pdu_hint() {

if let Some(next_pdu_hint) = credssp_sequence.next_pdu_hint() {
debug!(
connector.state = connector.state.name(),
hint = ?next_pdu_hint,
Expand All @@ -135,25 +138,23 @@ where
.map_err(|e| ironrdp_connector::custom_err!("read frame by hint", e))?;

trace!(length = pdu.len(), "PDU received");
Some(pdu.to_vec())
} else {
None
};

if credssp_sequence.wants_request_from_server() {
credssp_sequence.read_request_from_server(&input.unwrap_or_else(|| [].to_vec()))?;
credssp_sequence.read_request_from_server(&pdu)?;
}

let client_state = {
let mut generator = credssp_sequence.process();

if let Some(network_client_ref) = network_client.as_deref_mut() {
trace!("resolving network");
resolve_generator(&mut generator, network_client_ref).await?
} else {
generator
.resolve_to_result()
.map_err(|e| custom_err!(" cannot resolve generator without a network client", e))?
.map_err(|e| custom_err!("resolve without network client", e))?
}
}; // drop generator

let written = credssp_sequence.handle_process_result(client_state, buf)?;

if let Some(response_len) = written.size() {
Expand All @@ -165,7 +166,9 @@ where
.map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
}
}

connector.mark_credssp_as_done();

Ok(())
}

Expand All @@ -179,7 +182,7 @@ where
{
buf.clear();

let written: Written = if let Some(next_pdu_hint) = connector.next_pdu_hint() {
let written = if let Some(next_pdu_hint) = connector.next_pdu_hint() {
debug!(
connector.state = connector.state.name(),
hint = ?next_pdu_hint,
Expand Down
44 changes: 23 additions & 21 deletions crates/ironrdp-blocking/src/connector.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::io::{Read, Write};

use ironrdp_connector::credssp::{CredsspProcessGenerator, CredsspSequence, KerberosConfig};
use ironrdp_connector::sspi::credssp::ClientState;
use ironrdp_connector::sspi::generator::GeneratorState;
use ironrdp_connector::sspi::network_client::NetworkClient;
use ironrdp_connector::{
credssp_sequence::{CredsspProcessGenerator, CredsspSequence},
custom_err,
sspi::{credssp::ClientState, generator::GeneratorState, network_client::NetworkClient},
ClientConnector, ClientConnectorState, ConnectionResult, ConnectorResult, KerberosConfig, Sequence as _,
ClientConnector, ClientConnectorState, ConnectionResult, ConnectorError, ConnectorResult, Sequence as _,
ServerName, State as _,
};
use ironrdp_pdu::write_buf::WriteBuf;
Expand Down Expand Up @@ -49,10 +50,10 @@ pub fn mark_as_upgraded(_: ShouldUpgrade, connector: &mut ClientConnector) -> Up
pub fn connect_finalize<S>(
_: Upgraded,
framed: &mut Framed<S>,
mut connector: ClientConnector,
server_name: ServerName,
server_public_key: Vec<u8>,
network_client: &mut impl NetworkClient,
mut connector: ClientConnector,
kerberos_config: Option<KerberosConfig>,
) -> ConnectorResult<ConnectionResult>
where
Expand Down Expand Up @@ -89,28 +90,27 @@ where
Ok(result)
}

#[instrument(level = "info", skip(generator, network_client))]
fn resolve_generator(
generator: &mut CredsspProcessGenerator<'_>,
network_client: &mut impl NetworkClient,
) -> ConnectorResult<ClientState> {
let mut state = generator.start();
let res = loop {

loop {
match state {
GeneratorState::Suspended(request) => {
let response = network_client.send(&request).unwrap();
state = generator.resume(Ok(response));
}
GeneratorState::Completed(client_state) => {
break client_state.map_err(|e| custom_err!("failed to resolve generator", e))?
break client_state
.map_err(|e| ConnectorError::new("CredSSP", ironrdp_connector::ConnectorErrorKind::Credssp(e)))
}
}
};
debug!("client state = {:?}", &res);
Ok(res)
}
}

#[instrument(level = "trace", skip(network_client, framed, buf, server_name, server_public_key))]
#[instrument(level = "trace", skip_all)]
fn perform_credssp_step<S>(
framed: &mut Framed<S>,
connector: &mut ClientConnector,
Expand All @@ -124,10 +124,13 @@ where
S: Read + Write,
{
assert!(connector.should_perform_credssp());

let mut credssp_sequence = CredsspSequence::new(connector, server_name, server_public_key, kerberos_config)?;

while !credssp_sequence.is_done() {
buf.clear();
let input = if let Some(next_pdu_hint) = credssp_sequence.next_pdu_hint() {

if let Some(next_pdu_hint) = credssp_sequence.next_pdu_hint() {
debug!(
connector.state = connector.state.name(),
hint = ?next_pdu_hint,
Expand All @@ -139,18 +142,15 @@ where
.map_err(|e| ironrdp_connector::custom_err!("read frame by hint", e))?;

trace!(length = pdu.len(), "PDU received");
Some(pdu.to_vec())
} else {
None
};

if credssp_sequence.wants_request_from_server() {
credssp_sequence.read_request_from_server(&input.unwrap_or_else(|| [].to_vec()))?;
credssp_sequence.read_request_from_server(&pdu)?;
}

let client_state = {
let mut generator = credssp_sequence.process();
resolve_generator(&mut generator, network_client)?
}; // drop generator

let written = credssp_sequence.handle_process_result(client_state, buf)?;

if let Some(response_len) = written.size() {
Expand All @@ -161,15 +161,17 @@ where
.map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
}
}

connector.mark_credssp_as_done();

Ok(())
}

pub fn single_connect_step<S>(
framed: &mut Framed<S>,
connector: &mut ClientConnector,
buf: &mut WriteBuf,
) -> ConnectorResult<ironrdp_connector::Written>
) -> ConnectorResult<()>
where
S: Read + Write,
{
Expand Down Expand Up @@ -201,5 +203,5 @@ where
.map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
}

Ok(written)
Ok(())
}
44 changes: 22 additions & 22 deletions crates/ironrdp-client/src/network_client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use sspi::{Error, ErrorKind};
use std::future::Future;
use std::net::{IpAddr, Ipv4Addr};
use std::{future::Future, pin::Pin};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
use std::pin::Pin;

use ironrdp::connector::{custom_err, general_err, ConnectorResult};
use ironrdp::connector::{custom_err, ConnectorResult};
use ironrdp_tokio::AsyncNetworkClient;
use reqwest::Client;
use sspi::{Error, ErrorKind};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
use url::Url;

use ironrdp_tokio::AsyncNetworkClient;
pub(crate) struct ReqwestNetworkClient {
client: Option<Client>,
}

impl AsyncNetworkClient for ReqwestNetworkClient {
fn send<'a>(
&'a mut self,
Expand All @@ -38,22 +40,23 @@ impl ReqwestNetworkClient {
impl ReqwestNetworkClient {
async fn send_tcp(&self, url: &Url, data: &[u8]) -> ConnectorResult<Vec<u8>> {
let addr = format!("{}:{}", url.host_str().unwrap_or_default(), url.port().unwrap_or(88));

let mut stream = TcpStream::connect(addr)
.await
.map_err(|e| Error::new(ErrorKind::NoAuthenticatingAuthority, format!("{:?}", e)))
.map_err(|e| custom_err!("sending KDC request over TCP ", e))?;
.map_err(|e| custom_err!("failed to send KDC request over TCP", e))?;

stream
.write(data)
.await
.map_err(|e| Error::new(ErrorKind::NoAuthenticatingAuthority, format!("{:?}", e)))
.map_err(|e| custom_err!("Sending KDC request over TCP ", e))?;
.map_err(|e| custom_err!("failed to send KDC request over TCP", e))?;

let len = stream
.read_u32()
.await
.map_err(|e| Error::new(ErrorKind::NoAuthenticatingAuthority, format!("{:?}", e)))
.map_err(|e| custom_err!("Sending KDC request over TCP ", e))?;
.map_err(|e| custom_err!("failed to send KDC request over TCP", e))?;

let mut buf = vec![0; len as usize + 4];
buf[0..4].copy_from_slice(&(len.to_be_bytes()));
Expand All @@ -62,29 +65,30 @@ impl ReqwestNetworkClient {
.read_exact(&mut buf[4..])
.await
.map_err(|e| Error::new(ErrorKind::NoAuthenticatingAuthority, format!("{:?}", e)))
.map_err(|e| custom_err!("Sending KDC request over TCP ", e))?;
.map_err(|e| custom_err!("failed to send KDC request over TCP", e))?;

Ok(buf)
}

async fn send_udp(&self, url: &Url, data: &[u8]) -> ConnectorResult<Vec<u8>> {
let udp_socket = UdpSocket::bind((IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
.await
.map_err(|e| custom_err!("Cannot bind udp socket", e))?;
.map_err(|e| custom_err!("cannot bind UDP socket", e))?;

let addr = format!("{}:{}", url.host_str().unwrap_or_default(), url.port().unwrap_or(88));

udp_socket
.send_to(data, addr)
.await
.map_err(|e| custom_err!("Error sending udp request", e))?;
.map_err(|e| custom_err!("failed to send UDP request", e))?;

// 48 000 bytes: default maximum token len in Windows
let mut buf = vec![0; 0xbb80];

let n = udp_socket
.recv(&mut buf)
.await
.map_err(|e| custom_err!("Error receiving UDP request", e))?;
.map_err(|e| custom_err!("failed to receive UDP request", e))?;

let mut reply_buf = Vec::with_capacity(n + 4);
reply_buf.extend_from_slice(&(n as u32).to_be_bytes());
Expand All @@ -94,21 +98,17 @@ impl ReqwestNetworkClient {
}

async fn send_http(&mut self, url: &Url, data: &[u8]) -> ConnectorResult<Vec<u8>> {
if self.client.is_none() {
self.client = Some(Client::new()); // dont drop the cllient, keep-alive
}
let result_bytes = self
.client
.as_ref()
.ok_or_else(|| general_err!("Missing HTTP client, should never happen"))?
let client = self.client.get_or_insert_with(Client::new);

let result_bytes = client
.post(url.clone())
.body(data.to_vec())
.send()
.await
.map_err(|e| custom_err!("Sending KDC request over proxy", e))?
.map_err(|e| custom_err!("failed to send KDC request over proxy", e))?
.bytes()
.await
.map_err(|e| custom_err!("Receving KDC response", e))?
.map_err(|e| custom_err!("failed to receive KDC response", e))?
.to_vec();

Ok(result_bytes)
Expand Down
2 changes: 1 addition & 1 deletion crates/ironrdp-client/src/rdp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ async fn connect(
let connection_result = ironrdp_tokio::connect_finalize(
upgraded,
&mut upgraded_framed,
connector,
(&config.destination).into(),
server_public_key,
Some(&mut network_client),
connector,
None,
)
.await?;
Expand Down
Loading

0 comments on commit 8fc213e

Please sign in to comment.