Skip to content

Commit

Permalink
[R] use class names in importance outputs (#11100)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Dec 15, 2024
1 parent dea5753 commit 9191727
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
16 changes: 13 additions & 3 deletions R-package/R/xgb.importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
#' For a linear model:
#' - `Features`: Names of the features used in the model.
#' - `Weight`: Linear coefficient of this feature.
#' - `Class`: Class label (only for multiclass models).
#' - `Class`: Class label (only for multiclass models). For objects of class `xgboost` (as
#' produced by [xgboost()]), it will be a `factor`, while for objects of class `xgb.Booster`
#' (as produced by [xgb.train()]), it will be a zero-based integer vector.
#'
#' If `feature_names` is not provided and `model` doesn't have `feature_names`,
#' the index of the features will be used instead. Because the index is extracted from the model dump
Expand Down Expand Up @@ -146,11 +148,19 @@ xgb.importance <- function(model = NULL, feature_names = getinfo(model, "feature
n_classes <- 0
}
importance <- if (n_classes == 0) {
data.table(Feature = results$features, Weight = results$weight)[order(-abs(Weight))]
return(data.table(Feature = results$features, Weight = results$weight)[order(-abs(Weight))])
} else {
data.table(
out <- data.table(
Feature = rep(results$features, each = n_classes), Weight = results$weight, Class = seq_len(n_classes) - 1
)[order(Class, -abs(Weight))]
if (inherits(model, "xgboost") && NROW(attributes(model)$metadata$y_levels)) {
class_vec <- out$Class
class_vec <- as.integer(class_vec) + 1L
attributes(class_vec)$levels <- attributes(model)$metadata$y_levels
attributes(class_vec)$class <- "factor"
out[, Class := class_vec]
}
return(out[])
}
} else {
concatenated <- list()
Expand Down
4 changes: 3 additions & 1 deletion R-package/man/xgb.importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions R-package/tests/testthat/test_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -1013,3 +1013,20 @@ test_that("'eval_set' as fraction works", {
expect_true(hasName(evaluation_log, "eval_mlogloss"))
expect_equal(length(attributes(model)$metadata$y_levels), 3L)
})

test_that("Linear booster importance uses class names", {
y <- iris$Species
x <- iris[, -5L]
model <- xgboost(
x,
y,
nthreads = 1L,
nrounds = 4L,
verbosity = 0L,
booster = "gblinear",
learning_rate = 0.2
)
imp <- xgb.importance(model)
expect_true(is.factor(imp$Class))
expect_equal(levels(imp$Class), levels(y))
})

0 comments on commit 9191727

Please sign in to comment.