Skip to content

Commit

Permalink
Optimize prediction with QuantileDMatrix. (#9096)
Browse files Browse the repository at this point in the history
- Reduce overhead in `FVecDrop`.
- Reduce overhead caused by `HostVector()` calls.
  • Loading branch information
trivialfis authored Apr 27, 2023
1 parent fa267ad commit 0e470ef
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 35 deletions.
13 changes: 5 additions & 8 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ class RegTree : public Model {
* \brief drop the trace after fill, must be called after fill.
* \param inst The sparse instance to drop.
*/
void Drop(const SparsePage::Inst& inst);
void Drop();
/*!
* \brief returns the size of the feature vector
* \return the size of the feature vector
Expand Down Expand Up @@ -807,13 +807,10 @@ inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
has_missing_ = data_.size() != feature_count;
}

inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
for (auto const& entry : inst) {
if (entry.index >= data_.size()) {
continue;
}
data_[entry.index].flag = -1;
}
inline void RegTree::FVec::Drop() {
Entry e{};
e.flag = -1;
std::fill_n(data_.data(), data_.size(), e);
has_missing_ = true;
}

Expand Down
37 changes: 23 additions & 14 deletions src/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
auto const &values = cut.Values();
auto const &mins = cut.MinValues();
auto const &ptrs = cut.Ptrs();
return this->GetFvalue(ptrs, values, mins, ridx, fidx, is_cat);
}

float GHistIndexMatrix::GetFvalue(std::vector<std::uint32_t> const &ptrs,
std::vector<float> const &values, std::vector<float> const &mins,
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const {
if (is_cat) {
auto gidx = GetGindex(ridx, fidx);
if (gidx == -1) {
Expand All @@ -181,24 +187,27 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
}
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
};

if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
if (columns_->AnyMissing()) {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
return get_bin_val(column);
});
} else {
switch (columns_->GetColumnType(fidx)) {
case common::kDenseColumn: {
if (columns_->AnyMissing()) {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
return get_bin_val(column);
});
} else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
auto bin_idx = column[ridx];
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
});
}
}
case common::kSparseColumn: {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
return get_bin_val(column);
});
}
} else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
return get_bin_val(column);
});
}

SPAN_CHECK(false);
Expand Down
3 changes: 3 additions & 0 deletions src/data/gradient_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ class GHistIndexMatrix {
bst_bin_t GetGindex(size_t ridx, size_t fidx) const;

float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
float GetFvalue(std::vector<std::uint32_t> const& ptrs, std::vector<float> const& values,
std::vector<float> const& mins, bst_row_t ridx, bst_feature_t fidx,
bool is_cat) const;

private:
std::unique_ptr<common::ColumnMatrix> columns_;
Expand Down
29 changes: 17 additions & 12 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ bst_float PredValue(const SparsePage::Inst &inst,
psum += (*trees[i])[nidx].LeafValue();
}
}
p_feats->Drop(inst);
p_feats->Drop();
return psum;
}

Expand Down Expand Up @@ -172,13 +172,11 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
}
}

template <typename DataView>
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView *batch,
const size_t fvec_offset, std::vector<RegTree::FVec> *p_feats) {
void FVecDrop(std::size_t const block_size, std::size_t const fvec_offset,
std::vector<RegTree::FVec> *p_feats) {
for (size_t i = 0; i < block_size; ++i) {
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
const SparsePage::Inst inst = (*batch)[batch_offset + i];
feats.Drop(inst);
feats.Drop();
}
}

Expand All @@ -196,11 +194,15 @@ struct SparsePageView {
struct GHistIndexMatrixView {
private:
GHistIndexMatrix const &page_;
uint64_t n_features_;
std::uint64_t const n_features_;
common::Span<FeatureType const> ft_;
common::Span<Entry> workspace_;
std::vector<size_t> current_unroll_;

std::vector<std::uint32_t> const& ptrs_;
std::vector<float> const& mins_;
std::vector<float> const& values_;

public:
size_t base_rowid;

Expand All @@ -213,6 +215,9 @@ struct GHistIndexMatrixView {
ft_{ft},
workspace_{workplace},
current_unroll_(n_threads > 0 ? n_threads : 1, 0),
ptrs_{_page.cut.Ptrs()},
mins_{_page.cut.MinValues()},
values_{_page.cut.Values()},
base_rowid{_page.base_rowid} {}

SparsePage::Inst operator[](size_t r) {
Expand All @@ -221,7 +226,7 @@ struct GHistIndexMatrixView {
size_t non_missing{static_cast<std::size_t>(beg)};

for (bst_feature_t c = 0; c < n_features_; ++c) {
float f = page_.GetFvalue(r, c, common::IsCat(ft_, c));
float f = page_.GetFvalue(ptrs_, values_, mins_, r, c, common::IsCat(ft_, c));
if (!common::CheckNAN(f)) {
workspace_[non_missing] = Entry{c, f};
++non_missing;
Expand Down Expand Up @@ -301,7 +306,7 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &mod
// process block of rows through all trees to keep cache locality
PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, thread_temp,
fvec_offset, block_size, out_predt);
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
FVecDrop(block_size, fvec_offset, p_thread_temp);
});
}

Expand Down Expand Up @@ -529,7 +534,7 @@ class ColumnSplitHelper {

FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, &feat_vecs_);
MaskAllTrees(batch_offset, fvec_offset, block_size);
FVecDrop(block_size, batch_offset, &batch, fvec_offset, &feat_vecs_);
FVecDrop(block_size, fvec_offset, &feat_vecs_);
});

AllreduceBitVectors();
Expand Down Expand Up @@ -780,7 +785,7 @@ class CPUPredictor : public Predictor {
}
preds[ridx * ntree_limit + j] = static_cast<bst_float>(nidx);
}
feats.Drop(page[i]);
feats.Drop();
});
}
}
Expand Down Expand Up @@ -853,7 +858,7 @@ class CPUPredictor : public Predictor {
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
}
}
feats.Drop(page[i]);
feats.Drop();
// add base margin to BIAS
if (base_margin.Size() != 0) {
CHECK_EQ(base_margin.Shape(1), ngroup);
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_refresh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class TreeRefresher : public TreeUpdater {
dmlc::BeginPtr(stemp[tid]) + offset);
offset += tree->NumNodes();
}
feats.Drop(inst);
feats.Drop();
});
}
// aggregate the statistics
Expand Down

0 comments on commit 0e470ef

Please sign in to comment.