From 5fe50da55dfdeb50158ec880986843b91ce263bc Mon Sep 17 00:00:00 2001 From: Bruno Mendes Date: Mon, 13 May 2024 01:42:14 +0100 Subject: [PATCH] Implement lockless hashing --- Cargo.lock | 7 ++ Cargo.toml | 1 + src/search/table.rs | 255 +++++++++----------------------------------- 3 files changed, 57 insertions(+), 206 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 150313be..0d130fc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -51,6 +51,7 @@ dependencies = [ "derive_more", "primitive_enum", "rand", + "sync-unsafe-cell", ] [[package]] @@ -624,6 +625,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync-unsafe-cell" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8deaecba5382c095cb432cd1e21068dadb144208f057b13720e76bf89749beb4" + [[package]] name = "tinytemplate" version = "1.2.1" diff --git a/Cargo.toml b/Cargo.toml index a5600e07..b1eb53a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ primitive_enum = "1.2.0" derive_more = "0.99.17" rand = "0.8.5" ctor = "0.2.8" +sync-unsafe-cell = "0.1.1" [profile.dev] opt-level = 1 diff --git a/src/search/table.rs b/src/search/table.rs index 8b22c6c5..4e006906 100644 --- a/src/search/table.rs +++ b/src/search/table.rs @@ -2,23 +2,16 @@ use super::{Depth, MAX_DEPTH}; use crate::{ evaluation::{Score, ValueScore}, moves::Move, - position::Position, -}; -use std::{ - array, - mem::transmute, - sync::{ - atomic::{AtomicU16, AtomicU64, Ordering}, - RwLock, - }, + position::{board::ZobristHash, Position}, }; +use std::{array, sync::RwLock}; +use sync_unsafe_cell::SyncUnsafeCell; pub const MAX_TABLE_SIZE_MB: usize = 2048; pub const MIN_TABLE_SIZE_MB: usize = 1; pub const DEFAULT_TABLE_SIZE_MB: usize = 64; -const NULL_KILLER: u16 = u16::MAX; -const NULL_TT_ENTRY: u64 = u64::MAX; +type XoredZobristHash = ZobristHash; #[derive(Clone, Copy, Debug, PartialEq)] pub enum ScoreType { @@ -31,7 +24,9 @@ pub enum ScoreType { struct TableEntry { score: ValueScore, best_move: Move, - data: u32, // bits 0-1: score type, bits 2-7: depth, bit 8: age, bits 9-31: hash + score_type: ScoreType, + depth: Depth, + age: u8, } impl TableEntry { @@ -40,90 +35,58 @@ impl TableEntry { score_type: ScoreType, best_move: Move, depth: Depth, - hash: u64, - age: bool, + age: u8, ) -> Self { - TableEntry { - score, - best_move, - data: (score_type as u32) - | ((depth as u32 & 0x3F) << 2) - | ((age as u32) << 8) - | ((hash as u32) << 9), - } - } - - pub fn from_raw(bytes: u64) -> Self { - unsafe { transmute::(bytes) } - } - - pub fn raw(&self) -> u64 { - unsafe { transmute::(*self) } + TableEntry { score, best_move, score_type, depth, age } } pub fn shift_score(&self, shift: ValueScore) -> Self { TableEntry { score: self.score + shift, ..*self } } - fn score_type(&self) -> ScoreType { - match self.data & 3 { - 0 => ScoreType::Exact, - 1 => ScoreType::LowerBound, - 2 => ScoreType::UpperBound, - _ => panic!("Invalid score type"), - } - } - - fn depth(&self) -> Depth { - ((self.data >> 2) & 0x3F) as Depth - } - - fn same_hash(&self, hash: u64) -> bool { - // We store 23 bits of the hash, so we need to compare the first 23 bits. - (self.data >> 9) == (hash as u32 & 0x7F_FFFF) - } + pub fn data_raw(&self) -> u64 { + let score = self.score as u64; + let best_move = self.best_move.raw() as u64; + let score_type = self.score_type as u64; + let depth = self.depth as u64; + let age = self.age as u64; - fn age(&self) -> bool { - ((self.data >> 8) & 1) != 0 + score | (best_move << 16) | (score_type << 32) | (depth << 40) | (age << 48) } } struct TranspositionTable { - data: Vec, - age: bool, + data: Vec>>, + age: u8, } impl TranspositionTable { pub fn new(size_mb: usize) -> Self { let data_len = Self::calculate_data_len(size_mb); - Self { data: (0..data_len).map(|_| AtomicU64::new(NULL_TT_ENTRY)).collect(), age: false } + Self { data: (0..data_len).map(|_| SyncUnsafeCell::new(None)).collect(), age: 0 } } fn calculate_data_len(size_mb: usize) -> usize { - let element_size = std::mem::size_of::>(); + let element_size = std::mem::size_of::>(); let size = size_mb * 1024 * 1024; size / element_size } pub fn set_size(&mut self, size_mb: usize) { let data_len = Self::calculate_data_len(size_mb); - self.data = (0..data_len).map(|_| AtomicU64::new(NULL_TT_ENTRY)).collect(); + self.data = (0..data_len).map(|_| SyncUnsafeCell::new(None)).collect(); } pub fn hashfull_millis(&self) -> usize { - // The hash keys are disperse, so a small sample should suffice for a relevant statistic. - self.data - .iter() - .take(10000) - .filter(|entry| entry.load(Ordering::Relaxed) != NULL_TT_ENTRY) - .count() - / 10 + (0..10000).filter(|i| self.load_tt_entry(*i).is_some()).count() / 10 } pub fn get(&self, position: &Position) -> Option { let hash = position.zobrist_hash(); let entry = self.load_tt_entry(hash as usize % self.data.len()); - entry.filter(|entry| entry.same_hash(hash)) + entry + .filter(|(entry, entry_hash)| entry_hash ^ entry.data_raw() == hash) + .map(|(entry, _)| entry) } pub fn insert(&self, position: &Position, entry: TableEntry, force: bool) { @@ -131,52 +94,47 @@ impl TranspositionTable { let index = hash as usize % self.data.len(); if !force { - if let Some(old_entry) = self.load_tt_entry(index) { - if old_entry.depth() > entry.depth() && old_entry.age() == entry.age() { + if let Some((old_entry, _)) = self.load_tt_entry(index) { + if old_entry.depth > entry.depth && old_entry.age == entry.age { return; } } } - self.store_tt_entry(index, entry); + self.store_tt_entry(index, Some((entry, hash ^ entry.data_raw()))); } - fn load_tt_entry(&self, index: usize) -> Option { - let entry = self.data[index].load(Ordering::Relaxed); - if entry == NULL_TT_ENTRY { - None - } else { - Some(TableEntry::from_raw(entry)) - } + fn load_tt_entry(&self, index: usize) -> Option<(TableEntry, XoredZobristHash)> { + unsafe { *self.data[index].get() } } - fn store_tt_entry(&self, index: usize, entry: TableEntry) { - self.data[index].store(entry.raw(), Ordering::Relaxed) + fn store_tt_entry(&self, index: usize, entry: Option<(TableEntry, XoredZobristHash)>) { + unsafe { *self.data[index].get() = entry } } } pub struct SearchTable { transposition: RwLock, - killer_moves: [AtomicU16; 3 * (MAX_DEPTH + 1) as usize], + killer_moves: [SyncUnsafeCell>; 3 * (MAX_DEPTH + 1) as usize], } impl SearchTable { pub fn new(size_mb: usize) -> Self { Self { transposition: RwLock::new(TranspositionTable::new(size_mb)), - killer_moves: array::from_fn(|_| AtomicU16::new(NULL_KILLER)), + killer_moves: array::from_fn(|_| SyncUnsafeCell::new(None)), } } pub fn prepare_for_new_search(&self) { - // We flip the age bit to be able to replace all entries from previous searches. + // We add to the age to be able to replace all entries from previous searches. // This is both faster and more effective than clearing the table completely, // since we can profit from older entries that are still valid. let mut tt = self.transposition.write().unwrap(); - tt.age = !tt.age; + tt.age = tt.age.wrapping_add(1); // Killer moves are no longer at the same ply, so we clear them. - self.killer_moves.iter().for_each(|entry| entry.store(NULL_KILLER, Ordering::Relaxed)); + (0..self.killer_moves.len()).for_each(|index| self.store_killer(index, None)); } pub fn set_size(&self, size_mb: usize) { @@ -203,8 +161,8 @@ impl SearchTable { .unwrap() .get(position) .and_then(|entry| { - if entry.depth() >= depth { - Some((entry.score, entry.score_type())) + if entry.depth >= depth { + Some((entry.score, entry.score_type)) } else { None } @@ -231,14 +189,7 @@ impl SearchTable { is_root: bool, ) { let tt = self.transposition.read().unwrap(); - let entry = TableEntry::new( - score, - score_type, - best_move, - depth, - position.zobrist_hash(), - self.transposition.read().unwrap().age, - ); + let entry = TableEntry::new(score, score_type, best_move, depth, tt.age); // The score stored should be independent of the path from root to this node, // and only depend on the number of moves to mate. @@ -254,15 +205,12 @@ impl SearchTable { pub fn put_killer_move(&self, ply: Depth, mov: Move) { let index = 2 * ply as usize; if self.load_killer(index).is_none() { - self.store_killer(index, mov); + self.store_killer(index, Some(mov)); } else if self.load_killer(index + 1).is_none() { - self.store_killer(index + 1, mov); + self.store_killer(index + 1, Some(mov)); } else { - self.store_killer( - index, - self.load_killer(index + 1).unwrap_or(Move::new_raw(NULL_KILLER)), - ); - self.store_killer(index + 1, mov); + self.store_killer(0, self.load_killer(index + 1)); + self.store_killer(index + 1, Some(mov)); } } @@ -297,120 +245,15 @@ impl SearchTable { .unwrap() .data .iter_mut() - .for_each(|entry| *entry = AtomicU64::new(NULL_TT_ENTRY)); - self.killer_moves.iter().for_each(|entry| entry.store(NULL_KILLER, Ordering::Relaxed)); + .for_each(|entry| *entry = SyncUnsafeCell::new(None)); + (0..self.killer_moves.len()).for_each(|index| self.store_killer(index, None)); } fn load_killer(&self, index: usize) -> Option { - let killer = self.killer_moves[index].load(Ordering::Relaxed); - if killer == NULL_KILLER { - None - } else { - Some(Move::new_raw(killer)) - } - } - - fn store_killer(&self, index: usize, mov: Move) { - self.killer_moves[index].store(mov.raw(), Ordering::Relaxed); + unsafe { *self.killer_moves[index].get() } } -} - -#[cfg(test)] -mod tests { - use std::sync::atomic::Ordering; - - use super::{SearchTable, TableEntry, TranspositionTable}; - use crate::{ - moves::Move, - position::{ - fen::{FromFen, START_FEN}, - square::Square, - Position, - }, - search::{ - table::{ScoreType, NULL_KILLER, NULL_TT_ENTRY}, - MAX_DEPTH, - }, - }; - - #[test] - fn entry_packing() { - let entry1 = TableEntry::new(100, ScoreType::Exact, Move::new_raw(0), MAX_DEPTH, 0, true); - let entry2 = TableEntry::new(100, ScoreType::Exact, Move::new_raw(0), 0, 1, false); - let entry3 = TableEntry::new(100, ScoreType::Exact, Move::new_raw(0), 1, 2, true); - let entry4 = TableEntry::new(100, ScoreType::Exact, Move::new_raw(0), 2, 3, false); - - assert_eq!(entry1.depth(), MAX_DEPTH); - assert_eq!(entry2.depth(), 0); - assert_eq!(entry3.depth(), 1); - assert_eq!(entry4.depth(), 2); - - assert_eq!(entry1.score, 100); - assert_eq!(entry1.score_type(), ScoreType::Exact); - - assert!(entry1.age()); - assert!(!entry2.age()); - assert!(entry3.age()); - assert!(!entry4.age()); - } - - #[test] - fn entry_transmutation() { - let entry1 = TableEntry::new(100, ScoreType::Exact, Move::new_raw(0), MAX_DEPTH, 0, true); - let entry2 = TableEntry::new(100, ScoreType::Exact, Move::new_raw(0), 0, 1, false); - - assert!(entry1.raw() != entry2.raw()); - - assert_eq!(TableEntry::from_raw(entry1.raw()), entry1); - assert_eq!(TableEntry::from_raw(entry2.raw()), entry2); - } - - #[test] - fn tt_raw_contents() { - let table = TranspositionTable::new(1); - let position = Position::from_fen(START_FEN).unwrap(); - - assert_eq!(table.data[0].load(Ordering::Relaxed), NULL_TT_ENTRY); - assert_eq!(table.get(&position), None); - - let first_move = Move::new(Square::E2, Square::E4, crate::moves::MoveFlag::DoublePawnPush); - let first_move_entry = - TableEntry::new(100, ScoreType::Exact, first_move, 2, position.zobrist_hash(), true); - - table.insert(&position, first_move_entry, false); - - assert_eq!( - table.data[position.zobrist_hash() as usize % table.data.len()].load(Ordering::Relaxed), - first_move_entry.raw() - ); - assert_eq!(table.get(&position).unwrap().best_move, first_move); - } - - #[test] - fn killers_raw_contents() { - let table = SearchTable::new(1); - - assert_eq!(table.killer_moves[0].load(Ordering::Relaxed), NULL_KILLER); - assert_eq!(table.killer_moves[1].load(Ordering::Relaxed), NULL_KILLER); - assert_eq!(table.get_killers(0), [None, None]); - - let first_move = Move::new(Square::E2, Square::E4, crate::moves::MoveFlag::DoublePawnPush); - let second_move = Move::new(Square::D2, Square::D4, crate::moves::MoveFlag::DoublePawnPush); - let third_move = Move::new(Square::C2, Square::C4, crate::moves::MoveFlag::DoublePawnPush); - - table.put_killer_move(0, first_move); - assert_eq!(table.killer_moves[0].load(Ordering::Relaxed), first_move.raw()); - assert_eq!(table.killer_moves[1].load(Ordering::Relaxed), NULL_KILLER); - assert_eq!(table.get_killers(0), [Some(first_move), None]); - - table.put_killer_move(0, second_move); - assert_eq!(table.killer_moves[0].load(Ordering::Relaxed), first_move.raw()); - assert_eq!(table.killer_moves[1].load(Ordering::Relaxed), second_move.raw()); - assert_eq!(table.get_killers(0), [Some(first_move), Some(second_move)]); - table.put_killer_move(0, third_move); - assert_eq!(table.killer_moves[0].load(Ordering::Relaxed), second_move.raw()); - assert_eq!(table.killer_moves[1].load(Ordering::Relaxed), third_move.raw()); - assert_eq!(table.get_killers(0), [Some(second_move), Some(third_move)]); + fn store_killer(&self, index: usize, mov: Option) { + unsafe { *self.killer_moves[index].get() = mov } } }