diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 88a3651dacdc..9570cc716c7a 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -620,7 +620,7 @@ std::unique_ptr const& GBTree::GetPredictor(HostDeviceVector c auto on_device = is_ellpack || is_from_device; // Use GPU Predictor if data is already on device and gpu_id is set. - if (on_device && ctx_->gpu_id >= 0) { + if (on_device && ctx_->IsCUDA()) { #if defined(XGBOOST_USE_CUDA) CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost."; CHECK(gpu_predictor_); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 7c81b96f9d0f..49ff5e4127aa 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -395,6 +395,9 @@ std::shared_ptr RandomDataGenerator::GenerateDMatrix(bool with_label, b for (auto const& page : out->GetBatches()) { page.data.SetDevice(device_); page.offset.SetDevice(device_); + // pull to device + page.data.ConstDeviceSpan(); + page.offset.ConstDeviceSpan(); } } if (!ft_.empty()) { diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 6e66ae5af4f5..8085fd83c8b8 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -365,11 +365,14 @@ void TestCategoricalPredictLeafColumnSplit(Context const *ctx) { void TestIterationRange(Context const* ctx) { size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10; - auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses); + auto dmat = RandomDataGenerator(kRows, kCols, 0) + .Device(ctx->gpu_id) + .GenerateDMatrix(true, true, kClasses); auto learner = LearnerForTest(ctx, dmat, kIters, kForest); bool bound = false; - std::unique_ptr sliced {learner->Slice(0, 3, 1, &bound)}; + bst_layer_t lend{3}; + std::unique_ptr sliced{learner->Slice(0, lend, 1, &bound)}; ASSERT_FALSE(bound); HostDeviceVector out_predt_sliced; @@ -377,11 +380,8 @@ void TestIterationRange(Context const* ctx) { // margin { - sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false, - false, false); - - learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false, - false, false); + sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false, false, false); + learner->Predict(dmat, true, &out_predt_ranged, 0, lend, false, false, false, false, false); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); @@ -391,11 +391,8 @@ void TestIterationRange(Context const* ctx) { // SHAP { - sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, - true, false, false); - - learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true, - false, false); + sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, true, false, false); + learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, true, false, false); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); @@ -405,10 +402,8 @@ void TestIterationRange(Context const* ctx) { // SHAP interaction { - sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, - false, false, true); - learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false, - false, true); + sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, false, false, true); + learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, false, false, true); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); ASSERT_EQ(h_sliced.size(), h_range.size()); @@ -417,10 +412,8 @@ void TestIterationRange(Context const* ctx) { // Leaf { - sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true, - false, false, false); - learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false, - false, false); + sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true, false, false, false); + learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, true, false, false, false); auto const &h_sliced = out_predt_sliced.HostVector(); auto const &h_range = out_predt_ranged.HostVector(); ASSERT_EQ(h_sliced.size(), h_range.size());