Skip to content

Commit

Permalink
feat(atomic_mutex): Add synchronous try_lock method
Browse files Browse the repository at this point in the history
Add a method on `AtomicMutex` that can return an error if the lock is
already held.

We only want *one* instance to `triton_vm::prove` to run at a time, but
for some cases, it's OK to not produce the proof right now, in case
another task is running the prover. So we plan to add a field to
`GlobalStateLock` which is a Mutex that indicates if the prover is
already running or not. For this, we want both the option of waiting
until the prover is free and of checking if it's free and otherwise
aborting.
  • Loading branch information
Sword-Smith committed Oct 22, 2024
1 parent 7fc160b commit bd66d85
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
76 changes: 76 additions & 0 deletions src/locks/tokio/atomic_mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,18 @@ impl<T> AtomicMutex<T> {
AtomicMutexGuard::new(guard, &self.lock_callback_info, LockAcquisition::Read)
}

/// Attempt to return a read lock and return an `AtomicMutextGuard`. Returns
/// an error if the lock is already held, otherwise returns Ok(lock).
pub fn try_lock_guard(&self) -> Result<AtomicMutexGuard<T>, tokio::sync::TryLockError> {
self.try_acquire_try_acquire();
let guard = self.inner.try_lock()?;
Ok(AtomicMutexGuard::new(
guard,
&self.lock_callback_info,
LockAcquisition::TryAcquire,
))
}

/// Acquire write lock and return an `AtomicMutexGuard`
///
/// # Examples
Expand Down Expand Up @@ -350,6 +362,15 @@ impl<T> AtomicMutex<T> {
f(&mut guard).await
}

fn try_acquire_try_acquire(&self) {
if let Some(cb) = self.lock_callback_info.lock_callback_fn {
cb(LockEvent::TryAcquire {
info: self.lock_callback_info.lock_info_owned.as_lock_info(),
acquisition: LockAcquisition::TryAcquire,
});
}
}

fn try_acquire_read_cb(&self) {
if let Some(cb) = self.lock_callback_info.lock_callback_fn {
cb(LockEvent::TryAcquire {
Expand Down Expand Up @@ -449,6 +470,7 @@ impl<T> Atomic<T> for AtomicMutex<T> {
#[cfg(test)]
mod tests {
use futures::future::FutureExt;
use tracing_test::traced_test;

use super::*;

Expand All @@ -463,6 +485,60 @@ mod tests {
atomic_name.lock_mut(|n| new_name = n.to_string()).await;
}

#[traced_test]
#[tokio::test]
async fn try_acquire_no_log() {
let unit = ();
let atomic_unit = AtomicMutex::<()>::from(unit);
assert!(
atomic_unit.try_lock_guard().is_ok(),
"Must succeed when no lock is held"
);

let _held_lock = atomic_unit.try_lock_guard().unwrap();
assert!(
atomic_unit.try_lock_guard().is_err(),
"Must fail when lock is held"
);
}

#[traced_test]
#[tokio::test]
async fn try_acquire_with_log() {
pub fn log_lock_event(lock_event: LockEvent) {
let (event, info, acquisition) = match lock_event {
LockEvent::TryAcquire { info, acquisition } => ("TryAcquire", info, acquisition),
LockEvent::Acquire { info, acquisition } => ("Acquire", info, acquisition),
LockEvent::Release { info, acquisition } => ("Release", info, acquisition),
};

println!(
"{} lock `{}` of type `{}` for `{}` by\n\t|-- thread {}, `{:?}`",
event,
info.name().unwrap_or("?"),
info.lock_type(),
acquisition,
std::thread::current().name().unwrap_or("?"),
std::thread::current().id(),
);
}

const LOG_LOCK_EVENT_CB: LockCallbackFn = log_lock_event;
let name = "Jim".to_string();
let atomic_name =
AtomicMutex::<String>::from((name, Some("name"), Some(LOG_LOCK_EVENT_CB)));
assert!(
atomic_name.try_lock_guard().is_ok(),
"Must succeed when no lock is held"
);

let _held_lock = atomic_name.lock_guard().await;
assert!(
atomic_name.try_lock_guard().is_err(),
"Must fail when lock is held"
);
}

#[tokio::test]
async fn lock_async() {
struct Car {
Expand Down
2 changes: 2 additions & 0 deletions src/locks/tokio/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl std::fmt::Display for LockType {
pub enum LockAcquisition {
Read,
Write,
TryAcquire,
}

impl std::fmt::Display for LockAcquisition {
Expand All @@ -28,6 +29,7 @@ impl std::fmt::Display for LockAcquisition {
match self {
Self::Read => write!(f, "Read"),
Self::Write => write!(f, "Write"),
Self::TryAcquire => write!(f, "TryAcquire"),
}
}
}
Expand Down

0 comments on commit bd66d85

Please sign in to comment.