Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure sequential consistency in fallback implementation when requested #38

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 56 additions & 30 deletions src/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ use bytemuck::NoUninit;
struct SpinLock(AtomicUsize);

impl SpinLock {
fn lock(&self) {
fn lock(&self, order: Ordering) {
// If the corresponding atomic operation is `SeqCst`, acquire the lock
// with `SeqCst` ordering to ensure sequential consistency.
let success_order = match order {
Ordering::SeqCst => Ordering::SeqCst,
_ => Ordering::Acquire,
};
while self
.0
.compare_exchange_weak(0, 1, Ordering::Acquire, Ordering::Relaxed)
.compare_exchange_weak(0, 1, success_order, Ordering::Relaxed)
.is_err()
{
while self.0.load(Ordering::Relaxed) != 0 {
Expand All @@ -34,8 +40,16 @@ impl SpinLock {
}
}

fn unlock(&self) {
self.0.store(0, Ordering::Release);
fn unlock(&self, order: Ordering) {
self.0.store(
0,
// As with acquiring the lock, release the lock with `SeqCst`
// ordering if the corresponding atomic operation was `SeqCst`.
match order {
Ordering::SeqCst => Ordering::SeqCst,
_ => Ordering::Release,
},
);
}
}

Expand Down Expand Up @@ -84,35 +98,43 @@ fn lock_for_addr(addr: usize) -> &'static SpinLock {
}

#[inline]
fn lock(addr: usize) -> LockGuard {
fn lock(addr: usize, order: Ordering) -> LockGuard {
let lock = lock_for_addr(addr);
lock.lock();
LockGuard(lock)
lock.lock(order);
LockGuard {
lock,
order,
}
}

struct LockGuard {
lock: &'static SpinLock,
/// The ordering of the atomic operation for which the lock was obtained.
order: Ordering,
}

struct LockGuard(&'static SpinLock);
impl Drop for LockGuard {
#[inline]
fn drop(&mut self) {
self.0.unlock();
self.lock.unlock(self.order);
}
}

#[inline]
pub unsafe fn atomic_load<T>(dst: *mut T) -> T {
let _l = lock(dst as usize);
pub unsafe fn atomic_load<T>(dst: *mut T, order: Ordering) -> T {
let _l = lock(dst as usize, order);
ptr::read(dst)
}

#[inline]
pub unsafe fn atomic_store<T>(dst: *mut T, val: T) {
let _l = lock(dst as usize);
pub unsafe fn atomic_store<T>(dst: *mut T, val: T, order: Ordering) {
let _l = lock(dst as usize, order);
ptr::write(dst, val);
}

#[inline]
pub unsafe fn atomic_swap<T>(dst: *mut T, val: T) -> T {
let _l = lock(dst as usize);
pub unsafe fn atomic_swap<T>(dst: *mut T, val: T, order: Ordering) -> T {
let _l = lock(dst as usize, order);
ptr::replace(dst, val)
}

Expand All @@ -121,8 +143,10 @@ pub unsafe fn atomic_compare_exchange<T: NoUninit>(
dst: *mut T,
current: T,
new: T,
success: Ordering,
failure: Ordering,
) -> Result<T, T> {
let _l = lock(dst as usize);
let mut l = lock(dst as usize, success);
let result = ptr::read(dst);
// compare_exchange compares with memcmp instead of Eq
let a = bytemuck::bytes_of(&result);
Expand All @@ -131,67 +155,69 @@ pub unsafe fn atomic_compare_exchange<T: NoUninit>(
ptr::write(dst, new);
Ok(result)
} else {
// Use the failure ordering instead in this case.
l.order = failure;
Err(result)
}
}

#[inline]
pub unsafe fn atomic_add<T: Copy>(dst: *mut T, val: T) -> T
pub unsafe fn atomic_add<T: Copy>(dst: *mut T, val: T, order: Ordering) -> T
where
Wrapping<T>: ops::Add<Output = Wrapping<T>>,
{
let _l = lock(dst as usize);
let _l = lock(dst as usize, order);
let result = ptr::read(dst);
ptr::write(dst, (Wrapping(result) + Wrapping(val)).0);
result
}

#[inline]
pub unsafe fn atomic_sub<T: Copy>(dst: *mut T, val: T) -> T
pub unsafe fn atomic_sub<T: Copy>(dst: *mut T, val: T, order: Ordering) -> T
where
Wrapping<T>: ops::Sub<Output = Wrapping<T>>,
{
let _l = lock(dst as usize);
let _l = lock(dst as usize, order);
let result = ptr::read(dst);
ptr::write(dst, (Wrapping(result) - Wrapping(val)).0);
result
}

#[inline]
pub unsafe fn atomic_and<T: Copy + ops::BitAnd<Output = T>>(dst: *mut T, val: T) -> T {
let _l = lock(dst as usize);
pub unsafe fn atomic_and<T: Copy + ops::BitAnd<Output = T>>(dst: *mut T, val: T, order: Ordering) -> T {
let _l = lock(dst as usize, order);
let result = ptr::read(dst);
ptr::write(dst, result & val);
result
}

#[inline]
pub unsafe fn atomic_or<T: Copy + ops::BitOr<Output = T>>(dst: *mut T, val: T) -> T {
let _l = lock(dst as usize);
pub unsafe fn atomic_or<T: Copy + ops::BitOr<Output = T>>(dst: *mut T, val: T, order: Ordering) -> T {
let _l = lock(dst as usize, order);
let result = ptr::read(dst);
ptr::write(dst, result | val);
result
}

#[inline]
pub unsafe fn atomic_xor<T: Copy + ops::BitXor<Output = T>>(dst: *mut T, val: T) -> T {
let _l = lock(dst as usize);
pub unsafe fn atomic_xor<T: Copy + ops::BitXor<Output = T>>(dst: *mut T, val: T, order: Ordering) -> T {
let _l = lock(dst as usize, order);
let result = ptr::read(dst);
ptr::write(dst, result ^ val);
result
}

#[inline]
pub unsafe fn atomic_min<T: Copy + cmp::Ord>(dst: *mut T, val: T) -> T {
let _l = lock(dst as usize);
pub unsafe fn atomic_min<T: Copy + cmp::Ord>(dst: *mut T, val: T, order: Ordering) -> T {
let _l = lock(dst as usize, order);
let result = ptr::read(dst);
ptr::write(dst, cmp::min(result, val));
result
}

#[inline]
pub unsafe fn atomic_max<T: Copy + cmp::Ord>(dst: *mut T, val: T) -> T {
let _l = lock(dst as usize);
pub unsafe fn atomic_max<T: Copy + cmp::Ord>(dst: *mut T, val: T, order: Ordering) -> T {
let _l = lock(dst as usize, order);
let result = ptr::read(dst);
ptr::write(dst, cmp::max(result, val));
result
Expand Down
28 changes: 14 additions & 14 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub unsafe fn atomic_load<T: NoUninit>(dst: *mut T, order: Ordering) -> T {
T,
A,
mem::transmute_copy(&(*(dst as *const A)).load(order)),
fallback::atomic_load(dst)
fallback::atomic_load(dst, order)
)
}

Expand All @@ -128,7 +128,7 @@ pub unsafe fn atomic_store<T: NoUninit>(dst: *mut T, val: T, order: Ordering) {
T,
A,
(*(dst as *const A)).store(mem::transmute_copy(&val), order),
fallback::atomic_store(dst, val)
fallback::atomic_store(dst, val, order)
)
}

Expand All @@ -138,7 +138,7 @@ pub unsafe fn atomic_swap<T: NoUninit>(dst: *mut T, val: T, order: Ordering) ->
T,
A,
mem::transmute_copy(&(*(dst as *const A)).swap(mem::transmute_copy(&val), order)),
fallback::atomic_swap(dst, val)
fallback::atomic_swap(dst, val, order)
)
}

Expand Down Expand Up @@ -167,7 +167,7 @@ pub unsafe fn atomic_compare_exchange<T: NoUninit>(
success,
failure,
)),
fallback::atomic_compare_exchange(dst, current, new)
fallback::atomic_compare_exchange(dst, current, new, success, failure)
)
}

Expand All @@ -188,7 +188,7 @@ pub unsafe fn atomic_compare_exchange_weak<T: NoUninit>(
success,
failure,
)),
fallback::atomic_compare_exchange(dst, current, new)
fallback::atomic_compare_exchange(dst, current, new, success, failure)
)
}

Expand All @@ -201,7 +201,7 @@ where
T,
A,
mem::transmute_copy(&(*(dst as *const A)).fetch_add(mem::transmute_copy(&val), order),),
fallback::atomic_add(dst, val)
fallback::atomic_add(dst, val, order)
)
}

Expand All @@ -214,7 +214,7 @@ where
T,
A,
mem::transmute_copy(&(*(dst as *const A)).fetch_sub(mem::transmute_copy(&val), order),),
fallback::atomic_sub(dst, val)
fallback::atomic_sub(dst, val, order)
)
}

