Skip to content

Commit

Permalink
auto-select background data
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Aug 5, 2024
1 parent ad10ea4 commit 16a6961
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 129 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: kernelshap
Title: Kernel SHAP
Version: 0.6.1
Version: 0.7.0
Authors@R: c(
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0009-0007-2540-9629")),
Expand Down
16 changes: 13 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
# kernelshap 0.6.1
# kernelshap 0.7.0

This release is intended to be the last before stable version 1.0.0.

## Major change

Passing a background dataset `bg_X` is now optional.

If the explanation data `X` is sufficiently large (>= 50 rows), `bg_X` is derived as a random sample of `bg_n = 200` rows from `X`. If `X` has less than `bg_n` rows, then simply
`bg_X = X`. If `X` has too few rows (< 50), you will have to pass an explicit `bg_X`.

## Minor changes

- `ranger()` survival models now also work out-of-the-box without passing a tailored prediction function. Use the new argument `survival = "chf"` in `kernelshap()` and `permshap()` to distinguish cumulative hazards (default) and survival probabilities per time point.
- The resulting object of `kernelshap()` and `permshap()` now contain `bg_X` and `bg_w` used to calculate the SHAP values.

# kernelshap 0.6.0

This release is intended to be the last before stable version 1.0.0.

## Major changes

- Factor-valued predictions are not supported anymore.
Expand Down
89 changes: 52 additions & 37 deletions R/kernelshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#' For up to \eqn{p=8} features, the resulting Kernel SHAP values are exact regarding
#' the selected background data. For larger \eqn{p}, an almost exact
#' hybrid algorithm involving iterative sampling is used, see Details.
#' For up to eight features, however, we recomment to use [permshap()].
#'
#' Pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this:
#'
Expand Down Expand Up @@ -63,10 +64,11 @@
#' The columns should only represent model features, not the response
#' (but see `feature_names` on how to overrule this).
#' @param bg_X Background data used to integrate out "switched off" features,
#' often a subset of the training data (typically 50 to 500 rows)
#' It should contain the same columns as `X`.
#' often a subset of the training data (typically 50 to 500 rows).
#' In cases with a natural "off" value (like MNIST digits),
#' this can also be a single row with all values set to the off value.
#' If no `bg_X` is passed (the default) and if `X` is sufficiently large,
#' a random sample of `bg_n` rows from `X` serves as background data.
#' @param pred_fun Prediction function of the form `function(object, X, ...)`,
#' providing \eqn{K \ge 1} predictions per row. Its first argument
#' represents the model `object`, its second argument a data structure like `X`.
Expand All @@ -76,6 +78,8 @@
#' SHAP values. By default, this equals `colnames(X)`. Not supported if `X`
#' is a matrix.
#' @param bg_w Optional vector of case weights for each row of `bg_X`.
#' If `bg_X = NULL`, must be of same length as `X`. Set to `NULL` for no weights.
#' @param bg_n If `bg_X = NULL`: Size of background data to be sampled from `X`.
#' @param exact If `TRUE`, the algorithm will produce exact Kernel SHAP values
#' with respect to the background data. In this case, the arguments `hybrid_degree`,
#' `m`, `paired_sampling`, `tol`, and `max_iter` are ignored.
Expand Down Expand Up @@ -130,6 +134,8 @@
#' - `X`: Same as input argument `X`.
#' - `baseline`: Vector of length K representing the average prediction on the
#' background data.
#' - `bg_X`: The background data.
#' - `bg_w`: The background case weights.
#' - `SE`: Standard errors corresponding to `S` (and organized like `S`).
#' - `n_iter`: Integer vector of length n providing the number of iterations
#' per row of `X`.
Expand All @@ -155,28 +161,25 @@
#' @examples
#' # MODEL ONE: Linear regression
#' fit <- lm(Sepal.Length ~ ., data = iris)
#'
#'
#' # Select rows to explain (only feature columns)
#' X_explain <- iris[1:2, -1]
#'
#' # Select small background dataset (could use all rows here because iris is small)
#' set.seed(1)
#' bg_X <- iris[sample(nrow(iris), 100), ]
#'
#' X_explain <- iris[-1]
#'
#' # Calculate SHAP values
#' s <- kernelshap(fit, X_explain, bg_X = bg_X)
#' s <- kernelshap(fit, X_explain)
#' s
#'
#'
#' # MODEL TWO: Multi-response linear regression
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' s <- kernelshap(fit, iris[1:4, 3:5], bg_X = bg_X)
#' summary(s)
#'
#' # Non-feature columns can be dropped via 'feature_names'
#' s <- kernelshap(fit, iris[3:5])
#' s
#'
#' # Note 1: Feature columns can also be selected 'feature_names'
#' # Note 2: Especially when X is small, pass a sufficiently large background data bg_X
#' s <- kernelshap(
#' fit,
#' fit,
#' iris[1:4, ],
#' bg_X = bg_X,
#' bg_X = iris,
#' feature_names = c("Petal.Length", "Petal.Width", "Species")
#' )
#' s
Expand All @@ -189,10 +192,11 @@ kernelshap <- function(object, ...){
kernelshap.default <- function(
object,
X,
bg_X,
bg_X = NULL,
pred_fun = stats::predict,
feature_names = colnames(X),
bg_w = NULL,
bg_n = 200L,
exact = length(feature_names) <= 8L,
hybrid_degree = 1L + length(feature_names) %in% 4:16,
paired_sampling = TRUE,
Expand All @@ -204,24 +208,24 @@ kernelshap.default <- function(
verbose = TRUE,
...
) {
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
p <- length(feature_names)
basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun)
stopifnot(
exact %in% c(TRUE, FALSE),
p == 1L || exact || hybrid_degree %in% 0:(p / 2),
paired_sampling %in% c(TRUE, FALSE),
"m must be even" = trunc(m / 2) == m / 2
)
n <- nrow(X)
prep_bg <- prepare_bg(X = X, bg_X = bg_X, bg_n = bg_n, bg_w = bg_w, verbose = verbose)
bg_X <- prep_bg$bg_X
bg_w <- prep_bg$bg_w
bg_n <- nrow(bg_X)
if (!is.null(bg_w)) {
bg_w <- prep_w(bg_w, bg_n = bg_n)
}
n <- nrow(X)

# Calculate v1 and v0
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
bg_preds <- align_pred(pred_fun(object, bg_X, ...))
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K

# For p = 1, exact Shapley values are returned
if (p == 1L) {
Expand All @@ -231,18 +235,25 @@ kernelshap.default <- function(
return(out)
}

txt <- summarize_strategy(p, exact = exact, deg = hybrid_degree)
if (verbose) {
message(txt)
}

# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
# In what follows, predictions will never be applied directly to bg_X anymore
if (!identical(colnames(bg_X), feature_names)) {
bg_X <- bg_X[, feature_names, drop = FALSE]
}

# Precalculations for the real Kernel SHAP
# Precalculations that are identical for each row to be explained
if (exact || hybrid_degree >= 1L) {
if (exact) {
precalc <- input_exact(p, feature_names = feature_names)
} else {
precalc <- input_partly_exact(p, deg = hybrid_degree, feature_names = feature_names)
precalc <- input_partly_exact(
p, deg = hybrid_degree, feature_names = feature_names
)
}
m_exact <- nrow(precalc[["Z"]])
prop_exact <- sum(precalc[["w"]])
Expand All @@ -256,11 +267,6 @@ kernelshap.default <- function(
precalc[["bg_X_m"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m))
}

# Some infos
txt <- summarize_strategy(p, exact = exact, deg = hybrid_degree)
if (verbose) {
message(txt)
}
if (max(m, m_exact) * bg_n > 2e5) {
warning_burden(max(m, m_exact), bg_n = bg_n)
}
Expand Down Expand Up @@ -319,11 +325,18 @@ kernelshap.default <- function(
if (verbose && !all(converged)) {
warning("\nNon-convergence for ", sum(!converged), " rows.")
}

if (verbose) {
cat("\n")
}

out <- list(
S = reorganize_list(lapply(res, `[[`, "beta")),
X = X,
baseline = as.vector(v0),
SE = reorganize_list(lapply(res, `[[`, "sigma")),
S = reorganize_list(lapply(res, `[[`, "beta")),
X = X,
baseline = as.vector(v0),
bg_X = bg_X,
bg_w = bg_w,
SE = reorganize_list(lapply(res, `[[`, "sigma")),
n_iter = vapply(res, `[[`, "n_iter", FUN.VALUE = integer(1L)),
converged = converged,
m = m,
Expand All @@ -343,10 +356,11 @@ kernelshap.default <- function(
kernelshap.ranger <- function(
object,
X,
bg_X,
bg_X = NULL,
pred_fun = NULL,
feature_names = colnames(X),
bg_w = NULL,
bg_n = 200L,
exact = length(feature_names) <= 8L,
hybrid_degree = 1L + length(feature_names) %in% 4:16,
paired_sampling = TRUE,
Expand All @@ -371,6 +385,7 @@ kernelshap.ranger <- function(
pred_fun = pred_fun,
feature_names = feature_names,
bg_w = bg_w,
bg_n = bg_n,
exact = exact,
hybrid_degree = hybrid_degree,
paired_sampling = paired_sampling,
Expand Down
54 changes: 32 additions & 22 deletions R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#'
#' Exact permutation SHAP algorithm with respect to a background dataset,
#' see Strumbelj and Kononenko. The function works for up to 14 features.
#' For eight or more features, we recomment to switch to [kernelshap()].
#'
#' @inheritParams kernelshap
#' @returns
Expand All @@ -11,6 +12,8 @@
#' - `X`: Same as input argument `X`.
#' - `baseline`: Vector of length K representing the average prediction on the
#' background data.
#' - `bg_X`: The background data.
#' - `bg_w`: The background case weights.
#' - `m_exact`: Integer providing the effective number of exact on-off vectors used.
#' - `exact`: Logical flag indicating whether calculations are exact or not
#' (currently `TRUE`).
Expand All @@ -26,26 +29,23 @@
#' fit <- lm(Sepal.Length ~ ., data = iris)
#'
#' # Select rows to explain (only feature columns)
#' X_explain <- iris[1:2, -1]
#'
#' # Select small background dataset (could use all rows here because iris is small)
#' set.seed(1)
#' bg_X <- iris[sample(nrow(iris), 100), ]
#' X_explain <- iris[-1]
#'
#' # Calculate SHAP values
#' s <- permshap(fit, X_explain, bg_X = bg_X)
#' s <- permshap(fit, X_explain)
#' s
#'
#' # MODEL TWO: Multi-response linear regression
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' s <- permshap(fit, iris[1:4, 3:5], bg_X = bg_X)
#' s <- permshap(fit, iris[3:5])
#' s
#'
#' # Non-feature columns can be dropped via 'feature_names'
#' # Note 1: Feature columns can also be selected 'feature_names'
#' # Note 2: Especially when X is small, pass a sufficiently large background data bg_X
#' s <- permshap(
#' fit,
#' iris[1:4, ],
#' bg_X = bg_X,
#' bg_X = iris,
#' feature_names = c("Petal.Length", "Petal.Width", "Species")
#' )
#' s
Expand All @@ -58,37 +58,40 @@ permshap <- function(object, ...) {
permshap.default <- function(
object,
X,
bg_X,
bg_X = NULL,
pred_fun = stats::predict,
feature_names = colnames(X),
bg_w = NULL,
bg_n = 200L,
parallel = FALSE,
parallel_args = NULL,
verbose = TRUE,
...
) {
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
p <- length(feature_names)
if (p <= 1L) {
stop("Case p = 1 not implemented. Use kernelshap() instead.")
}
if (p > 14L) {
stop("Permutation SHAP only supported for up to 14 features")
}
n <- nrow(X)
bg_n <- nrow(bg_X)
if (!is.null(bg_w)) {
bg_w <- prep_w(bg_w, bg_n = bg_n)
}

txt <- "Exact permutation SHAP"
if (verbose) {
message(txt)
}

basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun)
prep_bg <- prepare_bg(X = X, bg_X = bg_X, bg_n = bg_n, bg_w = bg_w, verbose = verbose)
bg_X <- prep_bg$bg_X
bg_w <- prep_bg$bg_w
bg_n <- nrow(bg_X)
n <- nrow(X)

# Baseline and predictions on explanation data
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
bg_preds <- align_pred(pred_fun(object, bg_X, ...))
v0 <- wcolMeans(bg_preds, w = bg_w) # Average pred of bg data: 1 x K
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K

# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
# Predictions will never be applied directly to bg_X anymore
Expand Down Expand Up @@ -143,10 +146,15 @@ permshap.default <- function(
}
}
}
if (verbose) {
cat("\n")
}
out <- list(
S = reorganize_list(res),
X = X,
S = reorganize_list(res),
X = X,
baseline = as.vector(v0),
bg_X = bg_X,
bg_w = bg_w,
m_exact = m_exact,
exact = TRUE,
txt = txt,
Expand All @@ -162,10 +170,11 @@ permshap.default <- function(
permshap.ranger <- function(
object,
X,
bg_X,
bg_X = NULL,
pred_fun = NULL,
feature_names = colnames(X),
bg_w = NULL,
bg_n = 200L,
parallel = FALSE,
parallel_args = NULL,
verbose = TRUE,
Expand All @@ -184,6 +193,7 @@ permshap.ranger <- function(
pred_fun = pred_fun,
feature_names = feature_names,
bg_w = bg_w,
bg_n = bg_n,
parallel = parallel,
parallel_args = parallel_args,
verbose = verbose,
Expand Down
Loading

0 comments on commit 16a6961

Please sign in to comment.