Skip to content

Commit

Permalink
Prevent segmentation fault in sync and future caches (#34)
Browse files Browse the repository at this point in the history
- Add a lock to the rehash function of the concurrent hash table (`moka::cht`) to
  ensure only one thread can participate rehashing at a time.
- To prevent potential inconsistency issues in non x86 based systems, strengthen the
  memory ordering used for `compare_exchange_weak` (`Release` to `AcqRel`).
  • Loading branch information
tatsuya6502 committed Jul 19, 2022
1 parent 146fca8 commit fc044d6
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 34 deletions.
53 changes: 36 additions & 17 deletions src/cht/map/bucket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use std::{
hash::{BuildHasher, Hash, Hasher},
mem::{self, MaybeUninit},
ptr,
sync::atomic::{self, AtomicUsize, Ordering},
sync::{
atomic::{self, AtomicUsize, Ordering},
Arc, Mutex, TryLockError,
},
};

#[cfg(feature = "unstable-debug-counters")]
Expand All @@ -16,6 +19,7 @@ pub(crate) struct BucketArray<K, V> {
pub(crate) buckets: Box<[Atomic<Bucket<K, V>>]>,
pub(crate) next: Atomic<BucketArray<K, V>>,
pub(crate) epoch: usize,
pub(crate) rehash_lock: Arc<Mutex<()>>,
pub(crate) tombstone_count: AtomicUsize,
}

Expand Down Expand Up @@ -49,6 +53,7 @@ impl<K, V> BucketArray<K, V> {
buckets,
next: Atomic::null(),
epoch,
rehash_lock: Arc::new(Mutex::new(())),
tombstone_count: Default::default(),
}
}
Expand Down Expand Up @@ -147,10 +152,10 @@ impl<'g, K: 'g + Eq, V: 'g> BucketArray<K, V> {

let new_bucket_ptr = this_bucket_ptr.with_tag(TOMBSTONE_TAG);

match this_bucket.compare_exchange(
match this_bucket.compare_exchange_weak(
this_bucket_ptr,
new_bucket_ptr,
Ordering::Release,
Ordering::AcqRel,
Ordering::Relaxed,
guard,
) {
Expand Down Expand Up @@ -201,10 +206,10 @@ impl<'g, K: 'g + Eq, V: 'g> BucketArray<K, V> {

let new_bucket = state.into_insert_bucket();

if let Err(CompareExchangeError { new, .. }) = this_bucket.compare_exchange(
if let Err(CompareExchangeError { new, .. }) = this_bucket.compare_exchange_weak(
this_bucket_ptr,
new_bucket,
Ordering::Release,
Ordering::AcqRel,
Ordering::Relaxed,
guard,
) {
Expand Down Expand Up @@ -267,10 +272,10 @@ impl<'g, K: 'g + Eq, V: 'g> BucketArray<K, V> {
(state.into_insert_bucket(), None)
};

if let Err(CompareExchangeError { new, .. }) = this_bucket.compare_exchange(
if let Err(CompareExchangeError { new, .. }) = this_bucket.compare_exchange_weak(
this_bucket_ptr,
new_bucket,
Ordering::Release,
Ordering::AcqRel,
Ordering::Relaxed,
guard,
) {
Expand Down Expand Up @@ -317,10 +322,10 @@ impl<'g, K: 'g + Eq, V: 'g> BucketArray<K, V> {
if this_bucket_ptr.is_null() && is_tombstone(bucket_ptr) {
ProbeLoopAction::Return(None)
} else if this_bucket
.compare_exchange(
.compare_exchange_weak(
this_bucket_ptr,
bucket_ptr,
Ordering::Release,
Ordering::AcqRel,
Ordering::Relaxed,
guard,
)
Expand Down Expand Up @@ -398,11 +403,24 @@ impl<'g, K: 'g, V: 'g> BucketArray<K, V> {
guard: &'g Guard,
build_hasher: &H,
rehash_op: RehashOp,
) -> &'g BucketArray<K, V>
) -> Option<&'g BucketArray<K, V>>
where
K: Hash + Eq,
H: BuildHasher,
{
// Ensure that the rehashing is not performed concurrently.
let lock;
match self.rehash_lock.try_lock() {
Ok(lk) => lock = lk,
Err(TryLockError::WouldBlock) => {
// Wait until the lock become available.
std::mem::drop(self.rehash_lock.lock());
// We need to return here to see if rehashing is still needed.
return None;
}
Err(e @ TryLockError::Poisoned(_)) => panic!("{:?}", e),
};

let next_array = self.next_array(guard, rehash_op);

for this_bucket in self.buckets.iter() {
Expand All @@ -424,10 +442,10 @@ impl<'g, K: 'g, V: 'g> BucketArray<K, V> {

while is_borrowed(next_bucket_ptr)
&& next_bucket
.compare_exchange(
.compare_exchange_weak(
next_bucket_ptr,
to_put_ptr,
Ordering::Release,
Ordering::AcqRel,
Ordering::Relaxed,
guard,
)
Expand All @@ -445,10 +463,10 @@ impl<'g, K: 'g, V: 'g> BucketArray<K, V> {
}

if this_bucket
.compare_exchange(
.compare_exchange_weak(
this_bucket_ptr,
Shared::null().with_tag(SENTINEL_TAG),
Ordering::Release,
Ordering::AcqRel,
Ordering::Relaxed,
guard,
)
Expand All @@ -466,8 +484,9 @@ impl<'g, K: 'g, V: 'g> BucketArray<K, V> {
}
}
}
std::mem::drop(lock);

next_array
Some(next_array)
}

fn next_array(&self, guard: &'g Guard, rehash_op: RehashOp) -> &'g BucketArray<K, V> {
Expand All @@ -485,10 +504,10 @@ impl<'g, K: 'g, V: 'g> BucketArray<K, V> {
Owned::new(BucketArray::with_length(self.epoch + 1, new_length))
});

match self.next.compare_exchange(
match self.next.compare_exchange_weak(
Shared::null(),
new_next,
Ordering::Release,
Ordering::AcqRel,
Ordering::Relaxed,
guard,
) {
Expand Down
55 changes: 38 additions & 17 deletions src/cht/map/bucket_array_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ where
break;
}
Err(_) => {
bucket_array_ref =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand);
if let Some(r) =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand)
{
bucket_array_ref = r;
}
}
}
}
Expand Down Expand Up @@ -81,7 +84,9 @@ where
if rehash_op.is_skip() {
break;
}
bucket_array_ref = bucket_array_ref.rehash(guard, self.build_hasher, rehash_op);
if let Some(r) = bucket_array_ref.rehash(guard, self.build_hasher, rehash_op) {
bucket_array_ref = r;
}
}

