From f4a88c190157d1377ca6f701d5ea8cd488a2cfba Mon Sep 17 00:00:00 2001 From: Jorge Prendes Date: Fri, 29 Sep 2023 12:54:30 +0100 Subject: [PATCH] Remove complexity from wait API Signed-off-by: Jorge Prendes --- .../src/sandbox/instance.rs | 91 +--- .../containerd-shim-wasm/src/sandbox/mod.rs | 1 + .../containerd-shim-wasm/src/sandbox/shim.rs | 411 ++++++------------ .../containerd-shim-wasm/src/sandbox/sync.rs | 224 ++++++++++ .../src/sys/unix/container/instance.rs | 38 +- .../src/sys/windows/container/instance.rs | 15 +- crates/containerd-shim-wasm/src/testing.rs | 15 +- 7 files changed, 408 insertions(+), 387 deletions(-) create mode 100644 crates/containerd-shim-wasm/src/sandbox/sync.rs diff --git a/crates/containerd-shim-wasm/src/sandbox/instance.rs b/crates/containerd-shim-wasm/src/sandbox/instance.rs index d702ceb6a..b5c853768 100644 --- a/crates/containerd-shim-wasm/src/sandbox/instance.rs +++ b/crates/containerd-shim-wasm/src/sandbox/instance.rs @@ -1,16 +1,13 @@ //! Abstractions for running/managing a wasm/wasi instance. -use std::sync::mpsc::Sender; -use std::sync::{Arc, Condvar, Mutex}; -use std::thread; +use std::time::Duration; use chrono::{DateTime, Utc}; use super::error::Error; +use super::sync::WaitableCell; use crate::sys::signals::*; -pub type ExitCode = Arc<(Mutex)>>, Condvar)>; - /// Generic options builder for creating a wasm instance. /// This is passed to the `Instance::new` method. #[derive(Clone)] @@ -107,7 +104,9 @@ impl InstanceConfig { /// Represents a WASI module(s). /// Instance is a trait that gets implemented by consumers of this library. -pub trait Instance { +/// This trait requires that any type implementing it is `'static`, similar to `std::any::Any`. +/// This means that the type cannot contain a non-`'static` reference. +pub trait Instance: 'static { /// The WASI engine type type Engine: Send + Sync + Clone; @@ -128,61 +127,30 @@ pub trait Instance { /// This is called after the instance has exited. fn delete(&self) -> Result<(), Error>; - /// Set up waiting for the instance to exit - /// The Wait struct is used to send the exit code and time back to the - /// caller. The recipient is expected to call function - /// set_up_exit_code_wait() implemented by Wait to set up exit code - /// processing. Note that the "wait" function doesn't block, but - /// it sets up the waiting channel. - fn wait(&self, waiter: &Wait) -> Result<(), Error>; -} - -/// This is used for waiting for the container process to exit and deliver the exit code to the caller. -/// Since the shim needs to provide the caller the process exit code, this struct wraps the required -/// thread setup to make the shims simpler. -pub struct Wait { - tx: Sender<(u32, DateTime)>, -} - -impl Wait { - /// Create a new Wait struct with the provided sending endpoint of a channel. - pub fn new(sender: Sender<(u32, DateTime)>) -> Self { - Wait { tx: sender } + /// Waits for the instance to finish and retunrs its exit code + /// This is a blocking call. + fn wait(&self) -> (u32, DateTime) { + self.wait_timeout(None).unwrap() } - /// This is called by the shim to create the thread to wait for the exit - /// code. When the child process exits, the shim will use the ExitCode - /// to signal the exit status to the caller. This function returns so that - /// the wait() function in the shim implementation API would not block. - pub fn set_up_exit_code_wait(&self, exit_code: ExitCode) -> Result<(), Error> { - let sender = self.tx.clone(); - let code = Arc::clone(&exit_code); - thread::spawn(move || { - let (lock, cvar) = &*code; - let mut exit = lock.lock().unwrap(); - while (*exit).is_none() { - exit = cvar.wait(exit).unwrap(); - } - let ec = (*exit).unwrap(); - sender.send(ec).unwrap(); - }); - - Ok(()) - } + /// Waits for the instance to finish and retunrs its exit code + /// Returns None if the timeout is reached before the instance has finished. + /// This is a blocking call. + fn wait_timeout(&self, t: impl Into>) -> Option<(u32, DateTime)>; } /// This is used for the "pause" container with cri and is a no-op instance implementation. pub struct Nop { /// Since we are faking the container, we need to keep track of the "exit" code/time /// We'll just mark it as exited when kill is called. - exit_code: ExitCode, + exit_code: WaitableCell<(u32, DateTime)>, } impl Instance for Nop { type Engine = (); fn new(_id: String, _cfg: Option<&InstanceConfig>) -> Result { Ok(Nop { - exit_code: Arc::new((Mutex::new(None), Condvar::new())), + exit_code: WaitableCell::new(), }) } fn start(&self) -> Result { @@ -197,25 +165,20 @@ impl Instance for Nop { } }; - let exit_code = self.exit_code.clone(); - let (lock, cvar) = &*exit_code; - let mut lock = lock.lock().unwrap(); - *lock = Some((code, Utc::now())); - cvar.notify_all(); + let _ = self.exit_code.set((code, Utc::now())); Ok(()) } fn delete(&self) -> Result<(), Error> { Ok(()) } - fn wait(&self, waiter: &Wait) -> Result<(), Error> { - waiter.set_up_exit_code_wait(self.exit_code.clone()) + fn wait_timeout(&self, t: impl Into>) -> Option<(u32, DateTime)> { + self.exit_code.wait_timeout(t).copied() } } #[cfg(test)] mod noptests { - use std::sync::mpsc::channel; use std::time::Duration; use super::*; @@ -223,12 +186,10 @@ mod noptests { #[test] fn test_nop_kill_sigkill() -> Result<(), Error> { let nop = Nop::new("".to_string(), None)?; - let (tx, rx) = channel(); - let waiter = Wait::new(tx); - nop.wait(&waiter).unwrap(); nop.kill(SIGKILL as u32)?; - let ec = rx.recv_timeout(Duration::from_secs(3)).unwrap(); + + let ec = nop.wait_timeout(Duration::from_secs(3)).unwrap(); assert_eq!(ec.0, 137); Ok(()) } @@ -236,12 +197,10 @@ mod noptests { #[test] fn test_nop_kill_sigterm() -> Result<(), Error> { let nop = Nop::new("".to_string(), None)?; - let (tx, rx) = channel(); - let waiter = Wait::new(tx); - nop.wait(&waiter).unwrap(); nop.kill(SIGTERM as u32)?; - let ec = rx.recv_timeout(Duration::from_secs(3)).unwrap(); + + let ec = nop.wait_timeout(Duration::from_secs(3)).unwrap(); assert_eq!(ec.0, 0); Ok(()) } @@ -249,12 +208,10 @@ mod noptests { #[test] fn test_nop_kill_sigint() -> Result<(), Error> { let nop = Nop::new("".to_string(), None)?; - let (tx, rx) = channel(); - let waiter = Wait::new(tx); - nop.wait(&waiter).unwrap(); nop.kill(SIGINT as u32)?; - let ec = rx.recv_timeout(Duration::from_secs(3)).unwrap(); + + let ec = nop.wait_timeout(Duration::from_secs(3)).unwrap(); assert_eq!(ec.0, 0); Ok(()) } diff --git a/crates/containerd-shim-wasm/src/sandbox/mod.rs b/crates/containerd-shim-wasm/src/sandbox/mod.rs index e7bd2d1e0..5099437ef 100644 --- a/crates/containerd-shim-wasm/src/sandbox/mod.rs +++ b/crates/containerd-shim-wasm/src/sandbox/mod.rs @@ -9,6 +9,7 @@ pub mod instance_utils; pub mod manager; pub mod shim; pub mod stdio; +pub mod sync; pub use error::{Error, Result}; pub use instance::{Instance, InstanceConfig}; diff --git a/crates/containerd-shim-wasm/src/sandbox/shim.rs b/crates/containerd-shim-wasm/src/sandbox/shim.rs index be92ebb44..1637ed845 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim.rs @@ -10,8 +10,9 @@ use std::ops::Not; use std::os::unix::fs::DirBuilderExt; use std::path::Path; use std::sync::mpsc::{channel, Receiver, Sender}; -use std::sync::{Arc, Condvar, Mutex, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use std::thread; +use std::time::Duration; use anyhow::Context as AnyhowContext; use chrono::{DateTime, Utc}; @@ -35,286 +36,175 @@ use shim::api::{StatsRequest, StatsResponse}; use shim::Flags; use ttrpc::context::Context; -use super::instance::{Instance, InstanceConfig, Nop, Wait}; +use super::instance::{Instance, InstanceConfig, Nop}; use super::{oci, Error, SandboxService}; use crate::sys::metrics::get_metrics; -type InstanceDataStatus = (Mutex)>>, Condvar); - -struct InstanceData { - instance: Option, - base: Option, - cfg: InstanceConfig, - pid: RwLock>, - status: Arc, - state: Arc>, +enum InstanceOption { + Instance(I), + Nop(Nop), } -type Result = std::result::Result; +impl Instance for InstanceOption { + type Engine = (); -impl InstanceData { - fn start(&self) -> Result { - let mut s = self.state.write().unwrap(); - let new_state = s.start()?; - *s = new_state.into(); - if self.instance.is_some() { - return self - .instance - .as_ref() - .unwrap() - .start() - .map(|pid| { - s.started().map(|new_state| { - *s = new_state.into(); - new_state - })?; - Ok(pid) - }) - .map_err(|err| { - let _ = s - .stop() - .map(|new_state| { - *s = new_state.into(); - new_state - }) - .map_err(|e| warn!("could not set exited state after failed start: {}", e)); - err - })?; - }; + fn new(_id: String, _cfg: Option<&InstanceConfig>) -> Result { + // this is never called + unimplemented!(); + } - return self.base.as_ref().unwrap().start().and_then(|pid| { - s.started() - .map(|new_state| { - *s = new_state.into(); - new_state - }) - .map_err(|err| { - let new_state = s - .stop() - .map_err(|e| warn!("could not set exited state after failed start: {}", e)); - if let Ok(ns) = new_state { - *s = ns.into(); - } - err - })?; - Ok(pid) - }); + fn start(&self) -> Result { + match self { + Self::Instance(i) => i.start(), + Self::Nop(i) => i.start(), + } } fn kill(&self, signal: u32) -> Result<()> { - let s = self.state.read().unwrap(); - s.kill()?; - if self.instance.is_some() { - return self.instance.as_ref().unwrap().kill(signal); + match self { + Self::Instance(i) => i.kill(signal), + Self::Nop(i) => i.kill(signal), } - self.base.as_ref().unwrap().kill(signal) } fn delete(&self) -> Result<()> { - let mut s = self.state.write().unwrap(); - let new_state = s.delete()?; - *s = new_state.into(); - if self.instance.is_some() { - return self.instance.as_ref().unwrap().delete().map_err(|err| { - let _ = s - .stop() - .map(|new_state| { - *s = new_state.into(); - new_state - }) - .map_err(|e| warn!("could not set reset state after failed delete: {}", e)); - err - }); + match self { + Self::Instance(i) => i.delete(), + Self::Nop(i) => i.delete(), } - self.base.as_ref().unwrap().delete().map_err(|err| { - let _ = s - .stop() - .map(|new_state| { - *s = new_state.into(); - new_state - }) - .map_err(|e| { - warn!("could not set exited state after failed delete: {}", e); - }); - err - }) } - fn wait(&self, waiter: &Wait) -> Result<()> { - if self.instance.is_some() { - return self.instance.as_ref().unwrap().wait(waiter); + fn wait_timeout(&self, t: impl Into>) -> Option<(u32, DateTime)> { + match self { + Self::Instance(i) => i.wait_timeout(t), + Self::Nop(i) => i.wait_timeout(t), } - self.base.as_ref().unwrap().wait(waiter) } } -type EventSender = Sender<(String, Box)>; +struct InstanceData { + instance: InstanceOption, + cfg: InstanceConfig, + pid: RwLock>, + state: Arc>, +} -#[derive(Debug, Clone, Copy)] -struct Created {} -#[derive(Debug, Clone, Copy)] -struct Starting {} -#[derive(Debug, Clone, Copy)] -struct Started {} -#[derive(Debug, Clone, Copy)] -struct Exited {} -#[derive(Debug, Clone, Copy)] -struct Deleting {} +type Result = std::result::Result; -#[derive(Debug, Clone, Copy)] -struct TaskState { - s: std::marker::PhantomData, -} +impl InstanceData { + fn start(&self) -> Result { + let mut s = self.state.write().unwrap(); + s.start()?; -#[derive(Debug, Clone, Copy)] -enum TaskStateWrapper { - Created(TaskState), - Starting(TaskState), - Started(TaskState), - Exited(TaskState), - Deleting(TaskState), -} + let res = self.instance.start(); -impl TaskStateWrapper { - fn start(self) -> Result> { - match self { - TaskStateWrapper::Created(s) => Ok(s.into()), - s => Err(Error::FailedPrecondition(format!( - "invalid state transition: ${:?} => Starting", - s, - ))), - } - } + // These state transitions are always `Ok(())` because + // we hold the lock since `s.start()` + let _ = match res { + Ok(_) => s.started(), + Err(_) => s.stop(), + }; - fn kill(self) -> Result<()> { - match self { - TaskStateWrapper::Started(_) => Ok(()), - s => Err(Error::FailedPrecondition(format!( - "cannot kill non-running container, current state: {:?}", - s, - ))), - } + res } - fn delete(self) -> Result> { - match self { - TaskStateWrapper::Created(s) => Ok(s.into()), - TaskStateWrapper::Exited(s) => Ok(s.into()), - s => Err(Error::FailedPrecondition(format!( - "cannot delete non-exted container, current state: {:?}", - s, - ))), - } - } + fn kill(&self, signal: u32) -> Result<()> { + let mut s = self.state.write().unwrap(); + s.kill()?; - fn started(self) -> Result> { - match self { - TaskStateWrapper::Starting(s) => Ok(s.into()), - s => Err(Error::FailedPrecondition(format!( - "invalid state transition: ${:?} => Started", - s, - ))), - } + self.instance.kill(signal) } - fn stop(self) -> Result> { - match self { - TaskStateWrapper::Started(s) => Ok(s.into()), - TaskStateWrapper::Starting(s) => Ok(s.into()), - TaskStateWrapper::Deleting(s) => Ok(s.into()), - s => Err(Error::FailedPrecondition(format!( - "invalid state transition: ${:?} => Exited", - s, - ))), - } - } -} + fn delete(&self) -> Result<()> { + let mut s = self.state.write().unwrap(); + s.delete()?; -impl From> for TaskStateWrapper { - fn from(s: TaskState) -> Self { - TaskStateWrapper::Created(s) - } -} + let res = self.instance.delete(); -impl From> for TaskStateWrapper { - fn from(s: TaskState) -> Self { - TaskStateWrapper::Starting(s) - } -} + if res.is_err() { + // Always `Ok(())` because we hold the lock since `s.delete()` + let _ = s.stop(); + } -impl From> for TaskStateWrapper { - fn from(s: TaskState) -> Self { - TaskStateWrapper::Started(s) + res } -} -impl From> for TaskStateWrapper { - fn from(s: TaskState) -> Self { - TaskStateWrapper::Exited(s) + fn wait(&self) -> (u32, DateTime) { + let res = self.instance.wait(); + let mut s = self.state.write().unwrap(); + *s = TaskState::Exited; + res } -} -impl From> for TaskStateWrapper { - fn from(s: TaskState) -> Self { - TaskStateWrapper::Deleting(s) + fn wait_timeout(&self, t: impl Into>) -> Option<(u32, DateTime)> { + let res = self.instance.wait_timeout(t); + if res.is_some() { + let mut s = self.state.write().unwrap(); + *s = TaskState::Exited; + } + res } } -impl From> for TaskState { - fn from(_val: TaskState) -> Self { - TaskState { - s: std::marker::PhantomData, - } - } +type EventSender = Sender<(String, Box)>; + +#[derive(Debug, Clone, Copy)] +enum TaskState { + Created, + Starting, + Started, + Exited, + Deleting, } -impl From> for TaskState { - fn from(_val: TaskState) -> TaskState { - TaskState { - s: std::marker::PhantomData, - } +impl TaskState { + pub fn start(&mut self) -> Result<()> { + *self = match self { + Self::Created => Ok(Self::Starting), + _ => state_transition_error(*self, Self::Starting), + }?; + Ok(()) } -} -impl From> for TaskState { - fn from(_val: TaskState) -> TaskState { - TaskState { - s: std::marker::PhantomData, - } + pub fn kill(&mut self) -> Result<()> { + *self = match self { + Self::Started => Ok(Self::Started), + _ => state_transition_error(*self, "Killing"), + }?; + Ok(()) } -} -impl From> for TaskState { - fn from(_val: TaskState) -> TaskState { - TaskState { - s: std::marker::PhantomData, - } + pub fn delete(&mut self) -> Result<()> { + *self = match self { + Self::Created | Self::Exited => Ok(Self::Deleting), + _ => state_transition_error(*self, Self::Deleting), + }?; + Ok(()) } -} -impl From> for TaskState { - fn from(_val: TaskState) -> TaskState { - TaskState { - s: std::marker::PhantomData, - } + pub fn started(&mut self) -> Result<()> { + *self = match self { + Self::Starting => Ok(Self::Started), + _ => state_transition_error(*self, Self::Started), + }?; + Ok(()) } -} -impl From> for TaskState { - fn from(_val: TaskState) -> TaskState { - TaskState { - s: std::marker::PhantomData, - } + pub fn stop(&mut self) -> Result<()> { + *self = match self { + Self::Started | Self::Starting => Ok(Self::Exited), + // This is for potential failure cases where we want delete to be able to be retried. + Self::Deleting => Ok(Self::Exited), + _ => state_transition_error(*self, Self::Exited), + }?; + Ok(()) } } -// This is for potential failure cases where we want delete to be able to be retried. -impl From> for TaskState { - fn from(_val: TaskState) -> TaskState { - TaskState { - s: std::marker::PhantomData, - } - } +fn state_transition_error(from: impl std::fmt::Debug, to: impl std::fmt::Debug) -> Result { + Err(Error::FailedPrecondition(format!( + "invalid state transition: {from:?} => {to:?}" + ))) } type LocalInstances = Arc>>>>; @@ -458,8 +348,7 @@ mod localtests { // A little janky since this is internal data, but check that this is seen as a sandbox container let i = local.get_instance("testbase")?; - assert!(i.base.is_some()); - assert!(i.instance.is_none()); + assert!(matches!(i.instance, InstanceOption::Nop(_))); local.task_start(api::StartRequest { id: "testbase".to_string(), @@ -502,8 +391,7 @@ mod localtests { // again, this is janky since it is internal data, but check that this is seen as a "real" container. // this is the inverse of the above test case. let i = local.get_instance("testinstance")?; - assert!(i.base.is_none()); - assert!(i.instance.is_some()); + assert!(matches!(i.instance, InstanceOption::Instance(_))); local.task_start(api::StartRequest { id: "testinstance".to_string(), @@ -745,16 +633,10 @@ impl Local { self.containerd_address.clone(), ); InstanceData { - instance: None, - base: Some(Nop::new(id, None).unwrap()), + instance: InstanceOption::Nop(Nop::new(id, None).unwrap()), cfg, pid: RwLock::new(None), - status: Arc::default(), - state: Arc::new(RwLock::new(TaskStateWrapper::Created( - TaskState:: { - s: std::marker::PhantomData, - }, - ))), + state: Arc::new(RwLock::new(TaskState::Created)), } } @@ -954,16 +836,10 @@ impl Local { self.instances.write().unwrap().insert( req.id().to_string(), Arc::new(InstanceData { - instance: Some(T::new(req.id().to_string(), Some(&builder))?), - base: None, + instance: InstanceOption::Instance(T::new(req.id().to_string(), Some(&builder))?), cfg: builder, pid: RwLock::new(None), - status: Arc::default(), - state: Arc::new(RwLock::new(TaskStateWrapper::Created( - TaskState:: { - s: std::marker::PhantomData, - }, - ))), + state: Arc::new(RwLock::new(TaskState::Created)), }), ); @@ -1010,37 +886,18 @@ impl Local { ..Default::default() }); - let (tx, rx) = channel::<(u32, DateTime)>(); - let waiter = Wait::new(tx); - i.wait(&waiter)?; - - let status = i.status.clone(); - let sender = self.events.clone(); - let id = req.id().to_string(); - let state = i.state.clone(); - let _ = thread::Builder::new() + thread::Builder::new() .name(format!("{}-wait", req.id())) .spawn(move || { - let ec = rx.recv().unwrap(); - - let mut s = state.write().unwrap(); - *s = TaskStateWrapper::Exited(TaskState:: { - s: std::marker::PhantomData, - }); - - let (lock, cvar) = &*status; - let mut status = lock.lock().unwrap(); - *status = Some(ec); - cvar.notify_all(); - drop(status); + let exit_code = i.wait(); let timestamp = new_timestamp().unwrap(); let event = TaskExit { container_id: id.clone(), - exit_status: ec.0, + exit_status: exit_code.0, exited_at: MessageField::some(timestamp), pid, id, @@ -1097,7 +954,7 @@ impl Local { ..Default::default() }; - if let Some(ec) = *i.status.0.lock().unwrap() { + if let Some(ec) = i.wait_timeout(Duration::ZERO) { event.exit_status = ec.0; resp.exit_status = ec.0; @@ -1123,17 +980,7 @@ impl Local { let i = self.get_instance(req.id())?; - let (lock, cvar) = &*i.status.clone(); - let mut status = lock.lock().unwrap(); - while (*status).is_none() { - status = cvar.wait(status).unwrap(); - } - - let (tx, rx) = channel::<(u32, DateTime)>(); - let waiter = Wait::new(tx); - i.wait(&waiter)?; - - let code = rx.recv().unwrap(); + let code = i.wait(); debug!("wait done: {:?}", req); let mut timestamp = Timestamp::new(); @@ -1172,7 +1019,7 @@ impl Local { state.set_pid(pid.unwrap()); - if let Some(c) = *i.status.0.lock().unwrap() { + if let Some(c) = i.wait_timeout(Duration::ZERO) { state.set_status(Status::STOPPED); let ec = c; state.exit_status = ec.0; diff --git a/crates/containerd-shim-wasm/src/sandbox/sync.rs b/crates/containerd-shim-wasm/src/sandbox/sync.rs new file mode 100644 index 000000000..9b8fe4823 --- /dev/null +++ b/crates/containerd-shim-wasm/src/sandbox/sync.rs @@ -0,0 +1,224 @@ +use std::cell::OnceCell; +use std::sync::{Arc, Condvar, Mutex}; +use std::time::Duration; + +/// A cell where we can wait (with timeout) for +/// a value to be set +pub struct WaitableCell { + inner: Arc>, +} + +struct WaitableCellImpl { + // Ideally we would just use a OnceLock, but it doesn't + // have the `wait` and `wait_timeout` methods, so we use + // a Condvar + Mutex pair instead. + // We can't guard the OnceCell **inside** the Mutex as + // that would produce ownership problems with returning + // `&T`. This is because the mutex doesn't know that we + // won't mutate the OnceCell once it's set. + mutex: Mutex<()>, + cvar: Condvar, + cell: OnceCell, +} + +// this is safe because access to cell guarded by the mutex +unsafe impl Send for WaitableCell {} +unsafe impl Sync for WaitableCell {} + +impl Default for WaitableCell { + fn default() -> Self { + Self { + inner: Arc::new(WaitableCellImpl { + mutex: Mutex::new(()), + cvar: Condvar::new(), + cell: OnceCell::new(), + }), + } + } +} + +impl Clone for WaitableCell { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { inner } + } +} + +impl WaitableCell { + /// Creates an empty WaitableCell. + pub fn new() -> Self { + Self::default() + } + + /// Sets a value to the WaitableCell. + /// This method has no effect if the WaitableCell already has a value. + pub fn set(&self, val: impl Into) -> Result<(), T> { + let val = val.into(); + let _guard = self.inner.mutex.lock().unwrap(); + let res = self.inner.cell.set(val); + self.inner.cvar.notify_all(); + res + } + + /// If the `WaitableCell` is empty when this guard is dropped, the cell will be set to result of `f`. + /// ``` + /// let cell = WaitableCell::::new(); + /// { + /// let _guard = cell.set_guard_with(|| 42); + /// } + /// assert_eq!(&42, cell.wait()); + /// ``` + /// + /// The operation is a no-op if the cell conbtains a value before the guard is dropped. + /// ``` + /// let cell = WaitableCell::::new(); + /// { + /// let _guard = cell.set_guard_with(|| 42); + /// let _ = cell.set(24); + /// } + /// assert_eq!(&24, cell.wait()); + /// ``` + /// + /// The function `f` will always be called, regardsless of whether the `WaitableCell` has a value or not. + /// The `WaitableCell` is going to be set even in the case of an unwind. In this case, ff the function `f` + /// panics it will cause an abort, so it's recomended to avoid any panics in `f`. + pub fn set_guard_with>(&self, f: impl FnOnce() -> R) -> impl Drop { + let cell = (*self).clone(); + WaitableCellSetGuard { f: Some(f), cell } + } + + /// Wait for the WaitableCell to be set a value. + pub fn wait(&self) -> &T { + let value = self.wait_timeout(None); + // safe because we waited with timeout `None` + unsafe { value.unwrap_unchecked() } + } + + /// Wait for the WaitableCell to be set a value, with timeout. + /// Retuns None if the timeout is reached with no value. + pub fn wait_timeout(&self, timeout: impl Into>) -> Option<&T> { + let timeout = timeout.into(); + let cvar = &self.inner.cvar; + let guard = self.inner.mutex.lock().unwrap(); + let _guard = match timeout { + None => cvar + .wait_while(guard, |_| self.inner.cell.get().is_none()) + .unwrap(), + Some(Duration::ZERO) => guard, + Some(dur) => cvar + .wait_timeout_while(guard, dur, |_| self.inner.cell.get().is_none()) + .map(|(guard, _)| guard) + .unwrap(), + }; + self.inner.cell.get() + } +} + +// This is the type returned by `WaitableCell::set_guard_with`. +// The public API has no visibility over this type, other than it implements `Drop` +// If the `WaitableCell` `cell`` is empty when this guard is dropped, it will set it's value with the result of `f`. +struct WaitableCellSetGuard, F: FnOnce() -> R> { + f: Option, + cell: WaitableCell, +} + +impl, F: FnOnce() -> R> Drop for WaitableCellSetGuard { + fn drop(&mut self) { + let _ = self.cell.set(self.f.take().unwrap()()); + } +} + +#[cfg(test)] +mod test { + use std::thread::{sleep, spawn}; + use std::time::Duration; + + use super::WaitableCell; + + #[test] + fn basic() { + let cell = WaitableCell::::new(); + cell.set(42).unwrap(); + assert_eq!(&42, cell.wait()); + } + + #[test] + fn basic_timeout_zero() { + let cell = WaitableCell::::new(); + cell.set(42).unwrap(); + assert_eq!(Some(&42), cell.wait_timeout(Duration::ZERO)); + } + + #[test] + fn basic_timeout_1ms() { + let cell = WaitableCell::::new(); + cell.set(42).unwrap(); + assert_eq!(Some(&42), cell.wait_timeout(Duration::from_secs(1))); + } + + #[test] + fn basic_timeout_none() { + let cell = WaitableCell::::new(); + cell.set(42).unwrap(); + assert_eq!(Some(&42), cell.wait_timeout(None)); + } + + #[test] + fn unset_timeout_zero() { + let cell = WaitableCell::::new(); + assert_eq!(None, cell.wait_timeout(Duration::ZERO)); + } + + #[test] + fn unset_timeout_1ms() { + let cell = WaitableCell::::new(); + assert_eq!(None, cell.wait_timeout(Duration::from_millis(1))); + } + + #[test] + fn clone() { + let cell = WaitableCell::::new(); + let cloned = cell.clone(); + let _ = cloned.set(42); + assert_eq!(&42, cell.wait()); + } + + #[test] + fn basic_threaded() { + let cell = WaitableCell::::new(); + { + let cell = cell.clone(); + spawn(move || { + sleep(Duration::from_millis(1)); + let _ = cell.set(42); + }); + } + assert_eq!(&42, cell.wait()); + } + + #[test] + fn basic_double_set() { + let cell = WaitableCell::::new(); + assert_eq!(Ok(()), cell.set(42)); + assert_eq!(Err(24), cell.set(24)); + } + + #[test] + fn guard() { + let cell = WaitableCell::::new(); + { + let _guard = cell.set_guard_with(|| 42); + } + assert_eq!(&42, cell.wait()); + } + + #[test] + fn guard_no_op() { + let cell = WaitableCell::::new(); + { + let _guard = cell.set_guard_with(|| 42); + let _ = cell.set(24); + } + assert_eq!(&24, cell.wait()); + } +} diff --git a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs index a7b2b1c49..198cf1d4a 100644 --- a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs +++ b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs @@ -1,9 +1,10 @@ use std::marker::PhantomData; use std::path::{Path, PathBuf}; use std::thread; +use std::time::Duration; use anyhow::Context; -use chrono::Utc; +use chrono::{DateTime, Utc}; use libcontainer::container::builder::ContainerBuilder; use libcontainer::container::Container; use libcontainer::signal::Signal; @@ -13,15 +14,15 @@ use nix::sys::wait::{waitid, Id as WaitID, WaitPidFlag, WaitStatus}; use nix::unistd::Pid; use crate::container::Engine; -use crate::sandbox::instance::{ExitCode, Wait}; use crate::sandbox::instance_utils::{determine_rootdir, get_instance_root, instance_exists}; +use crate::sandbox::sync::WaitableCell; use crate::sandbox::{Error as SandboxError, Instance as SandboxInstance, InstanceConfig, Stdio}; use crate::sys::container::executor::Executor; static DEFAULT_CONTAINER_ROOT_DIR: &str = "/run/containerd"; pub struct Instance { - exit_code: ExitCode, + exit_code: WaitableCell<(u32, DateTime)>, rootdir: PathBuf, id: String, _phantom: PhantomData, @@ -48,7 +49,7 @@ impl SandboxInstance for Instance { Ok(Self { id, - exit_code: ExitCode::default(), + exit_code: WaitableCell::new(), rootdir, _phantom: Default::default(), }) @@ -59,6 +60,8 @@ impl SandboxInstance for Instance { /// Nothing internally should be using this ID, but it is returned to containerd where a user may want to use it. fn start(&self) -> Result { log::info!("starting instance: {}", self.id); + // make sure we have an exit code by the time we finish (even if there's a panic) + let guard = self.exit_code.set_guard_with(|| (137, Utc::now())); let container_root = get_instance_root(&self.rootdir, &self.id)?; let mut container = Container::load(container_root)?; @@ -68,7 +71,8 @@ impl SandboxInstance for Instance { let exit_code = self.exit_code.clone(); thread::spawn(move || { - let (lock, cvar) = &*exit_code; + // move the exit code guard into this thread + let _guard = guard; let status = match waitid(WaitID::Pid(Pid::from_raw(pid)), WaitPidFlag::WEXITED) { Ok(WaitStatus::Exited(_, status)) => status, @@ -78,12 +82,12 @@ impl SandboxInstance for Instance { log::info!("no child process"); 0 } - Err(e) => panic!("waitpid failed: {e}"), + Err(e) => { + log::error!("waitpid failed: {e}"); + 137 + } } as u32; - let mut ec = lock.lock().unwrap(); - *ec = Some((status, Utc::now())); - drop(ec); - cvar.notify_all(); + let _ = exit_code.set((status, Utc::now())); }); Ok(pid as u32) @@ -128,14 +132,10 @@ impl SandboxInstance for Instance { Ok(()) } - /// Set up waiting for the instance to exit - /// The Wait struct is used to send the exit code and time back to the - /// caller. The recipient is expected to call function - /// set_up_exit_code_wait() implemented by Wait to set up exit code - /// processing. Note that the "wait" function doesn't block, but - /// it sets up the waiting channel. - fn wait(&self, waiter: &Wait) -> Result<(), SandboxError> { - log::info!("waiting for instance: {}", self.id); - waiter.set_up_exit_code_wait(self.exit_code.clone()) + /// Waits for the instance to finish and retunrs its exit code + /// Returns None if the timeout is reached before the instance has finished. + /// This is a blocking call. + fn wait_timeout(&self, t: impl Into>) -> Option<(u32, DateTime)> { + self.exit_code.wait_timeout(t).copied() } } diff --git a/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs b/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs index 340b91aee..3ae418807 100644 --- a/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs +++ b/crates/containerd-shim-wasm/src/sys/windows/container/instance.rs @@ -1,7 +1,9 @@ use std::marker::PhantomData; +use std::time::Duration; + +use chrono::{DateTime, Utc}; use crate::container::Engine; -use crate::sandbox::instance::Wait; use crate::sandbox::{Error as SandboxError, Instance as SandboxInstance, InstanceConfig}; pub struct Instance(PhantomData); @@ -31,13 +33,10 @@ impl SandboxInstance for Instance { todo!(); } - /// Set up waiting for the instance to exit - /// The Wait struct is used to send the exit code and time back to the - /// caller. The recipient is expected to call function - /// set_up_exit_code_wait() implemented by Wait to set up exit code - /// processing. Note that the "wait" function doesn't block, but - /// it sets up the waiting channel. - fn wait(&self, _waiter: &Wait) -> Result<(), SandboxError> { + /// Waits for the instance to finish and retunrs its exit code + /// Returns None if the timeout is reached before the instance has finished. + /// This is a blocking call. + fn wait_timeout(&self, _t: impl Into>) -> Option<(u32, DateTime)> { todo!(); } } diff --git a/crates/containerd-shim-wasm/src/testing.rs b/crates/containerd-shim-wasm/src/testing.rs index 62a6e77e9..a56e680b5 100644 --- a/crates/containerd-shim-wasm/src/testing.rs +++ b/crates/containerd-shim-wasm/src/testing.rs @@ -4,14 +4,12 @@ use std::collections::HashMap; use std::fs::{create_dir, read_to_string, write, File}; use std::marker::PhantomData; use std::ops::Add; -use std::sync::mpsc::channel; use std::time::Duration; use anyhow::{bail, Result}; pub use containerd_shim_wasm_test_modules as modules; use oci_spec::runtime::{ProcessBuilder, RootBuilder, SpecBuilder}; -use crate::sandbox::instance::Wait; use crate::sandbox::{Instance, InstanceConfig}; use crate::sys::signals::SIGKILL; @@ -165,16 +163,11 @@ where let dir = self.tempdir.path(); log::info!("waiting wasi test"); - - let (tx, rx) = channel(); - let waiter = Wait::new(tx); - self.instance.wait(&waiter).unwrap(); - - let (status, _) = match rx.recv_timeout(timeout) { - Ok(res) => res, - Err(e) => { + let (status, _) = match self.instance.wait_timeout(timeout) { + Some(res) => res, + None => { self.instance.kill(SIGKILL as u32)?; - bail!("error waiting for module to finish: {e}"); + bail!("timeout while waiting for module to finish"); } };