diff --git a/R/kernelshap.R b/R/kernelshap.R index f5323e1..63b46f7 100644 --- a/R/kernelshap.R +++ b/R/kernelshap.R @@ -359,12 +359,11 @@ kernelshap.ranger <- function( survival = c("chf", "prob"), ... ) { - survival <- match.arg(survival) - + if (is.null(pred_fun)) { - pred_fun <- pred_ranger + pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) } - + kernelshap.default( object = object, X = X, @@ -381,7 +380,6 @@ kernelshap.ranger <- function( parallel = parallel, parallel_args = parallel_args, verbose = verbose, - survival = survival, ... ) } diff --git a/R/permshap.R b/R/permshap.R index 1b65fb2..0711f13 100644 --- a/R/permshap.R +++ b/R/permshap.R @@ -172,12 +172,11 @@ permshap.ranger <- function( survival = c("chf", "prob"), ... ) { - survival <- match.arg(survival) - + if (is.null(pred_fun)) { - pred_fun <- pred_ranger + pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) } - + permshap.default( object = object, X = X, @@ -188,7 +187,6 @@ permshap.ranger <- function( parallel = parallel, parallel_args = parallel_args, verbose = verbose, - survival = survival, ... ) } diff --git a/R/pred_fun.R b/R/pred_fun.R index 9586cf8..3d174b3 100644 --- a/R/pred_fun.R +++ b/R/pred_fun.R @@ -1,27 +1,33 @@ #' Predict Function for Ranger #' -#' Internal function that prepares the predictions of different types of ranger models, -#' including survival models. +#' Returns prediction function for different modes of ranger. #' #' @noRd #' @keywords internal -#' @param model Fitted ranger model. -#' @param newdata Data to predict on. +#' @param treetype The value of `fit$treetype` in a fitted ranger model. #' @param survival Cumulative hazards "chf" (default) or probabilities "prob" per time. -#' @param ... Additional arguments passed to ranger's predict function. #' -#' @returns A vector or matrix with predictions. -pred_ranger <- function(model, newdata, survival = c("chf", "prob"), ...) { +#' @returns A function with signature f(model, newdata, ...). +create_ranger_pred_fun <- function(treetype, survival = c("chf", "prob")) { survival <- match.arg(survival) - pred <- stats::predict(model, newdata, ...) + if (treetype != "Survival") { + pred_fun <- function(model, newdata, ...) { + stats::predict(model, newdata, ...)$predictions + } + return(pred_fun) + } + + if (survival == "prob") { + survival <- "survival" + } - if (model$treetype == "Survival") { - out <- if (survival == "chf") pred$chf else pred$survival + pred_fun <- function(model, newdata, ...) { + pred <- stats::predict(model, newdata, ...) + out <- pred[[survival]] colnames(out) <- paste0("t", pred$unique.death.times) - } else { - out <- pred$predictions + return(out) } - return(out) + return(pred_fun) } diff --git a/backlog/test_ranger.R b/backlog/test_ranger.R new file mode 100644 index 0000000..61d723d --- /dev/null +++ b/backlog/test_ranger.R @@ -0,0 +1,28 @@ +library(ranger) +library(survival) +library(kernelshap) + +set.seed(1) + +fit <- ranger(Surv(time, status) ~ ., data = veteran, num.trees = 20) +fit2 <- ranger(time ~ . - status, data = veteran, num.trees = 20) +fit3 <- ranger(time ~ . - status, data = veteran, quantreg = TRUE, num.trees = 20) +fit4 <- ranger(status ~ . - time, data = veteran, probability = TRUE, num.trees = 20) + +xvars <- setdiff(colnames(veteran), c("time", "status")) + +kernelshap(fit, head(veteran), feature_names = xvars, bg_X = veteran) +permshap(fit, head(veteran), feature_names = xvars, bg_X = veteran) + +kernelshap(fit, head(veteran), feature_names = xvars, bg_X = veteran, survival = "prob") +permshap(fit, head(veteran), feature_names = xvars, bg_X = veteran, survival = "prob") + +kernelshap(fit2, head(veteran), feature_names = xvars, bg_X = veteran) +permshap(fit2, head(veteran), feature_names = xvars, bg_X = veteran) + +kernelshap(fit3, head(veteran), feature_names = xvars, bg_X = veteran, type = "quantiles") +permshap(fit3, head(veteran), feature_names = xvars, bg_X = veteran, type = "quantiles") + +kernelshap(fit4, head(veteran), feature_names = xvars, bg_X = veteran) +permshap(fit4, head(veteran), feature_names = xvars, bg_X = veteran) +