Skip to content

Commit

Permalink
use thiserror instead of unwarp()
Browse files Browse the repository at this point in the history
Signed-off-by: 闹钟大魔王 <[email protected]>
  • Loading branch information
anti-entropy123 committed Sep 29, 2023
1 parent 3d0ff10 commit d816cd7
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 36 deletions.
58 changes: 48 additions & 10 deletions crates/libcontainer/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ pub enum ChannelError {
Serde(#[from] serde_json::Error),
#[error("channel connection broken")]
BrokenChannel,
#[error("Unable to be closed")]
#[error("unable to be closed")]
Unclosed,
#[error("channel has been closed")]
ClosedChannel,
}
#[derive(Clone)]
pub struct Receiver<T> {
Expand Down Expand Up @@ -47,8 +49,13 @@ where
} else {
vec![]
};

let sender = match self.sender.as_ref() {
Some(sender) => sender,
None => Err(ChannelError::ClosedChannel)?,
};
socket::sendmsg::<UnixAddr>(
self.sender.as_ref().unwrap().as_raw_fd(),
sender.as_raw_fd(),
iov,
&cmsgs,
socket::MsgFlags::empty(),
Expand Down Expand Up @@ -91,8 +98,12 @@ where
}

pub fn close(&mut self) -> Result<(), ChannelError> {
let sender = match self.sender.as_ref() {
Some(sender) => sender,
None => Err(ChannelError::ClosedChannel)?,
};
// must ensure that the fd is closed immediately.
let count = Arc::strong_count(self.sender.as_ref().unwrap());
let count = Arc::strong_count(sender);
if count != 1 {
tracing::trace!(?count, "incorrect reference count value");
return Err(ChannelError::Unclosed)?;
Expand All @@ -110,10 +121,16 @@ where
/// `clone()` can cause a leak of references residing on the stack in the
/// childprocess. This function allows for manual adjustment of the counter
/// to correct such situations.
pub unsafe fn decrement_count(&self) {
let rc = Arc::into_raw(Arc::clone(self.sender.as_ref().unwrap()));
pub unsafe fn decrement_count(&self) -> Result<(), ChannelError> {
let sender = match self.sender.as_ref() {
Some(sender) => sender,
None => Err(ChannelError::ClosedChannel)?,
};
let rc = Arc::into_raw(Arc::clone(sender));
Arc::decrement_strong_count(rc);
Arc::from_raw(rc);

Ok(())
}
}

Expand All @@ -129,8 +146,13 @@ where
std::mem::size_of::<u64>(),
)
})];

