diff --git a/src/fallback.rs b/src/fallback.rs index 949bbf3..645170e 100644 --- a/src/fallback.rs +++ b/src/fallback.rs @@ -22,10 +22,16 @@ use bytemuck::NoUninit; struct SpinLock(AtomicUsize); impl SpinLock { - fn lock(&self) { + fn lock(&self, order: Ordering) { + // If the corresponding atomic operation is `SeqCst`, acquire the lock + // with `SeqCst` ordering to ensure sequential consistency. + let success_order = match order { + Ordering::SeqCst => Ordering::SeqCst, + _ => Ordering::Acquire, + }; while self .0 - .compare_exchange_weak(0, 1, Ordering::Acquire, Ordering::Relaxed) + .compare_exchange_weak(0, 1, success_order, Ordering::Relaxed) .is_err() { while self.0.load(Ordering::Relaxed) != 0 { @@ -34,8 +40,16 @@ impl SpinLock { } } - fn unlock(&self) { - self.0.store(0, Ordering::Release); + fn unlock(&self, order: Ordering) { + self.0.store( + 0, + // As with acquiring the lock, release the lock with `SeqCst` + // ordering if the corresponding atomic operation was `SeqCst`. + match order { + Ordering::SeqCst => Ordering::SeqCst, + _ => Ordering::Release, + }, + ); } } @@ -84,35 +98,43 @@ fn lock_for_addr(addr: usize) -> &'static SpinLock { } #[inline] -fn lock(addr: usize) -> LockGuard { +fn lock(addr: usize, order: Ordering) -> LockGuard { let lock = lock_for_addr(addr); - lock.lock(); - LockGuard(lock) + lock.lock(order); + LockGuard { + lock, + order, + } +} + +struct LockGuard { + lock: &'static SpinLock, + /// The ordering of the atomic operation for which the lock was obtained. + order: Ordering, } -struct LockGuard(&'static SpinLock); impl Drop for LockGuard { #[inline] fn drop(&mut self) { - self.0.unlock(); + self.lock.unlock(self.order); } } #[inline] -pub unsafe fn atomic_load(dst: *mut T) -> T { - let _l = lock(dst as usize); +pub unsafe fn atomic_load(dst: *mut T, order: Ordering) -> T { + let _l = lock(dst as usize, order); ptr::read(dst) } #[inline] -pub unsafe fn atomic_store(dst: *mut T, val: T) { - let _l = lock(dst as usize); +pub unsafe fn atomic_store(dst: *mut T, val: T, order: Ordering) { + let _l = lock(dst as usize, order); ptr::write(dst, val); } #[inline] -pub unsafe fn atomic_swap(dst: *mut T, val: T) -> T { - let _l = lock(dst as usize); +pub unsafe fn atomic_swap(dst: *mut T, val: T, order: Ordering) -> T { + let _l = lock(dst as usize, order); ptr::replace(dst, val) } @@ -121,8 +143,10 @@ pub unsafe fn atomic_compare_exchange( dst: *mut T, current: T, new: T, + success: Ordering, + failure: Ordering, ) -> Result { - let _l = lock(dst as usize); + let mut l = lock(dst as usize, success); let result = ptr::read(dst); // compare_exchange compares with memcmp instead of Eq let a = bytemuck::bytes_of(&result); @@ -131,67 +155,69 @@ pub unsafe fn atomic_compare_exchange( ptr::write(dst, new); Ok(result) } else { + // Use the failure ordering instead in this case. + l.order = failure; Err(result) } } #[inline] -pub unsafe fn atomic_add(dst: *mut T, val: T) -> T +pub unsafe fn atomic_add(dst: *mut T, val: T, order: Ordering) -> T where Wrapping: ops::Add>, { - let _l = lock(dst as usize); + let _l = lock(dst as usize, order); let result = ptr::read(dst); ptr::write(dst, (Wrapping(result) + Wrapping(val)).0); result } #[inline] -pub unsafe fn atomic_sub(dst: *mut T, val: T) -> T +pub unsafe fn atomic_sub(dst: *mut T, val: T, order: Ordering) -> T where Wrapping: ops::Sub>, { - let _l = lock(dst as usize); + let _l = lock(dst as usize, order); let result = ptr::read(dst); ptr::write(dst, (Wrapping(result) - Wrapping(val)).0); result } #[inline] -pub unsafe fn atomic_and>(dst: *mut T, val: T) -> T { - let _l = lock(dst as usize); +pub unsafe fn atomic_and>(dst: *mut T, val: T, order: Ordering) -> T { + let _l = lock(dst as usize, order); let result = ptr::read(dst); ptr::write(dst, result & val); result } #[inline] -pub unsafe fn atomic_or>(dst: *mut T, val: T) -> T { - let _l = lock(dst as usize); +pub unsafe fn atomic_or>(dst: *mut T, val: T, order: Ordering) -> T { + let _l = lock(dst as usize, order); let result = ptr::read(dst); ptr::write(dst, result | val); result } #[inline] -pub unsafe fn atomic_xor>(dst: *mut T, val: T) -> T { - let _l = lock(dst as usize); +pub unsafe fn atomic_xor>(dst: *mut T, val: T, order: Ordering) -> T { + let _l = lock(dst as usize, order); let result = ptr::read(dst); ptr::write(dst, result ^ val); result } #[inline] -pub unsafe fn atomic_min(dst: *mut T, val: T) -> T { - let _l = lock(dst as usize); +pub unsafe fn atomic_min(dst: *mut T, val: T, order: Ordering) -> T { + let _l = lock(dst as usize, order); let result = ptr::read(dst); ptr::write(dst, cmp::min(result, val)); result } #[inline] -pub unsafe fn atomic_max(dst: *mut T, val: T) -> T { - let _l = lock(dst as usize); +pub unsafe fn atomic_max(dst: *mut T, val: T, order: Ordering) -> T { + let _l = lock(dst as usize, order); let result = ptr::read(dst); ptr::write(dst, cmp::max(result, val)); result diff --git a/src/ops.rs b/src/ops.rs index 1f54fa5..0315992 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -118,7 +118,7 @@ pub unsafe fn atomic_load(dst: *mut T, order: Ordering) -> T { T, A, mem::transmute_copy(&(*(dst as *const A)).load(order)), - fallback::atomic_load(dst) + fallback::atomic_load(dst, order) ) } @@ -128,7 +128,7 @@ pub unsafe fn atomic_store(dst: *mut T, val: T, order: Ordering) { T, A, (*(dst as *const A)).store(mem::transmute_copy(&val), order), - fallback::atomic_store(dst, val) + fallback::atomic_store(dst, val, order) ) } @@ -138,7 +138,7 @@ pub unsafe fn atomic_swap(dst: *mut T, val: T, order: Ordering) -> T, A, mem::transmute_copy(&(*(dst as *const A)).swap(mem::transmute_copy(&val), order)), - fallback::atomic_swap(dst, val) + fallback::atomic_swap(dst, val, order) ) } @@ -167,7 +167,7 @@ pub unsafe fn atomic_compare_exchange( success, failure, )), - fallback::atomic_compare_exchange(dst, current, new) + fallback::atomic_compare_exchange(dst, current, new, success, failure) ) } @@ -188,7 +188,7 @@ pub unsafe fn atomic_compare_exchange_weak( success, failure, )), - fallback::atomic_compare_exchange(dst, current, new) + fallback::atomic_compare_exchange(dst, current, new, success, failure) ) } @@ -201,7 +201,7 @@ where T, A, mem::transmute_copy(&(*(dst as *const A)).fetch_add(mem::transmute_copy(&val), order),), - fallback::atomic_add(dst, val) + fallback::atomic_add(dst, val, order) ) } @@ -214,7 +214,7 @@ where T, A, mem::transmute_copy(&(*(dst as *const A)).fetch_sub(mem::transmute_copy(&val), order),), - fallback::atomic_sub(dst, val) + fallback::atomic_sub(dst, val, order) ) } @@ -228,7 +228,7 @@ pub unsafe fn atomic_and>( T, A, mem::transmute_copy(&(*(dst as *const A)).fetch_and(mem::transmute_copy(&val), order),), - fallback::atomic_and(dst, val) + fallback::atomic_and(dst, val, order) ) } @@ -242,7 +242,7 @@ pub unsafe fn atomic_or>( T, A, mem::transmute_copy(&(*(dst as *const A)).fetch_or(mem::transmute_copy(&val), order),), - fallback::atomic_or(dst, val) + fallback::atomic_or(dst, val, order) ) } @@ -256,7 +256,7 @@ pub unsafe fn atomic_xor>( T, A, mem::transmute_copy(&(*(dst as *const A)).fetch_xor(mem::transmute_copy(&val), order),), - fallback::atomic_xor(dst, val) + fallback::atomic_xor(dst, val, order) ) } @@ -266,7 +266,7 @@ pub unsafe fn atomic_min(dst: *mut T, val: T, order: Ord T, A, mem::transmute_copy(&(*(dst as *const A)).fetch_min(mem::transmute_copy(&val), order),), - fallback::atomic_min(dst, val) + fallback::atomic_min(dst, val, order) ) } @@ -276,7 +276,7 @@ pub unsafe fn atomic_max(dst: *mut T, val: T, order: Ord T, A, mem::transmute_copy(&(*(dst as *const A)).fetch_max(mem::transmute_copy(&val), order),), - fallback::atomic_max(dst, val) + fallback::atomic_max(dst, val, order) ) } @@ -286,7 +286,7 @@ pub unsafe fn atomic_umin(dst: *mut T, val: T, order: Or T, A, mem::transmute_copy(&(*(dst as *const A)).fetch_min(mem::transmute_copy(&val), order),), - fallback::atomic_min(dst, val) + fallback::atomic_min(dst, val, order) ) } @@ -296,6 +296,6 @@ pub unsafe fn atomic_umax(dst: *mut T, val: T, order: Or T, A, mem::transmute_copy(&(*(dst as *const A)).fetch_max(mem::transmute_copy(&val), order),), - fallback::atomic_max(dst, val) + fallback::atomic_max(dst, val, order) ) }