Skip to content

Commit

Permalink
Require leaf statistics when expanding tree
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Dec 22, 2018
1 parent 9537a08 commit 2345918
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 62 deletions.
28 changes: 20 additions & 8 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,22 @@ class RegTree {
}

/**
* \brief Expands a leaf node into two additional leaf nodes
* \brief Expands a leaf node into two additional leaf nodes.
*
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_value The split condition.
* \param default_left True to default left.
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_value The split condition.
* \param default_left True to default left.
* \param base_weight The base weight, before learning rate.
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
* \param loss_change The loss change.
* \param sum_hess The sum hess.
*/
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left) {
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight,
bst_float left_leaf_weight, bst_float right_leaf_weight,
bst_float loss_change, float sum_hess) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
auto &node = nodes_[nid];
Expand All @@ -322,8 +330,12 @@ class RegTree {
node.SetSplit(split_index, split_value,
default_left);
// mark right child as 0, to indicate fresh leaf
nodes_[pleft].SetLeaf(0.0f, 0);
nodes_[pright].SetLeaf(0.0f, 0);
nodes_[pleft].SetLeaf(left_leaf_weight, 0);
nodes_[pright].SetLeaf(right_leaf_weight, 0);

this->Stat(nid).loss_chg = loss_change;
this->Stat(nid).base_weight = base_weight;
this->Stat(nid).sum_hess = sum_hess;
}

/*!
Expand Down
31 changes: 24 additions & 7 deletions src/tree/updater_colmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,13 +410,15 @@ class ColMaker: public TreeUpdater {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, c, e.stats);
} else {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, e.stats, c);
}
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, e.stats, c);
}
}
// update the statistics
Expand Down Expand Up @@ -486,18 +488,21 @@ class ColMaker: public TreeUpdater {
if (e.stats.sum_hess >= param_.min_child_weight &&
c.sum_hess >= param_.min_child_weight) {
bst_float loss_chg;
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
const bst_float delta = d_step == +1 ? gap: -gap;
if (d_step == -1) {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, c,
e.stats);
} else {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1,
e.stats, c);
}
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
const bst_float delta = d_step == +1 ? gap: -gap;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, e.stats, c);
}
}
}
Expand Down Expand Up @@ -545,12 +550,15 @@ class ColMaker: public TreeUpdater {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, c, e.stats);
} else {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, e.stats, c);
}
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1, e.stats, c);
}
}
// update the statistics
Expand Down Expand Up @@ -640,7 +648,16 @@ class ColMaker: public TreeUpdater {
NodeEntry &e = snode_[nid];
// now we know the solution in snode[nid], set split
if (e.best.loss_chg > kRtEps) {
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
bst_float left_leaf_weight =
spliteval_->ComputeWeight(nid, e.best.left_sum) *
param_.learning_rate;
bst_float right_leaf_weight =
spliteval_->ComputeWeight(nid, e.best.right_sum) *
param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg,
e.stats.sum_hess);
} else {
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
}
Expand Down
3 changes: 2 additions & 1 deletion src/tree/updater_gpu_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ inline void Dense2SparseTree(RegTree* p_tree,
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
const DeviceNodeStats& n = h_nodes[gpu_nid];
if (!n.IsUnused() && !n.IsLeaf()) {
tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir);
tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir, n.weight, 0.0f,
0.0f, n.root_gain, n.sum_gradients.GetHess());
tree.Stat(nid).loss_chg = n.root_gain;
tree.Stat(nid).base_weight = n.weight;
tree.Stat(nid).sum_hess = n.sum_gradients.GetHess();
Expand Down
45 changes: 19 additions & 26 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1182,42 +1182,35 @@ class GPUHistMakerSpecialised{
}

void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
// Add new leaves
RegTree& tree = *p_tree;
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
candidate.split.dir == kLeftDir);
auto& parent = tree[candidate.nid];
tree.Stat(candidate.nid).loss_chg = candidate.split.loss_chg;

// Set up child constraints
node_value_constraints_.resize(tree.GetNodes().size());
GradStats left_stats(param_);
left_stats.Add(candidate.split.left_sum);
GradStats right_stats(param_);
right_stats.Add(candidate.split.right_sum);
node_value_constraints_[candidate.nid].SetChild(
param_, parent.SplitIndex(), left_stats, right_stats,
&node_value_constraints_[parent.LeftChild()],
&node_value_constraints_[parent.RightChild()]);

// Configure left child
GradStats parent_sum(param_);
parent_sum.Add(left_stats);
parent_sum.Add(right_stats);
node_value_constraints_.resize(tree.GetNodes().size());
auto base_weight = node_value_constraints_[candidate.nid].CalcWeight(param_, parent_sum);
auto left_weight =
node_value_constraints_[parent.LeftChild()].CalcWeight(param_, left_stats);
tree[parent.LeftChild()].SetLeaf(left_weight * param_.learning_rate, 0);
tree.Stat(parent.LeftChild()).base_weight = left_weight;
tree.Stat(parent.LeftChild()).sum_hess = candidate.split.left_sum.GetHess();

// Configure right child
node_value_constraints_[candidate.nid].CalcWeight(param_, left_stats)*param_.learning_rate;
auto right_weight =
node_value_constraints_[parent.RightChild()].CalcWeight(param_, right_stats);
tree[parent.RightChild()].SetLeaf(right_weight * param_.learning_rate, 0);
tree.Stat(parent.RightChild()).base_weight = right_weight;
tree.Stat(parent.RightChild()).sum_hess = candidate.split.right_sum.GetHess();
node_value_constraints_[candidate.nid].CalcWeight(param_, right_stats)*param_.learning_rate;
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.sum_hess);
// Set up child constraints
node_value_constraints_.resize(tree.GetNodes().size());
node_value_constraints_[candidate.nid].SetChild(
param_, tree[candidate.nid].SplitIndex(), left_stats, right_stats,
&node_value_constraints_[tree[candidate.nid].LeftChild()],
&node_value_constraints_[tree[candidate.nid].RightChild()]);

// Store sum gradients
for (auto& shard : shards_) {
shard->node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum;
shard->node_sum_gradients[parent.RightChild()] = candidate.split.right_sum;
shard->node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum;
shard->node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum;
}
}

Expand Down
14 changes: 12 additions & 2 deletions src/tree/updater_histmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class HistMaker: public BaseMaker {
c.SetSubstract(node_sum, s);
if (c.sum_hess >= param_.min_child_weight) {
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, s, c)) {
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, c, s)) {
*left_sum = c;
}
}
Expand Down Expand Up @@ -244,8 +244,18 @@ class HistMaker: public BaseMaker {
p_tree->Stat(nid).loss_chg = best.loss_chg;
// now we know the solution in snode[nid], set split
if (best.loss_chg > kRtEps) {
bst_float base_weight = node_sum.CalcWeight(param_);
bst_float left_leaf_weight =
CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) *
param_.learning_rate;
bst_float right_leaf_weight =
CalcWeight(param_, best.right_sum.sum_grad,
best.right_sum.sum_hess) *
param_.learning_rate;
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft());
best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg,
node_sum.sum_hess);
// right side sum
TStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]);
Expand Down
7 changes: 6 additions & 1 deletion src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,13 @@ void QuantileHistMaker::Builder::ApplySplit(int nid,

/* 1. Create child nodes */
NodeEntry& e = snode_[nid];
bst_float left_leaf_weight =
spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate;
bst_float right_leaf_weight =
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft());
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);

