Skip to content

Commit

Permalink
Merge pull request #75 from ModelOriented/release_candidate
Browse files Browse the repository at this point in the history
Release candidate
  • Loading branch information
mayer79 authored Apr 11, 2023
2 parents ed6a7d6 + 2be114a commit e561dd1
Show file tree
Hide file tree
Showing 24 changed files with 2,358 additions and 4,652 deletions.
6 changes: 3 additions & 3 deletions CRAN-SUBMISSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Version: 0.6.0
Date: 2023-03-05 16:47:24 UTC
SHA: d9ee1aa8903020cb4ff34149f6e622a137be42d5
Version: 0.7.0
Date: 2023-04-10 16:20:03 UTC
SHA: 1d5cdf85049baa8819da6a36b1ee91cdf6637a69
35 changes: 29 additions & 6 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
# shapviz 0.7.0

## New features
## Milestone: Working with multiple 'shapviz' objects

- Multiple models: Use `c(xgb = s1, rf = s2, ...)` or `mshapviz(list(xgb = s1, rf = s2, ...))` to combine multiple "shapviz" objects to a "mshapviz" object. Their plots are glued together by the {patchwork} package and can modified, e.g., using `&` and other {patchwork} functionalities.
- Multiclass: Another way to create a "mshapviz" object is to call `shapviz()` to multiclass XGBoost/LightGBM/kernelshap objects.
Sometimes, you will find it necessary to work with several "shapviz" objects at the same time:

- To visualize SHAP values of a multiclass or multi-output model.
- To compare SHAP plots of different models.
- To compare SHAP plots between subgroups.

To simplify the workflow, {shapviz} introduces the "mshapviz" object ("m" like "multi"). You can create it in different ways:

- Use `shapviz()` on multiclass XGBoost or LightGBM models.
- Use `shapviz()` on "kernelshap" objects created from multiclass/multioutput models.
- Use `c(Mod_1 = s1, Mod_2 = s2, ...)` on "shapviz" objects `s1`, `s2`, ...
- Or `mshapviz(list(Mod_1 = s1, Mod_2 = s2, ...))`

The `sv_*()` functions use the {patchwork} package to glue the individual plots together.

See the new vignette for more info and specific examples.

## Other new features

- `sv_dependence()` now allows multiple `v` and/or `color_var` to be plotted (glued via {patchwork}).
- {DALEX}: Support for "predict_parts" objects from {DALEX}, thanks to Adrian Stando.
- Aggregated SHAP values: The argument `row_id` of `sv_waterfall()` and `sv_force()` now also allows a vector of integers or a logical vector. If more than one row is selected, SHAP values and predictions are averaged before plotting (*aggregated SHAP values* in {DALEX}).
- Row bind: "shapviz" objects `x1`, `x2` can now be concatenated in rowwise manner using `x1 + x2` or `rbind(x1, x2)`, again thanks to Adrian.
- `colnames()`: "shapviz" objects `x` have received a `dimnames()` function, so you can now, e.g., use `colnames(x)` to see the feature names.
- Subsetting: "shapviz" `x` can now be subsetted using `x[cond, features]`.
- New vignette on working with multiple "shapviz" objects.

## Maintenance

Expand All @@ -18,13 +35,19 @@
- Webpage created with "pgkdown"
- New dependency: {patchwork}

## Other changes and bug fixes
## Other changes

- Color guides are closer to the plot area. This affects `sv_dependence()`, `sv_importance(kind="bee")`, and `sv_interaction()`.
- The lengthy y axis title "SHAP interaction value" in `sv_dependence()` has been shortened to "SHAP interaction".
- As announced, the argument `show_other` of `sv_importance()` has been removed.
- Slightly less picky checks on `S_inter`.
- `sv_waterfall()`: Using `order_fun()` would not work as expected with `max_display`.
- `print.shapviz()` is much more compact, use `summary.shapviz()` for more info.

## Bug fixes

- `sv_waterfall()`: Using `order_fun()` would not work as expected with `max_display`. This has been fixed.
- `sv_dependence()`: Passing `viridis_args = NULL` would hide the color guide title. This has been fixed. But please pass `viridis_args = list()` instead.

# shapviz 0.6.0

