Skip to content

Commit

Permalink
refactor: use binary bits to determine the number of partitions (#919)
Browse files Browse the repository at this point in the history
## Related Issues
Prepare for #914 

## Detailed Changes
- Modify the type of `partitions` in `PartitionedMutex` and
`PartitionedRwLock`.
- Fix the bug that multiple partitions use the same lock.
## Test Plan 
Unit tests under the same file.
  • Loading branch information
tanruixiang authored May 24, 2023
1 parent 6051546 commit 0317948
Showing 1 changed file with 61 additions and 27 deletions.
88 changes: 61 additions & 27 deletions common_util/src/partitioned_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,27 @@
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
num::NonZeroUsize,
sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard},
sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard},
};

/// Simple partitioned `RwLock`
pub struct PartitionedRwLock<T> {
partitions: Vec<Arc<RwLock<T>>>,
partitions: Vec<RwLock<T>>,
partition_mask: usize,
}

impl<T> PartitionedRwLock<T> {
// TODO: we should get the nearest 2^n of `partition_num` as real
// `partition_num`. By doing so, we can use "&" to get partition rather than
// "%".
pub fn new(t: T, partition_num: NonZeroUsize) -> Self {
let partition_num = partition_num.get();
let locked_content = Arc::new(RwLock::new(t));
impl<T> PartitionedRwLock<T>
where
T: Clone,
{
pub fn new(t: T, partition_bit: usize) -> Self {
let partition_num = 1 << partition_bit;
let partitions = (0..partition_num)
.map(|_| RwLock::new(t.clone()))
.collect::<Vec<_>>();
Self {
partitions: vec![locked_content; partition_num],
partitions,
partition_mask: partition_num - 1,
}
}

Expand All @@ -41,26 +44,29 @@ impl<T> PartitionedRwLock<T> {
fn get_partition<K: Eq + Hash>(&self, key: &K) -> &RwLock<T> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let partition_num = self.partitions.len();

&self.partitions[(hasher.finish() as usize) % partition_num]
&self.partitions[(hasher.finish() as usize) & self.partition_mask]
}
}

/// Simple partitioned `Mutex`
pub struct PartitionedMutex<T> {
partitions: Vec<Arc<Mutex<T>>>,
partitions: Vec<Mutex<T>>,
partition_mask: usize,
}

impl<T> PartitionedMutex<T> {
// TODO: we should get the nearest 2^n of `partition_num` as real
// `partition_num`. By doing so, we can use "&" to get partition rather than
// "%".
pub fn new(t: T, partition_num: NonZeroUsize) -> Self {
let partition_num = partition_num.get();
let locked_content = Arc::new(Mutex::new(t));
impl<T> PartitionedMutex<T>
where
T: Clone,
{
pub fn new(t: T, partition_bit: usize) -> Self {
let partition_num = 1 << partition_bit;
let partitions = (0..partition_num)
.map(|_| Mutex::new(t.clone()))
.collect::<Vec<_>>();
Self {
partitions: vec![locked_content; partition_num],
partitions,
partition_mask: partition_num - 1,
}
}

Expand All @@ -73,9 +79,8 @@ impl<T> PartitionedMutex<T> {
fn get_partition<K: Eq + Hash>(&self, key: &K) -> &Mutex<T> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let partition_num = self.partitions.len();

&self.partitions[(hasher.finish() as usize) % partition_num]
&self.partitions[(hasher.finish() as usize) & self.partition_mask]
}
}

Expand All @@ -87,8 +92,7 @@ mod tests {

#[test]
fn test_partitioned_rwlock() {
let test_locked_map =
PartitionedRwLock::new(HashMap::new(), NonZeroUsize::new(10).unwrap());
let test_locked_map = PartitionedRwLock::new(HashMap::new(), 4);
let test_key = "test_key".to_string();
let test_value = "test_value".to_string();

Expand All @@ -105,7 +109,7 @@ mod tests {

#[test]
fn test_partitioned_mutex() {
let test_locked_map = PartitionedMutex::new(HashMap::new(), NonZeroUsize::new(10).unwrap());
let test_locked_map = PartitionedMutex::new(HashMap::new(), 4);
let test_key = "test_key".to_string();
let test_value = "test_value".to_string();

Expand All @@ -119,4 +123,34 @@ mod tests {
assert_eq!(map.get(&test_key).unwrap(), &test_value);
}
}

#[test]
fn test_partitioned_mutex_vis_different_partition() {
let tmp_vec: Vec<f32> = Vec::new();
let test_locked_map = PartitionedMutex::new(tmp_vec, 4);
let test_key_first = "test_key_first".to_string();
let mutex_first = test_locked_map.get_partition(&test_key_first);
let mut _tmp_data = mutex_first.lock().unwrap();
assert!(mutex_first.try_lock().is_err());

let test_key_second = "test_key_second".to_string();
let mutex_second = test_locked_map.get_partition(&test_key_second);
assert!(mutex_second.try_lock().is_ok());
assert!(mutex_first.try_lock().is_err());
}

#[test]
fn test_partitioned_rwmutex_vis_different_partition() {
let tmp_vec: Vec<f32> = Vec::new();
let test_locked_map = PartitionedRwLock::new(tmp_vec, 4);
let test_key_first = "test_key_first".to_string();
let mutex_first = test_locked_map.get_partition(&test_key_first);
let mut _tmp = mutex_first.write().unwrap();
assert!(mutex_first.try_write().is_err());

let test_key_second = "test_key_second".to_string();
let mutex_second_try_lock = test_locked_map.get_partition(&test_key_second);
assert!(mutex_second_try_lock.try_write().is_ok());
assert!(mutex_first.try_write().is_err());
}
}

0 comments on commit 0317948

Please sign in to comment.