Skip to content

Collect axes, axis titles, and guides #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# shapviz 0.10.0

### New feature
### New visualization

`sv_interaction()`: New `kind = "bar"` to show mean absolute SHAP interactions/main effects as barplots.
Modify via `fill` and `bar_width` arguments [#169](https://github.com/ModelOriented/shapviz/pull/169).

### User-visible changes

- `sv_interaction()`: If applied to a "mshapviz" object, we use {patchwork} functionality to collect guides and axis titles.
- We are now (cautiously) collecting axes, axis titles, and color guides via {patchwork}. (Currently fails for `sv_force()`.)

### Minor API changes

- In `sv_dependence()`, passing the same variable for `v` and `color_var` does not suppress the color axis anymore, except when `interactions = TRUE`.

### Maintenance

Expand Down
295 changes: 200 additions & 95 deletions R/sv_dependence.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#' (Numeric variables are considered discrete if they have at most 7 unique values.)
#' Can be a vector/list if `v` is a vector.
#' @param interactions Should SHAP interaction values be plotted? Default is `FALSE`.
#' Requires SHAP interaction values. If `color_var = NULL` (or it is equal to `v`),
#' Requires SHAP interaction values. If `color_var = NULL` (or is equal to `v`),
#' the pure main effect of `v` is visualized. Otherwise, twice the SHAP interaction
#' values between `v` and the `color_var` are plotted.
#' @param ih_nbins,ih_color_num,ih_scale,ih_adjusted Interaction heuristic (ih)
Expand All @@ -40,7 +40,8 @@
#' @returns An object of class "ggplot" (or "patchwork") representing a dependence plot.
#' @examples
#' dtrain <- xgboost::xgb.DMatrix(
#' data.matrix(iris[, -1]), label = iris[, 1], nthread = 1
#' data.matrix(iris[, -1]),
#' label = iris[, 1], nthread = 1
#' )
#' fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1)
#' x <- shapviz(fit, X_pred = dtrain, X = iris)
Expand All @@ -54,7 +55,12 @@
#' x2 <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE)
#' sv_dependence(x2, "Petal.Length", interactions = TRUE)
#' sv_dependence(
#' x2, c("Petal.Length", "Species"), color_var = NULL, interactions = TRUE
#' x2, c("Petal.Length", "Species"),
#' color_var = NULL, interactions = TRUE
#' )
#' sv_dependence(
#' x2, "Petal.Length",
#' color_var = colnames(iris[-1]), interactions = TRUE
#' )
#' @export
#' @seealso [potential_interactions()]
Expand All @@ -72,42 +78,168 @@ sv_dependence.default <- function(object, ...) {
#' @describeIn sv_dependence
#' SHAP dependence plot for "shapviz" object.
#' @export
sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL, interactions = FALSE,
ih_nbins = NULL, ih_color_num = TRUE,
ih_scale = FALSE, ih_adjusted = FALSE, ...) {
p <- length(v)
if (p > 1L || length(color_var) > 1L) {
if (is.null(color_var)) {
color_var <- replicate(p, NULL)
}
if (is.null(jitter_width)) {
jitter_width <- replicate(p, NULL)
}
plot_list <- mapply(
FUN = sv_dependence,
sv_dependence.shapviz <- function(
object,
v,
color_var = "auto",
color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL,
interactions = FALSE,
ih_nbins = NULL,
ih_color_num = TRUE,
ih_scale = FALSE,
ih_adjusted = FALSE,
...) {
nv <- length(v)
if (nv == 1L && length(color_var) <= 1L) {
p <- .one_dependence_plot(
object = object,
v = v,
color_var = color_var,
color = color,
viridis_args = viridis_args,
jitter_width = jitter_width,
MoreArgs = list(
object = object,
viridis_args = viridis_args,
interactions = interactions,
ih_nbins = ih_nbins,
ih_color_num = ih_color_num,
ih_scale = ih_scale,
ih_adjusted = ih_adjusted,
...
),
SIMPLIFY = FALSE
)
nms <- if (length(v) > 1L) v
plot_list <- add_titles(plot_list, nms = nms) # see sv_waterfall()
return(patchwork::wrap_plots(plot_list))
interactions = interactions,
ih_nbins = ih_nbins,
ih_color_num = ih_color_num,
ih_scale = ih_scale,
ih_adjusted = ih_adjusted,
...
)$p
return(p)
}

if (is.null(color_var)) {
color_var <- replicate(nv, NULL)
}
if (is.null(jitter_width)) {
jitter_width <- replicate(nv, NULL)
}
out_list <- mapply(
FUN = .one_dependence_plot,
v = v,
color_var = color_var,
color = color,
jitter_width = jitter_width,
MoreArgs = list(
object = object,
viridis_args = viridis_args,
interactions = interactions,
ih_nbins = ih_nbins,
ih_color_num = ih_color_num,
ih_scale = ih_scale,
ih_adjusted = ih_adjusted,
...
),
SIMPLIFY = FALSE
)

# Reorganize output
plot_list <- lapply(out_list, `[[`, "p")
y_labs <- vapply(out_list, `[[`, "y_lab", FUN.VALUE = character(1L))
has_keys <- vapply(out_list, `[[`, "color_key", FUN.VALUE = logical(1L))
color_vars <- lapply(out_list, `[[`, "color_var") # Elements NULL <=> has_keys = FALSE

# Add titles if v varies
plot_list <- add_titles(plot_list, nms = if (nv > 1L) v) # see sv_waterfall()

# Which aspects can be collected?
nvu <- length(unique(v))
nlab <- length(unique(y_labs))

axis_titles <- axes <- guides <- "keep"
if (nvu == 1L && nlab == 1L) {
axis_titles <- "collect"
} else if (nvu == 1L) {
axis_titles <- "collect_x"
} else if (nlab == 1L) {
axis_titles <- "collect_y"
}
if (nvu == 1L) {
axes <- if (isFALSE(interactions)) "collect" else "collect_x"
}
if (isFALSE(interactions) && length(unique(color_vars[has_keys])) <= 1L) {
guides <- "collect"
}

p <- patchwork::wrap_plots(
plot_list,
axis_titles = axis_titles, axes = axes, guides = guides
)

return(p)
}


#' @describeIn sv_dependence
#' SHAP dependence plot for "mshapviz" object.
#' @export
sv_dependence.mshapviz <- function(
object,
v,
color_var = "auto",
color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL,
interactions = FALSE,
ih_nbins = NULL,
ih_color_num = TRUE,
ih_scale = FALSE,
ih_adjusted = FALSE,
...) {
stopifnot(
length(v) == 1L,
length(color_var) <= 1L
)
out_list <- lapply(
object,
FUN = .one_dependence_plot,
# Argument list (simplify via match.call() or some rlang magic?)
v = v,
color_var = color_var,
color = color,
viridis_args = viridis_args,
jitter_width = jitter_width,
interactions = interactions,
ih_nbins = ih_nbins,
ih_color_num = ih_color_num,
ih_scale = ih_scale,
ih_adjusted = ih_adjusted,
...
)
plot_list <- lapply(out_list, `[[`, "p")
plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall()
p <- patchwork::wrap_plots(plot_list, axis_titles = "collect")

return(p)
}

# Helper functions

# Checks if z is discrete
.is_discrete <- function(z, n_unique) {
is.factor(z) || is.character(z) || is.logical(z) || (length(unique(z)) <= n_unique)
}

# Returns a list with the following elements:
# - p: the ggplot object
# - color_var: the feature used for coloring (or NULL)
# - color_key: whether a color key is present (TRUE/FALSE)
# - y_lab: the y-axis label
.one_dependence_plot <- function(
object,
v,
color_var,
color,
viridis_args,
jitter_width,
interactions,
ih_nbins,
ih_color_num,
ih_scale,
ih_adjusted,
...) {
S <- get_shap_values(object)
X <- get_feature_values(object)
S_inter <- get_shap_interactions(object)
Expand All @@ -116,7 +248,7 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528
v %in% nms,
is.null(color_var) || (color_var %in% c("auto", nms))
)
if (interactions && is.null(S_inter)) {
if (isTRUE(interactions) && is.null(S_inter)) {
stop("No SHAP interaction values available in 'object'.")
}

Expand All @@ -137,87 +269,60 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528
)
# 'scores' can be NULL, or a sorted vector like c(0.1, 0, -0.01, NA)
# Thus, let's take the first positive one (or NULL)
scores <- scores[!is.na(scores) & scores > 0] # NULL stays NULL
scores <- scores[!is.na(scores) & scores > 0] # NULL stays NULL
color_var <- if (length(scores) >= 1L) names(scores)[1L]
}
if (isTRUE(interactions)) {
if (is.null(color_var)) {
color_var <- v
if (!is.null(color_var) && color_var == v) {
color_var <- NULL
}
if (color_var == v) {
if (is.null(color_var)) { # we want to show the main effect
y_lab <- "SHAP main effect"
s <- S_inter[, v, v]
} else {
y_lab <- "SHAP interaction"
}
s <- S_inter[, v, color_var]
if (color_var != v) {
s <- 2 * s # Off-diagonals need to be multiplied by 2 for symmetry reasons
s <- S_inter[, v, color_var] + S_inter[, color_var, v] # symmetry
}
} else {
y_lab <- "SHAP value"
s <- S[, v]
}

# Create data.frame with SHAP values and features values of v (no color yet)
dat <- data.frame(s, X[[v]])
colnames(dat) <- c("shap", v)
if (is.null(color_var) || color_var == v) {

color_key <- !is.null(color_var)

# No color axis if color_var is NULL
if (!color_key) {
p <- ggplot2::ggplot(dat, ggplot2::aes(x = .data[[v]], y = shap)) +
ggplot2::geom_jitter(color = color, width = jitter_width, height = 0, ...) +
ggplot2::ylab(y_lab)
return(p)
}
dat[[color_var]] <- X[[color_var]]
if (.is_discrete(dat[[color_var]], n_unique = 0L)) { # only if non-numeric
vir <- ggplot2::scale_color_viridis_d
} else {
vir <- ggplot2::scale_color_viridis_c
}
if (is.null(viridis_args)) {
viridis_args <- list()
dat[[color_var]] <- X[[color_var]]
if (.is_discrete(dat[[color_var]], n_unique = 0L)) { # only if non-numeric
vir <- ggplot2::scale_color_viridis_d
} else {
vir <- ggplot2::scale_color_viridis_c
}
if (is.null(viridis_args)) {
viridis_args <- list()
}
p <- ggplot2::ggplot(
dat, ggplot2::aes(x = .data[[v]], y = shap, color = .data[[color_var]])
) +
ggplot2::geom_jitter(width = jitter_width, height = 0, ...) +
ggplot2::ylab(y_lab) +
do.call(vir, viridis_args) +
ggplot2::theme(legend.box.spacing = grid::unit(0, "pt"))
}
ggplot2::ggplot(
dat, ggplot2::aes(x = .data[[v]], y = shap, color = .data[[color_var]])
) +
ggplot2::geom_jitter(width = jitter_width, height = 0, ...) +
ggplot2::ylab(y_lab) +
do.call(vir, viridis_args) +
ggplot2::theme(legend.box.spacing = grid::unit(0, "pt"))
}

#' @describeIn sv_dependence
#' SHAP dependence plot for "mshapviz" object.
#' @export
sv_dependence.mshapviz <- function(object, v, color_var = "auto", color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL, interactions = FALSE,
ih_nbins = NULL, ih_color_num = TRUE,
ih_scale = FALSE, ih_adjusted = FALSE, ...) {
stopifnot(
length(v) == 1L,
length(color_var) <= 1L
)
plot_list <- lapply(
object,
FUN = sv_dependence,
# Argument list (simplify via match.call() or some rlang magic?)
v = v,
out <- list(
p = p,
color_var = color_var,
color = color,
viridis_args = viridis_args,
jitter_width = jitter_width,
interactions = interactions,
ih_nbins = ih_nbins,
ih_color_num = ih_color_num,
ih_scale = ih_scale,
ih_adjusted = ih_adjusted,
...
color_key = color_key,
y_lab = y_lab
)
plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall()
patchwork::wrap_plots(plot_list)
}

# Helper functions

# Checks if z is discrete
.is_discrete <- function(z, n_unique) {
is.factor(z) || is.character(z) || is.logical(z) || (length(unique(z)) <= n_unique)
return(out)
}
Loading