## Change in defaults
Expand Down
51 changes: 44 additions & 7 deletions R/sv_dependence.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
#' SHAP Dependence Plot
#'
#' Scatter plot of the SHAP values of a feature against its feature values.
#' Scatterplot of the SHAP values of a feature against its feature values.
#' If SHAP interaction values are available, setting \code{interactions = TRUE} allows
#' to focus on pure interaction effects (multiplied by two) or on pure main effects.
#'
#' @importFrom rlang .data
#' @param object An object of class "(m)shapviz".
#' @param v Column name of feature to be plotted.
#' Can be a vector/list if \code{object} is of class "shapviz".
#' @param color_var Feature name to be used on the color scale to investigate interactions.
#' The default ("auto") uses SHAP interaction values (if available) or a heuristic to
#' select the strongest interacting feature. Set to \code{NULL} to not use the color axis.
#' Can be a vector/list if \code{object} is of class "shapviz".
#' @param color Color to be used if \code{color_var = NULL}.
#' Can be a vector/list if \code{v} is a vector.
#' @param viridis_args List of viridis color scale arguments, see
#' \code{?ggplot2::scale_color_viridis_c()}. The default points to the global
#' option \code{shapviz.viridis_args}, which corresponds to
#' \code{list(begin = 0.25, end = 0.85, option = "inferno")}.
#' These values are passed to \code{ggplot2::scale_color_viridis_*()}.
#' For example, to switch to a standard viridis scale, you can either change the default
#' with \code{options(shapviz.viridis_args = NULL)} or set \code{viridis_args = NULL}.
#' with \code{options(shapviz.viridis_args = list())} or set \code{viridis_args = list()}.
#' Only relevant if \code{color_var} is not \code{NULL}.
#' @param jitter_width The amount of horizontal jitter. The default (\code{NULL}) will
#' use a value of 0.2 in case \code{v} is discrete, and no jitter otherwise.
#' (Numeric variables are considered discrete if they have at most 7 unique values.)
#' Can be a vector/list if \code{v} is a vector.
#' @param interactions Should SHAP interaction values be plotted? Default is \code{FALSE}.
#' Requires SHAP interaction values. If \code{color_var = NULL} (or it is equal to
#' \code{v}), the pure main effect of \code{v} is visualized. Otherwise, twice the SHAP
Expand All @@ -35,11 +39,13 @@
#' sv_dependence(x, "Petal.Length")
#' sv_dependence(x, "Petal.Length", color_var = "Species")
#' sv_dependence(x, "Petal.Length", color_var = NULL)
#' sv_dependence(x, c("Species", "Petal.Length"))
#' sv_dependence(x, "Petal.Width", color_var = c("Species", "Petal.Length"))
#'
#' # SHAP interaction values
#' x2 <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE)
#' sv_dependence(x2, "Petal.Length", interactions = TRUE)
#' sv_dependence(x2, "Petal.Length", color_var = NULL, interactions = TRUE)
#' sv_dependence(x2, c("Petal.Length", "Species"), color_var = NULL, interactions = TRUE)
#'
#' # Show main effect of "Petal.Length" for setosa and virginica separately
#' mx <- c(
Expand All @@ -64,12 +70,38 @@ sv_dependence.default <- function(object, ...) {
sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL, interactions = FALSE, ...) {
p <- length(v)
if (p > 1L || length(color_var) > 1L) {
if (is.null(color_var)) {
color_var <- replicate(p, NULL)
}
if (is.null(jitter_width)) {
jitter_width <- replicate(p, NULL)
}
plot_list <- mapply(
FUN = sv_dependence,
v = v,
color_var = color_var,
color = color,
jitter_width = jitter_width,
MoreArgs = list(
object = object,
viridis_args = viridis_args,
interactions = interactions,
...
),
SIMPLIFY = FALSE
)
nms <- if (length(v) > 1L) v
plot_list <- add_titles(plot_list, nms = nms) # see sv_waterfall()
return(patchwork::wrap_plots(plot_list))
}

S <- get_shap_values(object)
X <- get_feature_values(object)
S_inter <- get_shap_interactions(object)
nms <- colnames(object)
stopifnot(
length(v) == 1L,
v %in% nms,
is.null(color_var) || (color_var %in% c("auto", nms))
)
Expand All @@ -94,7 +126,7 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528
if (color_var == v) {
y_lab <- "SHAP main effect"
} else {
y_lab <- "SHAP interaction value"
y_lab <- "SHAP interaction"
}
s <- S_inter[, v, color_var]
if (color_var != v) {
Expand All @@ -119,19 +151,24 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528
vir <- scale_color_viridis_c
}
if (is.null(viridis_args)) {
viridis_args <- list(NULL)
viridis_args <- list()
}
ggplot(dat, aes(x = .data[[v]], y = shap, color = .data[[color_var]])) +
geom_jitter(width = jitter_width, height = 0, ...) +
ylab(y_lab) +
do.call(vir, viridis_args)
do.call(vir, viridis_args) +
theme(legend.box.spacing = grid::unit(0, "pt"))
}

