diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 61dd94302cbc..393dda59c2aa 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -567,7 +567,7 @@ class RegTree : public Model { * \brief drop the trace after fill, must be called after fill. * \param inst The sparse instance to drop. */ - void Drop(const SparsePage::Inst& inst); + void Drop(); /*! * \brief returns the size of the feature vector * \return the size of the feature vector @@ -807,13 +807,10 @@ inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) { has_missing_ = data_.size() != feature_count; } -inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) { - for (auto const& entry : inst) { - if (entry.index >= data_.size()) { - continue; - } - data_[entry.index].flag = -1; - } +inline void RegTree::FVec::Drop() { + Entry e{}; + e.flag = -1; + std::fill_n(data_.data(), data_.size(), e); has_missing_ = true; } diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 0a606ecd534f..3b3323bb5694 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -166,6 +166,12 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const { auto const &values = cut.Values(); auto const &mins = cut.MinValues(); auto const &ptrs = cut.Ptrs(); + return this->GetFvalue(ptrs, values, mins, ridx, fidx, is_cat); +} + +float GHistIndexMatrix::GetFvalue(std::vector const &ptrs, + std::vector const &values, std::vector const &mins, + bst_row_t ridx, bst_feature_t fidx, bool is_cat) const { if (is_cat) { auto gidx = GetGindex(ridx, fidx); if (gidx == -1) { @@ -181,24 +187,27 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const { } return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx); }; - - if (columns_->GetColumnType(fidx) == common::kDenseColumn) { - if (columns_->AnyMissing()) { - return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { - auto column = columns_->DenseColumn(fidx); - return get_bin_val(column); - }); - } else { + switch (columns_->GetColumnType(fidx)) { + case common::kDenseColumn: { + if (columns_->AnyMissing()) { + return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { + auto column = columns_->DenseColumn(fidx); + return get_bin_val(column); + }); + } else { + return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { + auto column = columns_->DenseColumn(fidx); + auto bin_idx = column[ridx]; + return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx); + }); + } + } + case common::kSparseColumn: { return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { - auto column = columns_->DenseColumn(fidx); + auto column = columns_->SparseColumn(fidx, 0); return get_bin_val(column); }); } - } else { - return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { - auto column = columns_->SparseColumn(fidx, 0); - return get_bin_val(column); - }); } SPAN_CHECK(false); diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 3cb0709bd95d..4c35870db595 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -239,6 +239,9 @@ class GHistIndexMatrix { bst_bin_t GetGindex(size_t ridx, size_t fidx) const; float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const; + float GetFvalue(std::vector const& ptrs, std::vector const& values, + std::vector const& mins, bst_row_t ridx, bst_feature_t fidx, + bool is_cat) const; private: std::unique_ptr columns_; diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 2b7a96d9cbde..b3b4c5e80a3c 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -75,7 +75,7 @@ bst_float PredValue(const SparsePage::Inst &inst, psum += (*trees[i])[nidx].LeafValue(); } } - p_feats->Drop(inst); + p_feats->Drop(); return psum; } @@ -172,13 +172,11 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_ } } -template -void FVecDrop(const size_t block_size, const size_t batch_offset, DataView *batch, - const size_t fvec_offset, std::vector *p_feats) { +void FVecDrop(std::size_t const block_size, std::size_t const fvec_offset, + std::vector *p_feats) { for (size_t i = 0; i < block_size; ++i) { RegTree::FVec &feats = (*p_feats)[fvec_offset + i]; - const SparsePage::Inst inst = (*batch)[batch_offset + i]; - feats.Drop(inst); + feats.Drop(); } } @@ -196,11 +194,15 @@ struct SparsePageView { struct GHistIndexMatrixView { private: GHistIndexMatrix const &page_; - uint64_t n_features_; + std::uint64_t const n_features_; common::Span ft_; common::Span workspace_; std::vector current_unroll_; + std::vector const& ptrs_; + std::vector const& mins_; + std::vector const& values_; + public: size_t base_rowid; @@ -213,6 +215,9 @@ struct GHistIndexMatrixView { ft_{ft}, workspace_{workplace}, current_unroll_(n_threads > 0 ? n_threads : 1, 0), + ptrs_{_page.cut.Ptrs()}, + mins_{_page.cut.MinValues()}, + values_{_page.cut.Values()}, base_rowid{_page.base_rowid} {} SparsePage::Inst operator[](size_t r) { @@ -221,7 +226,7 @@ struct GHistIndexMatrixView { size_t non_missing{static_cast(beg)}; for (bst_feature_t c = 0; c < n_features_; ++c) { - float f = page_.GetFvalue(r, c, common::IsCat(ft_, c)); + float f = page_.GetFvalue(ptrs_, values_, mins_, r, c, common::IsCat(ft_, c)); if (!common::CheckNAN(f)) { workspace_[non_missing] = Entry{c, f}; ++non_missing; @@ -301,7 +306,7 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &mod // process block of rows through all trees to keep cache locality PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, thread_temp, fvec_offset, block_size, out_predt); - FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp); + FVecDrop(block_size, fvec_offset, p_thread_temp); }); } @@ -529,7 +534,7 @@ class ColumnSplitHelper { FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, &feat_vecs_); MaskAllTrees(batch_offset, fvec_offset, block_size); - FVecDrop(block_size, batch_offset, &batch, fvec_offset, &feat_vecs_); + FVecDrop(block_size, fvec_offset, &feat_vecs_); }); AllreduceBitVectors(); @@ -780,7 +785,7 @@ class CPUPredictor : public Predictor { } preds[ridx * ntree_limit + j] = static_cast(nidx); } - feats.Drop(page[i]); + feats.Drop(); }); } } @@ -853,7 +858,7 @@ class CPUPredictor : public Predictor { (tree_weights == nullptr ? 1 : (*tree_weights)[j]); } } - feats.Drop(page[i]); + feats.Drop(); // add base margin to BIAS if (base_margin.Size() != 0) { CHECK_EQ(base_margin.Shape(1), ngroup); diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 17c5654907cd..448492de0688 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -79,7 +79,7 @@ class TreeRefresher : public TreeUpdater { dmlc::BeginPtr(stemp[tid]) + offset); offset += tree->NumNodes(); } - feats.Drop(inst); + feats.Drop(); }); } // aggregate the statistics