match bucket_array_ref.remove_if(guard, hash, &mut eq, condition) {
Expand All @@ -106,8 +111,11 @@ where
}
Err(c) => {
condition = c;
bucket_array_ref =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand);
if let Some(r) =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand)
{
bucket_array_ref = r;
}
}
}
}
Expand Down Expand Up @@ -143,7 +151,9 @@ where
if rehash_op.is_skip() {
break;
}
bucket_array_ref = bucket_array_ref.rehash(guard, self.build_hasher, rehash_op);
if let Some(r) = bucket_array_ref.rehash(guard, self.build_hasher, rehash_op) {
bucket_array_ref = r;
}
}

match bucket_array_ref.insert_if_not_present(guard, hash, state) {
Expand Down Expand Up @@ -171,8 +181,11 @@ where
}
Err(s) => {
state = s;
bucket_array_ref =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand);
if let Some(r) =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand)
{
bucket_array_ref = r;
}
}
}
}
Expand Down Expand Up @@ -207,7 +220,9 @@ where
if rehash_op.is_skip() {
break;
}
bucket_array_ref = bucket_array_ref.rehash(guard, self.build_hasher, rehash_op);
if let Some(r) = bucket_array_ref.rehash(guard, self.build_hasher, rehash_op) {
bucket_array_ref = r;
}
}

match bucket_array_ref.insert_or_modify(guard, hash, state, on_modify) {
Expand Down Expand Up @@ -235,8 +250,11 @@ where
Err((s, f)) => {
state = s;
on_modify = f;
bucket_array_ref =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand);
if let Some(r) =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand)
{
bucket_array_ref = r;
}
}
}
}
Expand All @@ -260,8 +278,11 @@ where
break;
}
Err(_) => {
bucket_array_ref =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand);
if let Some(r) =
bucket_array_ref.rehash(guard, self.build_hasher, RehashOp::Expand)
{
bucket_array_ref = r;
}
}
}
}
Expand All @@ -286,10 +307,10 @@ impl<'a, 'g, K, V, S> BucketArrayRef<'a, K, V, S> {
let new_bucket_array =
maybe_new_bucket_array.unwrap_or_else(|| Owned::new(BucketArray::default()));

match self.bucket_array.compare_exchange(
match self.bucket_array.compare_exchange_weak(
Shared::null(),
new_bucket_array,
Ordering::Release,
Ordering::AcqRel,
Ordering::Relaxed,
guard,
) {
Expand All @@ -315,10 +336,10 @@ impl<'a, 'g, K, V, S> BucketArrayRef<'a, K, V, S> {
return;
}

match self.bucket_array.compare_exchange(
match self.bucket_array.compare_exchange_weak(
current_ptr,
min_ptr,
Ordering::Release,
Ordering::AcqRel,
Ordering::Relaxed,
guard,
) {
Expand Down

0 comments on commit fc044d6

Please sign in to comment.