#' @describeIn sv_dependence SHAP dependence plot for "mshapviz" object.
#' @export
sv_dependence.mshapviz <- function(object, v, color_var = "auto", color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL, interactions = FALSE, ...) {
stopifnot(
length(v) == 1L,
length(color_var) <= 1L
)
plot_list <- lapply(
object,
FUN = sv_dependence,
Expand Down
5 changes: 3 additions & 2 deletions R/sv_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#' corresponds to \code{list(begin = 0.25, end = 0.85, option = "inferno")}.
#' These values are passed to \code{ggplot2::scale_color_viridis_c()}.
#' For example, to switch to a standard viridis scale, you can either change the default
#' with \code{options(shapviz.viridis_args = NULL)} or set \code{viridis_args = NULL}.
#' with \code{options(shapviz.viridis_args = list())} or set \code{viridis_args = list()}.
#' @param color_bar_title Title of color bar of the beeswarm plot.
#' Set to \code{NULL} to hide the color bar altogether.
#' @param show_numbers Should SHAP feature importances be printed?
Expand Down Expand Up @@ -127,7 +127,8 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n
bar = !is.null(color_bar_title),
ncol = length(unique(df$color)) # Special case of constant feature values
) +
labs(x = "SHAP value", y = element_blank(), color = color_bar_title)
labs(x = "SHAP value", y = element_blank(), color = color_bar_title) +
theme(legend.box.spacing = grid::unit(0, "pt"))
}
if (show_numbers) {
p <- p +
Expand Down
5 changes: 3 additions & 2 deletions R/sv_interaction.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#' corresponds to \code{list(begin = 0.25, end = 0.85, option = "inferno")}.
#' These values are passed to \code{ggplot2::scale_color_viridis_c()}.
#' For example, to switch to a standard viridis scale, you can either change the default
#' with \code{options(shapviz.viridis_args = NULL)} or set \code{viridis_args = NULL}.
#' with \code{options(shapviz.viridis_args = list())} or set \code{viridis_args = list()}.
#' @param color_bar_title Title of color bar of the beeswarm plot.
#' Set to \code{NULL} to hide the color bar altogether.
#' @param ... Arguments passed to \code{geom_point()}.
Expand Down Expand Up @@ -106,7 +106,8 @@ sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"),
ncol = length(unique(X_long$Freq))
) +
theme(
panel.spacing = unit(0.2, "lines"),
panel.spacing = grid::unit(0.2, "lines"),
legend.box.spacing = grid::unit(0, "pt"),
axis.ticks.y = element_blank(),
axis.text.y = element_blank()
)
Expand Down
61 changes: 31 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# shapviz <a href='https://github.com/mayer79/shapviz'><img src='man/figures/logo.png' align="right" height="139" /></a>
# {shapviz} <a href='https://github.com/ModelOriented/shapviz'><img src='man/figures/logo.png' align="right" height="139" /></a>

<!-- badges: start -->

Expand All @@ -11,7 +11,7 @@

<!-- badges: end -->

## Introduction
## Overview

SHAP (SHapley Additive exPlanations, [1]) is an ingenious way to study black box models. SHAP values decompose - as fair as possible - predictions into additive feature contributions. Crunching SHAP values requires clever algorithms by clever people. Analyzing them, however, is super easy with the right visualizations. {shapviz} offers the latter:

Expand Down Expand Up @@ -39,12 +39,12 @@ To further simplify the use of {shapviz}, we added direct connectors to:
- [`kernelshap`](https://CRAN.R-project.org/package=kernelshap)
- [`fastshap`](https://CRAN.R-project.org/package=fastshap)
- [`shapr`](https://CRAN.R-project.org/package=shapr)
- [`treeshap`](https://github.com/ModelOriented/treeshap)
- [`DALEX`](https://cran.r-project.org/web/packages/DALEX)
- [`treeshap`](https://github.com/ModelOriented/treeshap/)
- [`DALEX`](https://CRAN.R-project.org/package=DALEX)

For XGBoost, LightGBM, and H2O, the SHAP values are directly calculated from the fitted model.

[`CatBoost`](https://github.com/catboost) is not included, but see the vignette how to use its SHAP calculation backend with {shapviz}.
[`CatBoost`](https://github.com/catboost/) is not included, but see the vignette how to use its SHAP calculation backend with {shapviz}.

Multiple "shapviz" objects can be glued together, see Vignette "Multiple shapviz objects".

Expand All @@ -59,9 +59,9 @@ install.packages("shapviz")
devtools::install_github("mayer79/shapviz")
```

## Example
## Usage

Shiny diamonds... let's model their prices by four "c" variables with XGBoost:
Shiny diamonds... let's use XGBoost to model their prices by the four "C" variables:

### Model

Expand All @@ -72,18 +72,9 @@ library(xgboost)

set.seed(3653)

# Explanation data
dia_small <- diamonds[sample(nrow(diamonds), 2000L), ]

# XGBoost model
x <- c("carat", "cut", "color", "clarity")
dtrain <- xgb.DMatrix(data.matrix(diamonds[x]), label = diamonds$price)

fit <- xgb.train(
params = list(learning_rate = 0.1, objective = "reg:squarederror"),
data = dtrain,
nrounds = 65L
)
fit <- xgb.train(params = list(learning_rate = 0.1), data = dtrain, nrounds = 65L)
```

### Create "shapviz" object
Expand All @@ -93,6 +84,9 @@ One line of code creates a "shapviz" object. It contains SHAP values and feature
In this example, we construct the "shapviz" object directly from the fitted XGBoost model. Thus we also need to pass a corresponding prediction dataset `X_pred` used for calculating SHAP values by XGBoost.

``` r
# Explanation data
dia_small <- diamonds[sample(nrow(diamonds), 2000L), ]

shp <- shapviz(fit, X_pred = data.matrix(dia_small[x]), X = dia_small)
```

Expand Down Expand Up @@ -148,14 +142,6 @@ sv_importance(shp, kind = "beeswarm")

![](man/figures/README-imp2.png)

#### Or both combined

``` r
sv_importance(shp, kind = "both", show_numbers = TRUE, bee_width = 0.2)
```

![](man/figures/README-imp3.png)

### Dependence plot

A scatterplot of SHAP values of a feature like `color` against its observed values gives a great impression on the feature effect on the response. Vertical scatter gives additional info on interaction effects (using a heuristic to select the feature on the color axis).
Expand All @@ -166,24 +152,39 @@ sv_dependence(shp, v = "color")

![](man/figures/README-dep.svg)

Or multiple features together, using {patchwork}:

``` r
library(patchwork) # We need the & operator

sv_dependence(shp, v = x) &
theme_gray(base_size = 9) &
ylim(-5000, 15000)
```

![](man/figures/README-dep-multi.png)

### Interactions

If SHAP interaction values have been computed (via {xgboost} or {treeshap}), the dependence plot can focus on main effects or SHAP interaction effects (multiplied by two due to symmetry):
If SHAP interaction values have been computed (via {xgboost} or {treeshap}), the dependence plot can focus on main effects or SHAP interaction effects (multiplied by two due to symmetry).

``` r
shp_with_inter <- shapviz(
shp_i <- shapviz(
fit, X_pred = data.matrix(dia_small[x]), X = dia_small, interactions = TRUE
)

sv_dependence(shp_with_inter, v = "color", color_var = "cut", interactions = TRUE)
# Main effect of carat and its interactions
sv_dependence(
shp_i, v = "carat", color_var = x, interactions = TRUE) &
ylim(-6000, 13000)
```

![](man/figures/README-dep2.svg)
![](man/figures/README-dep2.png)

We can also study all interactions and main effects together using the following beeswarm visualization:

```{r}
sv_interaction(shp_with_inter) +
sv_interaction(shp_i) +
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
```

Expand Down
23 changes: 19 additions & 4 deletions cran-comments.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
# Submission of shapviz 0.6.0
# Re-Submission of {shapviz} 0.7.0

Dear CRAN team. The dependence plot now uses better defaults. As they are user visible in some cases, the version jumps from 0.5.0 to 0.6.0.
This second resubmission removes non-standard file (sorry, my bad!)

# Re-Submission of {shapviz} 0.7.0

This re-submission fixes non-standard links in the new vignette (and the README).

## Original message

Dear CRAN team.

- {shapviz} can now deal with multiclass models or SHAP values of multiple models. Hurray ;).
- Many additional features
- New contributor
- Additional vignette
- New home: github/ModelOriented/shapviz

## Checks

### check(manual = TRUE, cran = TRUE)

-> WARNING
'qpdf' is needed for checks on size reduction of PDFs
- WARNING: 'qpdf' is needed for checks on size reduction of PDFs
- Note: unable to verify current time

### check_rhub()

Expand All @@ -21,3 +35,4 @@ Found the following files/directories:
### check_win_devel()

Status: OK

Binary file added man/figures/README-dep-multi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit e561dd1

Please sign in to comment.