Skip to content

Add barplots to sv_interaction() #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: shapviz
Title: SHAP Visualizations
Version: 0.9.8
Version: 0.10.0
Authors@R: c(
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")),
person("Adrian", "Stando", , "[email protected]", role = "ctb")
Expand Down
12 changes: 11 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
# shapviz 0.9.8
# shapviz 0.10.0

### New feature

`sv_interaction()`: New `kind = "bar"` to show mean absolute SHAP interactions/main effects as barplots.
Modify via `fill` and `bar_width` arguments [#169](https://github.com/ModelOriented/shapviz/pull/169).

### User-visible changes

- `sv_interaction()`: If applied to a "mshapviz" object, we use {patchwork} functionality to collect guides and axis titles.


### Maintenance

Expand Down
6 changes: 4 additions & 2 deletions R/shapviz-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# Suppress R CMD check note
#' @importFrom xgboost xgb.train

globalVariables(c("from", "i", "id", "label", "to", "x", "shap", "SHAP",
"feature", "value", "color", "Var2", "Var3", "S", "ind", "values"))
globalVariables(c(
"from", "i", "id", "label", "to", "x", "shap", "SHAP",
"feature", "value", "color", "Var1", "Var2", "Var3", "S", "ind", "values"
))

.onLoad <- function(libname, pkgname) {
op <- options()
Expand Down
102 changes: 77 additions & 25 deletions R/sv_interaction.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
#' SHAP Interaction Plot
#'
#' Plots a beeswarm plot for each feature pair. Diagonals represent the main effects,
#' while off-diagonals show interactions (multiplied by two due to symmetry).
#' The colors on the beeswarm plots represent min-max scaled feature values.
#' @description
#' Creates a beeswarm plot or a barplot of SHAP interaction values/main effects.
#'
#' In the beeswarm plot (`kind = "beeswarm"`), diagonals represent the main effects,
#' while off-diagonals show SHAP interactions (multiplied by two due to symmetry).
#' The color axis represent min-max scaled feature values.
#' Non-numeric features are transformed to numeric by calling [data.matrix()] first.
#' The features are sorted in decreasing order of usual SHAP importance.
#'
#' The barplot (`kind = "bar"`) shows average absolute SHAP interaction values
#' and main effects for each feature pair.
#' Again, due to symmetry, the interaction values are multiplied by two.
#'
#' @param object An object of class "(m)shapviz" containing element `S_inter`.
#' @param kind Set to "no" to return the matrix of average absolute SHAP
#' interactions (or a list of such matrices in case of object of class "mshapviz").
Expand All @@ -19,12 +26,14 @@
#' absolute SHAP values (or a list of such matrices in case of "mshapviz" object).
#' @examples
#' dtrain <- xgboost::xgb.DMatrix(
#' data.matrix(iris[, -1]), label = iris[, 1], nthread = 1
#' data.matrix(iris[, -1]),
#' label = iris[, 1], nthread = 1
#' )
#' fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1)
#' x <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE)
#' sv_interaction(x, kind = "no")
#' sv_interaction(x, max_display = 2, size = 3)
#' sv_interaction(x, kind = "bar")
#' @seealso [sv_importance()]
#' @export
sv_interaction <- function(object, ...) {
Expand All @@ -41,47 +50,78 @@ sv_interaction.default <- function(object, ...) {
#' @describeIn sv_interaction
#' SHAP interaction plot for an object of class "shapviz".
#' @export
sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"),
max_display = 7L, alpha = 0.3,
bee_width = 0.3, bee_adjust = 0.5,
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Row feature value",
sort_features = TRUE, ...) {
sv_interaction.shapviz <- function(
object,
kind = c("beeswarm", "bar", "no"),
max_display = 15L - 8 * (kind == "beeswarm"),
alpha = 0.3,
bee_width = 0.3,
bee_adjust = 0.5,
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Row feature value",
sort_features = TRUE,
fill = "#fca50a",
bar_width = 2 / 3,
...) {
kind <- match.arg(kind)
if (is.null(get_shap_interactions(object))) {
stop("No SHAP interaction values available.")
}

# Sort features by SHAP importance first (irrelevant for kind = "bee")
ord <- names(.get_imp(get_shap_values(object), sort_features = sort_features))
object <- object[, ord]

# Calculate average absolute SHAP interactions
M <- apply(abs(get_shap_interactions(object)), MARGIN = 2:3, FUN = mean)
M <- M + t(M) - diag(diag(M)) # Off-diagonals twice

if (kind == "no") {
mat <- apply(abs(get_shap_interactions(object)), 2:3, mean)
off_diag <- row(mat) != col(mat)
mat[off_diag] <- mat[off_diag] * 2 # compensate symmetry
return(mat)
return(M)
}

if (kind == "bar") {
# Turn to long format and make feature pair names
imp_df <- transform(
as.data.frame.table(M, responseName = "value"),
feature = ifelse(Var1 == Var2, as.character(Var1), paste(Var1, Var2, sep = ":"))
)
if (sort_features) {
imp_df <- imp_df[order(imp_df$value, decreasing = TRUE), ]
imp_df <- transform(imp_df, feature = factor(feature, levels = rev(feature)))
}
if (nrow(imp_df) > max_display) {
imp_df <- imp_df[seq_len(max_display), ]
}

p <- ggplot2::ggplot(imp_df, ggplot2::aes(x = value, y = feature)) +
ggplot2::geom_bar(fill = fill, width = bar_width, stat = "identity", ...) +
ggplot2::labs(x = "mean(|SHAP interaction value|)", y = ggplot2::element_blank())

return(p)
}

# kind == "bee"
if (ncol(object) > max_display) {
ord <- ord[seq_len(max_display)]
object <- object[, ord]
}

# Prepare data.frame for beeswarm
S_inter <- get_shap_interactions(object)
X <- .scale_X(get_feature_values(object))
X_long <- as.data.frame.table(X)
df <- transform(
as.data.frame.table(S_inter, responseName = "value"),
Variable1 = factor(Var2, levels = ord),
Variable2 = factor(Var3, levels = ord),
color = X_long$Freq # Correctly recycled along the third dimension of S_inter
color = X_long$Freq # Correctly recycled along the third dimension of S_inter
)

# Compensate symmetry
mask <- df[["Variable1"]] != df[["Variable2"]]
df[mask, "value"] <- 2 * df[mask, "value"]

ggplot2::ggplot(df, ggplot2::aes(x = value, y = "1")) +
p <- ggplot2::ggplot(df, ggplot2::aes(x = value, y = "1")) +
ggplot2::geom_vline(xintercept = 0, color = "darkgray") +
ggplot2::geom_point(
ggplot2::aes(color = color),
Expand All @@ -104,17 +144,25 @@ sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"),
axis.ticks.y = ggplot2::element_blank(),
axis.text.y = ggplot2::element_blank()
)
return(p)
}

#' @describeIn sv_interaction
#' SHAP interaction plot for an object of class "mshapviz".
#' @export
sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
max_display = 7L, alpha = 0.3,
bee_width = 0.3, bee_adjust = 0.5,
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Row feature value",
sort_features = TRUE, ...) {
sv_interaction.mshapviz <- function(
object,
kind = c("beeswarm", "bar", "no"),
max_display = 7L,
alpha = 0.3,
bee_width = 0.3,
bee_adjust = 0.5,
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Row feature value",
sort_features = TRUE,
fill = "#fca50a",
bar_width = 2 / 3,
...) {
kind <- match.arg(kind)

plot_list <- lapply(
Expand All @@ -129,11 +177,15 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
viridis_args = viridis_args,
color_bar_title = color_bar_title,
sort_features = sort_features,
fill = fill,
bar_width = bar_width,
...
)
if (kind == "no") {
return(plot_list)
}
plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall()
patchwork::wrap_plots(plot_list)
plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall()
p <- patchwork::wrap_plots(plot_list, axis_titles = "collect", guides = "collect")

return(p)
}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

- `sv_importance()`: Importance plot (bar/beeswarm).
- `sv_dependence()` and `sv_dependence2D()`: Dependence plots to study feature effects and interactions.
- `sv_interaction()`: Interaction plot (beeswarm).
- `sv_interaction()`: Interaction plot (beeswarm/bar).
- `sv_waterfall()`: Waterfall plot to study single or average predictions.
- `sv_force()`: Force plot as alternative to waterfall plot.

Expand Down
30 changes: 23 additions & 7 deletions man/sv_interaction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ library(usethis)
use_description(
fields = list(
Title = "SHAP Visualizations",
Version = "0.9.8",
Version = "0.10.0",
Description = "Visualizations for SHAP (SHapley Additive exPlanations),
such as waterfall plots, force plots, various types of importance plots,
dependence plots, and interaction plots.
Expand Down Expand Up @@ -43,7 +43,7 @@ use_package("ggplot2", "Imports", min_version = "3.4.0")
use_package("gggenes", "Imports")
use_package("ggfittext", "Imports", min_version = "0.8.0")
use_package("ggrepel", "Imports")
use_package("patchwork", "Imports")
use_package("patchwork", "Imports", min_version = "1.3.0")
use_package("xgboost", "Imports")

use_package("fastshap", "Enhances")
Expand Down
10 changes: 6 additions & 4 deletions tests/testthat/test-plots-mshapviz.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
dtrain <- xgboost::xgb.DMatrix(
data.matrix(iris[, -1L]), label = iris[, 1L], nthread = 1
data.matrix(iris[, -1L]),
label = iris[, 1L], nthread = 1
)
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
x <- shapviz(fit, X_pred = dtrain, X = iris[, -1L])
Expand Down Expand Up @@ -73,13 +74,15 @@ test_that("main effect plots equal case color_var = v", {
expect_equal(
sv_dependence(x_inter, "Petal.Length", color_var = NULL, interactions = TRUE),
sv_dependence(
x_inter, "Petal.Length", color_var = "Petal.Length", interactions = TRUE
x_inter, "Petal.Length",
color_var = "Petal.Length", interactions = TRUE
)
)
})

test_that("Interaction plots provide patchwork object", {
expect_s3_class(sv_interaction(x_inter), "patchwork")
expect_s3_class(sv_interaction(x_inter, kind = "bee"), "patchwork")
expect_s3_class(sv_interaction(x_inter, kind = "bar"), "patchwork")
})

# Non-standard name
Expand Down Expand Up @@ -143,4 +146,3 @@ test_that("sv_dependence() does not work with multiple v", {
expect_error(sv_dependence2D(x, x = c("Species", "Sepal.Width"), y = "Petal.Width"))
expect_error(sv_dependence2D(x, x = "Petal.Width", y = c("Species", "Sepal.Width")))
})

Loading