diff --git a/NEWS.md b/NEWS.md index 3b98bae..7b8a946 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,14 +1,17 @@ # shapviz 0.10.0 -### New feature +### New visualization `sv_interaction()`: New `kind = "bar"` to show mean absolute SHAP interactions/main effects as barplots. Modify via `fill` and `bar_width` arguments [#169](https://github.com/ModelOriented/shapviz/pull/169). ### User-visible changes -- `sv_interaction()`: If applied to a "mshapviz" object, we use {patchwork} functionality to collect guides and axis titles. +- We are now (cautiously) collecting axes, axis titles, and color guides via {patchwork}. (Currently fails for `sv_force()`.) +### Minor API changes + +- In `sv_dependence()`, passing the same variable for `v` and `color_var` does not suppress the color axis anymore, except when `interactions = TRUE`. ### Maintenance diff --git a/R/sv_dependence.R b/R/sv_dependence.R index f5fc22f..d3c89c6 100644 --- a/R/sv_dependence.R +++ b/R/sv_dependence.R @@ -30,7 +30,7 @@ #' (Numeric variables are considered discrete if they have at most 7 unique values.) #' Can be a vector/list if `v` is a vector. #' @param interactions Should SHAP interaction values be plotted? Default is `FALSE`. -#' Requires SHAP interaction values. If `color_var = NULL` (or it is equal to `v`), +#' Requires SHAP interaction values. If `color_var = NULL` (or is equal to `v`), #' the pure main effect of `v` is visualized. Otherwise, twice the SHAP interaction #' values between `v` and the `color_var` are plotted. #' @param ih_nbins,ih_color_num,ih_scale,ih_adjusted Interaction heuristic (ih) @@ -40,7 +40,8 @@ #' @returns An object of class "ggplot" (or "patchwork") representing a dependence plot. #' @examples #' dtrain <- xgboost::xgb.DMatrix( -#' data.matrix(iris[, -1]), label = iris[, 1], nthread = 1 +#' data.matrix(iris[, -1]), +#' label = iris[, 1], nthread = 1 #' ) #' fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1) #' x <- shapviz(fit, X_pred = dtrain, X = iris) @@ -54,7 +55,12 @@ #' x2 <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE) #' sv_dependence(x2, "Petal.Length", interactions = TRUE) #' sv_dependence( -#' x2, c("Petal.Length", "Species"), color_var = NULL, interactions = TRUE +#' x2, c("Petal.Length", "Species"), +#' color_var = NULL, interactions = TRUE +#' ) +#' sv_dependence( +#' x2, "Petal.Length", +#' color_var = colnames(iris[-1]), interactions = TRUE #' ) #' @export #' @seealso [potential_interactions()] @@ -72,42 +78,168 @@ sv_dependence.default <- function(object, ...) { #' @describeIn sv_dependence #' SHAP dependence plot for "shapviz" object. #' @export -sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528b", - viridis_args = getOption("shapviz.viridis_args"), - jitter_width = NULL, interactions = FALSE, - ih_nbins = NULL, ih_color_num = TRUE, - ih_scale = FALSE, ih_adjusted = FALSE, ...) { - p <- length(v) - if (p > 1L || length(color_var) > 1L) { - if (is.null(color_var)) { - color_var <- replicate(p, NULL) - } - if (is.null(jitter_width)) { - jitter_width <- replicate(p, NULL) - } - plot_list <- mapply( - FUN = sv_dependence, +sv_dependence.shapviz <- function( + object, + v, + color_var = "auto", + color = "#3b528b", + viridis_args = getOption("shapviz.viridis_args"), + jitter_width = NULL, + interactions = FALSE, + ih_nbins = NULL, + ih_color_num = TRUE, + ih_scale = FALSE, + ih_adjusted = FALSE, + ...) { + nv <- length(v) + if (nv == 1L && length(color_var) <= 1L) { + p <- .one_dependence_plot( + object = object, v = v, color_var = color_var, color = color, + viridis_args = viridis_args, jitter_width = jitter_width, - MoreArgs = list( - object = object, - viridis_args = viridis_args, - interactions = interactions, - ih_nbins = ih_nbins, - ih_color_num = ih_color_num, - ih_scale = ih_scale, - ih_adjusted = ih_adjusted, - ... - ), - SIMPLIFY = FALSE - ) - nms <- if (length(v) > 1L) v - plot_list <- add_titles(plot_list, nms = nms) # see sv_waterfall() - return(patchwork::wrap_plots(plot_list)) + interactions = interactions, + ih_nbins = ih_nbins, + ih_color_num = ih_color_num, + ih_scale = ih_scale, + ih_adjusted = ih_adjusted, + ... + )$p + return(p) } + if (is.null(color_var)) { + color_var <- replicate(nv, NULL) + } + if (is.null(jitter_width)) { + jitter_width <- replicate(nv, NULL) + } + out_list <- mapply( + FUN = .one_dependence_plot, + v = v, + color_var = color_var, + color = color, + jitter_width = jitter_width, + MoreArgs = list( + object = object, + viridis_args = viridis_args, + interactions = interactions, + ih_nbins = ih_nbins, + ih_color_num = ih_color_num, + ih_scale = ih_scale, + ih_adjusted = ih_adjusted, + ... + ), + SIMPLIFY = FALSE + ) + + # Reorganize output + plot_list <- lapply(out_list, `[[`, "p") + y_labs <- vapply(out_list, `[[`, "y_lab", FUN.VALUE = character(1L)) + has_keys <- vapply(out_list, `[[`, "color_key", FUN.VALUE = logical(1L)) + color_vars <- lapply(out_list, `[[`, "color_var") # Elements NULL <=> has_keys = FALSE + + # Add titles if v varies + plot_list <- add_titles(plot_list, nms = if (nv > 1L) v) # see sv_waterfall() + + # Which aspects can be collected? + nvu <- length(unique(v)) + nlab <- length(unique(y_labs)) + + axis_titles <- axes <- guides <- "keep" + if (nvu == 1L && nlab == 1L) { + axis_titles <- "collect" + } else if (nvu == 1L) { + axis_titles <- "collect_x" + } else if (nlab == 1L) { + axis_titles <- "collect_y" + } + if (nvu == 1L) { + axes <- if (isFALSE(interactions)) "collect" else "collect_x" + } + if (isFALSE(interactions) && length(unique(color_vars[has_keys])) <= 1L) { + guides <- "collect" + } + + p <- patchwork::wrap_plots( + plot_list, + axis_titles = axis_titles, axes = axes, guides = guides + ) + + return(p) +} + + +#' @describeIn sv_dependence +#' SHAP dependence plot for "mshapviz" object. +#' @export +sv_dependence.mshapviz <- function( + object, + v, + color_var = "auto", + color = "#3b528b", + viridis_args = getOption("shapviz.viridis_args"), + jitter_width = NULL, + interactions = FALSE, + ih_nbins = NULL, + ih_color_num = TRUE, + ih_scale = FALSE, + ih_adjusted = FALSE, + ...) { + stopifnot( + length(v) == 1L, + length(color_var) <= 1L + ) + out_list <- lapply( + object, + FUN = .one_dependence_plot, + # Argument list (simplify via match.call() or some rlang magic?) + v = v, + color_var = color_var, + color = color, + viridis_args = viridis_args, + jitter_width = jitter_width, + interactions = interactions, + ih_nbins = ih_nbins, + ih_color_num = ih_color_num, + ih_scale = ih_scale, + ih_adjusted = ih_adjusted, + ... + ) + plot_list <- lapply(out_list, `[[`, "p") + plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall() + p <- patchwork::wrap_plots(plot_list, axis_titles = "collect") + + return(p) +} + +# Helper functions + +# Checks if z is discrete +.is_discrete <- function(z, n_unique) { + is.factor(z) || is.character(z) || is.logical(z) || (length(unique(z)) <= n_unique) +} + +# Returns a list with the following elements: +# - p: the ggplot object +# - color_var: the feature used for coloring (or NULL) +# - color_key: whether a color key is present (TRUE/FALSE) +# - y_lab: the y-axis label +.one_dependence_plot <- function( + object, + v, + color_var, + color, + viridis_args, + jitter_width, + interactions, + ih_nbins, + ih_color_num, + ih_scale, + ih_adjusted, + ...) { S <- get_shap_values(object) X <- get_feature_values(object) S_inter <- get_shap_interactions(object) @@ -116,7 +248,7 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528 v %in% nms, is.null(color_var) || (color_var %in% c("auto", nms)) ) - if (interactions && is.null(S_inter)) { + if (isTRUE(interactions) && is.null(S_inter)) { stop("No SHAP interaction values available in 'object'.") } @@ -137,87 +269,60 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528 ) # 'scores' can be NULL, or a sorted vector like c(0.1, 0, -0.01, NA) # Thus, let's take the first positive one (or NULL) - scores <- scores[!is.na(scores) & scores > 0] # NULL stays NULL + scores <- scores[!is.na(scores) & scores > 0] # NULL stays NULL color_var <- if (length(scores) >= 1L) names(scores)[1L] } if (isTRUE(interactions)) { - if (is.null(color_var)) { - color_var <- v + if (!is.null(color_var) && color_var == v) { + color_var <- NULL } - if (color_var == v) { + if (is.null(color_var)) { # we want to show the main effect y_lab <- "SHAP main effect" + s <- S_inter[, v, v] } else { y_lab <- "SHAP interaction" - } - s <- S_inter[, v, color_var] - if (color_var != v) { - s <- 2 * s # Off-diagonals need to be multiplied by 2 for symmetry reasons + s <- S_inter[, v, color_var] + S_inter[, color_var, v] # symmetry } } else { y_lab <- "SHAP value" s <- S[, v] } + + # Create data.frame with SHAP values and features values of v (no color yet) dat <- data.frame(s, X[[v]]) colnames(dat) <- c("shap", v) - if (is.null(color_var) || color_var == v) { + + color_key <- !is.null(color_var) + + # No color axis if color_var is NULL + if (!color_key) { p <- ggplot2::ggplot(dat, ggplot2::aes(x = .data[[v]], y = shap)) + ggplot2::geom_jitter(color = color, width = jitter_width, height = 0, ...) + ggplot2::ylab(y_lab) - return(p) - } - dat[[color_var]] <- X[[color_var]] - if (.is_discrete(dat[[color_var]], n_unique = 0L)) { # only if non-numeric - vir <- ggplot2::scale_color_viridis_d } else { - vir <- ggplot2::scale_color_viridis_c - } - if (is.null(viridis_args)) { - viridis_args <- list() + dat[[color_var]] <- X[[color_var]] + if (.is_discrete(dat[[color_var]], n_unique = 0L)) { # only if non-numeric + vir <- ggplot2::scale_color_viridis_d + } else { + vir <- ggplot2::scale_color_viridis_c + } + if (is.null(viridis_args)) { + viridis_args <- list() + } + p <- ggplot2::ggplot( + dat, ggplot2::aes(x = .data[[v]], y = shap, color = .data[[color_var]]) + ) + + ggplot2::geom_jitter(width = jitter_width, height = 0, ...) + + ggplot2::ylab(y_lab) + + do.call(vir, viridis_args) + + ggplot2::theme(legend.box.spacing = grid::unit(0, "pt")) } - ggplot2::ggplot( - dat, ggplot2::aes(x = .data[[v]], y = shap, color = .data[[color_var]]) - ) + - ggplot2::geom_jitter(width = jitter_width, height = 0, ...) + - ggplot2::ylab(y_lab) + - do.call(vir, viridis_args) + - ggplot2::theme(legend.box.spacing = grid::unit(0, "pt")) -} - -#' @describeIn sv_dependence -#' SHAP dependence plot for "mshapviz" object. -#' @export -sv_dependence.mshapviz <- function(object, v, color_var = "auto", color = "#3b528b", - viridis_args = getOption("shapviz.viridis_args"), - jitter_width = NULL, interactions = FALSE, - ih_nbins = NULL, ih_color_num = TRUE, - ih_scale = FALSE, ih_adjusted = FALSE, ...) { - stopifnot( - length(v) == 1L, - length(color_var) <= 1L - ) - plot_list <- lapply( - object, - FUN = sv_dependence, - # Argument list (simplify via match.call() or some rlang magic?) - v = v, + out <- list( + p = p, color_var = color_var, - color = color, - viridis_args = viridis_args, - jitter_width = jitter_width, - interactions = interactions, - ih_nbins = ih_nbins, - ih_color_num = ih_color_num, - ih_scale = ih_scale, - ih_adjusted = ih_adjusted, - ... + color_key = color_key, + y_lab = y_lab ) - plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall() - patchwork::wrap_plots(plot_list) -} - -# Helper functions -# Checks if z is discrete -.is_discrete <- function(z, n_unique) { - is.factor(z) || is.character(z) || is.logical(z) || (length(unique(z)) <= n_unique) + return(out) } diff --git a/R/sv_dependence2D.R b/R/sv_dependence2D.R index 3b871ac..825d1e1 100644 --- a/R/sv_dependence2D.R +++ b/R/sv_dependence2D.R @@ -14,10 +14,8 @@ #' #' @inheritParams sv_dependence #' @inheritParams sv_importance -#' @param x Feature name for x axis. Can be a vector/list if `object` is -#' of class "shapviz". -#' @param y Feature name for y axis. Can be a vector/list if `object` is -#' of class "shapviz". +#' @param x Feature name for x axis. Can be a vector if `object` is of class "shapviz". +#' @param y Feature name for y axis. Can be a vector if `object` is of class "shapviz". #' @param jitter_height Similar to `jitter_width` for vertical scatter. #' @param interactions Should SHAP interaction values be plotted? The default (`FALSE`) #' will show the rowwise sum of the SHAP values of `x` and `y`. If `TRUE`, will @@ -30,7 +28,8 @@ #' @returns An object of class "ggplot" (or "patchwork") representing a dependence plot. #' @examples #' dtrain <- xgboost::xgb.DMatrix( -#' data.matrix(iris[, -1]), label = iris[, 1], nthread = 1 +#' data.matrix(iris[, -1]), +#' label = iris[, 1], nthread = 1 #' ) #' fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1) #' sv <- shapviz(fit, X_pred = dtrain, X = iris) @@ -41,7 +40,8 @@ #' sv2 <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE) #' sv_dependence2D(sv2, x = "Petal.Length", y = "Species", interactions = TRUE) #' sv_dependence2D( -#' sv2, x = "Petal.Length", y = c("Species", "Petal.Width"), interactions = TRUE +#' sv2, +#' x = "Petal.Length", y = c("Species", "Petal.Width"), interactions = TRUE #' ) #' #' # mshapviz object @@ -63,37 +63,119 @@ sv_dependence2D.default <- function(object, ...) { #' @describeIn sv_dependence2D #' 2D SHAP dependence plot for "shapviz" object. #' @export -sv_dependence2D.shapviz <- function(object, x, y, - viridis_args = getOption("shapviz.viridis_args"), - jitter_width = NULL, jitter_height = NULL, - interactions = FALSE, add_vars = NULL, ...) { - p <- max(length(x), length(y)) - if (p > 1L) { - if (is.null(jitter_width)) { - jitter_width <- replicate(p, NULL) - } - if (is.null(jitter_height)) { - jitter_height <- replicate(p, NULL) - } - plot_list <- mapply( - FUN = sv_dependence2D, +sv_dependence2D.shapviz <- function( + object, + x, + y, + viridis_args = getOption("shapviz.viridis_args"), + jitter_width = NULL, + jitter_height = NULL, + interactions = FALSE, + add_vars = NULL, + ...) { + nx <- length(x) + ny <- length(y) + nplots <- max(nx, ny) + + if (nplots == 1L) { + p <- .one_dependence2D_plot( + object = object, x = x, y = y, + viridis_args = viridis_args, jitter_width = jitter_width, jitter_height = jitter_height, - MoreArgs = list( - object = object, - viridis_args = viridis_args, - interactions = interactions, - ... - ), - SIMPLIFY = FALSE + interactions = interactions, + add_vars = add_vars, + ... ) - return(patchwork::wrap_plots(plot_list)) + return(p) + } + if (is.null(jitter_width)) { + jitter_width <- replicate(nplots, NULL) + } + if (is.null(jitter_height)) { + jitter_height <- replicate(nplots, NULL) + } + plot_list <- mapply( + FUN = .one_dependence2D_plot, + x = x, + y = y, + jitter_width = jitter_width, + jitter_height = jitter_height, + MoreArgs = list( + object = object, + viridis_args = viridis_args, + interactions = interactions, + add_vars = add_vars, + ... + ), + SIMPLIFY = FALSE + ) + + # if nx == 1 and ny == 1, we can't reach here + if (nx == 1L) { + strategy <- "collect_x" + } else if (ny == 1L) { + strategy <- "collect_y" + } else { + strategy <- "keep" } + p <- patchwork::wrap_plots(plot_list, axis_titles = strategy, axes = strategy) + + return(p) +} + +#' @describeIn sv_dependence2D +#' 2D SHAP dependence plot for "mshapviz" object. +#' @export +sv_dependence2D.mshapviz <- function( + object, + x, + y, + viridis_args = getOption("shapviz.viridis_args"), + jitter_width = NULL, + jitter_height = NULL, + interactions = FALSE, + add_vars = NULL, + ...) { + stopifnot( + length(x) == 1L, + length(y) == 1L + ) + plot_list <- lapply( + object, + FUN = .one_dependence2D_plot, + # Argument list (simplify via match.call() or some rlang magic?) + x = x, + y = y, + viridis_args = viridis_args, + jitter_width = jitter_width, + jitter_height = jitter_height, + interactions = interactions, + add_vars = add_vars, + ... + ) + plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall() + p <- patchwork::wrap_plots(plot_list, axis_titles = "collect") + return(p) +} + +# Helper function +.one_dependence2D_plot <- function( + object, + x, + y, + viridis_args, + jitter_width, + jitter_height, + interactions, + add_vars, + ...) { S <- get_shap_values(object) X <- get_feature_values(object) + S_inter <- get_shap_interactions(object) nms <- colnames(object) stopifnot( @@ -114,50 +196,25 @@ sv_dependence2D.shapviz <- function(object, x, y, } # Color variable - if (!interactions) { - s <- rowSums(S[, unique(c(x, y, add_vars))]) # unique() if add_vars contains x or y + if (isFALSE(interactions)) { + s <- rowSums(S[, unique(c(x, y, add_vars))]) } else { - s <- S_inter[, x, y] - if (x != y) { - s <- 2 * s # Off-diagonals need to be multiplied by 2 for symmetry reasons - } + s <- S_inter[, x, y] + if (x != y) S_inter[, y, x] else 0 # symmetry } dat <- data.frame(SHAP = s, X[, c(x, y)], check.names = FALSE) vir <- ggplot2::scale_color_viridis_c if (is.null(viridis_args)) { viridis_args <- list() } - ggplot2::ggplot(dat, ggplot2::aes(x = .data[[x]], y = .data[[y]], color = SHAP)) + + p <- ggplot2::ggplot( + dat, ggplot2::aes(x = .data[[x]], y = .data[[y]], color = SHAP) + ) + ggplot2::geom_jitter(width = jitter_width, height = jitter_height, ...) + do.call(vir, viridis_args) + - ggplot2::theme(legend.box.spacing = grid::unit(0, "pt")) -} + ggplot2::theme( + legend.box.spacing = grid::unit(0, "pt"), + legend.key.width = grid::unit(12, "pt") + ) -#' @describeIn sv_dependence2D -#' 2D SHAP dependence plot for "mshapviz" object. -#' @export -sv_dependence2D.mshapviz <- function(object, x, y, - viridis_args = getOption("shapviz.viridis_args"), - jitter_width = NULL, jitter_height = NULL, - interactions = FALSE, add_vars = NULL, ...) { - stopifnot( - length(x) == 1L, - length(y) == 1L - ) - plot_list <- lapply( - object, - FUN = sv_dependence2D, - # Argument list (simplify via match.call() or some rlang magic?) - x = x, - y = y, - viridis_args = viridis_args, - jitter_width = jitter_width, - jitter_height = jitter_height, - interactions = interactions, - add_vars = add_vars, - ... - ) - plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall() - patchwork::wrap_plots(plot_list) + return(p) } - diff --git a/R/sv_force.R b/R/sv_force.R index a90ecc6..2bd7258 100644 --- a/R/sv_force.R +++ b/R/sv_force.R @@ -12,7 +12,8 @@ #' @returns An object of class "ggplot" (or "patchwork") representing a force plot. #' @examples #' dtrain <- xgboost::xgb.DMatrix( -#' data.matrix(iris[, -1]), label = iris[, 1], nthread = 1 +#' data.matrix(iris[, -1]), +#' label = iris[, 1], nthread = 1 #' ) #' fit <- xgboost::xgb.train(data = dtrain, nrounds = 20, nthread = 1) #' x <- shapviz(fit, X_pred = dtrain, X = iris[, -1]) @@ -37,12 +38,18 @@ sv_force.default <- function(object, ...) { #' @describeIn sv_force #' SHAP force plot for object of class "shapviz". #' @export -sv_force.shapviz <- function(object, row_id = 1L, max_display = 6L, - fill_colors = c("#f7d13d", "#a52c60"), - format_shap = getOption("shapviz.format_shap"), - format_feat = getOption("shapviz.format_feat"), - contrast = TRUE, bar_label_size = 3.2, - show_annotation = TRUE, annotation_size = 3.2, ...) { +sv_force.shapviz <- function( + object, + row_id = 1L, + max_display = 6L, + fill_colors = c("#f7d13d", "#a52c60"), + format_shap = getOption("shapviz.format_shap"), + format_feat = getOption("shapviz.format_feat"), + contrast = TRUE, + bar_label_size = 3.2, + show_annotation = TRUE, + annotation_size = 3.2, + ...) { stopifnot( "Exactly two fill colors must be passed" = length(fill_colors) == 2L, "format_shap must be a function" = is.function(format_shap), @@ -140,12 +147,18 @@ sv_force.shapviz <- function(object, row_id = 1L, max_display = 6L, #' @describeIn sv_force #' SHAP force plot for object of class "mshapviz". #' @export -sv_force.mshapviz <- function(object, row_id = 1L, max_display = 6L, - fill_colors = c("#f7d13d", "#a52c60"), - format_shap = getOption("shapviz.format_shap"), - format_feat = getOption("shapviz.format_feat"), - contrast = TRUE, bar_label_size = 3.2, - show_annotation = TRUE, annotation_size = 3.2, ...) { +sv_force.mshapviz <- function( + object, + row_id = 1L, + max_display = 6L, + fill_colors = c("#f7d13d", "#a52c60"), + format_shap = getOption("shapviz.format_shap"), + format_feat = getOption("shapviz.format_feat"), + contrast = TRUE, + bar_label_size = 3.2, + show_annotation = TRUE, + annotation_size = 3.2, + ...) { plot_list <- lapply( object, FUN = sv_force, @@ -161,7 +174,9 @@ sv_force.mshapviz <- function(object, row_id = 1L, max_display = 6L, annotation_size = annotation_size, ... ) - plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall() - patchwork::wrap_plots(plot_list) + - patchwork::plot_layout(ncol = 1L) + plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall() + + # Currently, collecting axes titles does not work (but sv_waterfall() is ok) + p <- patchwork::wrap_plots(plot_list, ncol = 1L, axis_titles = "collect_x") + return(p) } diff --git a/R/sv_importance.R b/R/sv_importance.R index f9d44a5..16099ac 100644 --- a/R/sv_importance.R +++ b/R/sv_importance.R @@ -70,13 +70,21 @@ sv_importance.default <- function(object, ...) { #' @describeIn sv_importance #' SHAP importance plot for an object of class "shapviz". #' @export -sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "no"), - max_display = 15L, fill = "#fca50a", bar_width = 2/3, - bee_width = 0.4, bee_adjust = 0.5, - viridis_args = getOption("shapviz.viridis_args"), - color_bar_title = "Feature value", - show_numbers = FALSE, format_fun = format_max, - number_size = 3.2, sort_features = TRUE, ...) { +sv_importance.shapviz <- function( + object, + kind = c("bar", "beeswarm", "both", "no"), + max_display = 15L, + fill = "#fca50a", + bar_width = 2 / 3, + bee_width = 0.4, + bee_adjust = 0.5, + viridis_args = getOption("shapviz.viridis_args"), + color_bar_title = "Feature value", + show_numbers = FALSE, + format_fun = format_max, + number_size = 3.2, + sort_features = TRUE, + ...) { stopifnot("format_fun must be a function" = is.function(format_fun)) kind <- match.arg(kind) imp <- .get_imp(get_shap_values(object), sort_features = sort_features) @@ -90,7 +98,7 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n imp <- imp[seq_len(max_display)] } ord <- names(imp) - object <- object[, ord] # not required for kind = "bar" + object <- object[, ord] # not required for kind = "bar" # ggplot will need to work with data.frame imp_df <- data.frame(feature = factor(ord, rev(ord)), value = imp) @@ -126,7 +134,7 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n .get_color_scale( viridis_args = viridis_args, bar = !is.null(color_bar_title), - ncol = length(unique(df$color)) # Special case of constant feature values + ncol = length(unique(df$color)) # Special case of constant feature values ) + ggplot2::labs( x = "SHAP value", y = ggplot2::element_blank(), color = color_bar_title @@ -138,15 +146,18 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n ggplot2::geom_text( data = imp_df, ggplot2::aes( - x = if (is_bar) value + max(value) / 60 else - min(df$value) - diff(range(df$value)) / 20, + x = if (is_bar) { + value + max(value) / 60 + } else { + min(df$value) - diff(range(df$value)) / 20 + }, label = format_fun(value) ), hjust = !is_bar, size = number_size ) + ggplot2::scale_x_continuous( - expand = ggplot2::expansion(mult = 0.05 + c(0.12 *!is_bar, 0.09 * is_bar)) + expand = ggplot2::expansion(mult = 0.05 + c(0.12 * !is_bar, 0.09 * is_bar)) ) } p @@ -155,15 +166,22 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n #' @describeIn sv_importance #' SHAP importance plot for an object of class "mshapviz". #' @export -sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "no"), - max_display = 15L, fill = "#fca50a", - bar_width = 2/3, - bar_type = c("dodge", "stack", "facets", "separate"), - bee_width = 0.4, bee_adjust = 0.5, - viridis_args = getOption("shapviz.viridis_args"), - color_bar_title = "Feature value", - show_numbers = FALSE, format_fun = format_max, - number_size = 3.2, sort_features = TRUE, ...) { +sv_importance.mshapviz <- function( + object, + kind = c("bar", "beeswarm", "both", "no"), + max_display = 15L, + fill = "#fca50a", + bar_width = 2 / 3, + bar_type = c("dodge", "stack", "facets", "separate"), + bee_width = 0.4, + bee_adjust = 0.5, + viridis_args = getOption("shapviz.viridis_args"), + color_bar_title = "Feature value", + show_numbers = FALSE, + format_fun = format_max, + number_size = 3.2, + sort_features = TRUE, + ...) { kind <- match.arg(kind) bar_type <- match.arg(bar_type) @@ -197,7 +215,7 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", " ggplot2::labs(fill = ggplot2::element_blank()) + do.call(ggplot2::scale_fill_viridis_d, viridis_args) + ggplot2::guides(fill = ggplot2::guide_legend(reverse = TRUE)) - } else { # facets + } else { # facets p <- ggplot2::ggplot(imp_df, ggplot2::aes(x = values, y = feature)) + ggplot2::geom_bar(fill = fill, width = bar_width, stat = "identity", ...) + ggplot2::facet_wrap("ind") @@ -230,8 +248,9 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", " if (kind == "no") { return(plot_list) } - plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall() - patchwork::wrap_plots(plot_list) + plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall() + p <- patchwork::wrap_plots(plot_list, axis_titles = "collect_x", guides = "collect") + return(p) } # Helper functions @@ -242,7 +261,7 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", " z[!is.na(z)] <- 0.5 return(z) } - (z - r[1L]) /(r[2L] - r[1L]) + (z - r[1L]) / (r[2L] - r[1L]) } .get_imp <- function(z, sort_features = TRUE) { diff --git a/R/sv_interaction.R b/R/sv_interaction.R index fbaa356..672abe4 100644 --- a/R/sv_interaction.R +++ b/R/sv_interaction.R @@ -185,7 +185,7 @@ sv_interaction.mshapviz <- function( return(plot_list) } plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall() - p <- patchwork::wrap_plots(plot_list, axis_titles = "collect", guides = "collect") + p <- patchwork::wrap_plots(plot_list, axis_titles = "collect_x", guides = "collect") return(p) } diff --git a/R/sv_waterfall.R b/R/sv_waterfall.R index 173edf7..77b619b 100644 --- a/R/sv_waterfall.R +++ b/R/sv_waterfall.R @@ -36,7 +36,8 @@ #' @returns An object of class "ggplot" (or "patchwork") representing a waterfall plot. #' @examples #' dtrain <- xgboost::xgb.DMatrix( -#' data.matrix(iris[, -1]), label = iris[, 1], nthread = 1 +#' data.matrix(iris[, -1]), +#' label = iris[, 1], nthread = 1 #' ) #' fit <- xgboost::xgb.train(data = dtrain, nrounds = 20, nthread = 1) #' x <- shapviz(fit, X_pred = dtrain, X = iris[, -1]) @@ -45,7 +46,8 @@ #' #' # Ordered by colnames(x), combined with max_display #' sv_waterfall( -#' x[, sort(colnames(x))], order_fun = function(s) length(s):1, max_display = 3 +#' x[, sort(colnames(x))], +#' order_fun = function(s) length(s):1, max_display = 3 #' ) #' #' # Aggregate over all observations with Petal.Length == 1.4 @@ -66,13 +68,19 @@ sv_waterfall.default <- function(object, ...) { #' @describeIn sv_waterfall #' SHAP waterfall plot for an object of class "shapviz". #' @export -sv_waterfall.shapviz <- function(object, row_id = 1L, max_display = 10L, - order_fun = function(s) order(abs(s)), - fill_colors = c("#f7d13d", "#a52c60"), - format_shap = getOption("shapviz.format_shap"), - format_feat = getOption("shapviz.format_feat"), - contrast = TRUE, show_connection = TRUE, - show_annotation = TRUE, annotation_size = 3.2, ...) { +sv_waterfall.shapviz <- function( + object, + row_id = 1L, + max_display = 10L, + order_fun = function(s) order(abs(s)), + fill_colors = c("#f7d13d", "#a52c60"), + format_shap = getOption("shapviz.format_shap"), + format_feat = getOption("shapviz.format_feat"), + contrast = TRUE, + show_connection = TRUE, + show_annotation = TRUE, + annotation_size = 3.2, + ...) { stopifnot( "Exactly two fill colors must be passed" = length(fill_colors) == 2L, "format_shap must be a function" = is.function(format_shap), @@ -166,13 +174,19 @@ sv_waterfall.shapviz <- function(object, row_id = 1L, max_display = 10L, #' @describeIn sv_waterfall #' SHAP waterfall plot for an object of class "mshapviz". #' @export -sv_waterfall.mshapviz <- function(object, row_id = 1L, max_display = 10L, - order_fun = function(s) order(abs(s)), - fill_colors = c("#f7d13d", "#a52c60"), - format_shap = getOption("shapviz.format_shap"), - format_feat = getOption("shapviz.format_feat"), - contrast = TRUE, show_connection = TRUE, - show_annotation = TRUE, annotation_size = 3.2, ...) { +sv_waterfall.mshapviz <- function( + object, + row_id = 1L, + max_display = 10L, + order_fun = function(s) order(abs(s)), + fill_colors = c("#f7d13d", "#a52c60"), + format_shap = getOption("shapviz.format_shap"), + format_feat = getOption("shapviz.format_feat"), + contrast = TRUE, + show_connection = TRUE, + show_annotation = TRUE, + annotation_size = 3.2, + ...) { plot_list <- lapply( object, FUN = sv_waterfall, @@ -190,7 +204,8 @@ sv_waterfall.mshapviz <- function(object, row_id = 1L, max_display = 10L, ... ) plot_list <- add_titles(plot_list, nms = names(object)) - patchwork::wrap_plots(plot_list) + p <- patchwork::wrap_plots(plot_list, axis_titles = "collect_x") + return(p) } # Helper functions for sv_waterfall() and sv_force() diff --git a/man/sv_dependence.Rd b/man/sv_dependence.Rd index 50b2fc2..0d416ea 100644 --- a/man/sv_dependence.Rd +++ b/man/sv_dependence.Rd @@ -72,7 +72,7 @@ use a value of 0.2 in case \code{v} is discrete, and no jitter otherwise. Can be a vector/list if \code{v} is a vector.} \item{interactions}{Should SHAP interaction values be plotted? Default is \code{FALSE}. -Requires SHAP interaction values. If \code{color_var = NULL} (or it is equal to \code{v}), +Requires SHAP interaction values. If \code{color_var = NULL} (or is equal to \code{v}), the pure main effect of \code{v} is visualized. Otherwise, twice the SHAP interaction values between \code{v} and the \code{color_var} are plotted.} @@ -101,7 +101,8 @@ By default, the feature on the color scale is selected via SHAP interactions }} \examples{ dtrain <- xgboost::xgb.DMatrix( - data.matrix(iris[, -1]), label = iris[, 1], nthread = 1 + data.matrix(iris[, -1]), + label = iris[, 1], nthread = 1 ) fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1) x <- shapviz(fit, X_pred = dtrain, X = iris) @@ -115,7 +116,12 @@ sv_dependence(x, "Petal.Width", color_var = c("Species", "Petal.Length")) x2 <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE) sv_dependence(x2, "Petal.Length", interactions = TRUE) sv_dependence( - x2, c("Petal.Length", "Species"), color_var = NULL, interactions = TRUE + x2, c("Petal.Length", "Species"), + color_var = NULL, interactions = TRUE +) +sv_dependence( + x2, "Petal.Length", + color_var = colnames(iris[-1]), interactions = TRUE ) } \seealso{ diff --git a/man/sv_dependence2D.Rd b/man/sv_dependence2D.Rd index 46473d9..3cf5393 100644 --- a/man/sv_dependence2D.Rd +++ b/man/sv_dependence2D.Rd @@ -40,11 +40,9 @@ sv_dependence2D(object, ...) \item{...}{Arguments passed to \code{\link[ggplot2:geom_jitter]{ggplot2::geom_jitter()}}.} -\item{x}{Feature name for x axis. Can be a vector/list if \code{object} is -of class "shapviz".} +\item{x}{Feature name for x axis. Can be a vector if \code{object} is of class "shapviz".} -\item{y}{Feature name for y axis. Can be a vector/list if \code{object} is -of class "shapviz".} +\item{y}{Feature name for y axis. Can be a vector if \code{object} is of class "shapviz".} \item{viridis_args}{List of viridis color scale arguments, see \code{?ggplot2::scale_color_viridis_c}. The default points to the global option @@ -95,7 +93,8 @@ has no effect. }} \examples{ dtrain <- xgboost::xgb.DMatrix( - data.matrix(iris[, -1]), label = iris[, 1], nthread = 1 + data.matrix(iris[, -1]), + label = iris[, 1], nthread = 1 ) fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1) sv <- shapviz(fit, X_pred = dtrain, X = iris) @@ -106,7 +105,8 @@ sv_dependence2D(sv, x = c("Petal.Length", "Species"), y = "Sepal.Width") sv2 <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE) sv_dependence2D(sv2, x = "Petal.Length", y = "Species", interactions = TRUE) sv_dependence2D( - sv2, x = "Petal.Length", y = c("Species", "Petal.Width"), interactions = TRUE + sv2, + x = "Petal.Length", y = c("Species", "Petal.Width"), interactions = TRUE ) # mshapviz object diff --git a/man/sv_force.Rd b/man/sv_force.Rd index d0a420f..9272720 100644 --- a/man/sv_force.Rd +++ b/man/sv_force.Rd @@ -97,7 +97,8 @@ baseline SHAP value. }} \examples{ dtrain <- xgboost::xgb.DMatrix( - data.matrix(iris[, -1]), label = iris[, 1], nthread = 1 + data.matrix(iris[, -1]), + label = iris[, 1], nthread = 1 ) fit <- xgboost::xgb.train(data = dtrain, nrounds = 20, nthread = 1) x <- shapviz(fit, X_pred = dtrain, X = iris[, -1]) diff --git a/man/sv_waterfall.Rd b/man/sv_waterfall.Rd index cbafd93..f76a523 100644 --- a/man/sv_waterfall.Rd +++ b/man/sv_waterfall.Rd @@ -103,7 +103,8 @@ baseline SHAP value. }} \examples{ dtrain <- xgboost::xgb.DMatrix( - data.matrix(iris[, -1]), label = iris[, 1], nthread = 1 + data.matrix(iris[, -1]), + label = iris[, 1], nthread = 1 ) fit <- xgboost::xgb.train(data = dtrain, nrounds = 20, nthread = 1) x <- shapviz(fit, X_pred = dtrain, X = iris[, -1]) @@ -112,7 +113,8 @@ sv_waterfall(x, row_id = 123, max_display = 2, size = 9, fill_colors = 4:5) # Ordered by colnames(x), combined with max_display sv_waterfall( - x[, sort(colnames(x))], order_fun = function(s) length(s):1, max_display = 3 + x[, sort(colnames(x))], + order_fun = function(s) length(s):1, max_display = 3 ) # Aggregate over all observations with Petal.Length == 1.4