Skip to content

Commit 560f411

Browse files
authored
Merge pull request #169 from ModelOriented/interaction-bar
Add barplots to sv_interaction()
2 parents 103cc61 + ee967ea commit 560f411

File tree

10 files changed

+138
-51
lines changed

10 files changed

+138
-51
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: shapviz
22
Title: SHAP Visualizations
3-
Version: 0.9.8
3+
Version: 0.10.0
44
Authors@R: c(
55
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")),
66
person("Adrian", "Stando", , "[email protected]", role = "ctb")

NEWS.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
# shapviz 0.9.8
1+
# shapviz 0.10.0
2+
3+
### New feature
4+
5+
`sv_interaction()`: New `kind = "bar"` to show mean absolute SHAP interactions/main effects as barplots.
6+
Modify via `fill` and `bar_width` arguments [#169](https://github.com/ModelOriented/shapviz/pull/169).
7+
8+
### User-visible changes
9+
10+
- `sv_interaction()`: If applied to a "mshapviz" object, we use {patchwork} functionality to collect guides and axis titles.
11+
212

313
### Maintenance
414

R/shapviz-package.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# Suppress R CMD check note
55
#' @importFrom xgboost xgb.train
66

7-
globalVariables(c("from", "i", "id", "label", "to", "x", "shap", "SHAP",
8-
"feature", "value", "color", "Var2", "Var3", "S", "ind", "values"))
7+
globalVariables(c(
8+
"from", "i", "id", "label", "to", "x", "shap", "SHAP",
9+
"feature", "value", "color", "Var1", "Var2", "Var3", "S", "ind", "values"
10+
))
911

1012
.onLoad <- function(libname, pkgname) {
1113
op <- options()

R/sv_interaction.R

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
#' SHAP Interaction Plot
22
#'
3-
#' Plots a beeswarm plot for each feature pair. Diagonals represent the main effects,
4-
#' while off-diagonals show interactions (multiplied by two due to symmetry).
5-
#' The colors on the beeswarm plots represent min-max scaled feature values.
3+
#' @description
4+
#' Creates a beeswarm plot or a barplot of SHAP interaction values/main effects.
5+
#'
6+
#' In the beeswarm plot (`kind = "beeswarm"`), diagonals represent the main effects,
7+
#' while off-diagonals show SHAP interactions (multiplied by two due to symmetry).
8+
#' The color axis represent min-max scaled feature values.
69
#' Non-numeric features are transformed to numeric by calling [data.matrix()] first.
710
#' The features are sorted in decreasing order of usual SHAP importance.
811
#'
12+
#' The barplot (`kind = "bar"`) shows average absolute SHAP interaction values
13+
#' and main effects for each feature pair.
14+
#' Again, due to symmetry, the interaction values are multiplied by two.
15+
#'
916
#' @param object An object of class "(m)shapviz" containing element `S_inter`.
1017
#' @param kind Set to "no" to return the matrix of average absolute SHAP
1118
#' interactions (or a list of such matrices in case of object of class "mshapviz").
@@ -19,12 +26,14 @@
1926
#' absolute SHAP values (or a list of such matrices in case of "mshapviz" object).
2027
#' @examples
2128
#' dtrain <- xgboost::xgb.DMatrix(
22-
#' data.matrix(iris[, -1]), label = iris[, 1], nthread = 1
29+
#' data.matrix(iris[, -1]),
30+
#' label = iris[, 1], nthread = 1
2331
#' )
2432
#' fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1)
2533
#' x <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE)
2634
#' sv_interaction(x, kind = "no")
2735
#' sv_interaction(x, max_display = 2, size = 3)
36+
#' sv_interaction(x, kind = "bar")
2837
#' @seealso [sv_importance()]
2938
#' @export
3039
sv_interaction <- function(object, ...) {
@@ -41,47 +50,78 @@ sv_interaction.default <- function(object, ...) {
4150
#' @describeIn sv_interaction
4251
#' SHAP interaction plot for an object of class "shapviz".
4352
#' @export
44-
sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"),
45-
max_display = 7L, alpha = 0.3,
46-
bee_width = 0.3, bee_adjust = 0.5,
47-
viridis_args = getOption("shapviz.viridis_args"),
48-
color_bar_title = "Row feature value",
49-
sort_features = TRUE, ...) {
53+
sv_interaction.shapviz <- function(
54+
object,
55+
kind = c("beeswarm", "bar", "no"),
56+
max_display = 15L - 8 * (kind == "beeswarm"),
57+
alpha = 0.3,
58+
bee_width = 0.3,
59+
bee_adjust = 0.5,
60+
viridis_args = getOption("shapviz.viridis_args"),
61+
color_bar_title = "Row feature value",
62+
sort_features = TRUE,
63+
fill = "#fca50a",
64+
bar_width = 2 / 3,
65+
...) {
5066
kind <- match.arg(kind)
5167
if (is.null(get_shap_interactions(object))) {
5268
stop("No SHAP interaction values available.")
5369
}
70+
71+
# Sort features by SHAP importance first (irrelevant for kind = "bee")
5472
ord <- names(.get_imp(get_shap_values(object), sort_features = sort_features))
5573
object <- object[, ord]
5674

75+
# Calculate average absolute SHAP interactions
76+
M <- apply(abs(get_shap_interactions(object)), MARGIN = 2:3, FUN = mean)
77+
M <- M + t(M) - diag(diag(M)) # Off-diagonals twice
78+
5779
if (kind == "no") {
58-
mat <- apply(abs(get_shap_interactions(object)), 2:3, mean)
59-
off_diag <- row(mat) != col(mat)
60-
mat[off_diag] <- mat[off_diag] * 2 # compensate symmetry
61-
return(mat)
80+
return(M)
81+
}
82+
83+
if (kind == "bar") {
84+
# Turn to long format and make feature pair names
85+
imp_df <- transform(
86+
as.data.frame.table(M, responseName = "value"),
87+
feature = ifelse(Var1 == Var2, as.character(Var1), paste(Var1, Var2, sep = ":"))
88+
)
89+
if (sort_features) {
90+
imp_df <- imp_df[order(imp_df$value, decreasing = TRUE), ]
91+
imp_df <- transform(imp_df, feature = factor(feature, levels = rev(feature)))
92+
}
93+
if (nrow(imp_df) > max_display) {
94+
imp_df <- imp_df[seq_len(max_display), ]
95+
}
96+
97+
p <- ggplot2::ggplot(imp_df, ggplot2::aes(x = value, y = feature)) +
98+
ggplot2::geom_bar(fill = fill, width = bar_width, stat = "identity", ...) +
99+
ggplot2::labs(x = "mean(|SHAP interaction value|)", y = ggplot2::element_blank())
100+
101+
return(p)
62102
}
63103

104+
# kind == "bee"
64105
if (ncol(object) > max_display) {
65106
ord <- ord[seq_len(max_display)]
66107
object <- object[, ord]
67108
}
68109

69-
# Prepare data.frame for beeswarm
70110
S_inter <- get_shap_interactions(object)
71111
X <- .scale_X(get_feature_values(object))
72112
X_long <- as.data.frame.table(X)
73113
df <- transform(
74114
as.data.frame.table(S_inter, responseName = "value"),
75115
Variable1 = factor(Var2, levels = ord),
76116
Variable2 = factor(Var3, levels = ord),
77-
color = X_long$Freq # Correctly recycled along the third dimension of S_inter
117+
color = X_long$Freq # Correctly recycled along the third dimension of S_inter
78118
)
79119

80120
# Compensate symmetry
81121
mask <- df[["Variable1"]] != df[["Variable2"]]
82122
df[mask, "value"] <- 2 * df[mask, "value"]
83123

84-
ggplot2::ggplot(df, ggplot2::aes(x = value, y = "1")) +
124+
p <- ggplot2::ggplot(df, ggplot2::aes(x = value, y = "1")) +
85125
ggplot2::geom_vline(xintercept = 0, color = "darkgray") +
86126
ggplot2::geom_point(
87127
ggplot2::aes(color = color),
@@ -104,17 +144,25 @@ sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"),
104144
axis.ticks.y = ggplot2::element_blank(),
105145
axis.text.y = ggplot2::element_blank()
106146
)
147+
return(p)
107148
}
108149

109150
#' @describeIn sv_interaction
110151
#' SHAP interaction plot for an object of class "mshapviz".
111152
#' @export
112-
sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
113-
max_display = 7L, alpha = 0.3,
114-
bee_width = 0.3, bee_adjust = 0.5,
115-
viridis_args = getOption("shapviz.viridis_args"),
116-
color_bar_title = "Row feature value",
117-
sort_features = TRUE, ...) {
153+
sv_interaction.mshapviz <- function(
154+
object,
155+
kind = c("beeswarm", "bar", "no"),
156+
max_display = 7L,
157+
alpha = 0.3,
158+
bee_width = 0.3,
159+
bee_adjust = 0.5,
160+
viridis_args = getOption("shapviz.viridis_args"),
161+
color_bar_title = "Row feature value",
162+
sort_features = TRUE,
163+
fill = "#fca50a",
164+
bar_width = 2 / 3,
165+
...) {
118166
kind <- match.arg(kind)
119167

120168
plot_list <- lapply(
@@ -129,11 +177,15 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
129177
viridis_args = viridis_args,
130178
color_bar_title = color_bar_title,
131179
sort_features = sort_features,
180+
fill = fill,
181+
bar_width = bar_width,
132182
...
133183
)
134184
if (kind == "no") {
135185
return(plot_list)
136186
}
137-
plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall()
138-
patchwork::wrap_plots(plot_list)
187+
plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall()
188+
p <- patchwork::wrap_plots(plot_list, axis_titles = "collect", guides = "collect")
189+
190+
return(p)
139191
}

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
- `sv_importance()`: Importance plot (bar/beeswarm).
1919
- `sv_dependence()` and `sv_dependence2D()`: Dependence plots to study feature effects and interactions.
20-
- `sv_interaction()`: Interaction plot (beeswarm).
20+
- `sv_interaction()`: Interaction plot (beeswarm/bar).
2121
- `sv_waterfall()`: Waterfall plot to study single or average predictions.
2222
- `sv_force()`: Force plot as alternative to waterfall plot.
2323

man/sv_interaction.Rd

Lines changed: 23 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packaging.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ library(usethis)
1515
use_description(
1616
fields = list(
1717
Title = "SHAP Visualizations",
18-
Version = "0.9.8",
18+
Version = "0.10.0",
1919
Description = "Visualizations for SHAP (SHapley Additive exPlanations),
2020
such as waterfall plots, force plots, various types of importance plots,
2121
dependence plots, and interaction plots.
@@ -43,7 +43,7 @@ use_package("ggplot2", "Imports", min_version = "3.4.0")
4343
use_package("gggenes", "Imports")
4444
use_package("ggfittext", "Imports", min_version = "0.8.0")
4545
use_package("ggrepel", "Imports")
46-
use_package("patchwork", "Imports")
46+
use_package("patchwork", "Imports", min_version = "1.3.0")
4747
use_package("xgboost", "Imports")
4848

4949
use_package("fastshap", "Enhances")

tests/testthat/test-plots-mshapviz.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
dtrain <- xgboost::xgb.DMatrix(
2-
data.matrix(iris[, -1L]), label = iris[, 1L], nthread = 1
2+
data.matrix(iris[, -1L]),
3+
label = iris[, 1L], nthread = 1
34
)
45
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
56
x <- shapviz(fit, X_pred = dtrain, X = iris[, -1L])
@@ -73,13 +74,15 @@ test_that("main effect plots equal case color_var = v", {
7374
expect_equal(
7475
sv_dependence(x_inter, "Petal.Length", color_var = NULL, interactions = TRUE),
7576
sv_dependence(
76-
x_inter, "Petal.Length", color_var = "Petal.Length", interactions = TRUE
77+
x_inter, "Petal.Length",
78+
color_var = "Petal.Length", interactions = TRUE
7779
)
7880
)
7981
})
8082

8183
test_that("Interaction plots provide patchwork object", {
82-
expect_s3_class(sv_interaction(x_inter), "patchwork")
84+
expect_s3_class(sv_interaction(x_inter, kind = "bee"), "patchwork")
85+
expect_s3_class(sv_interaction(x_inter, kind = "bar"), "patchwork")
8386
})
8487

8588
# Non-standard name
@@ -143,4 +146,3 @@ test_that("sv_dependence() does not work with multiple v", {
143146
expect_error(sv_dependence2D(x, x = c("Species", "Sepal.Width"), y = "Petal.Width"))
144147
expect_error(sv_dependence2D(x, x = "Petal.Width", y = c("Species", "Sepal.Width")))
145148
})
146-

0 commit comments

Comments
 (0)