Expand All @@ -228,7 +228,7 @@ pub unsafe fn atomic_and<T: NoUninit + ops::BitAnd<Output = T>>(
T,
A,
mem::transmute_copy(&(*(dst as *const A)).fetch_and(mem::transmute_copy(&val), order),),
fallback::atomic_and(dst, val)
fallback::atomic_and(dst, val, order)
)
}

Expand All @@ -242,7 +242,7 @@ pub unsafe fn atomic_or<T: NoUninit + ops::BitOr<Output = T>>(
T,
A,
mem::transmute_copy(&(*(dst as *const A)).fetch_or(mem::transmute_copy(&val), order),),
fallback::atomic_or(dst, val)
fallback::atomic_or(dst, val, order)
)
}

Expand All @@ -256,7 +256,7 @@ pub unsafe fn atomic_xor<T: NoUninit + ops::BitXor<Output = T>>(
T,
A,
mem::transmute_copy(&(*(dst as *const A)).fetch_xor(mem::transmute_copy(&val), order),),
fallback::atomic_xor(dst, val)
fallback::atomic_xor(dst, val, order)
)
}

Expand All @@ -266,7 +266,7 @@ pub unsafe fn atomic_min<T: NoUninit + cmp::Ord>(dst: *mut T, val: T, order: Ord
T,
A,
mem::transmute_copy(&(*(dst as *const A)).fetch_min(mem::transmute_copy(&val), order),),
fallback::atomic_min(dst, val)
fallback::atomic_min(dst, val, order)
)
}

Expand All @@ -276,7 +276,7 @@ pub unsafe fn atomic_max<T: NoUninit + cmp::Ord>(dst: *mut T, val: T, order: Ord
T,
A,
mem::transmute_copy(&(*(dst as *const A)).fetch_max(mem::transmute_copy(&val), order),),
fallback::atomic_max(dst, val)
fallback::atomic_max(dst, val, order)
)
}

Expand All @@ -286,7 +286,7 @@ pub unsafe fn atomic_umin<T: NoUninit + cmp::Ord>(dst: *mut T, val: T, order: Or
T,
A,
mem::transmute_copy(&(*(dst as *const A)).fetch_min(mem::transmute_copy(&val), order),),
fallback::atomic_min(dst, val)
fallback::atomic_min(dst, val, order)
)
}

Expand All @@ -296,6 +296,6 @@ pub unsafe fn atomic_umax<T: NoUninit + cmp::Ord>(dst: *mut T, val: T, order: Or
T,
A,
mem::transmute_copy(&(*(dst as *const A)).fetch_max(mem::transmute_copy(&val), order),),
fallback::atomic_max(dst, val)
fallback::atomic_max(dst, val, order)
)
}