Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 28, 2023
1 parent d2844e0 commit a93b80d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ std::unique_ptr<Predictor> const& GBTree::GetPredictor(HostDeviceVector<float> c
auto on_device = is_ellpack || is_from_device;

// Use GPU Predictor if data is already on device and gpu_id is set.
if (on_device && ctx_->gpu_id >= 0) {
if (on_device && ctx_->IsCUDA()) {
#if defined(XGBOOST_USE_CUDA)
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
CHECK(gpu_predictor_);
Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, b
for (auto const& page : out->GetBatches<SparsePage>()) {
page.data.SetDevice(device_);
page.offset.SetDevice(device_);
// pull to device
page.data.ConstDeviceSpan();
page.offset.ConstDeviceSpan();
}
}
if (!ft_.empty()) {
Expand Down
33 changes: 13 additions & 20 deletions tests/cpp/predictor/test_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,23 +365,23 @@ void TestCategoricalPredictLeafColumnSplit(Context const *ctx) {

void TestIterationRange(Context const* ctx) {
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
auto dmat = RandomDataGenerator(kRows, kCols, 0)
.Device(ctx->gpu_id)
.GenerateDMatrix(true, true, kClasses);
auto learner = LearnerForTest(ctx, dmat, kIters, kForest);

bool bound = false;
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
bst_layer_t lend{3};
std::unique_ptr<Learner> sliced{learner->Slice(0, lend, 1, &bound)};
ASSERT_FALSE(bound);

HostDeviceVector<float> out_predt_sliced;
HostDeviceVector<float> out_predt_ranged;

// margin
{
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false,
false, false);

learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false,
false, false);
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false, false, false);
learner->Predict(dmat, true, &out_predt_ranged, 0, lend, false, false, false, false, false);

auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
Expand All @@ -391,11 +391,8 @@ void TestIterationRange(Context const* ctx) {

// SHAP
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
true, false, false);

learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true,
false, false);
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, true, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, true, false, false);

auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
Expand All @@ -405,10 +402,8 @@ void TestIterationRange(Context const* ctx) {

// SHAP interaction
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
false, false, true);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false,
false, true);
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, false, false, true);
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, false, false, true);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
Expand All @@ -417,10 +412,8 @@ void TestIterationRange(Context const* ctx) {

// Leaf
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true,
false, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false,
false, false);
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true, false, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, true, false, false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
Expand Down

0 comments on commit a93b80d

Please sign in to comment.