diff --git a/Cargo.toml b/Cargo.toml index 7fcce67..aa73af7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,12 +19,20 @@ async-lock = "3.0.0" cfg-if = "1.0" event-listener = "4.0.0" futures-lite = "2.0.0" +tracing = { version = "0.1.40", default-features = false } [target.'cfg(unix)'.dependencies] async-io = "2.1.0" -async-signal = "0.2.3" rustix = { version = "0.38", default-features = false, features = ["std", "fs"] } +[target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies] +async-channel = "2.0.0" +async-task = "4.7.0" + +[target.'cfg(all(unix, not(any(target_os = "linux", target_os = "android"))))'.dependencies] +async-signal = "0.2.3" +rustix = { version = "0.38", default-features = false, features = ["std", "fs", "process"] } + [target.'cfg(windows)'.dependencies] async-channel = "2.0.0" blocking = "1.0.0" diff --git a/src/lib.rs b/src/lib.rs index 4988834..86ada29 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -80,8 +80,18 @@ use futures_lite::{future, io, prelude::*}; #[doc(no_inline)] pub use std::process::{ExitStatus, Output, Stdio}; -#[path = "reaper/signal.rs"] -mod reaper; +cfg_if::cfg_if! { + if #[cfg(any( + target_os = "linux", + target_os = "android" + ))] { + #[path = "reaper/wait.rs"] + mod reaper; + } else { + #[path = "reaper/signal.rs"] + mod reaper; + } +} #[cfg(unix)] pub mod unix; @@ -373,9 +383,7 @@ impl Child { self.stdin.take(); let child = self.child.clone(); - async move { - Reaper::get().sys.status(&child).await - } + async move { Reaper::get().sys.status(&child).await } } /// Drops the stdin handle and collects the output of the process. diff --git a/src/reaper/signal.rs b/src/reaper/signal.rs index e5d7c6c..fd31d3e 100644 --- a/src/reaper/signal.rs +++ b/src/reaper/signal.rs @@ -85,7 +85,10 @@ impl Reaper { } /// Wait for an event to occur for a child process. - pub(crate) async fn status(&'static self, child: &Mutex) -> io::Result { + pub(crate) async fn status( + &'static self, + child: &Mutex, + ) -> io::Result { let listener = EventListener::new(); futures_lite::pin!(listener); diff --git a/src/reaper/wait.rs b/src/reaper/wait.rs new file mode 100644 index 0000000..0a064c1 --- /dev/null +++ b/src/reaper/wait.rs @@ -0,0 +1,190 @@ +//! A version of the reaper that waits on some polling primitive. +//! +//! This uses: +//! +//! - pidfd on Linux/Android + +use async_channel::{Receiver, Sender}; +use async_task::Runnable; +use futures_lite::future; + +use std::io; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Mutex; +use std::task::{Context, Poll}; + +pub(crate) type Lock = (); + +/// The zombie process reaper. +pub(crate) struct Reaper { + /// The channel for sending new runnables. + sender: Sender, + + /// The channel for receiving new runnables. + recv: Receiver, + + /// Number of zombie processes. + zombies: AtomicU64, +} + +impl Reaper { + /// Create a new reaper. + pub(crate) fn new() -> Self { + let (sender, recv) = async_channel::unbounded(); + Self { + sender, + recv, + zombies: AtomicU64::new(0), + } + } + + /// "Lock" the driver thread. + /// + /// Since multiple threads can drive the reactor at once, there is no need to + /// actually lock anything. So this function only exists for symmetry. + pub(crate) async fn lock(&self) {} + + /// Reap zombie processes forever. + pub(crate) async fn reap(&'static self, _: ()) -> ! { + loop { + // Fetch the next task. + let task = match self.recv.recv().await { + Ok(task) => task, + Err(_) => panic!("sender should never be closed"), + }; + + // Poll the task. + task.run(); + } + } + + /// Register a child into this reaper. + pub(crate) fn register(&'static self, child: std::process::Child) -> io::Result { + Ok(ChildGuard { + inner: Some(WaitableChild::new(child)?), + }) + } + + /// Wait for a child to complete. + pub(crate) async fn status( + &'static self, + child: &Mutex, + ) -> io::Result { + future::poll_fn(|cx| { + // Lock the child and poll it once. + child + .lock() + .unwrap() + .inner + .inner + .as_mut() + .unwrap() + .poll_wait(cx) + }) + .await + } + + /// Do we have any registered zombie processes? + pub(crate) fn has_zombies(&'static self) -> bool { + self.zombies.load(Ordering::SeqCst) > 0 + } +} + +/// The wrapper around the child. +pub(crate) struct ChildGuard { + inner: Option, +} + +impl ChildGuard { + /// Get a mutable reference to the inner child. + pub(crate) fn get_mut(&mut self) -> &mut std::process::Child { + self.inner.as_mut().unwrap().get_mut() + } + + /// Begin the reaping process for this child. + pub(crate) fn reap(&mut self, reaper: &'static Reaper) { + struct CallOnDrop(F); + + impl Drop for CallOnDrop { + fn drop(&mut self) { + (self.0)(); + } + } + + // Create a future for polling this child. + let future = { + let mut inner = self.inner.take().unwrap(); + async move { + // Increment the zombie count. + reaper.zombies.fetch_add(1, Ordering::Relaxed); + + // Decrement the zombie count once we are done. + let _guard = CallOnDrop(|| { + reaper.zombies.fetch_sub(1, Ordering::SeqCst); + }); + + // Wait on this child forever. + let result = future::poll_fn(|cx| inner.poll_wait(cx)).await; + if let Err(e) = result { + tracing::error!("error while polling zombie process: {}", e); + } + } + }; + + // Create a future for scheduling this future. + let schedule = move |runnable| { + reaper.sender.try_send(runnable).ok(); + }; + + // Spawn the task and run it forever. + let (runnable, task) = async_task::spawn(future, schedule); + task.detach(); + runnable.schedule(); + } +} + +cfg_if::cfg_if! { + if #[cfg(any( + target_os = "linux", + target_os = "android" + ))] { + use async_io::Async; + use rustix::process; + use std::os::unix::io::OwnedFd; + + /// Waitable version of `std::process::Child` + struct WaitableChild { + child: std::process::Child, + handle: Async, + } + + impl WaitableChild { + fn new(child: std::process::Child) -> io::Result { + let pidfd = process::pidfd_open( + process::Pid::from_child(&child), + process::PidfdFlags::empty() + )?; + + Ok(Self { + child, + handle: Async::new(pidfd)? + }) + } + + fn get_mut(&mut self) -> &mut std::process::Child { + &mut self.child + } + + fn poll_wait(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(status) = self.child.try_wait()? { + return Poll::Ready(Ok(status)); + } + + // Wait for us to become readable. + futures_lite::ready!(self.handle.poll_readable(cx))?; + } + } + } + } +}