Skip to content

Commit

Permalink
Choose move based on normal distribution LCB.
Browse files Browse the repository at this point in the history
* Calculate node variance.
* Use normal distribution LCB to choose the played move.
* Cached student-t.
* Sort lz-analyze output according to LCB.
* Don't choose nodes with very few visits even if LCB is better.

Guard against NN misevaluations when top move has lot of visits.
Without this it's possible for move with few hundred visits to be picked
over a move with over ten thousand visits.

The problem is that the evaluation distribution isn't really normal
distribution. Evaluations correlate and the distribution can change
if deeper in the tree it finds a better alternative.

Pull request leela-zero#2290.
  • Loading branch information
Ttl authored and gcp committed Apr 2, 2019
1 parent aabfecc commit fd23877
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 22 deletions.
4 changes: 4 additions & 0 deletions src/GTP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ float cfg_logconst;
float cfg_softmax_temp;
float cfg_fpu_reduction;
float cfg_fpu_root_reduction;
float cfg_ci_alpha;
float cfg_lcb_min_visit_ratio;
std::string cfg_weightsfile;
std::string cfg_logfile;
FILE* cfg_logfile_handle;
Expand Down Expand Up @@ -347,6 +349,8 @@ void GTP::setup_default_parameters() {
cfg_resignpct = -1;
cfg_noise = false;
cfg_fpu_root_reduction = cfg_fpu_reduction;
cfg_ci_alpha = 1e-5f;
cfg_lcb_min_visit_ratio = 0.10f;
cfg_random_cnt = 0;
cfg_random_min_visits = 1;
cfg_random_temp = 1.0f;
Expand Down
2 changes: 2 additions & 0 deletions src/GTP.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ extern float cfg_logconst;
extern float cfg_softmax_temp;
extern float cfg_fpu_reduction;
extern float cfg_fpu_root_reduction;
extern float cfg_ci_alpha;
extern float cfg_lcb_min_visit_ratio;
extern std::string cfg_logfile;
extern std::string cfg_weightsfile;
extern FILE* cfg_logfile_handle;
Expand Down
6 changes: 6 additions & 0 deletions src/Leela.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ static void parse_commandline(int argc, char *argv[]) {
("logconst", po::value<float>())
("softmax_temp", po::value<float>())
("fpu_reduction", po::value<float>())
("ci_alpha", po::value<float>())
;
#endif
// These won't be shown, we use them to catch incorrect usage of the
Expand Down Expand Up @@ -278,6 +279,9 @@ static void parse_commandline(int argc, char *argv[]) {
if (vm.count("fpu_reduction")) {
cfg_fpu_reduction = vm["fpu_reduction"].as<float>();
}
if (vm.count("ci_alpha")) {
cfg_ci_alpha = vm["ci_alpha"].as<float>();
}
#endif

if (vm.count("logfile")) {
Expand Down Expand Up @@ -494,6 +498,8 @@ void init_global_objects() {
// improves reproducibility across platforms.
Random::get_Rng().seedrandom(cfg_rng_seed);

Utils::create_z_table();

initialize_network();
}

Expand Down
58 changes: 54 additions & 4 deletions src/UCTNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,16 @@ void UCTNode::virtual_loss_undo() {
}

void UCTNode::update(float eval) {
// Cache values to avoid race conditions.
auto old_eval = static_cast<float>(m_blackevals);
auto old_visits = static_cast<int>(m_visits);
auto old_delta = old_visits > 0 ? eval - old_eval / old_visits : 0.0f;
m_visits++;
accumulate_eval(eval);
auto new_delta = eval - (old_eval + eval) / (old_visits + 1);
// Welford's online algorithm for calculating variance.
auto delta = old_delta * new_delta;
atomic_add(m_squared_eval_diff, delta);
}

bool UCTNode::has_children() const {
Expand All @@ -230,10 +238,29 @@ void UCTNode::set_policy(float policy) {
m_policy = policy;
}

float UCTNode::get_eval_variance(float default_var) const {
return m_visits > 1 ? m_squared_eval_diff / (m_visits - 1) : default_var;
}

int UCTNode::get_visits() const {
return m_visits;
}

float UCTNode::get_eval_lcb(int color) const {
// Lower confidence bound of winrate.
auto visits = get_visits();
if (visits < 2) {
// Return large negative value if not enough visits.
return -1e6f + visits;
}
auto mean = get_raw_eval(color);

auto stddev = std::sqrt(get_eval_variance(1.0f) / visits);
auto z = cached_t_quantile(visits - 1);

return mean - z * stddev;
}

float UCTNode::get_raw_eval(int tomove, int virtual_loss) const {
auto visits = get_visits() + virtual_loss;
assert(visits > 0);
Expand Down Expand Up @@ -327,7 +354,8 @@ UCTNode* UCTNode::uct_select_child(int color, bool is_root) {
class NodeComp : public std::binary_function<UCTNodePointer&,
UCTNodePointer&, bool> {
public:
NodeComp(int color) : m_color(color) {};
NodeComp(int color, float lcb_min_visits) : m_color(color),
m_lcb_min_visits(lcb_min_visits){};

// WARNING : on very unusual cases this can be called on multithread
// contexts (e.g., UCTSearch::get_pv()) so beware of race conditions
Expand All @@ -336,6 +364,22 @@ class NodeComp : public std::binary_function<UCTNodePointer&,
auto a_visit = a.get_visits();
auto b_visit = b.get_visits();

// Need at least 2 visits for LCB.
if (m_lcb_min_visits < 2) {
m_lcb_min_visits = 2;
}

// Calculate the lower confidence bound for each node.
if ((a_visit > m_lcb_min_visits) && (b_visit > m_lcb_min_visits)) {
auto a_lcb = a.get_eval_lcb(m_color);
auto b_lcb = b.get_eval_lcb(m_color);

// Sort on lower confidence bounds
if (a_lcb != b_lcb) {
return a_lcb < b_lcb;
}
}

// if visits are not same, sort on visits
if (a_visit != b_visit) {
return a_visit < b_visit;
Expand All @@ -351,19 +395,25 @@ class NodeComp : public std::binary_function<UCTNodePointer&,
}
private:
int m_color;
float m_lcb_min_visits;
};

void UCTNode::sort_children(int color) {
std::stable_sort(rbegin(m_children), rend(m_children), NodeComp(color));
void UCTNode::sort_children(int color, float lcb_min_visits) {
std::stable_sort(rbegin(m_children), rend(m_children), NodeComp(color, lcb_min_visits));
}

UCTNode& UCTNode::get_best_root_child(int color) {
wait_expanded();

assert(!m_children.empty());

auto max_visits = 0;
for (const auto& node : m_children) {
max_visits = std::max(max_visits, node.get_visits());
}

auto ret = std::max_element(begin(m_children), end(m_children),
NodeComp(color));
NodeComp(color, cfg_lcb_min_visit_ratio * max_visits));
ret->inflate();

return *(ret->get());
Expand Down
8 changes: 7 additions & 1 deletion src/UCTNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class UCTNode {
float min_psa_ratio = 0.0f);

const std::vector<UCTNodePointer>& get_children() const;
void sort_children(int color);
void sort_children(int color, float lcb_min_visits);
UCTNode& get_best_root_child(int color);
UCTNode* uct_select_child(int color, bool is_root);

Expand All @@ -76,12 +76,14 @@ class UCTNode {
int get_visits() const;
float get_policy() const;
void set_policy(float policy);
float get_eval_variance(float default_var = 0.0f) const;
float get_eval(int tomove) const;
float get_raw_eval(int tomove, int virtual_loss = 0) const;
float get_net_eval(int tomove) const;
void virtual_loss();
void virtual_loss_undo();
void update(float eval);
float get_eval_lcb(int color) const;

// Defined in UCTNodeRoot.cpp, only to be called on m_root in UCTSearch
void randomize_first_proportionally();
Expand Down Expand Up @@ -122,6 +124,10 @@ class UCTNode {
float m_policy;
// Original net eval for this node (not children).
float m_net_eval{0.0f};
// Variable used for calculating variance of evaluations.
// Initialized to small non-zero value to avoid accidental zero variances
// at low visits.
std::atomic<float> m_squared_eval_diff{1e-4f};
std::atomic<double> m_blackevals{0.0};
std::atomic<Status> m_status{ACTIVE};

Expand Down
6 changes: 6 additions & 0 deletions src/UCTNodePointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ float UCTNodePointer::get_policy() const {
return read_policy(v);
}

float UCTNodePointer::get_eval_lcb(int color) const {
assert(is_inflated());
auto v = m_data.load();
return read_ptr(v)->get_eval_lcb(color);
}

bool UCTNodePointer::active() const {
auto v = m_data.load();
if (is_inflated(v)) return read_ptr(v)->active();
Expand Down
3 changes: 2 additions & 1 deletion src/UCTNodePointer.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ class UCTNodePointer {
float get_policy() const;
bool active() const;
int get_move() const;
// this can only be called if it is an inflated pointer
// these can only be called if it is an inflated pointer
float get_eval(int tomove) const;
float get_eval_lcb(int color) const;
};

#endif
Loading

0 comments on commit fd23877

Please sign in to comment.