Skip to content

Commit

Permalink
Merge pull request #114 from ModelOriented/mshap_consistency
Browse files Browse the repository at this point in the history
Add stricter mshapviz interface
  • Loading branch information
mayer79 authored Oct 18, 2023
2 parents 134a346 + 3e8f11b commit 84315c4
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 186 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# shapviz 0.9.3

## User-visible changes

- `mshapviz()` is more strict when combining multiple "shapviz" objects. These now need to have identical column names.

## Other changes

- Re-activate all unit tests.
Expand Down
17 changes: 14 additions & 3 deletions R/shapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,9 @@ shapviz.H2OModel = function(object, X_pred, X = as.data.frame(X_pred),
)
}

#' Concatenates "shapviz" Objects
#' Combines compatible "shapviz" Objects
#'
#' This function combines a list of "shapviz" objects to an object of class
#' This function combines a list of compatible "shapviz" objects to an object of class
#' "mshapviz". The elements can be named.
#'
#' @param object List of "shapviz" objects to be concatenated.
Expand All @@ -479,10 +479,21 @@ shapviz.H2OModel = function(object, X_pred, X = as.data.frame(X_pred),
#' s <- mshapviz(c(shp1 = s1, shp2 = s2))
#' s
mshapviz <- function(object, ...) {
stopifnot("'object' must be a list of 'shapviz' objects" = is.list(object))
stopifnot("'object' must be a list" = is.list(object))
if (!all(vapply(object, is.shapviz, FUN.VALUE = logical(1)))) {
stop("Must pass list of 'shapviz' objects")
}
nms <- lapply(object, colnames)
if (!all(vapply(nms, identical, y = nms[[1L]], FUN.VALUE = logical(1)))) {
stop("'shapviz' objects need to have identical column names")
}
# Plot methods using interactions and do.call(rbind, ...) will fail, other plots are ok
# inter <- vapply(
# object, function(x) is.null(get_shap_interactions(x)), FUN.VALUE = logical(1)
# )
# if (!(all(inter) || !any(inter))) {
# stop("Some 'shapviz' objects have SHAP interactions, some not.")
# }
structure(object, class = "mshapviz")
}

Expand Down
4 changes: 2 additions & 2 deletions man/mshapviz.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 24 additions & 6 deletions tests/testthat/test-interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ test_that("get_* functions work", {
expect_equal(4, get_baseline(mshp)[[1L]])
expect_equal(S, get_shap_values(mshp)[[1L]])
expect_equal(X, get_feature_values(mshp)[[1L]])

expect_error(get_baseline(3))
expect_error(get_shap_values("a"))
expect_error(get_feature_values(c(3, 9)))
})

test_that("dim, nrow, ncol, colnames work", {
Expand Down Expand Up @@ -176,6 +180,8 @@ mshp_inter <- c(shp1 = shp_inter, shp2 = shp_inter + shp_inter)

test_that("get_shap_interactions, +, rbind works for interactions", {
expect_equal(S_inter, get_shap_interactions(shp_inter))
expect_equal(length(get_shap_interactions(mshp_inter)), 2L)
expect_error(get_shap_interactions(4))
expect_equal(dim((shp_inter + shp_inter)$S_inter)[1L], 2 * dim(shp_inter$S_inter)[1L])
expect_equal(
dim(rbind(shp_inter, shp_inter, shp_inter)$S_inter)[1L],
Expand Down Expand Up @@ -211,7 +217,24 @@ test_that("mshapviz object contains original shapviz objects", {
expect_equal(mshp_inter[[2L]][1:nrow(shp_inter)], shp_inter)
})

# # Multiclass with XGBoost
test_that("shapviz objects with interactions can be rowbinded", {
expect_equal(dim(rbind(shp_inter, shp_inter)), dim(shp_inter) * (2:1))
expect_error(rbind(shp_inter, shp))
})

# Check on mshapviz
test_that("combining non-shapviz objects fails", {
expect_error(c(shp, 1))
expect_error(mshapviz(list(1, 2)))
})

test_that("combining incompatible shapviz objects fails", {
shp2 <- shp[, "x"]
expect_error(mshapviz(list(shp, shp2)))
expect_error(c(shp, shp2))
})

# Multiclass with XGBoost
X_pred <- data.matrix(iris[, -5L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = as.integer(iris[, 5L]) - 1L)
fit <- xgboost::xgb.train(
Expand Down Expand Up @@ -242,8 +265,3 @@ test_that("combining shapviz on classes 1, 2, 3 equal mshapviz", {
expect_equal(mshp, mshapviz(list(Class_1 = shp1, Class_2 = shp2, Class_3 = shp3)))
})

test_that("combining non-shapviz objects fails", {
expect_error(c(shp3, 1))
expect_error(mshapviz(1, 2))
})

13 changes: 13 additions & 0 deletions tests/testthat/test-plots-mshapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ test_that("dependence plots work for interactions = TRUE", {
)
})

test_that("shapviz objects w/o interactions can be combined and used for most things", {
x_temp1 <- shapviz(fit, X_pred = dtrain, X = iris[, -1L], interactions = TRUE)
x_temp2 <- shapviz(fit, X_pred = dtrain, X = iris[, -1L])
expect_no_error(x_inter_m <- c(x_temp1, x_temp2))
expect_error(do.call(rbind, x_inter_m))
expect_error(sv_interaction(x_inter_m))
expect_s3_class(sv_importance(x_inter_m), "ggplot")
expect_s3_class(sv_dependence(x_inter_m, "Sepal.Width"), "patchwork")
expect_error(sv_dependence(x_inter_m, "Sepal.Width", interactions = TRUE))
expect_equal(sapply(get_shap_interactions(x_inter_m), is.null), c(FALSE, TRUE))
})

test_that("main effect plots equal case color_var = v", {
expect_equal(
sv_dependence(x_inter, "Petal.Length", color_var = NULL, interactions = TRUE),
Expand Down Expand Up @@ -118,3 +130,4 @@ test_that("sv_dependence() does not work with multiple v", {
expect_error(sv_dependence2D(x, x = c("Species", "Sepal.Width"), y = "Petal.Width"))
expect_error(sv_dependence2D(x, x = "Petal.Width", y = c("Species", "Sepal.Width")))
})

Loading

0 comments on commit 84315c4

Please sign in to comment.