Skip to content

Commit

Permalink
feat(net): return the actual number of bytes read in `TracerSocket::r…
Browse files Browse the repository at this point in the history
…ead` and `TracerSocket::recv_from` calls on Windows
  • Loading branch information
fujiapple852 committed Oct 15, 2023
1 parent 744279f commit cd6348b
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/tracing/net/platform/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::tracing::net::platform::windows::adapter::Adapters;
use crate::tracing::net::socket::Socket;
use itertools::Itertools;
use socket2::{Domain, Protocol, SockAddr, Type};
use std::cell::RefCell;
use std::ffi::c_void;
use std::io::{Error, ErrorKind, Result};
use std::mem::{size_of, zeroed};
Expand Down Expand Up @@ -110,6 +111,7 @@ pub struct SocketImpl {
ol: Box<OVERLAPPED>,
buf: Vec<u8>,
from: Box<SOCKADDR_STORAGE>,
bytes_read: Box<RefCell<u32>>,
}

#[allow(clippy::cast_possible_wrap)]
Expand All @@ -129,11 +131,13 @@ impl SocketImpl {
let from = Box::new(Self::new_sockaddr_storage());
let ol = Box::new(Self::new_overlapped());
let buf = vec![0u8; MAX_PACKET_SIZE];
let bytes_read = Box::new(RefCell::new(0));
Ok(Self {
inner,
ol,
buf,
from,
bytes_read,
})
}

Expand Down Expand Up @@ -247,22 +251,23 @@ impl SocketImpl {
}

#[instrument(skip(self))]
fn get_overlapped_result(&self) -> IoResult<(u32, u32)> {
let mut bytes = 0;
fn get_overlapped_result(&self) -> IoResult<()> {
let mut bytes_read = 0;
let mut flags = 0;
let ol = *self.ol;
syscall!(
WSAGetOverlappedResult(
self.inner.as_raw_socket() as _,
addr_of!(ol),
&mut bytes,
&mut bytes_read,
0,
&mut flags,
),
|res| { res == 0 }
)
.map_err(|err| IoError::Other(err, IoOperation::WSAGetOverlappedResult))?;
Ok((bytes, flags))
*self.bytes_read.borrow_mut() = bytes_read;
Ok(())
}

#[allow(unsafe_code)]
Expand Down Expand Up @@ -499,17 +504,13 @@ impl Socket for SocketImpl {
Ok((len, Some(SocketAddr::new(addr, 0))))
}

// TODO
// we always copy and claim to have returned MAX_PACKET_SIZE bytes, regardless of how many bytes
// we actually received. The callers currently ignore this and just try to parse a packet
// from the buffer which isn't ideal. Really we should record the actual number of bytes
// read in the `get_overlapped_result` call and return that here.
#[instrument(skip(self, buf), ret)]
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
buf.copy_from_slice(self.buf.as_slice());
tracing::debug!(buf = format!("{:02x?}", buf[..MAX_PACKET_SIZE].iter().format(" ")));
let bytes_read = *self.bytes_read.borrow() as usize;
tracing::debug!(buf = format!("{:02x?}", buf[..bytes_read].iter().format(" ")));
self.post_recv_from()?;
Ok(MAX_PACKET_SIZE)
Ok(bytes_read)
}

#[instrument(skip(self))]
Expand Down

0 comments on commit cd6348b

Please sign in to comment.