Skip to content

Commit

Permalink
- PolicyBook V1.02対応
Browse files Browse the repository at this point in the history
  • Loading branch information
yaneurao committed Dec 21, 2024
1 parent 9c18e28 commit 18441f8
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 34 deletions.
68 changes: 55 additions & 13 deletions source/book/policybook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "../thread.h"
#include "../usi.h"
#include "../book/book.h"
#include "../eval/deep/nn.h"

// freqの和がUINT16_MAXに収まるようにする。
u16 MoveFreq32Record::overflow_check()
Expand Down Expand Up @@ -37,6 +38,7 @@ void PolicyBookEntry::from_move_freq32rec(const MoveFreq32Record& mf32r)
move_freq[i].move16 = mf32r.move_freq32[i].move16 ;
move_freq[i].freq = u16(mf32r.move_freq32[i].freq );
}
value = mf32r.value;
}


Expand All @@ -59,7 +61,7 @@ Tools::Result PolicyBook::read_book_db(std::string path)
Position pos;
StateInfo si;
reader.ReadLine(sfen);
if (sfen != "#YANEURAOU-POLICY-DB2024 1.00")
if (sfen != POLICY_BOOK_HEADER)
{
sync_cout << "info string Error! invalid policy book header" << sync_endl;
return Tools::ResultCode::FileMismatch;
Expand All @@ -73,24 +75,46 @@ Tools::Result PolicyBook::read_book_db(std::string path)
reader.ReadLine(moves_str);
auto moves = StringExtension::split(moves_str);

// 末尾判定が面倒なので、delimiterをセットしておく。
moves.emplace_back("");
moves.emplace_back("");

// 7g7f 123 のように指し手と出現頻度が書いてあるので正規化する。
// 7g7f 123 value 0.50
// のように、その局面での手番側から見たvalueを付与することができる。
// 7g7f 123 eval 400
// のようにevalを付与することもできる。この場合、eval_coef=800として勝率に変換される。
// いまどきのソフトは800より小さいのだが、ここをあえて大きな値にすることで、valueとして
// (0.5からの乖離が)大きめの値にする。

// 出現頻度を保管しておく。
MoveFreq32Record mf32r;

for (size_t i = 0 , j = 0 ; i < POLICY_BOOK_NUM; ++i)
for (size_t i = 0 , j = 0 ; ; i += 2)
{
if (i * 2 + 1 < moves.size())
{
Move16 move16 = USI::to_move16(moves[i * 2 + 0]);
auto& movestr = moves[i];
if (movestr == "")
break;
if (movestr == "value")
mf32r.value = StringExtension::to_float(moves[i + 1], FLT_MAX);
else if (movestr == "eval")
mf32r.value = Eval::dlshogi::cp_to_value(StringExtension::to_int(moves[i + 1], 0), 800);
else {
// もうお腹いっぱい。ここPOLICY_BOOK_NUM以上
// 書いてもいいことにはなっているのでこの判定が必要。
// まだこのあとvalueかevalが来ることはあるのでループは続行する。
if (j >= POLICY_BOOK_NUM)
continue;

Move16 move16 = USI::to_move16(movestr);

// 不成の指し手があるなら、それを無視する。
// (これを計算に入れてしまうと、Policyの合計が100%にならなくなる)
if (!pos.pseudo_legal_s<false>(pos.to_move(move16)))
continue;

mf32r.move_freq32[j].move16 = move16;
mf32r.move_freq32[j].freq = StringExtension::to_int(moves[i * 2 + 1], 1);
mf32r.move_freq32[j].freq = StringExtension::to_int(moves[i + 1], 1);
j++;
}
}
Expand All @@ -113,7 +137,10 @@ Tools::Result PolicyBook::read_book_db(std::string path)
sync_cout << "info string read " << counter << "..done." << sync_endl;

// sortしないと二分探索できない。
sort_book();
// sort_book();

// 重複レコードがあるかも知れないのでgarbageしておく。
garbage_book();

/*
// テストコード
Expand Down Expand Up @@ -238,6 +265,16 @@ void PolicyBook::merge_book(const PolicyBook& book)
// 連結
book_body.insert(book_body.end(), book.book_body.begin(), book.book_body.end());

// お掃除
garbage_book();

// 最終的なレコード数を出力する。
sync_cout << "..done. " << book_body.size() << " records." << sync_endl;
}

// PolicyBookの重複レコードなどを掃除する。
void PolicyBook::garbage_book()
{
// sort
std::sort(book_body.begin(), book_body.end(), [](PolicyBookEntry& x, PolicyBookEntry& y)
{ return x.key < y.key; });
Expand All @@ -256,11 +293,17 @@ void PolicyBook::merge_book(const PolicyBook& book)
HASH_KEY key = book_body[read_cursor].key;
// ⇨ このkeyと一致するところを集計して一つにまとめる。

for ( ; read_cursor < book_body.size() && key == book_body[read_cursor].key; ++read_cursor)
// これも集計しないといけない。
float value = FLT_MAX;

for (; read_cursor < book_body.size() && key == book_body[read_cursor].key; ++read_cursor)
{
auto& mf = book_body[read_cursor].move_freq;
for (int k = 0; k < POLICY_BOOK_NUM && mf[k].move16 != Move16::none(); ++k)
counter[mf[k].move16] += mf[k].freq;
float v = book_body[read_cursor].value;
if (v != FLT_MAX)
value = v;
}
// summarize終わったので、book_body[i]に反映させる。

Expand All @@ -275,6 +318,7 @@ void PolicyBook::merge_book(const PolicyBook& book)

// book_body[write_cursor]に反映。(POLICY_BOOK_NUM個、entryを埋めるのを忘れずに)
book_body[write_cursor].key = key;
book_body[write_cursor].value = value;

MoveFreq32Record mf32r;
for (size_t k = 0; k < POLICY_BOOK_NUM; ++k)
Expand All @@ -284,7 +328,7 @@ void PolicyBook::merge_book(const PolicyBook& book)
mf32r.overflow_check();

book_body[write_cursor].from_move_freq32rec(mf32r);

// read_cursorは、keyが一致しないところ(今回集計していないところ)まで進んだ。
}
else {
Expand All @@ -298,9 +342,6 @@ void PolicyBook::merge_book(const PolicyBook& book)
}
// write_cursor - 1 までが有効なデータなので切り詰める。
book_body.resize(write_cursor);

// 最終的なレコード数を出力する。
sync_cout << "..done. " << book_body.size() << " records." << sync_endl;
}


Expand Down Expand Up @@ -349,7 +390,8 @@ void PolicyBook::append_sfen_to_db_bin(const std::string& sfen)

BookTools::feed_position_string(pos, sfen, si, [&](Position& p, Move m) {
// 最後の局面は、m==Move::none()が入ってくる。
if (m == Move::none())
// また、不成の指し手は無視しないと対局時のPolicyの確率の合計が100%にならなくなる。
if (m == Move::none() || !pos.pseudo_legal_s<false>(m))
return;
PolicyBookEntry entry;
entry.key = p.hash_key();
Expand Down
16 changes: 14 additions & 2 deletions source/book/policybook.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ static_assert(HASH_KEY_BITS == 128 , "HASH_KEY_BITS must be 128");
#define POLICY_BOOK_DB_BIN_NAME "eval/policy_book.db.bin"
#define POLICY_BOOK_LEARN_DB_BIN_NAME "eval/policy_book-learn.db.bin"

// POLICY_BOOK_DB_NAME のDBファイルの先頭行。
#define POLICY_BOOK_HEADER "#YANEURAOU-POLICY-DB2024 1.01"

// ============================================================
// Policy Book
// ============================================================
Expand Down Expand Up @@ -43,13 +46,16 @@ struct MoveFreq32
// → これ用意すると、overflowしている値が代入されてしまう。
};

constexpr int POLICY_BOOK_NUM = 4;
constexpr int POLICY_BOOK_NUM = 3;

// MoveFreq32がPOLICY_BOOK_NUMだけある構造体(内部で集計にのみ使用する)
struct alignas(16) MoveFreq32Record
{
MoveFreq32 move_freq32[POLICY_BOOK_NUM];

// この局面でのvalue。FLT_MAXなら、不明。
float value = FLT_MAX;

// freqの和がUINT16_MAXに収まるようにする。
// 返し値 : freqの和
u16 overflow_check();
Expand All @@ -61,6 +67,9 @@ struct alignas(32) PolicyBookEntry
HASH_KEY key;
MoveFreq move_freq[POLICY_BOOK_NUM];

// この局面でのvalue。FLT_MAXなら、不明。
float value = FLT_MAX;

// MoveFreq32Record構造体の値をこの構造体のmove_freqに代入する。
void from_move_freq32rec(const MoveFreq32Record& mf32r);
};
Expand All @@ -83,7 +92,10 @@ class PolicyBook
Tools::Result write_book_db_bin(std::string path = POLICY_BOOK_DB_BIN_NAME);

// PolicyBook同士のmerge
void PolicyBook::merge_book(const PolicyBook& book);
void merge_book(const PolicyBook& book);

// PolicyBookの重複レコードなどを掃除する。
void garbage_book();

// "position "コマンドのposition以降の文字列を渡して、それを
// POLICY_BOOK_LEARN_DB_BIN_NAMEにappendで書き出す。
Expand Down
17 changes: 15 additions & 2 deletions source/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,14 @@
// ※ 次の子nodeに行くときに必ずevaluate()を呼び出さないといけないタイプの評価関数。
// #define USE_DIFF_EVAL

// PolicyBookを使うのか?
// TODO : ⇨ PolicyBookについて、記事を書く。
// #define USE_POLICY_BOOK

// PolicyBookの局後学習を有効化するのか?
// TODO : ⇨ PolicyBookの局後学習について、記事を書く。
// #define ENABLE_POLICY_BOOK_LEARN

// ===============================================================
// ここ以降では、↑↑↑で設定した内容に基づき必要なdefineを行う。
// ===============================================================
Expand Down Expand Up @@ -706,6 +714,12 @@ constexpr bool pretty_jp = true;
constexpr bool pretty_jp = false;
#endif

// --- PolicyBook

// PolicyBookを使うときは、hash keyを128bitにする。局面のhash keyが衝突してしまうとまずいので…。
#if defined(USE_POLICY_BOOK)
#define HASH_KEY_BITS 128
#endif

// --- hash key bits and TT_CLUSTER_SIZE

Expand Down Expand Up @@ -891,5 +905,4 @@ constexpr bool pretty_jp = false;
#define ADD_BOARD_EFFECT_REWIND(color_,sq_,e1_) { board_effect[color_].e[sq_] += (uint8_t)e1_; }
#define ADD_BOARD_EFFECT_BOTH_REWIND(color_,sq_,e1_,e2_) { board_effect[color_].e[sq_] += (uint8_t)e1_; board_effect[~color_].e[sq_] += (uint8_t)e2_; }

#endif // ifndef _CONFIG_H_INCLUDED

#endif // if !defined(CONFIG_H_INCLUDED)
5 changes: 5 additions & 0 deletions source/engine/dlshogi-engine/Node.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ namespace dlshogi
// 展開した子ノード以外はnullptrのまま。
std::unique_ptr<std::unique_ptr<Node>[]> child_nodes;

#if defined(USE_POLICY_BOOK)
// PolicyBookから与えられたvalue
// なければ FLT_MAX
float policy_book_value = FLT_MAX;
#endif

// 詰み関連のフラグ
bool dfpn_checked; // df-pn調べ済み
Expand Down
14 changes: 13 additions & 1 deletion source/engine/dlshogi-engine/UctSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,13 @@ namespace dlshogi
//
float UctSearcher::UctSearch(Position* pos, ChildNode* parent , Node* current, NodeVisitor& visitor)
{
#if defined(USE_POLICY_BOOK)
// PolicyBookからvalueを与えられていたら、それをそのまま返す。
// ただし、root局面では、いまから探索しないといけないので、それはしない。
if (parent!=nullptr && current->policy_book_value != FLT_MAX)
return 1.0f - current->policy_book_value;
#endif

auto ds = grp->get_dlsearcher();
auto& options = ds->search_options;

Expand Down Expand Up @@ -983,6 +990,11 @@ namespace dlshogi
} else {
// Policy Bookに従う。

// 評価値の書かれている局面であるか?
float v = policy_book_entry->value;
if (v != FLT_MAX)
node->policy_book_value = v;

u32 total = 0;
size_t k1;
for (k1 = 0; k1 < POLICY_BOOK_NUM; ++k1)
Expand All @@ -1006,7 +1018,7 @@ namespace dlshogi
// 注意 : 電竜戦の開始4手の玉の屈伸の棋譜を利用したときに、あれをPolicyとされてしまうと困る。
// (PolicyBookを作るときに除外する必要がある)

float book_policy_ratio = 0.7f + 0.1f * std::clamp(0.0f, log10f(float(total)), 3.0f);
float book_policy_ratio = 0.7f + 0.1f * std::clamp(log10f(float(total)), 0.0f, 3.0f);

for (ChildNumType j = 0; j < child_num; j++) {

Expand Down
6 changes: 3 additions & 3 deletions source/engine/dlshogi-engine/UctSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ namespace dlshogi
// EvalNode()ごとにどのNodeとColorから呼び出されたのかを記録しておく構造体。
// NNから返し値がもらえた時に、ここに記録されているNodeについて、その情報を更新する。
struct BatchElement {
Node* node; // どのNodeに対するEvalNode()なのか。
Color color; // その時の手番
Node* node; // どのNodeに対するEvalNode()なのか。
Color color; // その時の手番

#if defined(USE_POLICY_BOOK)
HASH_KEY key; // この局面のhash key
HASH_KEY key; // この局面のhash key
#endif

// 通常の探索では、このポインターはNodeVisitor::value_win を指している。
Expand Down
13 changes: 0 additions & 13 deletions source/engine/dlshogi-engine/dlshogi_min.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,6 @@ namespace dlshogi {
u64 nodes_visited();
}

namespace Eval::dlshogi {

// 価値(勝率)を評価値[cp]に変換。
// USIではcp(centi-pawn)でやりとりするので、そのための変換に必要。
// eval_coef : 勝率を評価値に変換する時の定数。default = 756
//
// 返し値 :
// +29900は、評価値の最大値
// -29900は、評価値の最小値
// +30000,-30000は、(おそらく)詰みのスコア
Value value_to_cp(const float score, float eval_coef);
}

#endif // defined(YANEURAOU_ENGINE_DEEP)

#endif // ifndef __DLSHOGI_MIN_H_INCLUDED__
9 changes: 9 additions & 0 deletions source/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1995,6 +1995,15 @@ namespace StringExtension
return result;
}

// 文字列をfloat化する。float化に失敗した場合はdefault_の値を返す。
float to_float(const std::string input, float default_)
{
std::istringstream ss(input);
float result = default_; // 失敗したときはこの値のままになる
ss >> result;
return result;
}

// スペース、タブなど空白に相当する文字で分割して返す。
std::vector<std::string> split(const std::string& input)
{
Expand Down
3 changes: 3 additions & 0 deletions source/misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,9 @@ namespace StringExtension
// 文字列をint化する。int化に失敗した場合はdefault_の値を返す。
int to_int(const std::string input, int default_);

// 文字列をfloat化する。float化に失敗した場合はdefault_の値を返す。
float to_float(const std::string input, float default_);

// スペース、タブなど空白に相当する文字で分割して返す。
std::vector<std::string> split(const std::string& input);

Expand Down

0 comments on commit 18441f8

Please sign in to comment.