/* 2. Categorize member rows */
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
Expand Down
13 changes: 11 additions & 2 deletions src/tree/updater_skmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,21 @@ class SketchMaker: public BaseMaker {
const int nid = qexpand_[wid];
const SplitEntry &best = sol[wid];
// set up the values
p_tree->Stat(nid).loss_chg = best.loss_chg;
this->SetStats(nid, node_stats_[nid], p_tree);
// now we know the solution in snode[nid], set split
if (best.loss_chg > kRtEps) {
bst_float base_weight = node_stats_[nid].CalcWeight(param_);
bst_float left_leaf_weight =
CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) *
param_.learning_rate;
bst_float right_leaf_weight =
CalcWeight(param_, best.right_sum.sum_grad,
best.right_sum.sum_hess) *
param_.learning_rate;
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft());
best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg,
node_stats_[nid].sum_hess);
} else {
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
}
Expand Down
13 changes: 2 additions & 11 deletions tests/cpp/tree/test_prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,13 @@ TEST(Updater, Prune) {
pruner->Init(cfg);

// loss_chg < min_split_loss;
tree.ExpandNode(0, 0, 0, true);
int cleft = tree[0].LeftChild();
int cright = tree[0].RightChild();
tree[cleft].SetLeaf(0.3f, 0);
tree[cright].SetLeaf(0.4f, 0);
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f);
pruner->Update(&gpair, dmat->get(), trees);

ASSERT_EQ(tree.NumExtraNodes(), 0);

// loss_chg > min_split_loss;
tree.ExpandNode(0, 0, 0, true);
cleft = tree[0].LeftChild();
cright = tree[0].RightChild();
tree[cleft].SetLeaf(0.3f, 0);
tree[cright].SetLeaf(0.4f, 0);
tree.Stat(0).loss_chg = 11;
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f);
pruner->Update(&gpair, dmat->get(), trees);

ASSERT_EQ(tree.NumExtraNodes(), 2);
Expand Down
5 changes: 1 addition & 4 deletions tests/cpp/tree/test_refresh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@ TEST(Updater, Refresh) {
std::vector<RegTree*> trees {&tree};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));

tree.ExpandNode(0, 0, 0, true);
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f);
int cleft = tree[0].LeftChild();
int cright = tree[0].RightChild();
tree[cleft].SetLeaf(0.2f, 0);
tree[cright].SetLeaf(0.8f, 0);
tree[0].SetSplit(2, 0.2f);

tree.Stat(cleft).base_weight = 1.2;
tree.Stat(cright).base_weight = 1.3;
Expand Down

0 comments on commit 2345918

Please sign in to comment.