diff --git a/plugin/sycl/data.h b/plugin/sycl/data.h index 8f4bb2516f05..c2501d652cb2 100644 --- a/plugin/sycl/data.h +++ b/plugin/sycl/data.h @@ -139,6 +139,17 @@ class USMVector { } } + /* Resize without keeping the data*/ + void ResizeNoCopy(::sycl::queue* qu, size_t size_new) { + if (size_new <= capacity_) { + size_ = size_new; + } else { + size_ = size_new; + capacity_ = size_new; + data_ = allocate_memory_(qu, size_); + } + } + void Resize(::sycl::queue* qu, size_t size_new, T v) { if (size_new <= size_) { size_ = size_new; diff --git a/plugin/sycl/tree/hist_updater.cc b/plugin/sycl/tree/hist_updater.cc index daf043e4e055..7a53d8d1f447 100644 --- a/plugin/sycl/tree/hist_updater.cc +++ b/plugin/sycl/tree/hist_updater.cc @@ -7,6 +7,8 @@ #include +#include + #include "../common/hist_util.h" #include "../../src/collective/allreduce.h" @@ -14,6 +16,10 @@ namespace xgboost { namespace sycl { namespace tree { +using ::sycl::ext::oneapi::plus; +using ::sycl::ext::oneapi::minimum; +using ::sycl::ext::oneapi::maximum; + template void HistUpdater::SetHistSynchronizer( HistSynchronizer *sync) { @@ -126,6 +132,10 @@ void HistUpdater::InitData( builder_monitor_.Start("InitData"); const auto& info = fmat.Info(); + if (!column_sampler_) { + column_sampler_ = xgboost::common::MakeColumnSampler(ctx_); + } + // initialize the row set { row_set_collection_.Clear(); @@ -213,6 +223,9 @@ void HistUpdater::InitData( } } + column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(), + param_.colsample_bynode, param_.colsample_bylevel, + param_.colsample_bytree); if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { /* specialized code for dense data: choose the column that has a least positive number of discrete bins. @@ -309,6 +322,148 @@ void HistUpdater::InitNewNode(int nid, builder_monitor_.Stop("InitNewNode"); } +// nodes_set - set of nodes to be processed in parallel +template +void HistUpdater::EvaluateSplits( + const std::vector& nodes_set, + const common::GHistIndexMatrix& gmat, + const RegTree& tree) { + builder_monitor_.Start("EvaluateSplits"); + + const size_t n_nodes_in_set = nodes_set.size(); + + using FeatureSetType = std::shared_ptr>; + + // Generate feature set for each tree node + size_t pos = 0; + for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) { + const bst_node_t nid = nodes_set[nid_in_set].nid; + FeatureSetType features_set = column_sampler_->GetFeatureSet(tree.GetDepth(nid)); + for (size_t idx = 0; idx < features_set->Size(); idx++) { + const size_t fid = features_set->ConstHostVector()[idx]; + if (interaction_constraints_.Query(nid, fid)) { + auto this_hist = hist_[nid].DataConst(); + if (pos < split_queries_host_.size()) { + split_queries_host_[pos] = SplitQuery{nid, fid, this_hist}; + } else { + split_queries_host_.push_back({nid, fid, this_hist}); + } + ++pos; + } + } + } + const size_t total_features = pos; + + split_queries_device_.Resize(&qu_, total_features); + auto event = qu_.memcpy(split_queries_device_.Data(), split_queries_host_.data(), + total_features * sizeof(SplitQuery)); + + auto evaluator = tree_evaluator_.GetEvaluator(); + SplitQuery* split_queries_device = split_queries_device_.Data(); + const uint32_t* cut_ptr = gmat.cut_device.Ptrs().DataConst(); + const bst_float* cut_val = gmat.cut_device.Values().DataConst(); + const bst_float* cut_minval = gmat.cut_device.MinValues().DataConst(); + + snode_device_.ResizeNoCopy(&qu_, snode_host_.size()); + event = qu_.memcpy(snode_device_.Data(), snode_host_.data(), + snode_host_.size() * sizeof(NodeEntry), event); + const NodeEntry* snode = snode_device_.Data(); + + const float min_child_weight = param_.min_child_weight; + + best_splits_device_.ResizeNoCopy(&qu_, total_features); + if (best_splits_host_.size() < total_features) best_splits_host_.resize(total_features); + SplitEntry* best_splits = best_splits_device_.Data(); + + event = qu_.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, sub_group_size_), + ::sycl::range<2>(1, sub_group_size_)), + [=](::sycl::nd_item<2> pid) { + int i = pid.get_global_id(0); + auto sg = pid.get_sub_group(); + int nid = split_queries_device[i].nid; + int fid = split_queries_device[i].fid; + const GradientPairT* hist_data = split_queries_device[i].hist; + + best_splits[i] = snode[nid].best; + EnumerateSplit(sg, cut_ptr, cut_val, hist_data, snode[nid], + &(best_splits[i]), fid, nid, evaluator, min_child_weight); + }); + }); + event = qu_.memcpy(best_splits_host_.data(), best_splits, + total_features * sizeof(SplitEntry), event); + + qu_.wait(); + for (size_t i = 0; i < total_features; i++) { + int nid = split_queries_host_[i].nid; + snode_host_[nid].best.Update(best_splits_host_[i]); + } + + builder_monitor_.Stop("EvaluateSplits"); +} + +// Enumerate the split values of specific feature. +// Returns the sum of gradients corresponding to the data points that contains a non-missing value +// for the particular feature fid. +template +void HistUpdater::EnumerateSplit( + const ::sycl::sub_group& sg, + const uint32_t* cut_ptr, + const bst_float* cut_val, + const GradientPairT* hist_data, + const NodeEntry& snode, + SplitEntry* p_best, + bst_uint fid, + bst_uint nodeID, + typename TreeEvaluator::SplitEvaluator const &evaluator, + float min_child_weight) { + SplitEntry best; + + int32_t ibegin = static_cast(cut_ptr[fid]); + int32_t iend = static_cast(cut_ptr[fid + 1]); + + GradStats sum(0, 0); + + int32_t sub_group_size = sg.get_local_range().size(); + const size_t local_id = sg.get_local_id()[0]; + + /* TODO(razdoburdin) + * Currently the first additions are fast and the last are slow. + * Maybe calculating of reduce overgroup in seprate kernel and reusing it here can be faster + */ + for (int32_t i = ibegin + local_id; i < iend; i += sub_group_size) { + sum.Add(::sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()), + ::sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>())); + + if (sum.GetHess() >= min_child_weight) { + GradStats c = snode.stats - sum; + if (c.GetHess() >= min_child_weight) { + bst_float loss_chg = evaluator.CalcSplitGain(nodeID, fid, sum, c) - snode.root_gain; + bst_float split_pt = cut_val[i]; + best.Update(loss_chg, fid, split_pt, false, sum, c); + } + } + + const bool last_iter = i + sub_group_size >= iend; + if (!last_iter) { + size_t end = i - local_id + sub_group_size; + if (end > iend) end = iend; + for (size_t j = i + 1; j < end; ++j) { + sum.Add(hist_data[j].GetGrad(), hist_data[j].GetHess()); + } + } + } + + bst_float total_loss_chg = ::sycl::reduce_over_group(sg, best.loss_chg, maximum<>()); + bst_feature_t total_split_index = ::sycl::reduce_over_group(sg, + best.loss_chg == total_loss_chg ? + best.SplitIndex() : + (1U << 31) - 1U, minimum<>()); + if (best.loss_chg == total_loss_chg && + best.SplitIndex() == total_split_index) p_best->Update(best); +} + template class HistUpdater; template class HistUpdater; diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h index 538f2fe5f707..4515b24a1cb3 100644 --- a/plugin/sycl/tree/hist_updater.h +++ b/plugin/sycl/tree/hist_updater.h @@ -20,6 +20,7 @@ #include "hist_synchronizer.h" #include "hist_row_adder.h" +#include "../../src/common/random.h" #include "../data.h" namespace xgboost { @@ -62,6 +63,9 @@ class HistUpdater { p_last_tree_(nullptr), p_last_fmat_(fmat) { builder_monitor_.Init("SYCL::Quantile::HistUpdater"); kernel_monitor_.Init("SYCL::Quantile::HistUpdater"); + if (param.max_depth > 0) { + snode_device_.Resize(&qu, 1u << (param.max_depth + 1)); + } const auto sub_group_sizes = qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>(); sub_group_size_ = sub_group_sizes.back(); @@ -74,9 +78,28 @@ class HistUpdater { friend class BatchHistSynchronizer; friend class BatchHistRowsAdder; + struct SplitQuery { + bst_node_t nid; + size_t fid; + const GradientPairT* hist; + }; + void InitSampling(const USMVector &gpair, USMVector* row_indices); + void EvaluateSplits(const std::vector& nodes_set, + const common::GHistIndexMatrix& gmat, + const RegTree& tree); + + // Enumerate the split values of specific feature + // Returns the sum of gradients corresponding to the data points that contains a non-missing + // value for the particular feature fid. + static void EnumerateSplit(const ::sycl::sub_group& sg, + const uint32_t* cut_ptr, const bst_float* cut_val, const GradientPairT* hist_data, + const NodeEntry &snode, SplitEntry* p_best, bst_uint fid, + bst_uint nodeID, + typename TreeEvaluator::SplitEvaluator const &evaluator, + float min_child_weight); void InitData(const common::GHistIndexMatrix& gmat, const USMVector &gpair, @@ -118,6 +141,14 @@ class HistUpdater { common::RowSetCollection row_set_collection_; const xgboost::tree::TrainParam& param_; + std::shared_ptr column_sampler_; + + std::vector split_queries_host_; + USMVector split_queries_device_; + + USMVector, MemoryType::on_device> best_splits_device_; + std::vector> best_splits_host_; + TreeEvaluator tree_evaluator_; std::unique_ptr pruner_; FeatureInteractionConstraintHost interaction_constraints_; @@ -137,6 +168,7 @@ class HistUpdater { /*! \brief TreeNode Data: statistics for each constructed node */ std::vector> snode_host_; + USMVector, MemoryType::on_device> snode_device_; xgboost::common::Monitor builder_monitor_; xgboost::common::Monitor kernel_monitor_; diff --git a/tests/ci_build/conda_env/linux_sycl_test.yml b/tests/ci_build/conda_env/linux_sycl_test.yml index 7335b7f20fd5..e82a6bed62f5 100644 --- a/tests/ci_build/conda_env/linux_sycl_test.yml +++ b/tests/ci_build/conda_env/linux_sycl_test.yml @@ -1,7 +1,7 @@ name: linux_sycl_test channels: - conda-forge -- intel +- https://software.repos.intel.com/python/conda/ dependencies: - python=3.8 - cmake diff --git a/tests/cpp/plugin/test_sycl_hist_updater.cc b/tests/cpp/plugin/test_sycl_hist_updater.cc index 1ef771a0c7ec..325769fe8a9a 100644 --- a/tests/cpp/plugin/test_sycl_hist_updater.cc +++ b/tests/cpp/plugin/test_sycl_hist_updater.cc @@ -54,6 +54,13 @@ class TestHistUpdater : public HistUpdater { HistUpdater::InitNewNode(nid, gmat, gpair, fmat, tree); return HistUpdater::snode_host_[nid]; } + + auto TestEvaluateSplits(const std::vector& nodes_set, + const common::GHistIndexMatrix& gmat, + const RegTree& tree) { + HistUpdater::EvaluateSplits(nodes_set, gmat, tree); + return HistUpdater::snode_host_; + } }; void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) { @@ -307,6 +314,84 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp EXPECT_NEAR(snode.stats.GetHess(), grad_stat.GetHess(), 1e-6 * grad_stat.GetHess()); } +template +void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) { + const size_t num_rows = 1u << 8; + const size_t num_columns = 2; + const size_t n_bins = 32; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + ObjInfo task{ObjInfo::kRegression}; + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0f}.GenerateDMatrix(); + + FeatureInteractionConstraintHost int_constraints; + std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; + + TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + updater.SetHistSynchronizer(new BatchHistSynchronizer()); + updater.SetHistRowsAdder(new BatchHistRowsAdder()); + + USMVector gpair(&qu, num_rows); + auto* gpair_ptr = gpair.Data(); + GenerateRandomGPairs(&qu, gpair_ptr, num_rows, false); + + DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + common::GHistIndexMatrix gmat; + gmat.Init(qu, &ctx, dmat, n_bins); + + RegTree tree; + tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); + ExpandEntry node(ExpandEntry::kRootNid, tree.GetDepth(ExpandEntry::kRootNid)); + + auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); + auto& row_idxs = row_set_collection->Data(); + const size_t* row_idxs_ptr = row_idxs.DataConst(); + const auto* hist = updater.TestBuildHistogramsLossGuide(node, gmat, &tree, gpair); + const auto snode_init = updater.TestInitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, tree); + + const auto snode_updated = updater.TestEvaluateSplits({node}, gmat, tree); + auto best_loss_chg = snode_updated[0].best.loss_chg; + auto stats = snode_init.stats; + auto root_gain = snode_init.root_gain; + + // Check all splits manually. Save the best one and compare with the ans + TreeEvaluator tree_evaluator(qu, param, num_columns); + auto evaluator = tree_evaluator.GetEvaluator(); + const uint32_t* cut_ptr = gmat.cut_device.Ptrs().DataConst(); + const size_t size = gmat.cut_device.Ptrs().Size(); + int n_better_splits = 0; + const auto* hist_ptr = (*hist)[0].DataConst(); + std::vector best_loss_chg_des(1, -1); + { + ::sycl::buffer best_loss_chg_buff(best_loss_chg_des.data(), 1); + qu.submit([&](::sycl::handler& cgh) { + auto best_loss_chg_acc = best_loss_chg_buff.template get_access<::sycl::access::mode::read_write>(cgh); + cgh.single_task<>([=]() { + for (size_t i = 1; i < size; ++i) { + GradStats left(0, 0); + GradStats right = stats - left; + for (size_t j = cut_ptr[i-1]; j < cut_ptr[i]; ++j) { + auto loss_change = evaluator.CalcSplitGain(0, i - 1, left, right) - root_gain; + if (loss_change > best_loss_chg_acc[0]) { + best_loss_chg_acc[0] = loss_change; + } + left.Add(hist_ptr[j].GetGrad(), hist_ptr[j].GetHess()); + right = stats - left; + } + } + }); + }).wait(); + } + + ASSERT_NEAR(best_loss_chg_des[0], best_loss_chg, 1e-6); +} + TEST(SyclHistUpdater, Sampling) { xgboost::tree::TrainParam param; param.UpdateAllowUnknown(Args{{"subsample", "0.7"}}); @@ -346,4 +431,12 @@ TEST(SyclHistUpdater, InitNewNode) { TestHistUpdaterInitNewNode(param, 0.5); } +TEST(SyclHistUpdater, EvaluateSplits) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_depth", "3"}}); + + TestHistUpdaterEvaluateSplits(param); + TestHistUpdaterEvaluateSplits(param); +} + } // namespace xgboost::sycl::tree