Skip to content

Commit

Permalink
search: Do not use singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
calcitem committed Jan 12, 2025
1 parent a054f8d commit 67ea4d4
Show file tree
Hide file tree
Showing 17 changed files with 129 additions and 138 deletions.
19 changes: 10 additions & 9 deletions src/engine_commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using std::string;

extern ThreadPool Threads;

namespace EngineCommands {

// FEN string of the initial position, normal mill game
Expand Down Expand Up @@ -78,26 +79,26 @@ void init_start_fen()
// go() is called when engine receives the "go" UCI command. The function sets
// the thinking time and other parameters from the input string, then starts
// the search.
void go(Position *pos)
void go(SearchEngine &searchEngine, Position *pos)
{
#ifdef UCI_AUTO_RE_GO
begin:
#endif

uint64_t localId = SearchEngine::getInstance().beginNewSearch(pos);
uint64_t localId = searchEngine.beginNewSearch(pos);

Threads.submit([]() { SearchEngine::getInstance().runSearch(); });
Threads.submit([&searchEngine]() { searchEngine.runSearch(); });

const auto limit_ms = gameOptions.getMoveTime() * 1000;

if (limit_ms > 0) {
std::thread([limit_ms, localId]() {
std::thread([&searchEngine, limit_ms, localId]() {
std::this_thread::sleep_for(std::chrono::milliseconds(limit_ms));

if (SearchEngine::getInstance().currentSearchId.load(
std::memory_order_relaxed) == localId) {
SearchEngine::getInstance().searchAborted.store(
true, std::memory_order_relaxed);
if (searchEngine.currentSearchId.load(std::memory_order_relaxed) ==
localId) {
searchEngine.searchAborted.store(true,
std::memory_order_relaxed);
}
}).detach();
}
Expand All @@ -108,7 +109,7 @@ void go(Position *pos)
Threads.stop_all();

Threads.set(1);
go(pos);
go(searchEngine, pos);
#else
return;
#endif
Expand Down
3 changes: 2 additions & 1 deletion src/engine_commands.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
#include <sstream>

class Position;
class SearchEngine;

namespace EngineCommands {

/// Handles the "go" UCI command to start the search.
void go(Position *pos);
void go(SearchEngine &searchEngine, Position *pos);

/// Handles the "position" UCI command to set up the board position.
void position(Position *pos, std::istringstream &is);
Expand Down
23 changes: 6 additions & 17 deletions src/engine_controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
#include "thread.h"
#include "search.h"
#include "misc.h"
#include "engine_commands.h" // Include for EngineCommands
#include "engine_commands.h"
#include "search_engine.h"

// Initialize the singleton instance
EngineController EngineController::instance;

EngineController::EngineController()
EngineController::EngineController(SearchEngine &searchEngine)
: searchEngine_(searchEngine)
{
// Constructor
}
Expand All @@ -25,11 +24,6 @@ EngineController::~EngineController()
// Destructor
}

EngineController &EngineController::getInstance()
{
return instance;
}

void EngineController::handleCommand(const std::string &cmd, Position *pos)
{
std::istringstream is(cmd);
Expand All @@ -38,10 +32,9 @@ void EngineController::handleCommand(const std::string &cmd, Position *pos)

if (token == "go") {
searchPos = *pos;
EngineCommands::go(&searchPos); // Call the EngineCommands::go function
EngineCommands::go(searchEngine_, &searchPos);
} else if (token == "position") {
EngineCommands::position(pos, is); // Call the EngineCommands::position
// function
EngineCommands::position(pos, is);
} else if (token == "ucinewgame") {
Search::clear(); // Clear the search state for a new game
// Additional custom non-UCI commands, mainly for debugging.
Expand All @@ -54,10 +47,6 @@ void EngineController::handleCommand(const std::string &cmd, Position *pos)
sync_cout << compiler_info() << sync_endl;
} else {
// Handle additional custom commands if necessary
// For example:
// else if (token == "customcommand") { ... }
// Currently, unknown commands are handled in UCI::loop, so you might
// not need to do anything here.
sync_cout << "Unknown command in EngineController: " << cmd
<< sync_endl;
}
Expand Down
12 changes: 5 additions & 7 deletions src/engine_controller.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,27 @@
#include <string>
#include "position.h"

class SearchEngine;

/// EngineController is responsible for handling commands from UciServer (or
/// UCI::loop).
class EngineController
{
public:
EngineController();
EngineController(SearchEngine &searchEngine);
~EngineController();

/// The main entry to handle a command.
/// We pass in the raw command string and a Position pointer
/// so we can call existing logic (like go(pos), position(pos, is)).
void handleCommand(const std::string &cmd, Position *pos);

/// Returns the singleton instance of EngineController
static EngineController &getInstance();

private:
// Singleton instance
static EngineController instance;

// Internal position
Position searchPos;

SearchEngine &searchEngine_;

// If needed, we could store references to Options, or keep an internal
// Position.
};
Expand Down
9 changes: 6 additions & 3 deletions src/mcts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
#include "option.h"
#include "position.h"
#include "search.h"
#include "search_engine.h"
#include "types.h"
#include "uci.h"

using namespace std;

static SearchEngine searchEngine;

class ThreadSafeNodeVisits
{
public:
Expand Down Expand Up @@ -191,9 +194,9 @@ bool simulate(Node *node, Sanmill::Stack<Position> &ss)

Move bestMove {MOVE_NONE};

Value value = Search::search(pos, ss, ALPHA_BETA_DEPTH, ALPHA_BETA_DEPTH,
-VALUE_INFINITE, VALUE_INFINITE, bestMove);

Value value = Search::search(searchEngine, pos, ss, ALPHA_BETA_DEPTH,
ALPHA_BETA_DEPTH, -VALUE_INFINITE,
VALUE_INFINITE, bestMove);
return value > 0;
}

Expand Down
3 changes: 0 additions & 3 deletions src/position.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,6 @@ bool Position::reset()

record[0] = '\0';

SearchEngine::getInstance().searchAborted.store(false,
std::memory_order_relaxed);

return true;
}

Expand Down
71 changes: 37 additions & 34 deletions src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ void Search::clear()
vector<Key> posKeyHistory;

// Quiescence Search
Value Search::qsearch(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
Value Search::qsearch(SearchEngine &searchEngine, Position *pos,
Sanmill::Stack<Position> &ss, Depth depth,
Depth originDepth, Value alpha, Value beta,
Move &bestMove)
{
Expand Down Expand Up @@ -91,10 +92,10 @@ Value Search::qsearch(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,

// Recursively call qsearch
Value value = (after != before) ?
-qsearch(pos, ss, depth - 1, originDepth, -beta,
-alpha, bestMove) :
qsearch(pos, ss, depth - 1, originDepth, alpha, beta,
bestMove);
-qsearch(searchEngine, pos, ss, depth - 1,
originDepth, -beta, -alpha, bestMove) :
qsearch(searchEngine, pos, ss, depth - 1, originDepth,
alpha, beta, bestMove);

// Undo the move
pos->undo_move(ss);
Expand All @@ -112,8 +113,7 @@ Value Search::qsearch(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
}
}

if (SearchEngine().getInstance().searchAborted.load(
std::memory_order_relaxed)) {
if (searchEngine.searchAborted.load(std::memory_order_relaxed)) {
return alpha;
}
}
Expand All @@ -123,15 +123,15 @@ Value Search::qsearch(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
}

/// Search function that performs recursive search with alpha-beta pruning
Value Search::search(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
Value Search::search(SearchEngine &searchEngine, Position *pos,
Sanmill::Stack<Position> &ss, Depth depth,
Depth originDepth, Value alpha, Value beta, Move &bestMove)
{
Value bestValue = -VALUE_INFINITE;

// Check for terminal position or search abortion
if (unlikely(pos->phase == Phase::gameOver) ||
SearchEngine().getInstance().searchAborted.load(
std::memory_order_relaxed)) {
searchEngine.searchAborted.load(std::memory_order_relaxed)) {
bestValue = Eval::evaluate(*pos);

// Adjust evaluation to prefer quicker wins or slower losses
Expand All @@ -146,7 +146,8 @@ Value Search::search(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,

if (depth <= 0) {
// Call quiescence search when depth limit is reached
return qsearch(pos, ss, depth, originDepth, alpha, beta, bestMove);
return qsearch(searchEngine, pos, ss, depth, originDepth, alpha, beta,
bestMove);
}

#ifdef RULE_50
Expand Down Expand Up @@ -292,10 +293,10 @@ Value Search::search(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,

// Perform recursive search
value = (after != before) ?
-search(pos, ss, depth - 1 + epsilon, originDepth, -beta,
-alpha, bestMove) :
search(pos, ss, depth - 1 + epsilon, originDepth, alpha,
beta, bestMove);
-search(searchEngine, pos, ss, depth - 1 + epsilon,
originDepth, -beta, -alpha, bestMove) :
search(searchEngine, pos, ss, depth - 1 + epsilon,
originDepth, alpha, beta, bestMove);

// Undo the move
pos->undo_move(ss);
Expand All @@ -320,8 +321,7 @@ Value Search::search(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
}

// Check for search abortion
if (SearchEngine().getInstance().searchAborted.load(
std::memory_order_relaxed)) {
if (searchEngine.searchAborted.load(std::memory_order_relaxed)) {
return bestValue;
}
}
Expand Down Expand Up @@ -350,9 +350,9 @@ Value Search::search(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
}

/// MTDF function implementing the MTD(f) search algorithm
Value Search::MTDF(Position *pos, Sanmill::Stack<Position> &ss,
Value firstguess, Depth depth, Depth originDepth,
Move &bestMove)
Value Search::MTDF(SearchEngine &searchEngine, Position *pos,
Sanmill::Stack<Position> &ss, Value firstguess, Depth depth,
Depth originDepth, Move &bestMove)
{
Value g = firstguess;
Value lowerbound = -VALUE_INFINITE;
Expand All @@ -366,8 +366,8 @@ Value Search::MTDF(Position *pos, Sanmill::Stack<Position> &ss,
beta = g;
}

g = search(pos, ss, depth, originDepth, beta - VALUE_MTDF_WINDOW, beta,
bestMove);
g = search(searchEngine, pos, ss, depth, originDepth,
beta - VALUE_MTDF_WINDOW, beta, bestMove);

if (g < beta) {
upperbound = g; // Fail low
Expand All @@ -380,32 +380,35 @@ Value Search::MTDF(Position *pos, Sanmill::Stack<Position> &ss,
}

/// Function that performs Principal Variation Search (PVS)
Value Search::pvs(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
Depth originDepth, Value alpha, Value beta, Move &bestMove,
int i, const Color before, const Color after)
Value Search::pvs(SearchEngine &searchEngine, Position *pos,
Sanmill::Stack<Position> &ss, Depth depth, Depth originDepth,
Value alpha, Value beta, Move &bestMove, int i,
const Color before, const Color after)
{
Value value;

if (i == 0) {
// First move: full window search
value = (after != before) ?
-search(pos, ss, depth, originDepth, -beta, -alpha,
bestMove) :
search(pos, ss, depth, originDepth, alpha, beta, bestMove);
-search(searchEngine, pos, ss, depth, originDepth, -beta,
-alpha, bestMove) :
search(searchEngine, pos, ss, depth, originDepth, alpha,
beta, bestMove);
} else {
// Subsequent moves: null window search (PVS)
value = (after != before) ?
-search(pos, ss, depth, originDepth,
-search(searchEngine, pos, ss, depth, originDepth,
-alpha - VALUE_PVS_WINDOW, -alpha, bestMove) :
search(pos, ss, depth, originDepth, alpha,
search(searchEngine, pos, ss, depth, originDepth, alpha,
alpha + VALUE_PVS_WINDOW, bestMove);

// Re-search if the value is within the search window
if (value > alpha && value < beta) {
value = (after != before) ? -search(pos, ss, depth, originDepth,
-beta, -alpha, bestMove) :
search(pos, ss, depth, originDepth,
alpha, beta, bestMove);
value = (after != before) ?
-search(searchEngine, pos, ss, depth, originDepth,
-beta, -alpha, bestMove) :
search(searchEngine, pos, ss, depth, originDepth, alpha,
beta, bestMove);
}
}

Expand Down
24 changes: 15 additions & 9 deletions src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "stopwatch.h"
#endif

class SearchEngine;

using std::vector;

namespace Search {
Expand All @@ -23,21 +25,25 @@ void init() noexcept;
void clear();

// Search algorithms
Value MTDF(Position *pos, Sanmill::Stack<Position> &ss, Value firstguess,
Depth depth, Depth originDepth, Move &bestMove);
Value MTDF(SearchEngine &searchEngine, Position *pos,
Sanmill::Stack<Position> &ss, Value firstguess, Depth depth,
Depth originDepth, Move &bestMove);

Value pvs(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
Depth originDepth, Value alpha, Value beta, Move &bestMove, int i,
const Color before, const Color after);
Value pvs(SearchEngine &searchEngine, Position *pos,
Sanmill::Stack<Position> &ss, Depth depth, Depth originDepth,
Value alpha, Value beta, Move &bestMove, int i, const Color before,
const Color after);

Value search(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
Depth originDepth, Value alpha, Value beta, Move &bestMove);
Value search(SearchEngine &searchEngine, Position *pos,
Sanmill::Stack<Position> &ss, Depth depth, Depth originDepth,
Value alpha, Value beta, Move &bestMove);

Value random_search(Position *pos, Move &bestMove);

// Quiescence Search
Value qsearch(Position *pos, Sanmill::Stack<Position> &ss, Depth depth,
Depth originDepth, Value alpha, Value beta, Move &bestMove);
Value qsearch(SearchEngine &searchEngine, Position *pos,
Sanmill::Stack<Position> &ss, Depth depth, Depth originDepth,
Value alpha, Value beta, Move &bestMove);

} // namespace Search

Expand Down
Loading

0 comments on commit 67ea4d4

Please sign in to comment.