let receiver = match self.receiver.as_ref() {
Some(receiver) => receiver,
None => Err(ChannelError::ClosedChannel)?,
};
let _ = socket::recvmsg::<UnixAddr>(
self.receiver.as_ref().unwrap().as_raw_fd(),
receiver.as_raw_fd(),
&mut iov,
None,
socket::MsgFlags::MSG_PEEK,
Expand All @@ -149,8 +171,13 @@ where
F: Default + AsMut<[RawFd]>,
{
let mut cmsgspace = nix::cmsg_space!(F);

let receiver = match self.receiver.as_ref() {
Some(receiver) => receiver,
None => Err(ChannelError::ClosedChannel)?,
};
let msg = socket::recvmsg::<UnixAddr>(
self.receiver.as_ref().unwrap().as_raw_fd(),
receiver.as_raw_fd(),
iov,
Some(&mut cmsgspace),
socket::MsgFlags::MSG_CMSG_CLOEXEC,
Expand Down Expand Up @@ -223,8 +250,12 @@ where
}

pub fn close(&mut self) -> Result<(), ChannelError> {
let receiver = match self.receiver.as_ref() {
Some(receiver) => receiver,
None => Err(ChannelError::ClosedChannel)?,
};
// must ensure that the fd is closed immediately.
let count = Arc::strong_count(self.receiver.as_ref().unwrap());
let count = Arc::strong_count(receiver);
if count != 1 {
tracing::trace!(?count, "incorrect reference count value");
return Err(ChannelError::Unclosed)?;
Expand All @@ -238,10 +269,17 @@ where
///
/// # Safety
/// The reason for `unsafe` is same as `Sender::decrement_count()`.
pub unsafe fn decrement_count(&self) {
let rc = Arc::into_raw(Arc::clone(self.receiver.as_ref().unwrap()));
pub unsafe fn decrement_count(&self) -> Result<(), ChannelError> {
let receiver = match self.receiver.as_ref() {
Some(receiver) => receiver,
None => Err(ChannelError::ClosedChannel)?,
};

let rc = Arc::into_raw(Arc::clone(receiver));
Arc::decrement_strong_count(rc);
Arc::from_raw(rc);

Ok(())
}
}

Expand Down
30 changes: 20 additions & 10 deletions crates/libcontainer/src/process/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ impl MainSender {
///
/// # Safety
/// The reason for `unsafe` is same as `Sender::decrement_count()`.
pub unsafe fn decrement_count(&self) {
self.sender.decrement_count()
pub unsafe fn decrement_count(&self) -> Result<(), ChannelError> {
self.sender
.decrement_count()
.map_err(ChannelError::BaseChannelError)
}
}

Expand Down Expand Up @@ -226,8 +228,10 @@ impl IntermediateSender {
///
/// # Safety
/// The reason for `unsafe` is same as `Sender::decrement_count()`.
pub unsafe fn decrement_count(&self) {
self.sender.decrement_count()
pub unsafe fn decrement_count(&self) -> Result<(), ChannelError> {
self.sender
.decrement_count()
.map_err(ChannelError::BaseChannelError)
}
}

Expand Down Expand Up @@ -266,8 +270,10 @@ impl IntermediateReceiver {
///
/// # Safety
/// The reason for `unsafe` is same as `Receiver::decrement_count()`.
pub unsafe fn decrement_count(&self) {
self.receiver.decrement_count()
pub unsafe fn decrement_count(&self) -> Result<(), ChannelError> {
self.receiver
.decrement_count()
.map_err(ChannelError::BaseChannelError)
}
}

Expand Down Expand Up @@ -298,8 +304,10 @@ impl InitSender {
///
/// # Safety
/// The reason for `unsafe` is same as `Sender::decrement_count()`.
pub unsafe fn decrement_count(&self) {
self.sender.decrement_count()
pub unsafe fn decrement_count(&self) -> Result<(), ChannelError> {
self.sender
.decrement_count()
.map_err(ChannelError::BaseChannelError)
}
}

Expand Down Expand Up @@ -337,8 +345,10 @@ impl InitReceiver {
///
/// # Safety
/// The reason for `unsafe` is same as `Receiver::decrement_count()`.
pub unsafe fn decrement_count(&self) {
self.receiver.decrement_count()
pub unsafe fn decrement_count(&self) -> Result<(), ChannelError> {
self.receiver
.decrement_count()
.map_err(ChannelError::BaseChannelError)
}
}

Expand Down
24 changes: 17 additions & 7 deletions crates/libcontainer/src/process/container_intermediate_process.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::error::MissingSpecError;
use crate::process::channel::ChannelError;
use crate::{namespaces::Namespaces, process::channel, process::fork};
use libcgroups::common::CgroupManager;
use nix::unistd::{close, write};
Expand Down Expand Up @@ -124,13 +125,22 @@ pub fn container_intermediate_process(

// Must clean up reference counts that are located on the stack.
// Please refer to the explanation within `container_main_process()`.
unsafe {
args.decrement_count();
init_sender.decrement_count();
inter_sender.decrement_count();
main_sender.decrement_count();
init_receiver.decrement_count();
}
match (|| {
unsafe {
args.decrement_count();
init_sender.decrement_count()?;
inter_sender.decrement_count()?;
main_sender.decrement_count()?;
init_receiver.decrement_count()?;
}
Ok::<(), ChannelError>(())
})() {
Ok(_) => (),
Err(err) => {
tracing::error!(?err, "channel status error");
return -1;
}
};

// We are inside the forked process here. The first thing we have to do
// is to close any unused senders, since fork will make a dup for all
Expand Down
28 changes: 19 additions & 9 deletions crates/libcontainer/src/process/container_main_process.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{
process::{
args::ContainerArgs,
channel, container_intermediate_process,
channel::{self, ChannelError},
container_intermediate_process,
fork::{self, CloneCb},
intel_rdt::setup_intel_rdt,
},
Expand Down Expand Up @@ -64,14 +65,23 @@ pub fn container_main_process(container_args: &ContainerArgs) -> Result<(Pid, bo
// in turn, can lead to delayed closure of file descriptors. The
// following code is equivalent to executing a drop on those reference
// counters.
unsafe {
container_args.decrement_count();
main_sender.decrement_count();
inter_chan.0.decrement_count();
inter_chan.1.decrement_count();
init_chan.0.decrement_count();
init_chan.1.decrement_count();
}
match (|| {
unsafe {
container_args.decrement_count();
main_sender.decrement_count()?;
inter_chan.0.decrement_count()?;
inter_chan.1.decrement_count()?;
init_chan.0.decrement_count()?;
init_chan.1.decrement_count()?;
}
Ok::<(), ChannelError>(())
})() {
Ok(_) => (),
Err(err) => {
tracing::error!(?err, "channel status error");
return -1;
}
};

match container_intermediate_process::container_intermediate_process(
&container_args,
Expand Down

0 comments on commit d816cd7

Please sign in to comment.