diff --git a/DESCRIPTION b/DESCRIPTION index f90f77f..01b3ad6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: shapviz Title: SHAP Visualizations -Version: 0.9.3 +Version: 0.9.4 Authors@R: c( person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre")), person("Adrian", "Stando", , "adrian.j.stando@gmail.com", role = "ctb") @@ -21,7 +21,7 @@ Depends: R (>= 3.6.0) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 Imports: ggfittext (>= 0.8.0), gggenes, diff --git a/NEWS.md b/NEWS.md index 233e3cc..4d07f96 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +# shapviz 0.9.4 + +## Improvements + +- New argument `sort_features = TRUE` in `sv_importance()` and `sv_interaction()`. Set to `FALSE` to show the features as they appear in your SHAP matrix. In that case, the plots will show the *first* `max_display` features, not the *most important* features. Implements #136. + # shapviz 0.9.3 ## `sv_dependence()`: Control over automatic color feature selection diff --git a/R/sv_importance.R b/R/sv_importance.R index 64a58f1..f9d44a5 100644 --- a/R/sv_importance.R +++ b/R/sv_importance.R @@ -14,7 +14,7 @@ #' @param kind Should a "bar" plot (the default), a "beeswarm" plot, or "both" be shown? #' Set to "no" in order to suppress plotting. In that case, the sorted #' SHAP feature importances of all variables are returned. -#' @param max_display Maximum number of features (with highest importance) to plot. +#' @param max_display How many features should be plotted? #' Set to `Inf` to show all features. Has no effect if `kind = "no"`. #' @param fill Color used to fill the bars (only used if bars are shown). #' @param bar_width Relative width of the bars (only used if bars are shown). @@ -38,6 +38,7 @@ #' (only if `show_numbers = TRUE`). To change to scientific notation, use #' `function(x) = prettyNum(x, scientific = TRUE)`. #' @param number_size Text size of the numbers (if `show_numbers = TRUE`). +#' @param sort_features Should features be sorted or not? The default is `TRUE`. #' @param ... Arguments passed to [ggplot2::geom_bar()] (if `kind = "bar"`) or to #' [ggplot2::geom_point()] otherwise. For instance, passing `alpha = 0.2` will produce #' semi-transparent beeswarms, and setting `size = 3` will produce larger dots. @@ -75,10 +76,10 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n viridis_args = getOption("shapviz.viridis_args"), color_bar_title = "Feature value", show_numbers = FALSE, format_fun = format_max, - number_size = 3.2, ...) { + number_size = 3.2, sort_features = TRUE, ...) { stopifnot("format_fun must be a function" = is.function(format_fun)) kind <- match.arg(kind) - imp <- .get_imp(get_shap_values(object)) + imp <- .get_imp(get_shap_values(object), sort_features = sort_features) if (kind == "no") { return(imp) @@ -162,13 +163,13 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", " viridis_args = getOption("shapviz.viridis_args"), color_bar_title = "Feature value", show_numbers = FALSE, format_fun = format_max, - number_size = 3.2, ...) { + number_size = 3.2, sort_features = TRUE, ...) { kind <- match.arg(kind) bar_type <- match.arg(bar_type) # All other cases are done via {patchwork} if (kind %in% c("bar", "no") && bar_type != "separate") { - imp <- .get_imp(get_shap_values(object)) + imp <- .get_imp(get_shap_values(object), sort_features = sort_features) if (kind == "no") { return(imp) } @@ -223,6 +224,7 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", " show_numbers = show_numbers, format_fun = format_fun, number_size = number_size, + sort_features = sort_features, ... ) if (kind == "no") { @@ -243,13 +245,20 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", " (z - r[1L]) /(r[2L] - r[1L]) } -.get_imp <- function(z) { +.get_imp <- function(z, sort_features = TRUE) { if (is.matrix(z)) { - return(sort(colMeans(abs(z)), decreasing = TRUE)) + imp <- colMeans(abs(z)) + if (sort_features) { + imp <- sort(imp, decreasing = TRUE) + } + return(imp) } # list/mshapviz imp <- sapply(z, function(x) colMeans(abs(x))) - imp[order(-rowSums(imp)), ] + if (sort_features) { + imp <- imp[order(-rowSums(imp)), ] + } + return(imp) } .scale_X <- function(X) { diff --git a/R/sv_interaction.R b/R/sv_interaction.R index f205408..cada24b 100644 --- a/R/sv_interaction.R +++ b/R/sv_interaction.R @@ -45,12 +45,13 @@ sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"), max_display = 7L, alpha = 0.3, bee_width = 0.3, bee_adjust = 0.5, viridis_args = getOption("shapviz.viridis_args"), - color_bar_title = "Row feature value", ...) { + color_bar_title = "Row feature value", + sort_features = TRUE, ...) { kind <- match.arg(kind) if (is.null(get_shap_interactions(object))) { stop("No SHAP interaction values available.") } - ord <- names(.get_imp(get_shap_values(object))) + ord <- names(.get_imp(get_shap_values(object), sort_features = sort_features)) object <- object[, ord] if (kind == "no") { @@ -112,7 +113,8 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"), max_display = 7L, alpha = 0.3, bee_width = 0.3, bee_adjust = 0.5, viridis_args = getOption("shapviz.viridis_args"), - color_bar_title = "Row feature value", ...) { + color_bar_title = "Row feature value", + sort_features = TRUE, ...) { kind <- match.arg(kind) plot_list <- lapply( @@ -126,6 +128,7 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"), bee_adjust = bee_adjust, viridis_args = viridis_args, color_bar_title = color_bar_title, + sort_features = sort_features, ... ) if (kind == "no") { diff --git a/man/shapviz-package.Rd b/man/shapviz-package.Rd index 64bfc6e..5237662 100644 --- a/man/shapviz-package.Rd +++ b/man/shapviz-package.Rd @@ -3,7 +3,6 @@ \docType{package} \name{shapviz-package} \alias{shapviz-package} -\alias{_PACKAGE} \title{shapviz: SHAP Visualizations} \description{ \if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} @@ -14,6 +13,7 @@ Visualizations for SHAP (SHapley Additive exPlanations), such as waterfall plots Useful links: \itemize{ \item \url{https://github.com/ModelOriented/shapviz} + \item \url{https://modeloriented.github.io/shapviz/} \item Report bugs at \url{https://github.com/ModelOriented/shapviz/issues} } diff --git a/man/sv_importance.Rd b/man/sv_importance.Rd index 5c8617f..1fc5ffd 100644 --- a/man/sv_importance.Rd +++ b/man/sv_importance.Rd @@ -24,6 +24,7 @@ sv_importance(object, ...) show_numbers = FALSE, format_fun = format_max, number_size = 3.2, + sort_features = TRUE, ... ) @@ -41,6 +42,7 @@ sv_importance(object, ...) show_numbers = FALSE, format_fun = format_max, number_size = 3.2, + sort_features = TRUE, ... ) } @@ -55,7 +57,7 @@ semi-transparent beeswarms, and setting \code{size = 3} will produce larger dots Set to "no" in order to suppress plotting. In that case, the sorted SHAP feature importances of all variables are returned.} -\item{max_display}{Maximum number of features (with highest importance) to plot. +\item{max_display}{How many features should be plotted? Set to \code{Inf} to show all features. Has no effect if \code{kind = "no"}.} \item{fill}{Color used to fill the bars (only used if bars are shown).} @@ -85,6 +87,8 @@ to hide the color bar altogether.} \item{number_size}{Text size of the numbers (if \code{show_numbers = TRUE}).} +\item{sort_features}{Should features be sorted or not? The default is \code{TRUE}.} + \item{bar_type}{For "mshapviz" objects with \code{kind = "bar"}: How should bars be represented? The default is "dodge" for dodged bars. Other options are "stack", "wrap", or "separate" (via "patchwork"). Note that "separate" is currently diff --git a/man/sv_interaction.Rd b/man/sv_interaction.Rd index a72d0c7..fb90b33 100644 --- a/man/sv_interaction.Rd +++ b/man/sv_interaction.Rd @@ -20,6 +20,7 @@ sv_interaction(object, ...) bee_adjust = 0.5, viridis_args = getOption("shapviz.viridis_args"), color_bar_title = "Row feature value", + sort_features = TRUE, ... ) @@ -32,6 +33,7 @@ sv_interaction(object, ...) bee_adjust = 0.5, viridis_args = getOption("shapviz.viridis_args"), color_bar_title = "Row feature value", + sort_features = TRUE, ... ) } @@ -45,7 +47,7 @@ passing \code{size = 1} will produce smaller dots.} interactions (or a list of such matrices in case of object of class "mshapviz"). Due to symmetry, off-diagonals are multiplied by two. The default is "beeswarm".} -\item{max_display}{Maximum number of features (with highest importance) to plot. +\item{max_display}{How many features should be plotted? Set to \code{Inf} to show all features. Has no effect if \code{kind = "no"}.} \item{alpha}{Transparency of the beeswarm dots. Defaults to 0.3.} @@ -64,6 +66,8 @@ either change the default with \code{options(shapviz.viridis_args = list())} or \item{color_bar_title}{Title of color bar of the beeswarm plot. Set to \code{NULL} to hide the color bar altogether.} + +\item{sort_features}{Should features be sorted or not? The default is \code{TRUE}.} } \value{ A "ggplot" (or "patchwork") object, or - if \code{kind = "no"} - a named diff --git a/packaging.R b/packaging.R index ab02ecb..bca7388 100644 --- a/packaging.R +++ b/packaging.R @@ -15,7 +15,7 @@ library(usethis) use_description( fields = list( Title = "SHAP Visualizations", - Version = "0.9.3", + Version = "0.9.4", Description = "Visualizations for SHAP (SHapley Additive exPlanations), such as waterfall plots, force plots, various types of importance plots, dependence plots, and interaction plots. diff --git a/tests/testthat/test-plots-mshapviz.R b/tests/testthat/test-plots-mshapviz.R index f4d8833..a7d7a48 100644 --- a/tests/testthat/test-plots-mshapviz.R +++ b/tests/testthat/test-plots-mshapviz.R @@ -108,13 +108,13 @@ test_that("plots work for non-syntactic column names", { ) }) -test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", { - X_pred <- data.matrix(iris[, -1L]) - dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1) - fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L) - x <- shapviz(fit, X_pred = X_pred, interactions = TRUE) - x <- c(m1 = x, m2 = x) +X_pred <- data.matrix(iris[, -1L]) +dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1) +fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L) +x <- shapviz(fit, X_pred = X_pred, interactions = TRUE) +x <- c(m1 = x, m2 = x) +test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", { imp <- sv_importance(x, kind = "no") expect_true(is.matrix(imp) && all(dim(imp) == c(4L, length(x)))) @@ -122,6 +122,17 @@ test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", { expect_true(is.list(inter) && all(dim(inter[[1L]]) == rep(ncol(X_pred), 2L))) }) + +test_that("sv_importance() and sv_interaction() respect sort_features = FALSE", { + imp <- sv_importance(x, kind = "no", sort_features = FALSE) + expect_true(all(rownames(imp) == colnames(x$m1))) + + inter <- sv_interaction(x, kind = "no", sort_features = FALSE) + expect_true(all(rownames(inter$m1) == colnames(x$m1))) +}) + + + test_that("sv_dependence() does not work with multiple v", { X_pred <- data.matrix(iris[, -1L]) dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1) diff --git a/tests/testthat/test-plots-shapviz.R b/tests/testthat/test-plots-shapviz.R index 2aea6bf..b255e34 100644 --- a/tests/testthat/test-plots-shapviz.R +++ b/tests/testthat/test-plots-shapviz.R @@ -173,3 +173,16 @@ test_that("sv_importance() and sv_interaction() and kind = 'no' gives numeric ou expect_true(is.numeric(inter) && all(dim(inter) == rep(ncol(X_pred), 2L))) }) +test_that("sv_importance() and sv_interaction() respect sort_features = FALSE", { + X_pred <- data.matrix(iris[, -1L]) + dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1) + fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L) + x <- shapviz(fit, X_pred = X_pred, interactions = TRUE) + + imp <- sv_importance(x, kind = "no", sort_features = FALSE) + expect_true(all(names(imp) == colnames(x))) + + inter <- sv_interaction(x, kind = "no", sort_features = FALSE) + expect_true(all(names(inter) == colnames(x))) +}) +