Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add age to transposition table replacement scheme #23

Merged
merged 9 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public SearchResult search(Duration duration) {
.toList();
SearchResult result = selectResult(threads).get();
threads.forEach(thread -> thread.cancel(true));
transpositionTable.incrementGeneration();
return result;
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
Expand Down
131 changes: 111 additions & 20 deletions src/main/java/com/kelseyde/calvin/transposition/HashEntry.java
Original file line number Diff line number Diff line change
@@ -1,48 +1,139 @@
package com.kelseyde.calvin.transposition;

import com.kelseyde.calvin.board.Move;
import lombok.AllArgsConstructor;

/**
* Individual entry in the transposition table containing the 64-bit zobrist key, and a 64-bit encoding of the score,
* move, flag and depth:
* - score: 32 bits (-1000000-1000000, capturing negative -> positive checkmate score)
* - move: 16 bits (0-5 = start square, 6-11 = end square, 12-15 = special move flag, see {@link Move})
* - flag: 4 bits (0-2, capturing three possible flag values + 1 bit padding)
* - depth: 12 bits (0-265, max depth = 256 = 8 bits + 4 bit padding)
* Entry in the {@link TranspositionTable}. Contains a 64-bit key and a 64-bit value which encodes the relevant
* information about the position.
* </p>
*
* Key encoding:
* 0-47: 48 bits representing three-quarters of the zobrist hash. Used to verify that the position truly matches.
* 48-63: 16 bits representing the generation of the entry, i.e. how old it is. Used to gradually replace old entries.
* </p>
*
* Value encoding:
* 0-11: the depth to which this position was last searched.
* 12-15: the {@link HashFlag} indicating what type of node this is.
* 16-31: the {@link Move} start square, end square, and special move flag.
* 32-63: the eval of the position in centipawns.
*/
public record HashEntry(long key, long value) {
@AllArgsConstructor
public class HashEntry {

private static final long CLEAR_SCORE_MASK = 0xffffffffL;
private static final long ZOBRIST_PART_MASK = 0x0000ffffffffffffL;
private static final long GENERATION_MASK = 0xffff000000000000L;
private static final long SCORE_MASK = 0xffffffff00000000L;
private static final long MOVE_MASK = 0x00000000ffff0000L;
private static final long FLAG_MASK = 0x000000000000f000L;
private static final long DEPTH_MASK = 0x0000000000000fffL;

private long key;
private long value;

/**
* Extracts the 48-bits representing the zobrist part of the given zobrist key.
*/
public static long zobristPart(long zobrist) {
return zobrist & ZOBRIST_PART_MASK;
}

/**
* Returns the 48-bits representing zobrist part of the hash entry key.
*/
public long getZobristPart() {
return key & ZOBRIST_PART_MASK;
}

/**
* Gets the generation part of this entry's key.
*/
public int getGeneration() {
return (int) ((key & GENERATION_MASK) >>> 48);
}

/**
* Sets the generation part of this entry's key.
*/
public void setGeneration(int generation) {
key = (key &~ GENERATION_MASK) | (long) generation << 48;
}

/**
* Gets the score from this entry's value.
*/
public int getScore() {
long score = value >>> 32;
long score = (value & SCORE_MASK) >>> 32;
return (int) score;
}

/**
* Sets the score in this entry's value.
*/
public void setScore(int score) {
value = (value &~ SCORE_MASK) | (long) score << 32;
}

/**
* Creates a new {@link HashEntry} with the adjusted score.
*/
public HashEntry withAdjustedScore(int score) {
long newValue = (value &~ SCORE_MASK) | (long) score << 32;
return new HashEntry(key, newValue);
}

/**
* Sets the move in this entry's value.
*/
public void setMove(Move move) {
value = (value &~ MOVE_MASK) | (long) move.value() << 16;
}

/**
* Gets the move from this entry's value.
*/
public Move getMove() {
long move = (value >> 16) & 0xffff;
long move = (value & MOVE_MASK) >>> 16;
return move > 0 ? new Move((short) move) : null;
}

/**
* Gets the flag from this entry's value.
*/
public HashFlag getFlag() {
long flag = (value >>> 12) & 0xf;
long flag = (value & FLAG_MASK) >>> 12;
return HashFlag.valueOf((int) flag);
}

/**
* Gets the depth from this entry's value.
*/
public int getDepth() {
return (int) value & 0xfff;
return (int) (value & DEPTH_MASK);
}

public static HashEntry of(long zobristKey, int score, Move move, HashFlag flag, int depth) {
/**
* Creates a new {@link HashEntry} with the specified parameters.
*
* @param zobristKey the Zobrist key
* @param score the score
* @param move the move
* @param flag the flag
* @param depth the depth
* @param generation the generation
* @return a new {@link HashEntry}
*/
public static HashEntry of(long zobristKey, int score, Move move, HashFlag flag, int depth, int generation) {
// Build the key using 48 bits for the zobrist part and 16 bits for the generation part.
long key = (zobristKey & ZOBRIST_PART_MASK) | (long) generation << 48;
// Get the 16-bit encoded move
long moveValue = move != null ? move.value() : 0;
// Get the 3-bit encoded flag
long flagValue = HashFlag.value(flag);
// Combine the score, move, flag and depth to create the hash entry value
long value = (long) score << 32 | moveValue << 16 | flagValue << 12 | depth;
return new HashEntry(zobristKey, value);
}

public static HashEntry withScore(HashEntry entry, int score) {
long value = (entry.value() & CLEAR_SCORE_MASK) | (long) score << 32;
return new HashEntry(entry.key(), value);
return new HashEntry(key, value);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,44 +29,99 @@ public class TranspositionTable {

int tries;
int hits;
int generation;

/**
* Constructs a transposition table of the given size in megabytes.
*/
public TranspositionTable(int tableSizeMb) {
this.tableSize = (tableSizeMb * 1024 * 1024) / ENTRY_SIZE_BYTES;
entries = new HashEntry[tableSize];
tries = 0;
hits = 0;
generation = 0;
}

/**
* Retrieves an entry from the transposition table using the given zobrist key.
*
* @param zobristKey the zobrist key of the position.
* @param ply the current ply in the search (used to adjust mate scores).
*/
public HashEntry get(long zobristKey, int ply) {
int index = getIndex(zobristKey);
long zobristPart = HashEntry.zobristPart(zobristKey);
tries++;
for (int i = 0; i < 4; i++) {
HashEntry entry = entries[index + i];
if (entry != null && entry.key() == zobristKey) {
if (entry != null && entry.getZobristPart() == zobristPart) {
hits++;
entry.setGeneration(generation);
if (isMateScore(entry.getScore())) {
int score = retrieveMateScore(entry.getScore(), ply);
entry = HashEntry.withScore(entry, score);
return entry.withAdjustedScore(score);
}
return entry;
}
}
return null;
}

/**
* Puts an entry into the transposition table.
* </p>
* The transposition table is separated into buckets of 4 entries each. This method uses a replacement scheme that
* prefers to replace the least-valuable entry among the 4 candidates in the bucket. The order of preference
* for replacement is:
* <ol>
* <li>An empty entry.</li>
* <li>An entry with the same zobrist key and a depth less than or equal to the new entry.</li>
* <li>The oldest entry in the bucket, stored further back in the game and so less likely to be relevant.</li>
* <li>The entry with the lowest depth.</li>
* </ol>
*
* @param zobristKey the zobrist key of the position.
* @param flag the flag indicating the type of node (e.g., exact, upper bound, lower bound).
* @param depth the search depth of the entry.
* @param ply the current ply from root in the search.
* @param move the best move found at this position.
* @param score the score of the position.
*/
public void put(long zobristKey, HashFlag flag, int depth, int ply, Move move, int score) {

// Get the start index of the 4-item bucket.
int startIndex = getIndex(zobristKey);

// Get the 48 bits of the zobrist used to verify the signature of the bucket entry
long zobristPart = HashEntry.zobristPart(zobristKey);

// If the eval is checkmate, adjust the score to reflect the number of ply from the root position
if (isMateScore(score)) score = calculateMateScore(score, ply);
HashEntry newEntry = HashEntry.of(zobristKey, score, move, flag, depth);

// Construct the new entry to store in the hash table.
HashEntry newEntry = HashEntry.of(zobristKey, score, move, flag, depth, generation);

int replacedIndex = -1;
int minDepth = Integer.MAX_VALUE;
boolean replacedByAge = false;

// Iterate over the four items in the bucket
for (int i = startIndex; i < startIndex + 4; i++) {
HashEntry storedEntry = entries[i];

if (storedEntry == null || storedEntry.key() == zobristKey) {
if (storedEntry == null || depth >= storedEntry.getDepth()) {
if (newEntry.getMove() == null && storedEntry != null && storedEntry.getMove() != null) {
newEntry = HashEntry.of(newEntry.key(), newEntry.getScore(), storedEntry.getMove(), newEntry.getFlag(), newEntry.getDepth());
// First, always prefer an empty slot if it is available.
if (storedEntry == null) {
replacedIndex = i;
break;
}

// Then, if the stored entry matches the zobrist key and the depth is >= the stored depth, replace it.
// If the depth is < the store depth, don't replace it and exit (although this should never happen).
if (storedEntry.getZobristPart() == zobristPart) {
if (depth >= storedEntry.getDepth()) {
// If the stored entry has a recorded best move but the new entry does not, use the stored one.
if (newEntry.getMove() == null && storedEntry.getMove() != null) {
newEntry.setMove(storedEntry.getMove());
}
replacedIndex = i;
break;
Expand All @@ -75,45 +130,99 @@ public void put(long zobristKey, HashFlag flag, int depth, int ply, Move move, i
}
}

if (storedEntry.getDepth() < minDepth) {
// Next, prefer to replace entries from earlier on in the game, since they are now less likely to be relevant.
if (newEntry.getGeneration() > storedEntry.getGeneration()) {
replacedByAge = true;
replacedIndex = i;
}

// Finally, just replace the entry with the shallowest search depth.
if (!replacedByAge && storedEntry.getDepth() < minDepth) {
minDepth = storedEntry.getDepth();
replacedIndex = i;
}

}

// Store the new entry in the table at the chosen index.
if (replacedIndex != -1) {
entries[replacedIndex] = newEntry;
}
}

/**
* Increments the generation counter for the transposition table.
*/
public void incrementGeneration() {
generation++;
}

/**
* Clears the transposition table, resetting all entries and statistics.
*/
public void clear() {
printStatistics();
tries = 0;
hits = 0;
generation = 0;
entries = new HashEntry[tableSize];
}

/**
* Compresses the 64-bit zobrist key into a 32-bit key, to be used as an index in the hash table.
*
* @param zobristKey the zobrist key of the position.
* @return a compressed 32-bit index.
*/
private int getIndex(long zobristKey) {
// XOR the upper and lower halves of the zobrist key together, producing a pseudo-random 32-bit result.
// Then apply a mask ensuring the number is always positive, since it is to be used as an array index.
long index = (zobristKey ^ (zobristKey >>> 32)) & 0x7FFFFFFF;
// Modulo the result with the number of entries in the table, and align it with a multiple of 4,
// ensuring the entries are always divided into 4-sized buckets.
return (int) (index % (tableSize - 3)) & ~3;
}

/**
* Checks if the given score is a mate score.
*
* @param score the score to check.
* @return {@code true} if the score is a mate score, {@code false} otherwise.
*/
private boolean isMateScore(int score) {
return Math.abs(score) >= CHECKMATE_BOUND;
}

/**
* Calculates the mate score, adjusting it based on the ply from the root.
*
* @param score the score to adjust.
* @param plyFromRoot the ply from the root.
* @return the adjusted mate score.
*/
private int calculateMateScore(int score, int plyFromRoot) {
return score > 0 ? score - plyFromRoot : score + plyFromRoot;
}

/**
* Retrieves the mate score, adjusting it based on the ply from the root.
*
* @param score the score to adjust.
* @param plyFromRoot the ply from the root.
* @return the adjusted mate score.
*/
private int retrieveMateScore(int score, int plyFromRoot) {
return score > 0 ? score + plyFromRoot : score - plyFromRoot;
}

/**
* Prints the statistics of the transposition table.
*/
public void printStatistics() {
long fill = Arrays.stream(entries).filter(Objects::nonNull).count();
float fillPercentage = ((float) fill / tableSize) * 100;
float hitPercentage = ((float) hits / tries) * 100;
//System.out.printf("TT %s -- size: %s / %s (%s), tries: %s, hits: %s (%s)%n", this.hashCode(), fill, entries.length, fillPercentage, tries, hits, hitPercentage);
// System.out.printf("TT %s -- size: %s / %s (%s), tries: %s, hits: %s (%s)%n", this.hashCode(), fill, entries.length, fillPercentage, tries, hits, hitPercentage);
}

}
Loading