From 712e39d3d25dbb3a6ef4dc64ccaeca150acdbc1a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 11 Jan 2025 22:15:39 +0800 Subject: [PATCH] Cleanup CPU predict function. (#11139) --- include/xgboost/gbm.h | 27 +- include/xgboost/predictor.h | 58 ++--- include/xgboost/tree_model.h | 59 ++--- plugin/sycl/predictor/predictor.cc | 22 +- src/common/column_matrix.h | 6 +- src/common/hist_util.h | 5 +- src/common/ref_resource_view.h | 10 +- src/data/adapter.h | 8 +- src/data/gradient_index.cc | 20 +- src/data/gradient_index.h | 33 ++- src/gbm/gblinear.cc | 13 +- src/gbm/gbtree.cc | 12 +- src/gbm/gbtree.h | 9 +- src/metric/elementwise_metric.cu | 31 ++- src/predictor/cpu_predictor.cc | 341 ++++++++++---------------- src/predictor/gpu_predictor.cu | 82 +++---- src/predictor/predict_fn.h | 35 ++- tests/cpp/predictor/test_predictor.cc | 12 +- tests/python/test_quantile_dmatrix.py | 6 +- 19 files changed, 324 insertions(+), 465 deletions(-) diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index ae8652eee66d..3f4e8540efa5 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023 by XGBoost Contributors + * Copyright 2014-2025, XGBoost Contributors * \file gbm.h * \brief Interface of gradient booster, * that learns through gradient statistics. @@ -15,10 +15,8 @@ #include #include -#include #include #include -#include #include namespace xgboost { @@ -42,13 +40,13 @@ class GradientBooster : public Model, public Configurable { public: /*! \brief virtual destructor */ ~GradientBooster() override = default; - /*! - * \brief Set the configuration of gradient boosting. + /** + * @brief Set the configuration of gradient boosting. * User must call configure once before InitModel and Training. * - * \param cfg configurations on both training and model parameters. + * @param cfg configurations on both training and model parameters. */ - virtual void Configure(const std::vector >& cfg) = 0; + virtual void Configure(Args const& cfg) = 0; /*! * \brief load model from stream * \param fi input stream. @@ -117,21 +115,6 @@ class GradientBooster : public Model, public Configurable { bst_layer_t) const { LOG(FATAL) << "Inplace predict is not supported by the current booster."; } - /*! - * \brief online prediction function, predict score for one instance at a time - * NOTE: use the batch prediction interface if possible, batch prediction is usually - * more efficient than online prediction - * This function is NOT threadsafe, make sure you only call from one thread - * - * \param inst the instance you want to predict - * \param out_preds output vector to hold the predictions - * \param layer_begin Beginning of boosted tree layer used for prediction. - * \param layer_end End of booster layer. 0 means do not limit trees. - * \sa Predict - */ - virtual void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, - unsigned layer_begin, unsigned layer_end) = 0; /*! * \brief predict the leaf index of each tree, the output will be nsample * ntree vector * this is only valid in gbtree predictor diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 62f0895e024c..f40abdf4faa6 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -1,5 +1,5 @@ /** - * Copyright 2017-2024, XGBoost Contributors + * Copyright 2017-2025, XGBoost Contributors * \file predictor.h * \brief Interface of predictor, * performs predictions for a gradient booster. @@ -28,7 +28,7 @@ namespace xgboost { */ struct PredictionCacheEntry { // A storage for caching prediction values - HostDeviceVector predictions; + HostDeviceVector predictions; // The version of current cache, corresponding number of layers of trees std::uint32_t version{0}; @@ -91,7 +91,7 @@ class Predictor { * \param out_predt Prediction vector to be initialized. * \param model Tree model used for prediction. */ - virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_predt, + virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_predt, const gbm::GBTreeModel& model) const; /** @@ -105,8 +105,8 @@ class Predictor { * \param tree_end The tree end index. */ virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds, - const gbm::GBTreeModel& model, uint32_t tree_begin, - uint32_t tree_end = 0) const = 0; + gbm::GBTreeModel const& model, bst_tree_t tree_begin, + bst_tree_t tree_end = 0) const = 0; /** * \brief Inplace prediction. @@ -123,25 +123,7 @@ class Predictor { */ virtual bool InplacePredict(std::shared_ptr p_fmat, const gbm::GBTreeModel& model, float missing, PredictionCacheEntry* out_preds, - uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0; - /** - * \brief online prediction function, predict score for one instance at a time - * NOTE: use the batch prediction interface if possible, batch prediction is - * usually more efficient than online prediction This function is NOT - * threadsafe, make sure you only call from one thread. - * - * \param inst The instance to predict. - * \param [in,out] out_preds The output preds. - * \param model The model to predict from - * \param tree_end (Optional) The tree end index. - * \param is_column_split (Optional) If the data is split column-wise. - */ - - virtual void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, - const gbm::GBTreeModel& model, - unsigned tree_end = 0, - bool is_column_split = false) const = 0; + bst_tree_t tree_begin = 0, bst_tree_t tree_end = 0) const = 0; /** * \brief predict the leaf index of each tree, the output will be nsample * @@ -153,9 +135,8 @@ class Predictor { * \param tree_end (Optional) The tree end index. */ - virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, - unsigned tree_end = 0) const = 0; + virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, + gbm::GBTreeModel const& model, bst_tree_t tree_end = 0) const = 0; /** * \brief feature contributions to individual predictions; the output will be @@ -172,18 +153,17 @@ class Predictor { * \param condition_feature Feature to condition on (i.e. fix) during calculations. */ - virtual void - PredictContribution(DMatrix *dmat, HostDeviceVector *out_contribs, - const gbm::GBTreeModel &model, unsigned tree_end = 0, - std::vector const *tree_weights = nullptr, - bool approximate = false, int condition = 0, - unsigned condition_feature = 0) const = 0; - - virtual void PredictInteractionContributions( - DMatrix *dmat, HostDeviceVector *out_contribs, - const gbm::GBTreeModel &model, unsigned tree_end = 0, - std::vector const *tree_weights = nullptr, - bool approximate = false) const = 0; + virtual void PredictContribution(DMatrix* dmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end = 0, + std::vector const* tree_weights = nullptr, + bool approximate = false, int condition = 0, + unsigned condition_feature = 0) const = 0; + + virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, + bst_tree_t tree_end = 0, + std::vector const* tree_weights = nullptr, + bool approximate = false) const = 0; /** * \brief Creates a new Predictor*. diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index e9f661215653..921fc5a1ebc8 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2024, XGBoost Contributors + * Copyright 2014-2025, XGBoost Contributors * \file tree_model.h * \brief model structure for tree * \author Tianqi Chen @@ -23,7 +23,6 @@ #include // for make_unique #include #include -#include #include namespace xgboost { @@ -562,7 +561,7 @@ class RegTree : public Model { * \brief fill the vector with sparse vector * \param inst The sparse instance to fill. */ - void Fill(const SparsePage::Inst& inst); + void Fill(SparsePage::Inst const& inst); /*! * \brief drop the trace after fill, must be called after fill. @@ -587,18 +586,17 @@ class RegTree : public Model { */ [[nodiscard]] bool IsMissing(size_t i) const; [[nodiscard]] bool HasMissing() const; + void HasMissing(bool has_missing) { this->has_missing_ = has_missing; } + [[nodiscard]] common::Span Data() { return data_; } private: - /*! - * \brief a union value of value and flag - * when flag == -1, this indicate the value is missing + /** + * @brief A dense vector for a single sample. + * + * It's nan if the value is missing. */ - union Entry { - bst_float fvalue; - int flag; - }; - std::vector data_; + std::vector data_; bool has_missing_; }; @@ -793,46 +791,35 @@ class RegTree : public Model { }; inline void RegTree::FVec::Init(size_t size) { - Entry e; e.flag = -1; data_.resize(size); - std::fill(data_.begin(), data_.end(), e); + std::fill(data_.begin(), data_.end(), std::numeric_limits::quiet_NaN()); has_missing_ = true; } -inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) { - size_t feature_count = 0; - for (auto const& entry : inst) { - if (entry.index >= data_.size()) { - continue; - } - data_[entry.index].fvalue = entry.fvalue; - ++feature_count; +inline void RegTree::FVec::Fill(SparsePage::Inst const& inst) { + auto p_data = inst.data(); + auto p_out = data_.data(); + + for (std::size_t i = 0, n = inst.size(); i < n; ++i) { + auto const& entry = p_data[i]; + p_out[entry.index] = entry.fvalue; } - has_missing_ = data_.size() != feature_count; + has_missing_ = data_.size() != inst.size(); } -inline void RegTree::FVec::Drop() { - Entry e{}; - e.flag = -1; - std::fill_n(data_.data(), data_.size(), e); - has_missing_ = true; -} +inline void RegTree::FVec::Drop() { this->Init(this->Size()); } inline size_t RegTree::FVec::Size() const { return data_.size(); } -inline bst_float RegTree::FVec::GetFvalue(size_t i) const { - return data_[i].fvalue; +inline float RegTree::FVec::GetFvalue(size_t i) const { + return data_[i]; } -inline bool RegTree::FVec::IsMissing(size_t i) const { - return data_[i].flag == -1; -} +inline bool RegTree::FVec::IsMissing(size_t i) const { return std::isnan(data_[i]); } -inline bool RegTree::FVec::HasMissing() const { - return has_missing_; -} +inline bool RegTree::FVec::HasMissing() const { return has_missing_; } // Multi-target tree not yet implemented error inline StringView MTNotImplemented() { diff --git a/plugin/sycl/predictor/predictor.cc b/plugin/sycl/predictor/predictor.cc index a92c820ab350..43356f64eb0b 100755 --- a/plugin/sycl/predictor/predictor.cc +++ b/plugin/sycl/predictor/predictor.cc @@ -201,8 +201,8 @@ class Predictor : public xgboost::Predictor { } void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, - const gbm::GBTreeModel &model, uint32_t tree_begin, - uint32_t tree_end = 0) const override { + const gbm::GBTreeModel &model, bst_tree_t tree_begin, + bst_tree_t tree_end = 0) const override { auto* out_preds = &predts->predictions; out_preds->SetDevice(ctx_->Device()); if (tree_end == 0) { @@ -221,28 +221,20 @@ class Predictor : public xgboost::Predictor { bool InplacePredict(std::shared_ptr p_m, const gbm::GBTreeModel &model, float missing, - PredictionCacheEntry *out_preds, uint32_t tree_begin, - unsigned tree_end) const override { + PredictionCacheEntry *out_preds, bst_tree_t tree_begin, + bst_tree_t tree_end) const override { LOG(WARNING) << "InplacePredict is not yet implemented for SYCL. CPU Predictor is used."; return cpu_predictor->InplacePredict(p_m, model, missing, out_preds, tree_begin, tree_end); } - void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit, - bool is_column_split) const override { - LOG(WARNING) << "PredictInstance is not yet implemented for SYCL. CPU Predictor is used."; - cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit, is_column_split); - } - void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + const gbm::GBTreeModel& model, bst_tree_t ntree_limit) const override { LOG(WARNING) << "PredictLeaf is not yet implemented for SYCL. CPU Predictor is used."; cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit); } void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, - const gbm::GBTreeModel& model, uint32_t ntree_limit, + const gbm::GBTreeModel& model, bst_tree_t ntree_limit, const std::vector* tree_weights, bool approximate, int condition, unsigned condition_feature) const override { @@ -252,7 +244,7 @@ class Predictor : public xgboost::Predictor { } void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, - const gbm::GBTreeModel& model, unsigned ntree_limit, + const gbm::GBTreeModel& model, bst_tree_t ntree_limit, const std::vector* tree_weights, bool approximate) const override { LOG(WARNING) << "PredictInteractionContributions is not yet implemented for SYCL. " diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 2817face3e63..17f3ed4c6824 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -1,5 +1,5 @@ /** - * Copyright 2017-2024, XGBoost Contributors + * Copyright 2017-2025, XGBoost Contributors * \file column_matrix.h * \brief Utility for fast column-wise access * \author Philip Cho @@ -45,7 +45,7 @@ class Column { virtual ~Column() = default; [[nodiscard]] bst_bin_t GetGlobalBinIdx(size_t idx) const { - return index_base_ + static_cast(index_[idx]); + return index_base_ + static_cast(index_.data()[idx]); } /* returns number of elements in column */ @@ -53,7 +53,7 @@ class Column { private: /* bin indexes in range [0, max_bins - 1] */ - common::Span index_; + common::Span index_; /* bin index offset for specific feature */ bst_bin_t const index_base_; }; diff --git a/src/common/hist_util.h b/src/common/hist_util.h index efec6bd11b1b..dc2bc3fd6a89 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -83,10 +83,7 @@ class HistogramCuts { [[nodiscard]] bst_bin_t FeatureBins(bst_feature_t feature) const { return cut_ptrs_.ConstHostVector().at(feature + 1) - cut_ptrs_.ConstHostVector()[feature]; } - [[nodiscard]] bst_feature_t NumFeatures() const { - CHECK_EQ(this->min_vals_.Size(), this->cut_ptrs_.Size() - 1); - return this->min_vals_.Size(); - } + [[nodiscard]] bst_feature_t NumFeatures() const { return this->cut_ptrs_.Size() - 1; } std::vector const& Ptrs() const { return cut_ptrs_.ConstHostVector(); } std::vector const& Values() const { return cut_values_.ConstHostVector(); } diff --git a/src/common/ref_resource_view.h b/src/common/ref_resource_view.h index 81058d923d3b..3c33a839ab77 100644 --- a/src/common/ref_resource_view.h +++ b/src/common/ref_resource_view.h @@ -1,5 +1,5 @@ /** - * Copyright 2023-2024, XGBoost Contributors + * Copyright 2023-2025, XGBoost Contributors */ #ifndef XGBOOST_COMMON_REF_RESOURCE_VIEW_H_ #define XGBOOST_COMMON_REF_RESOURCE_VIEW_H_ @@ -88,6 +88,14 @@ class RefResourceView { [[nodiscard]] value_type& operator[](size_type i) { return ptr_[i]; } [[nodiscard]] value_type const& operator[](size_type i) const { return ptr_[i]; } + [[nodiscard]] value_type& at(size_type i) { // NOLINT + SPAN_LT(i, this->size_); + return ptr_[i]; + } + [[nodiscard]] value_type const& at(size_type i) const { // NOLINT + SPAN_LT(i, this->size_); + return ptr_[i]; + } /** * @brief Get the underlying resource. diff --git a/src/data/adapter.h b/src/data/adapter.h index 9259e54b6a00..0888a2f86b4f 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2024, XGBoost Contributors + * Copyright 2019-2025, XGBoost Contributors * \file adapter.h */ #ifndef XGBOOST_DATA_ADAPTER_H_ @@ -546,12 +546,12 @@ class ColumnarAdapterBatch : public detail::NoMetaInfo { : columns_{columns}, ridx_{ridx} {} [[nodiscard]] std::size_t Size() const { return columns_.empty() ? 0 : columns_.size(); } - [[nodiscard]] COOTuple GetElement(std::size_t idx) const { - auto const& column = columns_[idx]; + [[nodiscard]] COOTuple GetElement(std::size_t fidx) const { + auto const& column = columns_.data()[fidx]; float value = column.valid.Data() == nullptr || column.valid.Check(ridx_) ? column(ridx_) : std::numeric_limits::quiet_NaN(); - return {ridx_, idx, value}; + return {ridx_, fidx, value}; } }; diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 4f7f8be0c9b5..e9a1a7e329ff 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -1,5 +1,5 @@ /** - * Copyright 2017-2024, XGBoost Contributors + * Copyright 2017-2025, XGBoost Contributors * \brief Data type for fast histogram aggregation. */ #include "gradient_index.h" @@ -205,19 +205,11 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const { return this->GetFvalue(ptrs, values, mins, ridx, fidx, is_cat); } -float GHistIndexMatrix::GetFvalue(std::vector const &ptrs, - std::vector const &values, std::vector const &mins, - bst_idx_t ridx, bst_feature_t fidx, bool is_cat) const { - if (is_cat) { - auto gidx = GetGindex(ridx, fidx); - if (gidx == -1) { - return std::numeric_limits::quiet_NaN(); - } - return values[gidx]; - } - +float GetFvalueImpl(std::vector const &ptrs, std::vector const &values, + std::vector const &mins, bst_idx_t ridx, bst_feature_t fidx, + bst_idx_t base_rowid, std::unique_ptr const &columns_) { auto get_bin_val = [&](auto &column) { - auto bin_idx = column[ridx - this->base_rowid]; + auto bin_idx = column[ridx - base_rowid]; if (bin_idx == common::DenseColumnIter::kMissingId) { return std::numeric_limits::quiet_NaN(); } @@ -233,7 +225,7 @@ float GHistIndexMatrix::GetFvalue(std::vector const &ptrs, } else { return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { auto column = columns_->DenseColumn(fidx); - auto bin_idx = column[ridx - this->base_rowid]; + auto bin_idx = column[ridx - base_rowid]; return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx); }); } diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index d732e0a3c868..6560198093c1 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 by XGBoost Contributors + * Copyright 2017-2025, XGBoost Contributors * \brief Data type for fast histogram aggregation. */ #ifndef XGBOOST_DATA_GRADIENT_INDEX_H_ @@ -9,8 +9,9 @@ #include // for atomic #include // for size_t #include // for uint32_t +#include // for numeric_limits #include // for make_unique -#include +#include // for vector #include "../common/categorical.h" #include "../common/error_msg.h" // for InfInData @@ -29,6 +30,10 @@ class ColumnMatrix; class AlignedFileWriteStream; } // namespace common +float GetFvalueImpl(std::vector const& ptrs, std::vector const& values, + std::vector const& mins, bst_idx_t ridx, bst_feature_t fidx, + bst_idx_t base_rowid, std::unique_ptr const& columns_); + /** * @brief preprocessed global index matrix, in CSR format. * @@ -245,12 +250,14 @@ class GHistIndexMatrix { void SetDense(bool is_dense) { isDense_ = is_dense; } [[nodiscard]] bst_idx_t BaseRowId() const { return base_rowid; } /** - * @brief Get the local row index. + * @brief Get the local row index from the global row index. */ - [[nodiscard]] bst_idx_t RowIdx(bst_idx_t ridx) const { return row_ptr[ridx - this->base_rowid]; } + [[nodiscard]] bst_idx_t RowIdx(bst_idx_t gridx) const { + return row_ptr[gridx - this->base_rowid]; + } [[nodiscard]] bst_idx_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; } - [[nodiscard]] bst_feature_t Features() const { return cut.Ptrs().size() - 1; } + [[nodiscard]] bst_feature_t Features() const { return cut.NumFeatures(); } [[nodiscard]] bool ReadColumnPage(common::AlignedResourceReadStream* fi); [[nodiscard]] std::size_t WriteColumnPage(common::AlignedFileWriteStream* fo) const; @@ -262,7 +269,21 @@ class GHistIndexMatrix { [[nodiscard]] float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const; [[nodiscard]] float GetFvalue(std::vector const& ptrs, std::vector const& values, std::vector const& mins, - bst_idx_t ridx, bst_feature_t fidx, bool is_cat) const; + bst_idx_t ridx, bst_feature_t fidx, bool is_cat) const { + if (is_cat) { + auto gidx = GetGindex(ridx, fidx); + if (gidx == -1) { + return std::numeric_limits::quiet_NaN(); + } + return values[gidx]; + } + if (this->IsDense()) { + auto begin = RowIdx(ridx); + auto bin_idx = this->index[begin + fidx]; + return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx); + } + return GetFvalueImpl(ptrs, values, mins, ridx, fidx, this->base_rowid, this->columns_); + } [[nodiscard]] common::HistogramCuts& Cuts() { return cut; } [[nodiscard]] common::HistogramCuts const& Cuts() const { return cut; } diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index d9d48f00bd0d..2cacfe078b4b 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2024, XGBoost Contributors + * Copyright 2014-2025, XGBoost Contributors * \file gblinear.cc * \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net * the update rule is parallel coordinate descent (shotgun) @@ -163,17 +163,6 @@ class GBLinear : public GradientBooster { this->PredictBatchInternal(p_fmat, &out_preds->HostVector()); monitor_.Stop("PredictBatch"); } - // add base margin - void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, - uint32_t layer_begin, uint32_t) override { - LinearCheckLayer(layer_begin); - const int ngroup = model_.learner_model_param->num_output_group; - - auto base_score = learner_model_param_->BaseScore(ctx_); - for (int gid = 0; gid < ngroup; ++gid) { - this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, base_score(0)); - } - } void PredictLeaf(DMatrix *, HostDeviceVector *, unsigned, unsigned) override { LOG(FATAL) << "gblinear does not support prediction of leaf index"; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 4fe4d73a24ec..8eb75323d00a 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2024, XGBoost Contributors + * Copyright 2014-2025, XGBoost Contributors * \file gbtree.cc * \brief gradient boosted tree implementation. * \author Tianqi Chen @@ -880,16 +880,6 @@ class Dart : public GBTree { } } - void PredictInstance(const SparsePage::Inst &inst, - std::vector *out_preds, - unsigned layer_begin, unsigned layer_end) override { - DropTrees(false); - auto &predictor = this->GetPredictor(false); - uint32_t _, tree_end; - std::tie(_, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); - predictor->PredictInstance(inst, out_preds, model_, tree_end); - } - void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, bst_layer_t layer_begin, bst_layer_t layer_end, bool approximate) override { diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index e8765d5c5447..1fbf0ebdaf7f 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2024, XGBoost Contributors + * Copyright 2014-2025, XGBoost Contributors * \file gbtree.cc * \brief gradient boosted tree implementation. * \author Tianqi Chen @@ -287,13 +287,6 @@ class GBTree : public GradientBooster { } } - void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, - uint32_t layer_begin, uint32_t layer_end) override { - std::uint32_t _, tree_end; - std::tie(_, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); - cpu_predictor_->PredictInstance(inst, out_preds, model_, tree_end); - } - void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, uint32_t layer_begin, uint32_t layer_end) override { diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 24d84c23673e..bc4ab4d8e811 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -1,5 +1,5 @@ /** - * Copyright 2015-2024, XGBoost Contributors + * Copyright 2015-2025, XGBoost Contributors * \file elementwise_metric.cu * \brief evaluation metrics for elementwise binary or regression. * \author Kailong Chen, Tianqi Chen @@ -75,16 +75,27 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) { // for approximation in distributed setting. For rmse: // - sqrt(1/w(sum_t0 + sum_t1 + ... + sum_tm)) // multi-target // - sqrt(avg_t0) + sqrt(avg_t1) + ... sqrt(avg_tm) // distributed - common::ParallelFor(info.labels.Size(), ctx->Threads(), [&](size_t i) { + + auto size = info.labels.Size(); + auto const kBlockSize = 2048; + auto n_blocks = size / kBlockSize + 1; + + common::ParallelFor(n_blocks, n_threads, [&](auto block_idx) { + const size_t begin = block_idx * kBlockSize; + const size_t end = std::min(size, begin + kBlockSize); + + double sum_score = 0, sum_weight = 0; + for (std::size_t i = begin; i < end; ++i) { + auto [sample_id, target_id] = linalg::UnravelIndex(i, labels.Shape()); + + auto [v, wt] = loss(i, sample_id, target_id); + sum_score += v; + sum_weight += wt; + } + auto t_idx = omp_get_thread_num(); - size_t sample_id; - size_t target_id; - std::tie(sample_id, target_id) = linalg::UnravelIndex(i, labels.Shape()); - - float v, wt; - std::tie(v, wt) = loss(i, sample_id, target_id); - score_tloc[t_idx] += v; - weight_tloc[t_idx] += wt; + score_tloc[t_idx] += sum_score; + weight_tloc[t_idx] += sum_weight; }); double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0); double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 4ac39d802a2b..2bb1f375caaa 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2017-2024, XGBoost Contributors + * Copyright 2017-2025, XGBoost Contributors */ #include // for max, fill, min #include // for any, any_cast @@ -14,7 +14,6 @@ #include "../collective/communicator-inl.h" // for Allreduce, IsDistributed #include "../collective/allreduce.h" #include "../common/bitfield.h" // for RBitField8 -#include "../common/categorical.h" // for IsCat, Decision #include "../common/common.h" // for DivRoundUp #include "../common/error_msg.h" // for InplacePredictProxy #include "../common/math.h" // for CheckNAN @@ -56,33 +55,9 @@ bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat, return nidx; } -bst_float PredValue(const SparsePage::Inst &inst, - const std::vector> &trees, - const std::vector &tree_info, std::int32_t bst_group, - RegTree::FVec *p_feats, std::uint32_t tree_begin, std::uint32_t tree_end) { - bst_float psum = 0.0f; - p_feats->Fill(inst); - for (size_t i = tree_begin; i < tree_end; ++i) { - if (tree_info[i] == bst_group) { - auto const &tree = *trees[i]; - bool has_categorical = tree.HasCategoricalSplit(); - auto cats = tree.GetCategoriesMatrix(); - bst_node_t nidx = -1; - if (has_categorical) { - nidx = GetLeafIndex(tree, *p_feats, cats); - } else { - nidx = GetLeafIndex(tree, *p_feats, cats); - } - psum += (*trees[i])[nidx].LeafValue(); - } - } - p_feats->Drop(); - return psum; -} - template -bst_float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree, - RegTree::CategoricalSplitMatrix const &cats) { +[[nodiscard]] float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree, + RegTree::CategoricalSplitMatrix const &cats) noexcept(true) { const bst_node_t leaf = p_feats.HasMissing() ? GetLeafIndex(tree, p_feats, cats) : GetLeafIndex(tree, p_feats, cats); @@ -96,7 +71,7 @@ bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat, RegTree::CategoricalSplitMatrix const &cats) { bst_node_t nidx{0}; while (!tree.IsLeaf(nidx)) { - unsigned split_index = tree.SplitIndex(nidx); + bst_feature_t split_index = tree.SplitIndex(nidx); auto fvalue = feat.GetFvalue(split_index); nidx = GetNextNodeMulti( tree, nidx, fvalue, has_missing && feat.IsMissing(split_index), cats); @@ -161,15 +136,17 @@ void PredictByAllTrees(gbm::GBTreeModel const &model, std::uint32_t const tree_b } template -void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature, - DataView *batch, const size_t fvec_offset, std::vector *p_feats) { - for (size_t i = 0; i < block_size; ++i) { - RegTree::FVec &feats = (*p_feats)[fvec_offset + i]; +void FVecFill(std::size_t const block_size, std::size_t const batch_offset, + bst_feature_t n_features, DataView *p_batch, std::size_t const fvec_offset, + std::vector *p_feats) { + auto &feats_vec = *p_feats; + auto &batch = *p_batch; + for (std::size_t i = 0; i < block_size; ++i) { + RegTree::FVec &feats = feats_vec[fvec_offset + i]; if (feats.Size() == 0) { - feats.Init(num_feature); + feats.Init(n_features); } - const SparsePage::Inst inst = (*batch)[batch_offset + i]; - feats.Fill(inst); + batch.Fill(batch_offset + i, &feats); } } @@ -181,115 +158,117 @@ void FVecDrop(std::size_t const block_size, std::size_t const fvec_offset, } } -static std::size_t constexpr kUnroll = 8; +// Convert a single sample in batch view to FVec +template +struct DataToFeatVec { + void Fill(bst_idx_t ridx, RegTree::FVec *p_feats) const { + auto &feats = *p_feats; + auto n_valid = static_cast(this)->DoFill(ridx, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + } +}; -struct SparsePageView { +struct SparsePageView : public DataToFeatVec { bst_idx_t base_rowid; HostSparsePageView view; explicit SparsePageView(SparsePage const *p) : base_rowid{p->base_rowid} { view = p->GetView(); } - SparsePage::Inst operator[](size_t i) { return view[i]; } - [[nodiscard]] size_t Size() const { return view.Size(); } -}; + [[nodiscard]] std::size_t Size() const { return view.Size(); } -struct SingleInstanceView { - bst_idx_t base_rowid{}; - SparsePage::Inst const &inst; + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { + auto p_data = view[ridx].data(); - explicit SingleInstanceView(SparsePage::Inst const &instance) : inst{instance} {} - SparsePage::Inst operator[](size_t) { return inst; } - static size_t Size() { return 1; } + for (std::size_t i = 0, n = view[ridx].size(); i < n; ++i) { + auto const &entry = p_data[i]; + out[entry.index] = entry.fvalue; + } + + return view[ridx].size(); + } }; -struct GHistIndexMatrixView { +struct GHistIndexMatrixView : public DataToFeatVec { private: GHistIndexMatrix const &page_; - std::uint64_t const n_features_; common::Span ft_; - common::Span workspace_; - std::vector current_unroll_; - std::vector const& ptrs_; - std::vector const& mins_; - std::vector const& values_; + std::vector const &ptrs_; + std::vector const &mins_; + std::vector const &values_; public: - bst_idx_t base_rowid; + bst_idx_t const base_rowid; public: - GHistIndexMatrixView(GHistIndexMatrix const &_page, uint64_t n_feat, - common::Span ft, common::Span workplace, - int32_t n_threads) + GHistIndexMatrixView(GHistIndexMatrix const &_page, common::Span ft) : page_{_page}, - n_features_{n_feat}, ft_{ft}, - workspace_{workplace}, - current_unroll_(n_threads > 0 ? n_threads : 1, 0), ptrs_{_page.cut.Ptrs()}, mins_{_page.cut.MinValues()}, values_{_page.cut.Values()}, base_rowid{_page.base_rowid} {} - SparsePage::Inst operator[](size_t r) { - r += base_rowid; - auto t = omp_get_thread_num(); - auto const beg = (n_features_ * kUnroll * t) + (current_unroll_[t] * n_features_); - size_t non_missing{static_cast(beg)}; - - auto ws = workspace_.data(); - for (bst_feature_t c = 0; c < n_features_; ++c) { - float f = page_.GetFvalue(ptrs_, values_, mins_, r, c, common::IsCat(ft_, c)); - if (!common::CheckNAN(f)) { - ws[non_missing] = Entry{c, f}; - ++non_missing; + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float* out) const { + auto gridx = ridx + this->base_rowid; + auto n_features = page_.Features(); + + bst_idx_t n_non_missings = 0; + if (page_.IsDense()) { + common::DispatchBinType(page_.index.GetBinTypeSize(), [&](auto t) { + using T = decltype(t); + auto ptr = page_.index.data(); + auto rbeg = page_.row_ptr[ridx]; + for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { + bst_bin_t bin_idx; + float fvalue; + if (common::IsCat(ft_, fidx)) { + bin_idx = page_.GetGindex(gridx, fidx); + fvalue = this->values_[bin_idx]; + } else { + bin_idx = ptr[rbeg + fidx] + page_.index.Offset()[fidx]; + fvalue = + common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, bin_idx); + } + out[fidx] = fvalue; + } + }); + n_non_missings += n_features; + } else { + for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { + float f = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, common::IsCat(ft_, fidx)); + if (!common::CheckNAN(f)) { + out[fidx] = f; + n_non_missings++; + } } } - - auto ret = workspace_.subspan(beg, non_missing - beg); - current_unroll_[t]++; - if (current_unroll_[t] == kUnroll) { - current_unroll_[t] = 0; - } - return ret; + return n_non_missings; } - [[nodiscard]] size_t Size() const { return page_.Size(); } + + [[nodiscard]] auto Size() const { return page_.Size(); } }; template -class AdapterView { - Adapter* adapter_; +class AdapterView : public DataToFeatVec> { + Adapter const *adapter_; float missing_; - common::Span workspace_; - std::vector current_unroll_; public: - explicit AdapterView(Adapter *adapter, float missing, common::Span workplace, - int32_t n_threads) - : adapter_{adapter}, - missing_{missing}, - workspace_{workplace}, - current_unroll_(n_threads > 0 ? n_threads : 1, 0) {} - SparsePage::Inst operator[](size_t i) { - bst_feature_t columns = adapter_->NumColumns(); + explicit AdapterView(Adapter const *adapter, float missing) + : adapter_{adapter}, missing_{missing} {} + + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { auto const &batch = adapter_->Value(); - auto row = batch.GetLine(i); - auto t = omp_get_thread_num(); - auto const beg = (columns * kUnroll * t) + (current_unroll_[t] * columns); - size_t non_missing {beg}; + auto row = batch.GetLine(ridx); + bst_idx_t n_non_missings = 0; for (size_t c = 0; c < row.Size(); ++c) { auto e = row.GetElement(c); if (missing_ != e.value && !common::CheckNAN(e.value)) { - workspace_[non_missing] = - Entry{static_cast(e.column_idx), e.value}; - ++non_missing; + out[e.column_idx] = e.value; + n_non_missings++; } } - auto ret = workspace_.subspan(beg, non_missing - beg); - current_unroll_[t]++; - if (current_unroll_[t] == kUnroll) { - current_unroll_[t] = 0; - } - return ret; + return n_non_missings; } [[nodiscard]] size_t Size() const { return adapter_->NumRows(); } @@ -297,24 +276,26 @@ class AdapterView { bst_idx_t const static base_rowid = 0; // NOLINT }; -template +template void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model, - std::uint32_t tree_begin, std::uint32_t tree_end, - std::vector *p_thread_temp, int32_t n_threads, + bst_tree_t tree_begin, bst_tree_t tree_end, + std::vector *p_thread_temp, + std::int32_t n_threads, linalg::TensorView out_predt) { auto &thread_temp = *p_thread_temp; - // parallel over local batch - const auto nsize = static_cast(batch.Size()); - const int num_feature = model.learner_model_param->num_feature; - omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size); + // Parallel over local batches + auto const n_samples = batch.Size(); + auto const n_features = model.learner_model_param->num_feature; + auto const n_blocks = common::DivRoundUp(n_samples, kBlockOfRowsSize); - common::ParallelFor(n_blocks, n_threads, [&](bst_omp_uint block_id) { - const size_t batch_offset = block_id * block_of_rows_size; - const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size); - const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size; + common::ParallelFor(n_blocks, n_threads, [&](auto block_id) { + auto const batch_offset = block_id * kBlockOfRowsSize; + auto const block_size = + std::min(static_cast(n_samples - batch_offset), kBlockOfRowsSize); + auto const fvec_offset = omp_get_thread_num() * kBlockOfRowsSize; - FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp); + FVecFill(block_size, batch_offset, n_features, &batch, fvec_offset, p_thread_temp); // process block of rows through all trees to keep cache locality PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, thread_temp, fvec_offset, block_size, out_predt); @@ -340,11 +321,11 @@ float FillNodeMeanValues(RegTree const *tree, bst_node_t nidx, std::vector* mean_values) { - size_t num_nodes = tree->NumNodes(); - if (mean_values->size() == num_nodes) { + auto n_nodes = tree->NumNodes(); + if (static_cast(mean_values->size()) == n_nodes) { return; } - mean_values->resize(num_nodes); + mean_values->resize(n_nodes); FillNodeMeanValues(tree, 0, mean_values); } @@ -421,14 +402,6 @@ class ColumnSplitHelper { } } - void PredictInstance(Context const *ctx, SparsePage::Inst const &inst, - std::vector *out_preds) { - CHECK(xgboost::collective::IsDistributed()) - << "column-split prediction is only supported for distributed training"; - - PredictBatchKernel(ctx, SingleInstanceView{inst}, out_preds); - } - void PredictLeaf(Context const* ctx, DMatrix *p_fmat, std::vector *out_preds) { CHECK(xgboost::collective::IsDistributed()) << "column-split prediction is only supported for distributed training"; @@ -650,8 +623,8 @@ class ColumnSplitHelper { class CPUPredictor : public Predictor { protected: - void PredictDMatrix(DMatrix *p_fmat, std::vector *out_preds, - gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const { + void PredictDMatrix(DMatrix *p_fmat, std::vector *out_preds, gbm::GBTreeModel const &model, + bst_tree_t tree_begin, bst_tree_t tree_end) const { if (p_fmat->Info().IsColumnSplit()) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict DMatrix with column split" << MTNotImplemented(); @@ -677,17 +650,16 @@ class CPUPredictor : public Predictor { auto out_predt = linalg::MakeTensorView(ctx_, *out_preds, n_samples, n_groups); if (!p_fmat->PageExists()) { - std::vector workspace(p_fmat->Info().num_col_ * kUnroll * n_threads); auto ft = p_fmat->Info().feature_types.ConstHostVector(); for (auto const &batch : p_fmat->GetBatches(ctx_, {})) { if (blocked) { PredictBatchByBlockOfRowsKernel( - GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model, - tree_begin, tree_end, &feat_vecs, n_threads, out_predt); + GHistIndexMatrixView{batch, ft}, model, tree_begin, tree_end, &feat_vecs, n_threads, + out_predt); } else { PredictBatchByBlockOfRowsKernel( - GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model, - tree_begin, tree_end, &feat_vecs, n_threads, out_predt); + GHistIndexMatrixView{batch, ft}, model, tree_begin, tree_end, &feat_vecs, n_threads, + out_predt); } } } else { @@ -707,14 +679,11 @@ class CPUPredictor : public Predictor { } template - void PredictContributionKernel(DataView batch, const MetaInfo& info, - const gbm::GBTreeModel& model, - const std::vector* tree_weights, - std::vector>* mean_values, - std::vector* feat_vecs, - std::vector* contribs, uint32_t ntree_limit, - bool approximate, int condition, - unsigned condition_feature) const { + void PredictContributionKernel( + DataView batch, const MetaInfo &info, const gbm::GBTreeModel &model, + const std::vector *tree_weights, std::vector> *mean_values, + std::vector *feat_vecs, std::vector *contribs, + bst_tree_t ntree_limit, bool approximate, int condition, unsigned condition_feature) const { const int num_feature = model.learner_model_param->num_feature; const int ngroup = model.learner_model_param->num_output_group; CHECK_NE(ngroup, 0); @@ -734,10 +703,10 @@ class CPUPredictor : public Predictor { std::vector this_tree_contribs(ncolumns); // loop over all classes for (int gid = 0; gid < ngroup; ++gid) { - bst_float* p_contribs = &(*contribs)[(row_idx * ngroup + gid) * ncolumns]; - feats.Fill(batch[i]); + bst_float *p_contribs = &(*contribs)[(row_idx * ngroup + gid) * ncolumns]; + batch.Fill(i, &feats); // calculate contributions - for (unsigned j = 0; j < ntree_limit; ++j) { + for (bst_tree_t j = 0; j < ntree_limit; ++j) { auto *tree_mean_values = &mean_values->at(j); std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); if (model.tree_info[j] != gid) { @@ -771,8 +740,8 @@ class CPUPredictor : public Predictor { public: explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {} - void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, const gbm::GBTreeModel &model, - uint32_t tree_begin, uint32_t tree_end = 0) const override { + void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, gbm::GBTreeModel const &model, + bst_tree_t tree_begin, bst_tree_t tree_end = 0) const override { auto *out_preds = &predts->predictions; // This is actually already handled in gbm, but large amount of tests rely on the // behaviour. @@ -785,8 +754,8 @@ class CPUPredictor : public Predictor { template void DispatchedInplacePredict(std::any const &x, std::shared_ptr p_m, const gbm::GBTreeModel &model, float missing, - PredictionCacheEntry *out_preds, uint32_t tree_begin, - uint32_t tree_end) const { + PredictionCacheEntry *out_preds, bst_tree_t tree_begin, + bst_tree_t tree_end) const { auto const n_threads = this->ctx_->Threads(); auto m = std::any_cast>(x); CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) @@ -796,20 +765,19 @@ class CPUPredictor : public Predictor { CHECK_EQ(p_m->Info().num_col_, m->NumColumns()); this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model); - std::vector workspace(m->NumColumns() * kUnroll * n_threads); auto &predictions = out_preds->predictions.HostVector(); std::vector thread_temp; InitThreadTemp(n_threads * kBlockSize, &thread_temp); std::size_t n_groups = model.learner_model_param->OutputLength(); auto out_predt = linalg::MakeTensorView(ctx_, predictions, m->NumRows(), n_groups); PredictBatchByBlockOfRowsKernel, kBlockSize>( - AdapterView(m.get(), missing, common::Span{workspace}, n_threads), model, - tree_begin, tree_end, &thread_temp, n_threads, out_predt); + AdapterView(m.get(), missing), model, tree_begin, tree_end, &thread_temp, + n_threads, out_predt); } bool InplacePredict(std::shared_ptr p_m, const gbm::GBTreeModel &model, float missing, - PredictionCacheEntry *out_preds, uint32_t tree_begin, - unsigned tree_end) const override { + PredictionCacheEntry *out_preds, bst_tree_t tree_begin, + bst_tree_t tree_end) const override { auto proxy = dynamic_cast(p_m.get()); CHECK(proxy)<< error::InplacePredictProxy(); CHECK(!p_m->Info().IsColumnSplit()) @@ -836,44 +804,11 @@ class CPUPredictor : public Predictor { return true; } - void PredictInstance(const SparsePage::Inst &inst, std::vector *out_preds, - const gbm::GBTreeModel &model, unsigned ntree_limit, - bool is_column_split) const override { - CHECK(!model.learner_model_param->IsVectorLeaf()) << "predict instance" << MTNotImplemented(); - ntree_limit *= model.learner_model_param->num_output_group; - if (ntree_limit == 0 || ntree_limit > model.trees.size()) { - ntree_limit = static_cast(model.trees.size()); - } - out_preds->resize(model.learner_model_param->num_output_group); - - if (is_column_split) { - CHECK(!model.learner_model_param->IsVectorLeaf()) - << "Predict instance with column split" << MTNotImplemented(); - - ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit); - helper.PredictInstance(ctx_, inst, out_preds); - return; - } - - std::vector feat_vecs; - feat_vecs.resize(1, RegTree::FVec()); - feat_vecs[0].Init(model.learner_model_param->num_feature); - auto base_score = model.learner_model_param->BaseScore(ctx_)(0); - // loop over output groups - for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) { - (*out_preds)[gid] = scalar::PredValue(inst, model.trees, model.tree_info, gid, &feat_vecs[0], - 0, ntree_limit) + - base_score; - } - } - - void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *out_preds, - const gbm::GBTreeModel &model, unsigned ntree_limit) const override { + void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *out_preds, + gbm::GBTreeModel const &model, bst_tree_t ntree_limit) const override { auto const n_threads = this->ctx_->Threads(); // number of valid trees - if (ntree_limit == 0 || ntree_limit > model.trees.size()) { - ntree_limit = static_cast(model.trees.size()); - } + ntree_limit = GetTreeLimit(model.trees, ntree_limit); const MetaInfo &info = p_fmat->Info(); std::vector &preds = out_preds->HostVector(); preds.resize(info.num_row_ * ntree_limit); @@ -902,7 +837,7 @@ class CPUPredictor : public Predictor { feats.Init(num_feature); } feats.Fill(page[i]); - for (std::uint32_t j = 0; j < ntree_limit; ++j) { + for (bst_tree_t j = 0; j < ntree_limit; ++j) { auto const &tree = *model.trees[j]; auto const &cats = tree.GetCategoriesMatrix(); bst_node_t nidx; @@ -919,7 +854,7 @@ class CPUPredictor : public Predictor { } void PredictContribution(DMatrix *p_fmat, HostDeviceVector *out_contribs, - const gbm::GBTreeModel &model, uint32_t ntree_limit, + const gbm::GBTreeModel &model, bst_tree_t ntree_limit, std::vector const *tree_weights, bool approximate, int condition, unsigned condition_feature) const override { CHECK(!model.learner_model_param->IsVectorLeaf()) @@ -931,9 +866,7 @@ class CPUPredictor : public Predictor { InitThreadTemp(n_threads, &feat_vecs); const MetaInfo& info = p_fmat->Info(); // number of valid trees - if (ntree_limit == 0 || ntree_limit > model.trees.size()) { - ntree_limit = static_cast(model.trees.size()); - } + ntree_limit = GetTreeLimit(model.trees, ntree_limit); size_t const ncolumns = model.learner_model_param->num_feature + 1; // allocate space for (number of features + bias) times the number of rows std::vector& contribs = out_contribs->HostVector(); @@ -948,13 +881,11 @@ class CPUPredictor : public Predictor { }); // start collecting the contributions if (!p_fmat->PageExists()) { - std::vector workspace(info.num_col_ * kUnroll * n_threads); auto ft = p_fmat->Info().feature_types.ConstHostVector(); for (const auto &batch : p_fmat->GetBatches(ctx_, {})) { - PredictContributionKernel( - GHistIndexMatrixView{batch, info.num_col_, ft, workspace, n_threads}, - info, model, tree_weights, &mean_values, &feat_vecs, &contribs, ntree_limit, - approximate, condition, condition_feature); + PredictContributionKernel(GHistIndexMatrixView{batch, ft}, info, model, tree_weights, + &mean_values, &feat_vecs, &contribs, ntree_limit, approximate, + condition, condition_feature); } } else { for (const auto &batch : p_fmat->GetBatches()) { @@ -965,9 +896,9 @@ class CPUPredictor : public Predictor { } } - void PredictInteractionContributions(DMatrix *p_fmat, HostDeviceVector *out_contribs, - const gbm::GBTreeModel &model, unsigned ntree_limit, - std::vector const *tree_weights, + void PredictInteractionContributions(DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t ntree_limit, + std::vector const *tree_weights, bool approximate) const override { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict interaction contribution" << MTNotImplemented(); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 7cb472a4ef4a..d99f00cd35a4 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -1,5 +1,5 @@ /** - * Copyright 2017-2024, XGBoost Contributors + * Copyright 2017-2025, XGBoost Contributors */ #include #include @@ -39,7 +39,7 @@ struct TreeView { common::Span d_tree; XGBOOST_DEVICE - TreeView(size_t tree_begin, size_t tree_idx, common::Span d_nodes, + TreeView(bst_tree_t tree_begin, bst_tree_t tree_idx, common::Span d_nodes, common::Span d_tree_segments, common::Span d_tree_split_types, common::Span d_cat_tree_segments, @@ -252,7 +252,7 @@ PredictLeafKernel(Data data, common::Span d_nodes, common::Span d_cat_node_segments, common::Span d_categories, - size_t tree_begin, size_t tree_end, bst_feature_t num_features, + bst_tree_t tree_begin, bst_tree_t tree_end, bst_feature_t num_features, size_t num_rows, bool use_shared, float missing) { bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x; @@ -260,7 +260,7 @@ PredictLeafKernel(Data data, common::Span d_nodes, return; } Loader loader{data, use_shared, num_features, num_rows, missing}; - for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { TreeView d_tree{ tree_begin, tree_idx, d_nodes, d_tree_segments, d_tree_split_types, d_cat_tree_segments, @@ -285,8 +285,8 @@ PredictKernel(Data data, common::Span d_nodes, common::Span d_tree_split_types, common::Span d_cat_tree_segments, common::Span d_cat_node_segments, - common::Span d_categories, size_t tree_begin, - size_t tree_end, size_t num_features, size_t num_rows, + common::Span d_categories, bst_tree_t tree_begin, + bst_tree_t tree_end, size_t num_features, size_t num_rows, bool use_shared, int num_group, float missing) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; Loader loader(data, use_shared, num_features, num_rows, missing); @@ -294,7 +294,7 @@ PredictKernel(Data data, common::Span d_nodes, if (num_group == 1) { float sum = 0; - for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { TreeView d_tree{ tree_begin, tree_idx, d_nodes, d_tree_segments, d_tree_split_types, d_cat_tree_segments, @@ -304,7 +304,7 @@ PredictKernel(Data data, common::Span d_nodes, } d_out_predictions[global_idx] += sum; } else { - for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { int tree_group = d_tree_group[tree_idx]; TreeView d_tree{ tree_begin, tree_idx, d_nodes, @@ -474,7 +474,7 @@ struct ShapSplitCondition { struct PathInfo { int64_t leaf_position; // -1 not a leaf size_t length; - size_t tree_idx; + bst_tree_t tree_idx; }; // Transform model into path element form for GPUTreeShap @@ -494,8 +494,7 @@ void ExtractPaths(Context const* ctx, if (!n.IsLeaf() || n.IsDeleted()) { return PathInfo{-1, 0, 0}; } - size_t tree_idx = - dh::SegmentId(d_tree_segments.begin(), d_tree_segments.end(), idx); + bst_tree_t tree_idx = dh::SegmentId(d_tree_segments.begin(), d_tree_segments.end(), idx); size_t tree_offset = d_tree_segments[tree_idx]; size_t path_length = 1; while (!n.IsRoot()) { @@ -622,7 +621,7 @@ __global__ void MaskBitVectorKernel( common::Span d_cat_tree_segments, common::Span d_cat_node_segments, common::Span d_categories, BitVector decision_bits, BitVector missing_bits, - std::size_t tree_begin, std::size_t tree_end, bst_feature_t num_features, std::size_t num_rows, + bst_tree_t tree_begin, bst_tree_t tree_end, bst_feature_t num_features, std::size_t num_rows, std::size_t num_nodes, bool use_shared, float missing) { // This needs to be always instantiated since the data is loaded cooperatively by all threads. SparsePageLoader loader{data, use_shared, num_features, num_rows, missing}; @@ -695,7 +694,7 @@ __global__ void PredictByBitVectorKernel( common::Span d_cat_tree_segments, common::Span d_cat_node_segments, common::Span d_categories, BitVector decision_bits, BitVector missing_bits, - std::size_t tree_begin, std::size_t tree_end, std::size_t num_rows, std::size_t num_nodes, + bst_tree_t tree_begin, bst_tree_t tree_end, std::size_t num_rows, std::size_t num_nodes, std::uint32_t num_group) { auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x; if (row_idx >= num_rows) { @@ -704,7 +703,7 @@ __global__ void PredictByBitVectorKernel( std::size_t tree_offset = 0; if constexpr (predict_leaf) { - for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { TreeView d_tree{tree_begin, tree_idx, d_nodes, d_tree_segments, d_tree_split_types, d_cat_tree_segments, d_cat_node_segments, d_categories}; @@ -942,9 +941,8 @@ class GPUPredictor : public xgboost::Predictor { } } - void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, - const gbm::GBTreeModel& model, uint32_t tree_begin, - uint32_t tree_end = 0) const override { + void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, const gbm::GBTreeModel& model, + bst_tree_t tree_begin, bst_tree_t tree_end = 0) const override { CHECK(ctx_->Device().IsCUDA()) << "Set `device' to `cuda` for processing GPU data."; auto* out_preds = &predts->predictions; if (tree_end == 0) { @@ -956,8 +954,8 @@ class GPUPredictor : public xgboost::Predictor { template void DispatchedInplacePredict(std::any const& x, std::shared_ptr p_m, const gbm::GBTreeModel& model, float missing, - PredictionCacheEntry* out_preds, uint32_t tree_begin, - uint32_t tree_end) const { + PredictionCacheEntry* out_preds, bst_tree_t tree_begin, + bst_tree_t tree_end) const { uint32_t const output_groups = model.learner_model_param->num_output_group; auto m = std::any_cast>(x); @@ -996,9 +994,9 @@ class GPUPredictor : public xgboost::Predictor { tree_begin, tree_end, m->NumColumns(), m->NumRows(), use_shared, output_groups, missing); } - bool InplacePredict(std::shared_ptr p_m, const gbm::GBTreeModel& model, float missing, - PredictionCacheEntry* out_preds, uint32_t tree_begin, - unsigned tree_end) const override { + bool InplacePredict(std::shared_ptr p_m, gbm::GBTreeModel const& model, float missing, + PredictionCacheEntry* out_preds, bst_tree_t tree_begin, + bst_tree_t tree_end) const override { auto proxy = dynamic_cast(p_m.get()); CHECK(proxy) << error::InplacePredictProxy(); auto x = proxy->Adapter(); @@ -1016,11 +1014,9 @@ class GPUPredictor : public xgboost::Predictor { return true; } - void PredictContribution(DMatrix* p_fmat, - HostDeviceVector* out_contribs, - const gbm::GBTreeModel& model, unsigned tree_end, - std::vector const* tree_weights, - bool approximate, int, + void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, + const gbm::GBTreeModel& model, bst_tree_t tree_end, + std::vector const* tree_weights, bool approximate, int, unsigned) const override { std::string not_implemented{ "contribution is not implemented in the GPU predictor, use CPU instead."}; @@ -1034,9 +1030,7 @@ class GPUPredictor : public xgboost::Predictor { << "Predict contribution support for column-wise data split is not yet implemented."; dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); out_contribs->SetDevice(ctx_->Device()); - if (tree_end == 0 || tree_end > model.trees.size()) { - tree_end = static_cast(model.trees.size()); - } + tree_end = GetTreeLimit(model.trees, tree_end); const int ngroup = model.learner_model_param->num_output_group; CHECK_NE(ngroup, 0); @@ -1087,11 +1081,9 @@ class GPUPredictor : public xgboost::Predictor { }); } - void PredictInteractionContributions(DMatrix* p_fmat, - HostDeviceVector* out_contribs, - const gbm::GBTreeModel& model, - unsigned tree_end, - std::vector const* tree_weights, + void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, bool approximate) const override { std::string not_implemented{"contribution is not implemented in GPU " "predictor, use `cpu_predictor` instead."}; @@ -1103,9 +1095,7 @@ class GPUPredictor : public xgboost::Predictor { } dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); out_contribs->SetDevice(ctx_->Device()); - if (tree_end == 0 || tree_end > model.trees.size()) { - tree_end = static_cast(model.trees.size()); - } + tree_end = GetTreeLimit(model.trees, tree_end); const int ngroup = model.learner_model_param->num_output_group; CHECK_NE(ngroup, 0); @@ -1162,24 +1152,14 @@ class GPUPredictor : public xgboost::Predictor { }); } - void PredictInstance(const SparsePage::Inst&, - std::vector*, - const gbm::GBTreeModel&, unsigned, bool) const override { - LOG(FATAL) << "[Internal error]: " << __func__ - << " is not implemented in GPU Predictor."; - } - - void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *predictions, - const gbm::GBTreeModel &model, - unsigned tree_end) const override { + void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* predictions, + gbm::GBTreeModel const& model, bst_tree_t tree_end) const override { dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); auto max_shared_memory_bytes = ConfigureDevice(ctx_->Device()); const MetaInfo& info = p_fmat->Info(); bst_idx_t num_rows = info.num_row_; - if (tree_end == 0 || tree_end > model.trees.size()) { - tree_end = static_cast(model.trees.size()); - } + tree_end = GetTreeLimit(model.trees, tree_end); predictions->SetDevice(ctx_->Device()); predictions->Resize(num_rows * tree_end); DeviceModel d_model; diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h index 044832010ccb..e3be91d5fa3f 100644 --- a/src/predictor/predict_fn.h +++ b/src/predictor/predict_fn.h @@ -1,10 +1,14 @@ /** - * Copyright 2021-2023 by XGBoost Contributors + * Copyright 2021-2025, XGBoost Contributors */ #ifndef XGBOOST_PREDICTOR_PREDICT_FN_H_ #define XGBOOST_PREDICTOR_PREDICT_FN_H_ -#include "../common/categorical.h" -#include "xgboost/tree_model.h" + +#include // for unique_ptr +#include // for vector + +#include "../common/categorical.h" // for IsCat, Decision +#include "xgboost/tree_model.h" // for RegTree namespace xgboost::predictor { /** @brief Whether it should traverse to the left branch of a tree. */ @@ -20,9 +24,9 @@ XGBOOST_DEVICE bool GetDecision(RegTree::Node const &node, bst_node_t nid, float } template -inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid, - float fvalue, bool is_missing, - RegTree::CategoricalSplitMatrix const &cats) { +XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue, + bool is_missing, + RegTree::CategoricalSplitMatrix const &cats) { if (has_missing && is_missing) { return node.DefaultChild(); } else { @@ -31,10 +35,9 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs } template -inline XGBOOST_DEVICE bst_node_t GetNextNodeMulti(MultiTargetTree const &tree, - bst_node_t const nidx, float fvalue, - bool is_missing, - RegTree::CategoricalSplitMatrix const &cats) { +XGBOOST_DEVICE bst_node_t GetNextNodeMulti(MultiTargetTree const &tree, bst_node_t const nidx, + float fvalue, bool is_missing, + RegTree::CategoricalSplitMatrix const &cats) { if (has_missing && is_missing) { return tree.DefaultChild(nidx); } else { @@ -49,5 +52,17 @@ inline XGBOOST_DEVICE bst_node_t GetNextNodeMulti(MultiTargetTree const &tree, } } +/** + * @brief Some old prediction methods accept the ntree_limit parameter and they use 0 to + * indicate no limit. + */ +inline bst_tree_t GetTreeLimit(std::vector> const &trees, + bst_tree_t ntree_limit) { + auto n_trees = static_cast(trees.size()); + if (ntree_limit == 0 || ntree_limit > n_trees) { + ntree_limit = n_trees; + } + return ntree_limit; +} } // namespace xgboost::predictor #endif // XGBOOST_PREDICTOR_PREDICT_FN_H_ diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index cbec1a8d6896..6e1f1301a530 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2024, XGBoost Contributors + * Copyright 2020-2025, XGBoost Contributors */ #include "test_predictor.h" @@ -47,16 +47,6 @@ void TestBasic(DMatrix* dmat, Context const *ctx) { ASSERT_EQ(out_predictions_h[i], 1.5); } - // Test predict instance - auto const& batch = *dmat->GetBatches().begin(); - auto page = batch.GetView(); - for (size_t i = 0; i < batch.Size(); i++) { - std::vector instance_out_predictions; - predictor->PredictInstance(page[i], &instance_out_predictions, model, 0, - dmat->Info().IsColumnSplit()); - ASSERT_EQ(instance_out_predictions[0], 1.5); - } - // Test predict leaf HostDeviceVector leaf_out_predictions; predictor->PredictLeaf(dmat, &leaf_out_predictions, model); diff --git a/tests/python/test_quantile_dmatrix.py b/tests/python/test_quantile_dmatrix.py index e64212265212..cbfec64527e3 100644 --- a/tests/python/test_quantile_dmatrix.py +++ b/tests/python/test_quantile_dmatrix.py @@ -286,10 +286,10 @@ def run_ref_dmatrix(self, rng: Any, device: str, enable_cat: bool) -> None: def test_ref_quantile_cut(self) -> None: check_ref_quantile_cut("cpu") - def test_ref_dmatrix(self) -> None: + @pytest.mark.parametrize("enable_cat", [True, False]) + def test_ref_dmatrix(self, enable_cat: bool) -> None: rng = np.random.RandomState(1994) - self.run_ref_dmatrix(rng, "cpu", True) - self.run_ref_dmatrix(rng, "cpu", False) + self.run_ref_dmatrix(rng, "cpu", enable_cat) @pytest.mark.parametrize("sparsity", [0.0, 0.5]) def test_predict(self, sparsity: float) -> None: